From c8f5d986a1cc39f625105ab395e0d982e177104e Mon Sep 17 00:00:00 2001 From: artie Date: Sat, 1 Feb 2025 19:28:36 +0100 Subject: [PATCH] pp --- artemis/cogs/chat.py | 35 +++++++++++++++++++++++++++++++++-- data/prompt | 1 + 2 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 data/prompt diff --git a/artemis/cogs/chat.py b/artemis/cogs/chat.py index 0890f74..f787a83 100644 --- a/artemis/cogs/chat.py +++ b/artemis/cogs/chat.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import logging +from pathlib import Path import re from typing import TYPE_CHECKING @@ -101,9 +102,16 @@ class Chat(commands.Cog): self.bot: Artemis = bot self.memory: list[dict] = [] self.model = "google/gemma-2-2b-it" + self.prompt = self.read_prompt() self.client = AsyncInferenceClient(api_key=self.bot.secrets.huggingface) self.lock = asyncio.Lock() + def read_prompt(self): + return Path("data/prompt").read_text() + + def write_prompt(self, prompt: str): + Path("data/prompt").write_text(prompt) + def replace_emojis(self, text: str) -> str: return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text) @@ -116,7 +124,7 @@ class Chat(commands.Cog): return content def add_memory(self, role: str, message: str): - prompt = "You're Artemis, a bot hanging out in this Discord server, you're friendly and can answer anything, following is a user chat message directed at you.\n\n" + prompt = self.prompt + "Following is a user chat message directed at you." + "\n\n" if len(self.memory) == 0: message = prompt + message if len(self.memory) >= 15: @@ -150,7 +158,7 @@ class Chat(commands.Cog): @commands.Cog.listener() async def on_message(self, message: discord.Message): - if message.author.bot: + if message.author.bot or message.content.startswith(config.prefix): return reference = message.reference @@ -184,6 +192,29 @@ class Chat(commands.Cog): except aiohttp.client_exceptions.ClientResponseError as err: await message.reply(f"{self.model} error: {err.status} {err.message}") + @commands.group(name="chat") + async def _chat(self, ctx: commands.Context): + """LLM chat management.""" + if ctx.invoked_subcommand is None: + await ctx.send("Invalid subcommand passed.") + + @_chat.command() + async def reset(self, ctx: commands.Context): + """Reset chat memory.""" + self.memory = [] + await ctx.send("Chat memory reset.") + + @_chat.command(name="prompt") + async def _prompt(self, ctx: commands.Context, *, prompt: str = None): + """Get or set system prompt and reset chat memory.""" + if not prompt: + await ctx.send(self.prompt) + return + self.prompt = prompt + self.write_prompt(prompt) + self.memory = [] + await ctx.send("Prompt updated.") + async def setup(bot: Artemis): await bot.add_cog(Chat(bot)) diff --git a/data/prompt b/data/prompt new file mode 100644 index 0000000..539ba55 --- /dev/null +++ b/data/prompt @@ -0,0 +1 @@ +You're Artemis, a bot hanging out in this Discord server, you're friendly and can answer anything. \ No newline at end of file