Compare commits

..

No commits in common. "4d941fe226b3852e34a4165de8a2639d3726e255" and "ec436bd63afb53f507ae57cc8f12cc25a8f2bcec" have entirely different histories.

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from pathlib import Path
import re import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -102,16 +101,9 @@ class Chat(commands.Cog):
self.bot: Artemis = bot self.bot: Artemis = bot
self.memory: list[dict] = [] self.memory: list[dict] = []
self.model = "google/gemma-2-2b-it" self.model = "google/gemma-2-2b-it"
self.prompt = self.read_prompt()
self.client = AsyncInferenceClient(api_key=self.bot.secrets.huggingface) self.client = AsyncInferenceClient(api_key=self.bot.secrets.huggingface)
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
def read_prompt(self):
return Path("temp/prompt").read_text() if Path("temp/prompt").exists() else ""
def write_prompt(self, prompt: str):
Path("temp/prompt").write_text(prompt)
def replace_emojis(self, text: str) -> str: def replace_emojis(self, text: str) -> str:
return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text) return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text)
@ -124,11 +116,7 @@ class Chat(commands.Cog):
return content return content
def add_memory(self, role: str, message: str): def add_memory(self, role: str, message: str):
prompt = ( prompt = "You're Artemis, a bot hanging out in this Discord server, you're friendly and can answer anything, following is a user chat message directed at you.\n\n"
self.prompt
+ "The following is a user chat message directed at you, the format will be the same for subsequent messages, respond with only the message content, without specyfing actions."
+ "\n\n"
)
if len(self.memory) == 0: if len(self.memory) == 0:
message = prompt + message message = prompt + message
if len(self.memory) >= 15: if len(self.memory) >= 15:
@ -156,14 +144,13 @@ class Chat(commands.Cog):
chat_response = self.replace_emojis(chat_response) chat_response = self.replace_emojis(chat_response)
chat_response = self.strip_emojis(chat_response) chat_response = self.strip_emojis(chat_response)
chat_response = re.sub(r"[ ]{2,}", " ", chat_response) chat_response = re.sub(r"[ ]{2,}", " ", chat_response)
chat_response = re.sub(r"[\n]{2,}", "\n", chat_response)
self.add_assistant_memory(chat_response) self.add_assistant_memory(chat_response)
return chat_response return chat_response
@commands.Cog.listener() @commands.Cog.listener()
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
if message.author.bot or message.content.startswith(config.prefix): if message.author.bot:
return return
reference = message.reference reference = message.reference
@ -187,7 +174,7 @@ class Chat(commands.Cog):
if not content: if not content:
return return
content = f"[USERNAME]: {message.author.display_name}\n[MESSAGE]: {content}" content = message.author.display_name + ": " + content
try: try:
async with message.channel.typing(): async with message.channel.typing():
@ -197,29 +184,6 @@ class Chat(commands.Cog):
except aiohttp.client_exceptions.ClientResponseError as err: except aiohttp.client_exceptions.ClientResponseError as err:
await message.reply(f"{self.model} error: {err.status} {err.message}") await message.reply(f"{self.model} error: {err.status} {err.message}")
@commands.group(name="chat")
async def _chat(self, ctx: commands.Context):
"""LLM chat management."""
if ctx.invoked_subcommand is None:
await ctx.send("Invalid subcommand passed.")
@_chat.command()
async def reset(self, ctx: commands.Context):
"""Reset chat memory."""
self.memory = []
await ctx.send("Chat memory reset.")
@_chat.command(name="prompt")
async def _prompt(self, ctx: commands.Context, *, prompt: str = None):
"""Get or set system prompt and reset chat memory."""
if not prompt:
await ctx.send(self.prompt)
return
self.prompt = prompt
self.write_prompt(prompt)
self.memory = []
await ctx.send("Prompt updated.")
async def setup(bot: Artemis): async def setup(bot: Artemis):
await bot.add_cog(Chat(bot)) await bot.add_cog(Chat(bot))