From b38e5cc7ed976cd8f29c4e969b9e1a104093042f Mon Sep 17 00:00:00 2001 From: artie Date: Sat, 1 Feb 2025 17:07:33 +0100 Subject: [PATCH] fixes --- artemis/cogs/chat.py | 57 +++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/artemis/cogs/chat.py b/artemis/cogs/chat.py index 8333228..3cbddbd 100644 --- a/artemis/cogs/chat.py +++ b/artemis/cogs/chat.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging import re from typing import TYPE_CHECKING @@ -42,7 +43,6 @@ emoji_map = { "😛": ":P", "🤑": "$-$", "🤗": "(hug)", - "🤔": ":/", "😎": "8)", "😏": "^_^", "😒": "-_-", @@ -99,9 +99,10 @@ ARTEMIS_RE = re.compile(r"\bar(i)?temis\b", flags=re.IGNORECASE) class Chat(commands.Cog): def __init__(self, bot: Artemis): self.bot: Artemis = bot - self.memory: dict[int, list[dict]] = {} + self.memory: list[dict] = [] self.model = "google/gemma-2-2b-it" self.client = AsyncInferenceClient(api_key=self.bot.secrets.huggingface) + self.lock = asyncio.Lock() def replace_emojis(self, text: str) -> str: return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text) @@ -114,40 +115,44 @@ class Chat(commands.Cog): content = content.replace(user.mention, "@" + user.display_name) return content - def get_memory(self, user_id: int): - if user_id not in self.memory: - self.memory[user_id] = [] - return self.memory[user_id] + def add_memory(self, role: str, message: str): + prompt = "You're Artemis, a friendly AI hanging out in this Discord server, following is a user chat message directed at you.\n\n" + if len(self.memory) == 0: + message = prompt + message + if len(self.memory) >= 15: + del self.memory[0] + del self.memory[0] + self.memory[0] = {"role": "user", "content": prompt + self.memory[0]["content"]} + self.memory.append({"role": role, "content": message}) - def add_memory(self, user_id: int, role: str, message: str): - memory = self.get_memory(user_id) - if len(memory) == 0: - message = f"You're Artemis, a friendly AI hanging out in this Discord server, following is a user chat message directed at you.\n\n{message}" - if len(memory) >= 10: - del memory[1] - memory.append({"role": role, "content": message}) - self.memory[user_id] = memory + def add_user_memory(self, message: str): + self.add_memory("user", message) - def add_user_memory(self, user_id: int, message: str): - self.add_memory(user_id, "user", message) + def add_assistant_memory(self, message: str): + self.add_memory("assistant", message) - def add_assistant_memory(self, user_id: int, message: str): - self.add_memory(user_id, "assistant", message) - - async def chat(self, user_id: int, message: str): - self.add_user_memory(user_id, message) + async def chat(self, message: str): + self.add_user_memory(message) response = await self.client.chat.completions.create( model=self.model, - messages=self.get_memory(user_id), + messages=self.memory, max_tokens=500, stream=False, ) chat_response = response.choices[0].message.content - self.add_assistant_memory(user_id, chat_response) + + chat_response = self.replace_emojis(chat_response) + chat_response = self.strip_emojis(chat_response) + chat_response = re.sub(r"[ ]{2,}", " ", chat_response) + + self.add_assistant_memory(chat_response) return chat_response @commands.Cog.listener() async def on_message(self, message: discord.Message): + if message.author.bot: + return + reference = message.reference parent = reference.cached_message if reference else None @@ -173,10 +178,8 @@ class Chat(commands.Cog): try: async with message.channel.typing(): - response = await self.chat(message.author.id, content) - response = self.replace_emojis(response) - response = self.strip_emojis(response) - response = re.sub(r"[ ]{2,}", " ", response) + async with self.lock: + response = await self.chat(content) await message.reply(response) except aiohttp.client_exceptions.ClientResponseError as err: await message.reply(f"{self.model} error: {err.status} {err.message}")