improvements to lens ocr - image compression, finding url in replies

This commit is contained in:
artie 2024-09-26 23:15:10 +02:00
parent 6e350099d3
commit 348ef8af3d
3 changed files with 117 additions and 48 deletions

View File

@ -17,12 +17,11 @@ import pendulum
import yt_dlp import yt_dlp
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from discord.ext import commands from discord.ext import commands
from PIL import Image
from pycaption import SRTWriter, WebVTTReader from pycaption import SRTWriter, WebVTTReader
from yt_dlp.utils import parse_duration from yt_dlp.utils import parse_duration
from .. import utils from .. import utils
from ..utils.common import ArtemisError from ..utils.common import ArtemisError, compress_image
from ..utils.constants import MAX_DISCORD_SIZE, MAX_LITTERBOX_SIZE, TEMP_DIR from ..utils.constants import MAX_DISCORD_SIZE, MAX_LITTERBOX_SIZE, TEMP_DIR
from ..utils.catbox import CatboxError from ..utils.catbox import CatboxError
from ..utils.flags import DLFlags from ..utils.flags import DLFlags
@ -327,14 +326,6 @@ class Media(commands.Cog):
utils.check_for_ssrf(url) utils.check_for_ssrf(url)
ytdl_opts = {**DEFAULT_OPTS, "format": "bv*/b"} ytdl_opts = {**DEFAULT_OPTS, "format": "bv*/b"}
@utils.in_executor
def to_jpeg(image):
im = Image.open(image)
buff = BytesIO()
im.save(buff, "JPEG", quality=90)
buff.seek(0)
return buff
if not (re.fullmatch(TIMESTAMP_RE, timestamp) or re.fullmatch(SECONDS_RE, timestamp)): if not (re.fullmatch(TIMESTAMP_RE, timestamp) or re.fullmatch(SECONDS_RE, timestamp)):
return await ctx.reply("Invalid timestamp format, check out `$help screencap`.") return await ctx.reply("Invalid timestamp format, check out `$help screencap`.")
@ -359,7 +350,7 @@ class Media(commands.Cog):
buff = BytesIO(stdout) buff = BytesIO(stdout)
if len(stdout) > MAX_DISCORD_SIZE: if len(stdout) > MAX_DISCORD_SIZE:
buff = await to_jpeg(buff) buff = await compress_image(buff, fmt="JPEG", quality=90)
msg += "\nThe image was too big for me to upload so I converted it to JPEG Q90." msg += "\nThe image was too big for me to upload so I converted it to JPEG Q90."
dfile = discord.File(buff, f"{title}.png") dfile = discord.File(buff, f"{title}.png")
return await ctx.reply(content=msg, file=dfile) return await ctx.reply(content=msg, file=dfile)

View File

@ -11,7 +11,7 @@ import magic
from discord.ext import commands from discord.ext import commands
from .. import utils from .. import utils
from ..utils.common import ArtemisError from ..utils.common import ArtemisError, compress_image, get_reply
from ..utils.constants import TESSERACT_LANGUAGES from ..utils.constants import TESSERACT_LANGUAGES
from ..utils.flags import Flags, OCRFlags, OCRTranslateFlags from ..utils.flags import Flags, OCRFlags, OCRTranslateFlags
from ..utils.iso_639 import get_language_name from ..utils.iso_639 import get_language_name
@ -48,8 +48,17 @@ class OCR(commands.Cog):
embed = discord.Embed(description=msg, color=discord.Color.red()) embed = discord.Embed(description=msg, color=discord.Color.red())
return await ctx.reply(embed=embed) return await ctx.reply(embed=embed)
message = await utils.get_message_or_reference(ctx) if url or ctx.message.attachments:
image = await utils.get_attachment_or_url(ctx, message, url, ["image/jpeg", "image/png"]) message = ctx.message
else:
message = await get_reply(ctx)
if not message:
raise ArtemisError("Could not find any images.")
image = await utils.get_file_from_attachment_or_url(
ctx, message, url, ["image/jpeg", "image/png"]
)
args = f"tesseract stdin stdout -l {lang}" args = f"tesseract stdin stdout -l {lang}"
result = await utils.run_cmd(args, input=image) result = await utils.run_cmd(args, input=image)
@ -89,8 +98,23 @@ class OCR(commands.Cog):
await ctx.typing() await ctx.typing()
message = await utils.get_message_or_reference(ctx) if url or ctx.message.attachments:
image = await utils.get_attachment_or_url(ctx, message, url, ["image/jpeg", "image/png"]) message = ctx.message
else:
message = await get_reply(ctx)
if not message:
raise ArtemisError("Could not find any images.")
image = await utils.get_file_from_attachment_or_url(
ctx, message, url, ["image/jpeg", "image/png"]
)
try:
image = await compress_image(image, size=1000)
image = image.getvalue()
except Exception as e:
raise ArtemisError(f"Could not compress image: {e}") from e
content_type = magic.from_buffer(image, mime=True) content_type = magic.from_buffer(image, mime=True)
ext = mimetypes.guess_extension(content_type) ext = mimetypes.guess_extension(content_type)
@ -111,7 +135,7 @@ class OCR(commands.Cog):
match = re.search(final_data_re, html) match = re.search(final_data_re, html)
if not match: if not match:
if ctx.author.id == self.bot.owner.id: if ctx.author.id == self.bot.owner.id:
await ctx.send(file=utils.File(html, "lens.html")) await ctx.send(file=utils.file(html, "lens.html"))
raise ArtemisError("No text detected.") raise ArtemisError("No text detected.")
_lang, lines = match.groups() _lang, lines = match.groups()

