Commit 269ebc2
Changed files (4)
src
src/asr/voice_recognition.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import contextlib
import re
from pathlib import Path
@@ -8,12 +9,13 @@ from pyrogram.client import Client
from pyrogram.types import Message
from asr.tecent_asr import Credential, FlashRecognitionRequest, FlashRecognizer
-from config import ASR_MAX_DURATION, ENABLE, PREFIX, TOKEN, cache
+from config import ASR_MAX_DURATION, CAPTION_LENGTH, ENABLE, PREFIX, TOKEN, cache
from messages.parser import parse_msg
from messages.progress import modify_progress
-from messages.sender import send2tg
-from messages.utils import equal_prefix, startswith_prefix
-from multimedia import convert_to_audio, parse_media_info
+from messages.sender import send2tg, send_texts
+from messages.utils import equal_prefix, get_reply_to, smart_split, startswith_prefix
+from multimedia import convert_to_audio
+from utils import to_int
# ruff: noqa: RUF001
@@ -93,10 +95,11 @@ async def voice_to_text(
return
if not (trigger_message := get_trigger_message(message, asr_need_prefix, asr_skip_voice, asr_skip_audio, asr_skip_video)):
return
- trigger_info = parse_msg(trigger_message)
+ this_info = parse_msg(message, silent=True)
+ trigger_info = parse_msg(trigger_message, silent=True)
asr_engine = "16k_zh-PY" # default: 中英粤
- if matched := re.match(r"/asr\s+([^.。,,/\s]+)", str(message.text)): # /asr yue
+ if matched := re.match(r"/asr\s+([^.。,,/\s]+)", this_info["text"]): # /asr yue
asr_engine = f"16k_{matched.group(1)}"
asr_engine = asr_engine.replace("16k_fy", "16k_zh_dialect") # fix dialect engine code
@@ -111,7 +114,7 @@ async def voice_to_text(
voice_format = ""
path: str | Path = await trigger_message.download() # type: ignore
if trigger_info["mtype"] == "voice": # audio/ogg
- voice_format = str(trigger_message.voice.mime_type).split("/")[-1] # set voice format
+ voice_format = str(trigger_info["mime_type"]).split("/")[-1] # set voice format
elif trigger_info["mtype"] in ["audio", "video"]:
path = convert_to_audio(path, ext="m4a")
voice_format = "m4a"
@@ -137,10 +140,9 @@ async def voice_to_text(
await modify_progress(text=msg, force_update=True)
return
- # 音频长度
- duration = parse_media_info(path).get("duration", 0)
- if duration > ASR_MAX_DURATION:
- msg = f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒"
+ # Skip very long
+ if trigger_info["duration"] > ASR_MAX_DURATION:
+ msg = f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {trigger_info['duration']}秒"
logger.error(msg)
await modify_progress(text=msg, force_update=True, **kwargs)
return
@@ -150,7 +152,6 @@ async def voice_to_text(
recognizer = FlashRecognizer(TOKEN.TENCENT_ASR_APPID, credential_var)
req = FlashRecognitionRequest(engine_type=asr_engine)
req.set_voice_format(voice_format)
-
final = ""
try:
with path.open("rb") as f:
@@ -163,14 +164,24 @@ async def voice_to_text(
for cid, text in enumerate(texts):
final += f"通道{cid + 1}: {text}\n"
if final:
- final = f"🗣语音转文字:\n{final}"
- logger.success(f"Recognized text: {final}")
- await send2tg(client, trigger_message, texts=final.replace("。", "。\n"), **kwargs)
+ final = f"🗣语音转文字:\n{final}".replace("。", "。\n")
+ logger.success(f"{final!r}")
+
+ # send results
+ caption = smart_split(final, CAPTION_LENGTH)[0]
+ remaining_texts = final.removeprefix(caption)
+ reply_parameters = get_reply_to(this_info["mid"], trigger_info["mid"])
+ target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
+ await client.copy_message(chat_id=to_int(target_chat), from_chat_id=trigger_info["cid"], message_id=trigger_info["mid"], caption=caption, reply_parameters=reply_parameters)
+ await send_texts(client, target_chat, reply_parameters, texts=remaining_texts)
await modify_progress(del_status=True, **kwargs)
except Exception as e:
logger.error(f"Failed to recognize audio: {e}")
finally:
path.unlink(missing_ok=True)
+ with contextlib.suppress(Exception):
+ if this_info["mtype"] == "text":
+ await message.delete()
@cache.memoize(ttl=10)
src/messages/sender.py
@@ -11,8 +11,8 @@ from pyrogram.types import Message, ReplyParameters
from config import CAPTION_LENGTH
from messages.preprocess import preprocess_media, warp_media_group
from messages.progress import modify_progress, telegram_uploading
-from messages.utils import get_reply_to, summay_media
-from utils import smart_split, to_int
+from messages.utils import get_reply_to, smart_split, summay_media
+from utils import to_int
async def send2tg(
src/messages/utils.py
@@ -5,7 +5,7 @@ import re
from pyrogram.types import ReplyParameters
-from config import cache
+from config import TEXT_LENGTH, cache
from utils import readable_size, to_int
@@ -77,3 +77,28 @@ def sender_markdown_to_html(sender: str) -> str:
if not sender:
return ""
return re.sub(r"^👤\[@(.*?)\]\(tg://user\?id=(\d+)\)", r'👤<a href="tg://user?id=\2">@\1</a>', sender)
+
+
+def smart_split(text: str, chars_per_string: int = TEXT_LENGTH) -> list[str]:
+ """Splits one string into multiple strings, with a maximum amount of `chars_per_string` characters per string."""
+
+ def _text_before_last(substr: str) -> str:
+ return substr.join(part.split(substr)[:-1]) + substr
+
+ parts = []
+ while True:
+ if len(text) < chars_per_string:
+ parts.append(text)
+ return parts
+
+ part = text[:chars_per_string]
+
+ if "\n" in part:
+ part = _text_before_last("\n")
+ elif ". " in part:
+ part = _text_before_last(". ")
+ elif " " in part:
+ part = _text_before_last(" ")
+
+ parts.append(part)
+ text = text[len(part) :]
src/utils.py
@@ -15,7 +15,7 @@ from loguru import logger
from pyrogram.client import Client
from yt_dlp.extractor import gen_extractors
-from config import DOWNLOAD_DIR, TEXT_LENGTH, TZ, cache
+from config import DOWNLOAD_DIR, TZ, cache
# ruff: noqa: RUF001
@@ -24,31 +24,6 @@ def nowdt(tz: str = "UTC") -> datetime:
return datetime.now(ZoneInfo(tz))
-def smart_split(text: str, chars_per_string: int = TEXT_LENGTH) -> list[str]:
- """Splits one string into multiple strings, with a maximum amount of `chars_per_string` characters per string."""
-
- def _text_before_last(substr: str) -> str:
- return substr.join(part.split(substr)[:-1]) + substr
-
- parts = []
- while True:
- if len(text) < chars_per_string:
- parts.append(text)
- return parts
-
- part = text[:chars_per_string]
-
- if "\n" in part:
- part = _text_before_last("\n")
- elif ". " in part:
- part = _text_before_last(". ")
- elif " " in part:
- part = _text_before_last(" ")
-
- parts.append(part)
- text = text[len(part) :]
-
-
def split_parts(first: int = 0, middle: int = 0, last: int = 0) -> dict:
"""Split a list of items into three parts: first, middle, and last.