From 348ef8af3d935e63b00a356bae5ab5ac9ff9b0d1 Mon Sep 17 00:00:00 2001 From: artie Date: Thu, 26 Sep 2024 23:15:10 +0200 Subject: [PATCH] improvements to lens ocr - image compression, finding url in replies --- artemis/cogs/media.py | 13 +---- artemis/cogs/ocr.py | 36 ++++++++++--- artemis/utils/common.py | 116 +++++++++++++++++++++++++++++----------- 3 files changed, 117 insertions(+), 48 deletions(-) diff --git a/artemis/cogs/media.py b/artemis/cogs/media.py index 93bb764..6803445 100644 --- a/artemis/cogs/media.py +++ b/artemis/cogs/media.py @@ -17,12 +17,11 @@ import pendulum import yt_dlp from bs4 import BeautifulSoup from discord.ext import commands -from PIL import Image from pycaption import SRTWriter, WebVTTReader from yt_dlp.utils import parse_duration from .. import utils -from ..utils.common import ArtemisError +from ..utils.common import ArtemisError, compress_image from ..utils.constants import MAX_DISCORD_SIZE, MAX_LITTERBOX_SIZE, TEMP_DIR from ..utils.catbox import CatboxError from ..utils.flags import DLFlags @@ -327,14 +326,6 @@ class Media(commands.Cog): utils.check_for_ssrf(url) ytdl_opts = {**DEFAULT_OPTS, "format": "bv*/b"} - @utils.in_executor - def to_jpeg(image): - im = Image.open(image) - buff = BytesIO() - im.save(buff, "JPEG", quality=90) - buff.seek(0) - return buff - if not (re.fullmatch(TIMESTAMP_RE, timestamp) or re.fullmatch(SECONDS_RE, timestamp)): return await ctx.reply("Invalid timestamp format, check out `$help screencap`.") @@ -359,7 +350,7 @@ class Media(commands.Cog): buff = BytesIO(stdout) if len(stdout) > MAX_DISCORD_SIZE: - buff = await to_jpeg(buff) + buff = await compress_image(buff, fmt="JPEG", quality=90) msg += "\nThe image was too big for me to upload so I converted it to JPEG Q90." dfile = discord.File(buff, f"{title}.png") return await ctx.reply(content=msg, file=dfile) diff --git a/artemis/cogs/ocr.py b/artemis/cogs/ocr.py index cd9a5bf..b3319bf 100644 --- a/artemis/cogs/ocr.py +++ b/artemis/cogs/ocr.py @@ -11,7 +11,7 @@ import magic from discord.ext import commands from .. import utils -from ..utils.common import ArtemisError +from ..utils.common import ArtemisError, compress_image, get_reply from ..utils.constants import TESSERACT_LANGUAGES from ..utils.flags import Flags, OCRFlags, OCRTranslateFlags from ..utils.iso_639 import get_language_name @@ -48,8 +48,17 @@ class OCR(commands.Cog): embed = discord.Embed(description=msg, color=discord.Color.red()) return await ctx.reply(embed=embed) - message = await utils.get_message_or_reference(ctx) - image = await utils.get_attachment_or_url(ctx, message, url, ["image/jpeg", "image/png"]) + if url or ctx.message.attachments: + message = ctx.message + else: + message = await get_reply(ctx) + + if not message: + raise ArtemisError("Could not find any images.") + + image = await utils.get_file_from_attachment_or_url( + ctx, message, url, ["image/jpeg", "image/png"] + ) args = f"tesseract stdin stdout -l {lang}" result = await utils.run_cmd(args, input=image) @@ -89,8 +98,23 @@ class OCR(commands.Cog): await ctx.typing() - message = await utils.get_message_or_reference(ctx) - image = await utils.get_attachment_or_url(ctx, message, url, ["image/jpeg", "image/png"]) + if url or ctx.message.attachments: + message = ctx.message + else: + message = await get_reply(ctx) + + if not message: + raise ArtemisError("Could not find any images.") + + image = await utils.get_file_from_attachment_or_url( + ctx, message, url, ["image/jpeg", "image/png"] + ) + + try: + image = await compress_image(image, size=1000) + image = image.getvalue() + except Exception as e: + raise ArtemisError(f"Could not compress image: {e}") from e content_type = magic.from_buffer(image, mime=True) ext = mimetypes.guess_extension(content_type) @@ -111,7 +135,7 @@ class OCR(commands.Cog): match = re.search(final_data_re, html) if not match: if ctx.author.id == self.bot.owner.id: - await ctx.send(file=utils.File(html, "lens.html")) + await ctx.send(file=utils.file(html, "lens.html")) raise ArtemisError("No text detected.") _lang, lines = match.groups() diff --git a/artemis/utils/common.py b/artemis/utils/common.py index 3b328cc..fccc91a 100644 --- a/artemis/utils/common.py +++ b/artemis/utils/common.py @@ -11,18 +11,30 @@ from ipaddress import ip_address from subprocess import PIPE from time import perf_counter, time_ns from time import time as _time -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + List, + Literal, + Optional, + Sequence, + TypeVar, +) from urllib.parse import quote_plus, urlparse +import tomllib import discord import humanize import pendulum import pykakasi -import tomllib from aiohttp.helpers import is_ip_address from discord.ext import commands import feedparser from rapidfuzz import process +from PIL import Image from .. import utils @@ -55,7 +67,7 @@ class InvalidColour(commands.BadArgument): class BetterColour(commands.Converter): - async def convert(self, ctx, argument: str): + async def convert(self, _ctx, argument: str): colour_name = fuzzy_search_one(argument, list(utils.COMMON_COLOURS), cutoff=60) if not colour_name: raise InvalidColour("Invalid colour code/name.") @@ -67,7 +79,7 @@ class BetterColour(commands.Converter): class URL(commands.Converter): """URL converter.""" - async def convert(self, ctx, argument: str): + async def convert(self, _ctx, argument: str): argument = argument.strip("<>") match = is_valid_url(argument) if match: @@ -172,7 +184,7 @@ class ProgressBarMessage: if self._delete_on_finished: return await self._msg.delete() self._msg = await self._msg.edit(**self._finished_kwargs) - except Exception as err: + except BaseException as err: print(str(err)) if self._msg: await self._msg.edit(content="Error while rendering progress bar.") @@ -222,7 +234,7 @@ class ProgressBarMessage: self.finish() -def File(fp: bytes | str | list | dict, filename: str): +def file(fp: bytes | str | list | dict, filename: str): """Dirty discord.File helper for fast debugging.""" if isinstance(fp, bytes): buf = BytesIO(fp) @@ -257,11 +269,14 @@ def read_toml(path: str) -> Any: return tomllib.load(f) -def in_executor(func: Callable): +F = TypeVar("F", bound=Callable) + + +def in_executor(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]: """A decorator for running a function in a thread.""" @functools.wraps(func) - def decorator(*args: Any, **kwargs: Any): + def decorator(*args: Any, **kwargs: Any) -> Coroutine[Any, Any, T]: return asyncio.to_thread(func, *args, **kwargs) return decorator @@ -284,11 +299,11 @@ def time(resolution: Literal["s", "ms", "ns"] = "s") -> int: return time_ns() -def trim(text: Optional[str], max: int) -> Optional[str]: +def trim(text: Optional[str], max_len: int) -> Optional[str]: """Trims text to specified max length.""" if text is None: return None - return f"{text[:max - 3]}..." if len(text) > max else text + return f"{text[:max_len - 3]}..." if len(text) > max_len else text def romajify(text: str, strict: bool = True) -> str: @@ -391,20 +406,20 @@ async def run_cmd(args: str, shell=False, input=None) -> CommandResult: stdin = PIPE if input else None try: if shell: - process = await asyncio.create_subprocess_shell( + subprocess = await asyncio.create_subprocess_shell( args, stdout=PIPE, stderr=PIPE, stdin=stdin ) else: split_args = shlex.split(args) - process = await asyncio.create_subprocess_exec( + subprocess = await asyncio.create_subprocess_exec( *split_args, stdout=PIPE, stderr=PIPE, stdin=stdin ) - stdout, stderr = await process.communicate(input=input) + stdout, stderr = await subprocess.communicate(input=input) except Exception as err: - raise CommandExecutionError(err) + raise CommandExecutionError(err) from err - return CommandResult(stdout, stderr, process.returncode) + return CommandResult(stdout, stderr, subprocess.returncode) async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File | str: @@ -424,6 +439,26 @@ async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File raise CommandExecutionError("The file is too big to upload.") +@in_executor +def compress_image( + image: bytes | BytesIO, fmt: str = "JPEG", quality: int = 90, size=None +) -> BytesIO: + if isinstance(image, bytes): + image = BytesIO(image) + im = Image.open(image) + + if im.mode in ("RGBA", "P"): + im = im.convert("RGB") + + if size: + im.thumbnail((size, size), Image.Resampling.LANCZOS) + + buff = BytesIO() + im.save(buff, fmt, quality=quality) + buff.seek(0) + return buff + + def parse_short_time(time_string: str, as_duration: bool = False): compiled = re.compile( """(?:(?P[0-9])(?:years?|y))? # e.g. 2y @@ -447,10 +482,35 @@ def parse_short_time(time_string: str, as_duration: bool = False): return pendulum.now("UTC") + pendulum.duration(**data) -async def get_attachment_or_url( +def extract_urls(s: str | None): + if not s: + return [] + return re.findall(URL_RE, s) + + +def extract_first_url(s: str | None) -> Optional[str]: + url = next(iter(extract_urls(s)), None) + if not url: + return None + return url[0] + + +async def get_reply(ctx: commands.Context[Artemis]) -> discord.Message | None: + reference = ctx.message.reference + if not reference: + return None + try: + return reference.cached_message or await ctx.channel.fetch_message(reference.message_id) + except discord.DiscordException: + return None + + +async def get_file_from_attachment_or_url( ctx: commands.Context[Artemis], message: discord.Message, url: Optional[str], types: list = None ) -> bytes: - if not message.attachments and not url: + is_replied_to = ctx.message is not message + + if not message.attachments and not url and (is_replied_to and not message.content): raise ArtemisError("Please send me an attachment or URL first!") elif message.attachments: attachment = message.attachments[0] @@ -462,7 +522,12 @@ async def get_attachment_or_url( f"Unsupported file type, should be one of: `{', '.join(types)}`." ) return await attachment.read() - elif url: + elif url or (is_replied_to and message.content): + if is_replied_to: + url = extract_first_url(message.content) + if not url: + raise ArtemisError("No URL found in message.") + url = url.strip("<>") if not is_valid_url(url): raise ArtemisError("URL is not valid.") @@ -480,19 +545,8 @@ async def get_attachment_or_url( elif r.content_type not in types: raise ArtemisError("Unsupported file type, should be an image.") return await r.read() - except Exception: - raise ArtemisError("An error occured when trying to connect to the given URL.") - - -async def get_message_or_reference(ctx: commands.Context[Artemis]) -> discord.Message: - reference = ctx.message.reference - if reference: - try: - return reference.cached_message or await ctx.channel.fetch_message(reference.message_id) - except Exception: - return ctx.message - else: - return ctx.message + except BaseException as err: + raise ArtemisError("An error occured when trying to connect to the given URL.") from err T = TypeVar("T")