Compare commits

..

No commits in common. "46b25d9a679dd77f921e77da9370a1a05dea47fd" and "4d941fe226b3852e34a4165de8a2639d3726e255" have entirely different histories.

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from pathlib import Path
import re import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -11,7 +12,6 @@ from discord.ext import commands
from huggingface_hub import AsyncInferenceClient from huggingface_hub import AsyncInferenceClient
from artemis.utils import config from artemis.utils import config
from artemis.utils.constants import TEMP_DIR
if TYPE_CHECKING: if TYPE_CHECKING:
@ -107,12 +107,10 @@ class Chat(commands.Cog):
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
def read_prompt(self): def read_prompt(self):
path = TEMP_DIR / "prompt" return Path("temp/prompt").read_text() if Path("temp/prompt").exists() else ""
return path.read_text() if path.exists() else ""
def write_prompt(self, prompt: str): def write_prompt(self, prompt: str):
path = TEMP_DIR / "prompt" Path("temp/prompt").write_text(prompt)
path.write_text(prompt)
def replace_emojis(self, text: str) -> str: def replace_emojis(self, text: str) -> str:
return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text) return EMOJI_RE.sub(lambda match: emoji_map[match.group(0)], text)
@ -126,17 +124,18 @@ class Chat(commands.Cog):
return content return content
def add_memory(self, role: str, message: str): def add_memory(self, role: str, message: str):
prompt = self.prompt + "\n\n" prompt = (
self.prompt
+ "The following is a user chat message directed at you, the format will be the same for subsequent messages, respond with only the message content, without specyfing actions."
+ "\n\n"
)
if len(self.memory) == 0: if len(self.memory) == 0:
message = prompt + message message = prompt + message
if len(self.memory) >= 20: if len(self.memory) >= 15:
del self.memory[0] del self.memory[0]
del self.memory[0] del self.memory[0]
self.memory[0] = { self.memory[0] = {"role": "user", "content": prompt + self.memory[0]["content"]}
"role": "user", self.memory.append({"role": role, "content": message})
"content": (prompt + self.memory[0]["content"]).strip(),
}
self.memory.append({"role": role, "content": message.strip()})
def add_user_memory(self, message: str): def add_user_memory(self, message: str):
self.add_memory("user", message) self.add_memory("user", message)
@ -188,6 +187,8 @@ class Chat(commands.Cog):
if not content: if not content:
return return
content = f"[USERNAME]: {message.author.display_name}\n[MESSAGE]: {content}"
try: try:
async with message.channel.typing(): async with message.channel.typing():
async with self.lock: async with self.lock:
@ -212,16 +213,12 @@ class Chat(commands.Cog):
async def _prompt(self, ctx: commands.Context, *, prompt: str = None): async def _prompt(self, ctx: commands.Context, *, prompt: str = None):
"""Get or set system prompt and reset chat memory.""" """Get or set system prompt and reset chat memory."""
if not prompt: if not prompt:
await ctx.send(repr(self.prompt) if len(self.prompt) > 0 else "No prompt set.") await ctx.send(self.prompt)
elif prompt == "reset": return
self.prompt = "" self.prompt = prompt
self.write_prompt(self.prompt) self.write_prompt(prompt)
await ctx.send("Prompt reset.") self.memory = []
else: await ctx.send("Prompt updated.")
self.prompt = prompt
self.write_prompt(prompt)
self.memory = []
await ctx.send("Prompt updated.")
async def setup(bot: Artemis): async def setup(bot: Artemis):