improvements to lens ocr - image compression, finding url in replies

This commit is contained in:
artie 2024-09-26 23:15:10 +02:00
parent 6e350099d3
commit 348ef8af3d
3 changed files with 117 additions and 48 deletions

View File

@ -17,12 +17,11 @@ import pendulum
import yt_dlp
from bs4 import BeautifulSoup
from discord.ext import commands
from PIL import Image
from pycaption import SRTWriter, WebVTTReader
from yt_dlp.utils import parse_duration
from .. import utils
from ..utils.common import ArtemisError
from ..utils.common import ArtemisError, compress_image
from ..utils.constants import MAX_DISCORD_SIZE, MAX_LITTERBOX_SIZE, TEMP_DIR
from ..utils.catbox import CatboxError
from ..utils.flags import DLFlags
@ -327,14 +326,6 @@ class Media(commands.Cog):
utils.check_for_ssrf(url)
ytdl_opts = {**DEFAULT_OPTS, "format": "bv*/b"}
@utils.in_executor
def to_jpeg(image):
im = Image.open(image)
buff = BytesIO()
im.save(buff, "JPEG", quality=90)
buff.seek(0)
return buff
if not (re.fullmatch(TIMESTAMP_RE, timestamp) or re.fullmatch(SECONDS_RE, timestamp)):
return await ctx.reply("Invalid timestamp format, check out `$help screencap`.")
@ -359,7 +350,7 @@ class Media(commands.Cog):
buff = BytesIO(stdout)
if len(stdout) > MAX_DISCORD_SIZE:
buff = await to_jpeg(buff)
buff = await compress_image(buff, fmt="JPEG", quality=90)
msg += "\nThe image was too big for me to upload so I converted it to JPEG Q90."
dfile = discord.File(buff, f"{title}.png")
return await ctx.reply(content=msg, file=dfile)

View File

@ -11,7 +11,7 @@ import magic
from discord.ext import commands
from .. import utils
from ..utils.common import ArtemisError
from ..utils.common import ArtemisError, compress_image, get_reply
from ..utils.constants import TESSERACT_LANGUAGES
from ..utils.flags import Flags, OCRFlags, OCRTranslateFlags
from ..utils.iso_639 import get_language_name
@ -48,8 +48,17 @@ class OCR(commands.Cog):
embed = discord.Embed(description=msg, color=discord.Color.red())
return await ctx.reply(embed=embed)
message = await utils.get_message_or_reference(ctx)
image = await utils.get_attachment_or_url(ctx, message, url, ["image/jpeg", "image/png"])
if url or ctx.message.attachments:
message = ctx.message
else:
message = await get_reply(ctx)
if not message:
raise ArtemisError("Could not find any images.")
image = await utils.get_file_from_attachment_or_url(
ctx, message, url, ["image/jpeg", "image/png"]
)
args = f"tesseract stdin stdout -l {lang}"
result = await utils.run_cmd(args, input=image)
@ -89,8 +98,23 @@ class OCR(commands.Cog):
await ctx.typing()
message = await utils.get_message_or_reference(ctx)
image = await utils.get_attachment_or_url(ctx, message, url, ["image/jpeg", "image/png"])
if url or ctx.message.attachments:
message = ctx.message
else:
message = await get_reply(ctx)
if not message:
raise ArtemisError("Could not find any images.")
image = await utils.get_file_from_attachment_or_url(
ctx, message, url, ["image/jpeg", "image/png"]
)
try:
image = await compress_image(image, size=1000)
image = image.getvalue()
except Exception as e:
raise ArtemisError(f"Could not compress image: {e}") from e
content_type = magic.from_buffer(image, mime=True)
ext = mimetypes.guess_extension(content_type)
@ -111,7 +135,7 @@ class OCR(commands.Cog):
match = re.search(final_data_re, html)
if not match:
if ctx.author.id == self.bot.owner.id:
await ctx.send(file=utils.File(html, "lens.html"))
await ctx.send(file=utils.file(html, "lens.html"))
raise ArtemisError("No text detected.")
_lang, lines = match.groups()

View File