View File

@ -11,18 +11,30 @@ from ipaddress import ip_address
from subprocess import PIPE from subprocess import PIPE
from time import perf_counter, time_ns from time import perf_counter, time_ns
from time import time as _time from time import time as _time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, TypeVar from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
List,
Literal,
Optional,
Sequence,
TypeVar,
)
from urllib.parse import quote_plus, urlparse from urllib.parse import quote_plus, urlparse
import tomllib
import discord import discord
import humanize import humanize
import pendulum import pendulum
import pykakasi import pykakasi
import tomllib
from aiohttp.helpers import is_ip_address from aiohttp.helpers import is_ip_address
from discord.ext import commands from discord.ext import commands
import feedparser import feedparser
from rapidfuzz import process from rapidfuzz import process
from PIL import Image
from .. import utils from .. import utils
@ -55,7 +67,7 @@ class InvalidColour(commands.BadArgument):
class BetterColour(commands.Converter): class BetterColour(commands.Converter):
async def convert(self, ctx, argument: str): async def convert(self, _ctx, argument: str):
colour_name = fuzzy_search_one(argument, list(utils.COMMON_COLOURS), cutoff=60) colour_name = fuzzy_search_one(argument, list(utils.COMMON_COLOURS), cutoff=60)
if not colour_name: if not colour_name:
raise InvalidColour("Invalid colour code/name.") raise InvalidColour("Invalid colour code/name.")
@ -67,7 +79,7 @@ class BetterColour(commands.Converter):
class URL(commands.Converter): class URL(commands.Converter):
"""URL converter.""" """URL converter."""
async def convert(self, ctx, argument: str): async def convert(self, _ctx, argument: str):
argument = argument.strip("<>") argument = argument.strip("<>")
match = is_valid_url(argument) match = is_valid_url(argument)
if match: if match:
@ -172,7 +184,7 @@ class ProgressBarMessage:
if self._delete_on_finished: if self._delete_on_finished:
return await self._msg.delete() return await self._msg.delete()
self._msg = await self._msg.edit(**self._finished_kwargs) self._msg = await self._msg.edit(**self._finished_kwargs)
except Exception as err: except BaseException as err:
print(str(err)) print(str(err))
if self._msg: if self._msg:
await self._msg.edit(content="Error while rendering progress bar.") await self._msg.edit(content="Error while rendering progress bar.")
@ -222,7 +234,7 @@ class ProgressBarMessage:
self.finish() self.finish()
def File(fp: bytes | str | list | dict, filename: str): def file(fp: bytes | str | list | dict, filename: str):
"""Dirty discord.File helper for fast debugging.""" """Dirty discord.File helper for fast debugging."""
if isinstance(fp, bytes): if isinstance(fp, bytes):
buf = BytesIO(fp) buf = BytesIO(fp)
@ -257,11 +269,14 @@ def read_toml(path: str) -> Any:
return tomllib.load(f) return tomllib.load(f)
def in_executor(func: Callable): F = TypeVar("F", bound=Callable)
def in_executor(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
"""A decorator for running a function in a thread.""" """A decorator for running a function in a thread."""
@functools.wraps(func) @functools.wraps(func)
def decorator(*args: Any, **kwargs: Any): def decorator(*args: Any, **kwargs: Any) -> Coroutine[Any, Any, T]:
return asyncio.to_thread(func, *args, **kwargs) return asyncio.to_thread(func, *args, **kwargs)
return decorator return decorator
@ -284,11 +299,11 @@ def time(resolution: Literal["s", "ms", "ns"] = "s") -> int:
return time_ns() return time_ns()
def trim(text: Optional[str], max: int) -> Optional[str]: def trim(text: Optional[str], max_len: int) -> Optional[str]:
"""Trims text to specified max length.""" """Trims text to specified max length."""
if text is None: if text is None:
return None return None
return f"{text[:max - 3]}..." if len(text) > max else text return f"{text[:max_len - 3]}..." if len(text) > max_len else text
def romajify(text: str, strict: bool = True) -> str: def romajify(text: str, strict: bool = True) -> str:
@ -391,20 +406,20 @@ async def run_cmd(args: str, shell=False, input=None) -> CommandResult:
stdin = PIPE if input else None stdin = PIPE if input else None
try: try:
if shell: if shell:
process = await asyncio.create_subprocess_shell( subprocess = await asyncio.create_subprocess_shell(
args, stdout=PIPE, stderr=PIPE, stdin=stdin args, stdout=PIPE, stderr=PIPE, stdin=stdin
) )
else: else:
split_args = shlex.split(args) split_args = shlex.split(args)
process = await asyncio.create_subprocess_exec( subprocess = await asyncio.create_subprocess_exec(
*split_args, stdout=PIPE, stderr=PIPE, stdin=stdin *split_args, stdout=PIPE, stderr=PIPE, stdin=stdin
) )
stdout, stderr = await process.communicate(input=input) stdout, stderr = await subprocess.communicate(input=input)
except Exception as err: except Exception as err:
raise CommandExecutionError(err) raise CommandExecutionError(err) from err
return CommandResult(stdout, stderr, process.returncode) return CommandResult(stdout, stderr, subprocess.returncode)
async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File | str: async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File | str:
@ -424,6 +439,26 @@ async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File
raise CommandExecutionError("The file is too big to upload.") raise CommandExecutionError("The file is too big to upload.")
@in_executor
def compress_image(
image: bytes | BytesIO, fmt: str = "JPEG", quality: int = 90, size=None
) -> BytesIO:
if isinstance(image, bytes):
image = BytesIO(image)
im = Image.open(image)
if im.mode in ("RGBA", "P"):
im = im.convert("RGB")
if size:
im.thumbnail((size, size), Image.Resampling.LANCZOS)
buff = BytesIO()
im.save(buff, fmt, quality=quality)
buff.seek(0)
return buff
def parse_short_time(time_string: str, as_duration: bool = False): def parse_short_time(time_string: str, as_duration: bool = False):
compiled = re.compile( compiled = re.compile(
"""(?:(?P<years>[0-9])(?:years?|y))? # e.g. 2y """(?:(?P<years>[0-9])(?:years?|y))? # e.g. 2y
@ -447,10 +482,35 @@ def parse_short_time(time_string: str, as_duration: bool = False):
return pendulum.now("UTC") + pendulum.duration(**data) return pendulum.now("UTC") + pendulum.duration(**data)
async def get_attachment_or_url( def extract_urls(s: str | None):
if not s:
return []
return re.findall(URL_RE, s)
def extract_first_url(s: str | None) -> Optional[str]:
url = next(iter(extract_urls(s)), None)
if not url:
return None
return url[0]
async def get_reply(ctx: commands.Context[Artemis]) -> discord.Message | None:
reference = ctx.message.reference
if not reference:
return None
try:
return reference.cached_message or await ctx.channel.fetch_message(reference.message_id)
except discord.DiscordException:
return None
async def get_file_from_attachment_or_url(
ctx: commands.Context[Artemis], message: discord.Message, url: Optional[str], types: list = None ctx: commands.Context[Artemis], message: discord.Message, url: Optional[str], types: list = None
) -> bytes: ) -> bytes:
if not message.attachments and not url: is_replied_to = ctx.message is not message
if not message.attachments and not url and (is_replied_to and not message.content):
raise ArtemisError("Please send me an attachment or URL first!") raise ArtemisError("Please send me an attachment or URL first!")
elif message.attachments: elif message.attachments:
attachment = message.attachments[0] attachment = message.attachments[0]
@ -462,7 +522,12 @@ async def get_attachment_or_url(
f"Unsupported file type, should be one of: `{', '.join(types)}`." f"Unsupported file type, should be one of: `{', '.join(types)}`."
) )
return await attachment.read() return await attachment.read()
elif url: elif url or (is_replied_to and message.content):
if is_replied_to:
url = extract_first_url(message.content)
if not url:
raise ArtemisError("No URL found in message.")
url = url.strip("<>") url = url.strip("<>")
if not is_valid_url(url): if not is_valid_url(url):
raise ArtemisError("URL is not valid.") raise ArtemisError("URL is not valid.")
@ -480,19 +545,8 @@ async def get_attachment_or_url(
elif r.content_type not in types: elif r.content_type not in types:
raise ArtemisError("Unsupported file type, should be an image.") raise ArtemisError("Unsupported file type, should be an image.")
return await r.read() return await r.read()
except Exception: except BaseException as err:
raise ArtemisError("An error occured when trying to connect to the given URL.") raise ArtemisError("An error occured when trying to connect to the given URL.") from err
async def get_message_or_reference(ctx: commands.Context[Artemis]) -> discord.Message:
reference = ctx.message.reference
if reference:
try:
return reference.cached_message or await ctx.channel.fetch_message(reference.message_id)
except Exception:
return ctx.message
else:
return ctx.message
T = TypeVar("T") T = TypeVar("T")