mirror of
https://github.com/artiemis/artemis.git
synced 2026-02-14 08:31:55 +00:00
fixes
This commit is contained in:
parent
794bf48256
commit
b38e5cc7ed
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@ -42,7 +43,6 @@ emoji_map = {
|
|||||||
"😛": ":P",
|
"😛": ":P",
|
||||||
"🤑": "$-$",
|
"🤑": "$-$",
|
||||||
"🤗": "(hug)",
|
"🤗": "(hug)",
|
||||||
"🤔": ":/",
|
|
||||||
"😎": "8)",
|
"😎": "8)",
|
||||||
"😏": "^_^",
|
"😏": "^_^",
|
||||||
"😒": "-_-",
|
"😒": "-_-",
|
||||||
@ -99,9 +99,10 @@ ARTEMIS_RE = re.compile(r"\bar(i)?temis\b", flags=re.IGNORECASE)
|
|||||||
class Chat(commands.Cog):
|
class Chat(commands.Cog):
|
||||||
def __init__(self, bot: Artemis):
|
def __init__(self, bot: Artemis):
|
||||||
self.bot: Artemis = bot
|
self.bot: Artemis = bot
|
||||||
self.memory: dict[int, list[dict]] = {}
|
self.memory: list[dict] = []
|
||||||
self.model = "google/gemma-2-2b-it"
|
self.model = "google/gemma-2-2b-it"
|
||||||
self.client = AsyncInferenceClient(api_key=self.bot.secrets.huggingface)
|
self.client = AsyncInferenceClient(api_key=self.bot.secrets.huggingface)
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
def replace_emojis(self, text: str) -> str:
|
def replace_emojis(self, text: str) -> str:
|
||||||
return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text)
|
return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text)
|
||||||
@ -114,40 +115,44 @@ class Chat(commands.Cog):
|
|||||||
content = content.replace(user.mention, "@" + user.display_name)
|
content = content.replace(user.mention, "@" + user.display_name)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def get_memory(self, user_id: int):
|
def add_memory(self, role: str, message: str):
|
||||||
if user_id not in self.memory:
|
prompt = "You're Artemis, a friendly AI hanging out in this Discord server, following is a user chat message directed at you.\n\n"
|
||||||
self.memory[user_id] = []
|
if len(self.memory) == 0:
|
||||||
return self.memory[user_id]
|
message = prompt + message
|
||||||
|
if len(self.memory) >= 15:
|
||||||
|
del self.memory[0]
|
||||||
|
del self.memory[0]
|
||||||
|
self.memory[0] = {"role": "user", "content": prompt + self.memory[0]["content"]}
|
||||||
|
self.memory.append({"role": role, "content": message})
|
||||||
|
|
||||||
def add_memory(self, user_id: int, role: str, message: str):
|
def add_user_memory(self, message: str):
|
||||||
memory = self.get_memory(user_id)
|
self.add_memory("user", message)
|
||||||
if len(memory) == 0:
|
|
||||||
message = f"You're Artemis, a friendly AI hanging out in this Discord server, following is a user chat message directed at you.\n\n{message}"
|
|
||||||
if len(memory) >= 10:
|
|
||||||
del memory[1]
|
|
||||||
memory.append({"role": role, "content": message})
|
|
||||||
self.memory[user_id] = memory
|
|
||||||
|
|
||||||
def add_user_memory(self, user_id: int, message: str):
|
def add_assistant_memory(self, message: str):
|
||||||
self.add_memory(user_id, "user", message)
|
self.add_memory("assistant", message)
|
||||||
|
|
||||||
def add_assistant_memory(self, user_id: int, message: str):
|
async def chat(self, message: str):
|
||||||
self.add_memory(user_id, "assistant", message)
|
self.add_user_memory(message)
|
||||||
|
|
||||||
async def chat(self, user_id: int, message: str):
|
|
||||||
self.add_user_memory(user_id, message)
|
|
||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=self.get_memory(user_id),
|
messages=self.memory,
|
||||||
max_tokens=500,
|
max_tokens=500,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
chat_response = response.choices[0].message.content
|
chat_response = response.choices[0].message.content
|
||||||
self.add_assistant_memory(user_id, chat_response)
|
|
||||||
|
chat_response = self.replace_emojis(chat_response)
|
||||||
|
chat_response = self.strip_emojis(chat_response)
|
||||||
|
chat_response = re.sub(r"[ ]{2,}", " ", chat_response)
|
||||||
|
|
||||||
|
self.add_assistant_memory(chat_response)
|
||||||
return chat_response
|
return chat_response
|
||||||
|
|
||||||
@commands.Cog.listener()
|
@commands.Cog.listener()
|
||||||
async def on_message(self, message: discord.Message):
|
async def on_message(self, message: discord.Message):
|
||||||
|
if message.author.bot:
|
||||||
|
return
|
||||||
|
|
||||||
reference = message.reference
|
reference = message.reference
|
||||||
parent = reference.cached_message if reference else None
|
parent = reference.cached_message if reference else None
|
||||||
|
|
||||||
@ -173,10 +178,8 @@ class Chat(commands.Cog):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
response = await self.chat(message.author.id, content)
|
async with self.lock:
|
||||||
response = self.replace_emojis(response)
|
response = await self.chat(content)
|
||||||
response = self.strip_emojis(response)
|
|
||||||
response = re.sub(r"[ ]{2,}", " ", response)
|
|
||||||
await message.reply(response)
|
await message.reply(response)
|
||||||
except aiohttp.client_exceptions.ClientResponseError as err:
|
except aiohttp.client_exceptions.ClientResponseError as err:
|
||||||
await message.reply(f"{self.model} error: {err.status} {err.message}")
|
await message.reply(f"{self.model} error: {err.status} {err.message}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user