mirror of
https://github.com/artiemis/artemis.git
synced 2026-02-14 08:31:55 +00:00
improvements to lens ocr - image compression, finding url in replies
This commit is contained in:
parent
6e350099d3
commit
348ef8af3d
@ -17,12 +17,11 @@ import pendulum
|
|||||||
import yt_dlp
|
import yt_dlp
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from PIL import Image
|
|
||||||
from pycaption import SRTWriter, WebVTTReader
|
from pycaption import SRTWriter, WebVTTReader
|
||||||
from yt_dlp.utils import parse_duration
|
from yt_dlp.utils import parse_duration
|
||||||
|
|
||||||
from .. import utils
|
from .. import utils
|
||||||
from ..utils.common import ArtemisError
|
from ..utils.common import ArtemisError, compress_image
|
||||||
from ..utils.constants import MAX_DISCORD_SIZE, MAX_LITTERBOX_SIZE, TEMP_DIR
|
from ..utils.constants import MAX_DISCORD_SIZE, MAX_LITTERBOX_SIZE, TEMP_DIR
|
||||||
from ..utils.catbox import CatboxError
|
from ..utils.catbox import CatboxError
|
||||||
from ..utils.flags import DLFlags
|
from ..utils.flags import DLFlags
|
||||||
@ -327,14 +326,6 @@ class Media(commands.Cog):
|
|||||||
utils.check_for_ssrf(url)
|
utils.check_for_ssrf(url)
|
||||||
ytdl_opts = {**DEFAULT_OPTS, "format": "bv*/b"}
|
ytdl_opts = {**DEFAULT_OPTS, "format": "bv*/b"}
|
||||||
|
|
||||||
@utils.in_executor
|
|
||||||
def to_jpeg(image):
|
|
||||||
im = Image.open(image)
|
|
||||||
buff = BytesIO()
|
|
||||||
im.save(buff, "JPEG", quality=90)
|
|
||||||
buff.seek(0)
|
|
||||||
return buff
|
|
||||||
|
|
||||||
if not (re.fullmatch(TIMESTAMP_RE, timestamp) or re.fullmatch(SECONDS_RE, timestamp)):
|
if not (re.fullmatch(TIMESTAMP_RE, timestamp) or re.fullmatch(SECONDS_RE, timestamp)):
|
||||||
return await ctx.reply("Invalid timestamp format, check out `$help screencap`.")
|
return await ctx.reply("Invalid timestamp format, check out `$help screencap`.")
|
||||||
|
|
||||||
@ -359,7 +350,7 @@ class Media(commands.Cog):
|
|||||||
buff = BytesIO(stdout)
|
buff = BytesIO(stdout)
|
||||||
|
|
||||||
if len(stdout) > MAX_DISCORD_SIZE:
|
if len(stdout) > MAX_DISCORD_SIZE:
|
||||||
buff = await to_jpeg(buff)
|
buff = await compress_image(buff, fmt="JPEG", quality=90)
|
||||||
msg += "\nThe image was too big for me to upload so I converted it to JPEG Q90."
|
msg += "\nThe image was too big for me to upload so I converted it to JPEG Q90."
|
||||||
dfile = discord.File(buff, f"{title}.png")
|
dfile = discord.File(buff, f"{title}.png")
|
||||||
return await ctx.reply(content=msg, file=dfile)
|
return await ctx.reply(content=msg, file=dfile)
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import magic
|
|||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from .. import utils
|
from .. import utils
|
||||||
from ..utils.common import ArtemisError
|
from ..utils.common import ArtemisError, compress_image, get_reply
|
||||||
from ..utils.constants import TESSERACT_LANGUAGES
|
from ..utils.constants import TESSERACT_LANGUAGES
|
||||||
from ..utils.flags import Flags, OCRFlags, OCRTranslateFlags
|
from ..utils.flags import Flags, OCRFlags, OCRTranslateFlags
|
||||||
from ..utils.iso_639 import get_language_name
|
from ..utils.iso_639 import get_language_name
|
||||||
@ -48,8 +48,17 @@ class OCR(commands.Cog):
|
|||||||
embed = discord.Embed(description=msg, color=discord.Color.red())
|
embed = discord.Embed(description=msg, color=discord.Color.red())
|
||||||
return await ctx.reply(embed=embed)
|
return await ctx.reply(embed=embed)
|
||||||
|
|
||||||
message = await utils.get_message_or_reference(ctx)
|
if url or ctx.message.attachments:
|
||||||
image = await utils.get_attachment_or_url(ctx, message, url, ["image/jpeg", "image/png"])
|
message = ctx.message
|
||||||
|
else:
|
||||||
|
message = await get_reply(ctx)
|
||||||
|
|
||||||
|
if not message:
|
||||||
|
raise ArtemisError("Could not find any images.")
|
||||||
|
|
||||||
|
image = await utils.get_file_from_attachment_or_url(
|
||||||
|
ctx, message, url, ["image/jpeg", "image/png"]
|
||||||
|
)
|
||||||
|
|
||||||
args = f"tesseract stdin stdout -l {lang}"
|
args = f"tesseract stdin stdout -l {lang}"
|
||||||
result = await utils.run_cmd(args, input=image)
|
result = await utils.run_cmd(args, input=image)
|
||||||
@ -89,8 +98,23 @@ class OCR(commands.Cog):
|
|||||||
|
|
||||||
await ctx.typing()
|
await ctx.typing()
|
||||||
|
|
||||||
message = await utils.get_message_or_reference(ctx)
|
if url or ctx.message.attachments:
|
||||||
image = await utils.get_attachment_or_url(ctx, message, url, ["image/jpeg", "image/png"])
|
message = ctx.message
|
||||||
|
else:
|
||||||
|
message = await get_reply(ctx)
|
||||||
|
|
||||||
|
if not message:
|
||||||
|
raise ArtemisError("Could not find any images.")
|
||||||
|
|
||||||
|
image = await utils.get_file_from_attachment_or_url(
|
||||||
|
ctx, message, url, ["image/jpeg", "image/png"]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = await compress_image(image, size=1000)
|
||||||
|
image = image.getvalue()
|
||||||
|
except Exception as e:
|
||||||
|
raise ArtemisError(f"Could not compress image: {e}") from e
|
||||||
|
|
||||||
content_type = magic.from_buffer(image, mime=True)
|
content_type = magic.from_buffer(image, mime=True)
|
||||||
ext = mimetypes.guess_extension(content_type)
|
ext = mimetypes.guess_extension(content_type)
|
||||||
@ -111,7 +135,7 @@ class OCR(commands.Cog):
|
|||||||
match = re.search(final_data_re, html)
|
match = re.search(final_data_re, html)
|
||||||
if not match:
|
if not match:
|
||||||
if ctx.author.id == self.bot.owner.id:
|
if ctx.author.id == self.bot.owner.id:
|
||||||
await ctx.send(file=utils.File(html, "lens.html"))
|
await ctx.send(file=utils.file(html, "lens.html"))
|
||||||
raise ArtemisError("No text detected.")
|
raise ArtemisError("No text detected.")
|
||||||
_lang, lines = match.groups()
|
_lang, lines = match.groups()
|
||||||
|
|
||||||
|
|||||||
@ -11,18 +11,30 @@ from ipaddress import ip_address
|
|||||||
from subprocess import PIPE
|
from subprocess import PIPE
|
||||||
from time import perf_counter, time_ns
|
from time import perf_counter, time_ns
|
||||||
from time import time as _time
|
from time import time as _time
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, TypeVar
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
from urllib.parse import quote_plus, urlparse
|
from urllib.parse import quote_plus, urlparse
|
||||||
|
import tomllib
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import humanize
|
import humanize
|
||||||
import pendulum
|
import pendulum
|
||||||
import pykakasi
|
import pykakasi
|
||||||
import tomllib
|
|
||||||
from aiohttp.helpers import is_ip_address
|
from aiohttp.helpers import is_ip_address
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
import feedparser
|
import feedparser
|
||||||
from rapidfuzz import process
|
from rapidfuzz import process
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from .. import utils
|
from .. import utils
|
||||||
|
|
||||||
@ -55,7 +67,7 @@ class InvalidColour(commands.BadArgument):
|
|||||||
|
|
||||||
|
|
||||||
class BetterColour(commands.Converter):
|
class BetterColour(commands.Converter):
|
||||||
async def convert(self, ctx, argument: str):
|
async def convert(self, _ctx, argument: str):
|
||||||
colour_name = fuzzy_search_one(argument, list(utils.COMMON_COLOURS), cutoff=60)
|
colour_name = fuzzy_search_one(argument, list(utils.COMMON_COLOURS), cutoff=60)
|
||||||
if not colour_name:
|
if not colour_name:
|
||||||
raise InvalidColour("Invalid colour code/name.")
|
raise InvalidColour("Invalid colour code/name.")
|
||||||
@ -67,7 +79,7 @@ class BetterColour(commands.Converter):
|
|||||||
class URL(commands.Converter):
|
class URL(commands.Converter):
|
||||||
"""URL converter."""
|
"""URL converter."""
|
||||||
|
|
||||||
async def convert(self, ctx, argument: str):
|
async def convert(self, _ctx, argument: str):
|
||||||
argument = argument.strip("<>")
|
argument = argument.strip("<>")
|
||||||
match = is_valid_url(argument)
|
match = is_valid_url(argument)
|
||||||
if match:
|
if match:
|
||||||
@ -172,7 +184,7 @@ class ProgressBarMessage:
|
|||||||
if self._delete_on_finished:
|
if self._delete_on_finished:
|
||||||
return await self._msg.delete()
|
return await self._msg.delete()
|
||||||
self._msg = await self._msg.edit(**self._finished_kwargs)
|
self._msg = await self._msg.edit(**self._finished_kwargs)
|
||||||
except Exception as err:
|
except BaseException as err:
|
||||||
print(str(err))
|
print(str(err))
|
||||||
if self._msg:
|
if self._msg:
|
||||||
await self._msg.edit(content="Error while rendering progress bar.")
|
await self._msg.edit(content="Error while rendering progress bar.")
|
||||||
@ -222,7 +234,7 @@ class ProgressBarMessage:
|
|||||||
self.finish()
|
self.finish()
|
||||||
|
|
||||||
|
|
||||||
def File(fp: bytes | str | list | dict, filename: str):
|
def file(fp: bytes | str | list | dict, filename: str):
|
||||||
"""Dirty discord.File helper for fast debugging."""
|
"""Dirty discord.File helper for fast debugging."""
|
||||||
if isinstance(fp, bytes):
|
if isinstance(fp, bytes):
|
||||||
buf = BytesIO(fp)
|
buf = BytesIO(fp)
|
||||||
@ -257,11 +269,14 @@ def read_toml(path: str) -> Any:
|
|||||||
return tomllib.load(f)
|
return tomllib.load(f)
|
||||||
|
|
||||||
|
|
||||||
def in_executor(func: Callable):
|
F = TypeVar("F", bound=Callable)
|
||||||
|
|
||||||
|
|
||||||
|
def in_executor(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
|
||||||
"""A decorator for running a function in a thread."""
|
"""A decorator for running a function in a thread."""
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def decorator(*args: Any, **kwargs: Any):
|
def decorator(*args: Any, **kwargs: Any) -> Coroutine[Any, Any, T]:
|
||||||
return asyncio.to_thread(func, *args, **kwargs)
|
return asyncio.to_thread(func, *args, **kwargs)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@ -284,11 +299,11 @@ def time(resolution: Literal["s", "ms", "ns"] = "s") -> int:
|
|||||||
return time_ns()
|
return time_ns()
|
||||||
|
|
||||||
|
|
||||||
def trim(text: Optional[str], max: int) -> Optional[str]:
|
def trim(text: Optional[str], max_len: int) -> Optional[str]:
|
||||||
"""Trims text to specified max length."""
|
"""Trims text to specified max length."""
|
||||||
if text is None:
|
if text is None:
|
||||||
return None
|
return None
|
||||||
return f"{text[:max - 3]}..." if len(text) > max else text
|
return f"{text[:max_len - 3]}..." if len(text) > max_len else text
|
||||||
|
|
||||||
|
|
||||||
def romajify(text: str, strict: bool = True) -> str:
|
def romajify(text: str, strict: bool = True) -> str:
|
||||||
@ -391,20 +406,20 @@ async def run_cmd(args: str, shell=False, input=None) -> CommandResult:
|
|||||||
stdin = PIPE if input else None
|
stdin = PIPE if input else None
|
||||||
try:
|
try:
|
||||||
if shell:
|
if shell:
|
||||||
process = await asyncio.create_subprocess_shell(
|
subprocess = await asyncio.create_subprocess_shell(
|
||||||
args, stdout=PIPE, stderr=PIPE, stdin=stdin
|
args, stdout=PIPE, stderr=PIPE, stdin=stdin
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
split_args = shlex.split(args)
|
split_args = shlex.split(args)
|
||||||
process = await asyncio.create_subprocess_exec(
|
subprocess = await asyncio.create_subprocess_exec(
|
||||||
*split_args, stdout=PIPE, stderr=PIPE, stdin=stdin
|
*split_args, stdout=PIPE, stderr=PIPE, stdin=stdin
|
||||||
)
|
)
|
||||||
|
|
||||||
stdout, stderr = await process.communicate(input=input)
|
stdout, stderr = await subprocess.communicate(input=input)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise CommandExecutionError(err)
|
raise CommandExecutionError(err) from err
|
||||||
|
|
||||||
return CommandResult(stdout, stderr, process.returncode)
|
return CommandResult(stdout, stderr, subprocess.returncode)
|
||||||
|
|
||||||
|
|
||||||
async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File | str:
|
async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File | str:
|
||||||
@ -424,6 +439,26 @@ async def run_cmd_to_file(args: str, filename: str, shell=False) -> discord.File
|
|||||||
raise CommandExecutionError("The file is too big to upload.")
|
raise CommandExecutionError("The file is too big to upload.")
|
||||||
|
|
||||||
|
|
||||||
|
@in_executor
|
||||||
|
def compress_image(
|
||||||
|
image: bytes | BytesIO, fmt: str = "JPEG", quality: int = 90, size=None
|
||||||
|
) -> BytesIO:
|
||||||
|
if isinstance(image, bytes):
|
||||||
|
image = BytesIO(image)
|
||||||
|
im = Image.open(image)
|
||||||
|
|
||||||
|
if im.mode in ("RGBA", "P"):
|
||||||
|
im = im.convert("RGB")
|
||||||
|
|
||||||
|
if size:
|
||||||
|
im.thumbnail((size, size), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
buff = BytesIO()
|
||||||
|
im.save(buff, fmt, quality=quality)
|
||||||
|
buff.seek(0)
|
||||||
|
return buff
|
||||||
|
|
||||||
|
|
||||||
def parse_short_time(time_string: str, as_duration: bool = False):
|
def parse_short_time(time_string: str, as_duration: bool = False):
|
||||||
compiled = re.compile(
|
compiled = re.compile(
|
||||||
"""(?:(?P<years>[0-9])(?:years?|y))? # e.g. 2y
|
"""(?:(?P<years>[0-9])(?:years?|y))? # e.g. 2y
|
||||||
@ -447,10 +482,35 @@ def parse_short_time(time_string: str, as_duration: bool = False):
|
|||||||
return pendulum.now("UTC") + pendulum.duration(**data)
|
return pendulum.now("UTC") + pendulum.duration(**data)
|
||||||
|
|
||||||
|
|
||||||
async def get_attachment_or_url(
|
def extract_urls(s: str | None):
|
||||||
|
if not s:
|
||||||
|
return []
|
||||||
|
return re.findall(URL_RE, s)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_first_url(s: str | None) -> Optional[str]:
|
||||||
|
url = next(iter(extract_urls(s)), None)
|
||||||
|
if not url:
|
||||||
|
return None
|
||||||
|
return url[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_reply(ctx: commands.Context[Artemis]) -> discord.Message | None:
|
||||||
|
reference = ctx.message.reference
|
||||||
|
if not reference:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return reference.cached_message or await ctx.channel.fetch_message(reference.message_id)
|
||||||
|
except discord.DiscordException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_file_from_attachment_or_url(
|
||||||
ctx: commands.Context[Artemis], message: discord.Message, url: Optional[str], types: list = None
|
ctx: commands.Context[Artemis], message: discord.Message, url: Optional[str], types: list = None
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
if not message.attachments and not url:
|
is_replied_to = ctx.message is not message
|
||||||
|
|
||||||
|
if not message.attachments and not url and (is_replied_to and not message.content):
|
||||||
raise ArtemisError("Please send me an attachment or URL first!")
|
raise ArtemisError("Please send me an attachment or URL first!")
|
||||||
elif message.attachments:
|
elif message.attachments:
|
||||||
attachment = message.attachments[0]
|
attachment = message.attachments[0]
|
||||||
@ -462,7 +522,12 @@ async def get_attachment_or_url(
|
|||||||
f"Unsupported file type, should be one of: `{', '.join(types)}`."
|
f"Unsupported file type, should be one of: `{', '.join(types)}`."
|
||||||
)
|
)
|
||||||
return await attachment.read()
|
return await attachment.read()
|
||||||
elif url:
|
elif url or (is_replied_to and message.content):
|
||||||
|
if is_replied_to:
|
||||||
|
url = extract_first_url(message.content)
|
||||||
|
if not url:
|
||||||
|
raise ArtemisError("No URL found in message.")
|
||||||
|
|
||||||
url = url.strip("<>")
|
url = url.strip("<>")
|
||||||
if not is_valid_url(url):
|
if not is_valid_url(url):
|
||||||
raise ArtemisError("URL is not valid.")
|
raise ArtemisError("URL is not valid.")
|
||||||
@ -480,19 +545,8 @@ async def get_attachment_or_url(
|
|||||||
elif r.content_type not in types:
|
elif r.content_type not in types:
|
||||||
raise ArtemisError("Unsupported file type, should be an image.")
|
raise ArtemisError("Unsupported file type, should be an image.")
|
||||||
return await r.read()
|
return await r.read()
|
||||||
except Exception:
|
except BaseException as err:
|
||||||
raise ArtemisError("An error occured when trying to connect to the given URL.")
|
raise ArtemisError("An error occured when trying to connect to the given URL.") from err
|
||||||
|
|
||||||
|
|
||||||
async def get_message_or_reference(ctx: commands.Context[Artemis]) -> discord.Message:
|
|
||||||
reference = ctx.message.reference
|
|
||||||
if reference:
|
|
||||||
try:
|
|
||||||
return reference.cached_message or await ctx.channel.fetch_message(reference.message_id)
|
|
||||||
except Exception:
|
|
||||||
return ctx.message
|
|
||||||
else:
|
|
||||||
return ctx.message
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user