@ -11,18 +11,30 @@ from ipaddress import ip_address
from subprocess import PIPE
from time import perf_counter, time_ns
from time import time as _time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
List,
Literal,
Optional,
Sequence,
TypeVar,
)
from urllib.parse import quote_plus, urlparse
import tomllib
import discord
import humanize
import pendulum
import pykakasi
import tomllib
from aiohttp.helpers import is_ip_address
from discord.ext import commands
import feedparser
from rapidfuzz import process
from PIL import Image
from .. import utils
@ -55,7 +67,7 @@ class InvalidColour(commands.BadArgument):
class BetterColour(commands.Converter):
async def convert(self, ctx, argument: str):
async def convert(self, _ctx, argument: str):
colour_name = fuzzy_search_one(argument, list(utils.COMMON_COLOURS), cutoff=60)
if not colour_name:
raise InvalidColour("Invalid colour code/name.")
@ -67,7 +79,7 @@ class BetterColour(commands.Converter):
class URL(commands.Converter):
"""URL converter."""
async def convert(self, ctx, argument: str):
async def convert(self, _ctx, argument: str):
argument = argument.strip("<>")
match = is_valid_url(argument)
if match:
@ -172,7 +184,7 @@ class ProgressBarMessage:
if self._delete_on_finished:
return await self._msg.delete()
self._msg = await self._msg.edit(**self._finished_kwargs)
except Exception as err:
except BaseException as err:
print(str(err))
if self._msg:
await self._msg.edit(content="Error while rendering progress bar.")
@ -222,7 +234,7 @@ class ProgressBarMessage:
self.finish()
def File(fp: bytes | str | list | dict, filename: str):
def file(fp: bytes | str | list | dict, filename: str):
"""Dirty discord.File helper for fast debugging."""
if isinstance(fp, bytes):
buf = BytesIO(fp)
@ -257,11 +269,14 @@ def read_toml(path: str) -> Any:
return tomllib.load(f)
def in_executor(func: Callable):
F = TypeVar("F", bound=Callable)
def in_executor(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
"""A decorator for running a function in a thread."""
@functools.wraps(func)
def decorator(*args: Any, **kwargs: Any):
def decorator(*args: Any, **kwargs: Any) -> Coroutine[Any, Any, T]:
return asyncio.to_thread(func, *args, **kwargs)
return decorator
@ -284,11 +299,11 @@ def time(resolution: Literal["s", "ms", "ns"] = "s") -> int:
return time_ns()
def trim(text: Optional[str], max: int) -> Optional[str]:
def trim(text: Optional[str], max_len: int) -> Optional[str]:
"""Trims text to specified max length."""
if text is None:
return None
return f"{text[:max - 3]}..." if len(text) > max else text
return f"{text[:max_len - 3]}..." if len(text) > max_len else text
def romajify(text: str, strict: bool = True) -> str:
@ -391,20 +406,20 @@ async def run_cmd(args: str, shell=False, input=None) -> CommandResult:
stdin = PIPE if input else None
try:
if shell:
process = await asyncio.create_subprocess_shell(
subprocess = await asyncio.create_subprocess_shell(
args, stdout=PIPE, stderr=PIPE, stdin=stdin
)
else:
split_args = shlex.split(args)
process = await asyncio.create_subprocess_exec(
subprocess = await asyncio.create_subprocess_exec(
*split_args, stdout=PIPE, stderr=PIPE, stdin=stdin
)
stdout, stderr = await process.communicate(input=input)
stdout, stderr = await subprocess.communicate(input=input)
except Exception as err:
raise CommandExecutionError(err)
raise CommandExecutionError(err) from err
return CommandResult(stdout, stderr, process.returncode)
return CommandResult(stdout, stderr, subprocess.returncode)
async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File | str:
@ -424,6 +439,26 @@ async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File
raise CommandExecutionError("The file is too big to upload.")
@in_executor
def compress_image(
image: bytes | BytesIO, fmt: str = "JPEG", quality: int = 90, size=None
) -> BytesIO:
if isinstance(image, bytes):
image = BytesIO(image)
im = Image.open(image)
if im.mode in ("RGBA", "P"):
im = im.convert("RGB")
if size:
im.thumbnail((size, size), Image.Resampling.LANCZOS)
buff = BytesIO()
im.save(buff, fmt, quality=quality)
buff.seek(0)
return buff
def parse_short_time(time_string: str, as_duration: bool = False):
compiled = re.compile(
"""(?:(?P<years>[0-9])(?:years?|y))? # e.g. 2y
@ -447,10 +482,35 @@ def parse_short_time(time_string: str, as_duration: bool = False):
return pendulum.now("UTC") + pendulum.duration(**data)
async def get_attachment_or_url(
def extract_urls(s: str | None):
if not s:
return []
return re.findall(URL_RE, s)
def extract_first_url(s: str | None) -> Optional[str]:
url = next(iter(extract_urls(s)), None)
if not url:
return None
return url[0]
async def get_reply(ctx: commands.Context[Artemis]) -> discord.Message | None:
reference = ctx.message.reference
if not reference:
return None
try:
return reference.cached_message or await ctx.channel.fetch_message(reference.message_id)
except discord.DiscordException:
return None
async def get_file_from_attachment_or_url(
ctx: commands.Context[Artemis], message: discord.Message, url: Optional[str], types: list = None
) -> bytes:
if not message.attachments and not url:
is_replied_to = ctx.message is not message
if not message.attachments and not url and (is_replied_to and not message.content):
raise ArtemisError("Please send me an attachment or URL first!")
elif message.attachments:
attachment = message.attachments[0]
@ -462,7 +522,12 @@ async def get_attachment_or_url(
f"Unsupported file type, should be one of: `{', '.join(types)}`."
)
return await attachment.read()
elif url:
elif url or (is_replied_to and message.content):
if is_replied_to:
url = extract_first_url(message.content)
if not url:
raise ArtemisError("No URL found in message.")
url = url.strip("<>")
if not is_valid_url(url):
raise ArtemisError("URL is not valid.")
@ -480,19 +545,8 @@ async def get_attachment_or_url(
elif r.content_type not in types:
raise ArtemisError("Unsupported file type, should be an image.")
return await r.read()
except Exception:
raise ArtemisError("An error occured when trying to connect to the given URL.")
async def get_message_or_reference(ctx: commands.Context[Artemis]) -> discord.Message:
reference = ctx.message.reference
if reference:
try:
return reference.cached_message or await ctx.channel.fetch_message(reference.message_id)
except Exception:
return ctx.message
else:
return ctx.message
except BaseException as err:
raise ArtemisError("An error occured when trying to connect to the given URL.") from err
T = TypeVar("T")