This commit is contained in:
artie 2025-02-01 17:07:33 +01:00
parent 794bf48256
commit b38e5cc7ed

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import re import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -42,7 +43,6 @@ emoji_map = {
"😛": ":P", "😛": ":P",
"🤑": "$-$", "🤑": "$-$",
"🤗": "(hug)", "🤗": "(hug)",
"🤔": ":/",
"😎": "8)", "😎": "8)",
"😏": "^_^", "😏": "^_^",
"😒": "-_-", "😒": "-_-",
@ -99,9 +99,10 @@ ARTEMIS_RE = re.compile(r"\bar(i)?temis\b", flags=re.IGNORECASE)
class Chat(commands.Cog): class Chat(commands.Cog):
def __init__(self, bot: Artemis): def __init__(self, bot: Artemis):
self.bot: Artemis = bot self.bot: Artemis = bot
self.memory: dict[int, list[dict]] = {} self.memory: list[dict] = []
self.model = "google/gemma-2-2b-it" self.model = "google/gemma-2-2b-it"
self.client = AsyncInferenceClient(api_key=self.bot.secrets.huggingface) self.client = AsyncInferenceClient(api_key=self.bot.secrets.huggingface)
self.lock = asyncio.Lock()
def replace_emojis(self, text: str) -> str: def replace_emojis(self, text: str) -> str:
return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text) return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text)
@ -114,40 +115,44 @@ class Chat(commands.Cog):
content = content.replace(user.mention, "@" + user.display_name) content = content.replace(user.mention, "@" + user.display_name)
return content return content
def get_memory(self, user_id: int): def add_memory(self, role: str, message: str):
if user_id not in self.memory: prompt = "You're Artemis, a friendly AI hanging out in this Discord server, following is a user chat message directed at you.\n\n"
self.memory[user_id] = [] if len(self.memory) == 0:
return self.memory[user_id] message = prompt + message
if len(self.memory) >= 15:
del self.memory[0]
del self.memory[0]
self.memory[0] = {"role": "user", "content": prompt + self.memory[0]["content"]}
self.memory.append({"role": role, "content": message})
def add_memory(self, user_id: int, role: str, message: str): def add_user_memory(self, message: str):
memory = self.get_memory(user_id) self.add_memory("user", message)
if len(memory) == 0:
message = f"You're Artemis, a friendly AI hanging out in this Discord server, following is a user chat message directed at you.\n\n{message}"
if len(memory) >= 10:
del memory[1]
memory.append({"role": role, "content": message})
self.memory[user_id] = memory
def add_user_memory(self, user_id: int, message: str): def add_assistant_memory(self, message: str):
self.add_memory(user_id, "user", message) self.add_memory("assistant", message)
def add_assistant_memory(self, user_id: int, message: str): async def chat(self, message: str):
self.add_memory(user_id, "assistant", message) self.add_user_memory(message)
async def chat(self, user_id: int, message: str):
self.add_user_memory(user_id, message)
response = await self.client.chat.completions.create( response = await self.client.chat.completions.create(
model=self.model, model=self.model,
messages=self.get_memory(user_id), messages=self.memory,
max_tokens=500, max_tokens=500,
stream=False, stream=False,
) )
chat_response = response.choices[0].message.content chat_response = response.choices[0].message.content
self.add_assistant_memory(user_id, chat_response)
chat_response = self.replace_emojis(chat_response)
chat_response = self.strip_emojis(chat_response)
chat_response = re.sub(r"[ ]{2,}", " ", chat_response)
self.add_assistant_memory(chat_response)
return chat_response return chat_response
@commands.Cog.listener() @commands.Cog.listener()
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
if message.author.bot:
return
reference = message.reference reference = message.reference
parent = reference.cached_message if reference else None parent = reference.cached_message if reference else None
@ -173,10 +178,8 @@ class Chat(commands.Cog):
try: try:
async with message.channel.typing(): async with message.channel.typing():
response = await self.chat(message.author.id, content) async with self.lock:
response = self.replace_emojis(response) response = await self.chat(content)
response = self.strip_emojis(response)
response = re.sub(r"[ ]{2,}", " ", response)
await message.reply(response) await message.reply(response)
except aiohttp.client_exceptions.ClientResponseError as err: except aiohttp.client_exceptions.ClientResponseError as err:
await message.reply(f"{self.model} error: {err.status} {err.message}") await message.reply(f"{self.model} error: {err.status} {err.message}")