Commit 269ebc2

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-01-31 03:26:20
feat(asr): add original voice to message
1 parent fe984a8
Changed files (4)
src/asr/voice_recognition.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import contextlib
 import re
 from pathlib import Path
 
@@ -8,12 +9,13 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from asr.tecent_asr import Credential, FlashRecognitionRequest, FlashRecognizer
-from config import ASR_MAX_DURATION, ENABLE, PREFIX, TOKEN, cache
+from config import ASR_MAX_DURATION, CAPTION_LENGTH, ENABLE, PREFIX, TOKEN, cache
 from messages.parser import parse_msg
 from messages.progress import modify_progress
-from messages.sender import send2tg
-from messages.utils import equal_prefix, startswith_prefix
-from multimedia import convert_to_audio, parse_media_info
+from messages.sender import send2tg, send_texts
+from messages.utils import equal_prefix, get_reply_to, smart_split, startswith_prefix
+from multimedia import convert_to_audio
+from utils import to_int
 
 # ruff: noqa: RUF001
 
@@ -93,10 +95,11 @@ async def voice_to_text(
         return
     if not (trigger_message := get_trigger_message(message, asr_need_prefix, asr_skip_voice, asr_skip_audio, asr_skip_video)):
         return
-    trigger_info = parse_msg(trigger_message)
+    this_info = parse_msg(message, silent=True)
+    trigger_info = parse_msg(trigger_message, silent=True)
 
     asr_engine = "16k_zh-PY"  # default: 中英粤
-    if matched := re.match(r"/asr\s+([^.。,,/\s]+)", str(message.text)):  # /asr yue
+    if matched := re.match(r"/asr\s+([^.。,,/\s]+)", this_info["text"]):  # /asr yue
         asr_engine = f"16k_{matched.group(1)}"
     asr_engine = asr_engine.replace("16k_fy", "16k_zh_dialect")  # fix dialect engine code
 
@@ -111,7 +114,7 @@ async def voice_to_text(
     voice_format = ""
     path: str | Path = await trigger_message.download()  # type: ignore
     if trigger_info["mtype"] == "voice":  # audio/ogg
-        voice_format = str(trigger_message.voice.mime_type).split("/")[-1]  # set voice format
+        voice_format = str(trigger_info["mime_type"]).split("/")[-1]  # set voice format
     elif trigger_info["mtype"] in ["audio", "video"]:
         path = convert_to_audio(path, ext="m4a")
         voice_format = "m4a"
@@ -137,10 +140,9 @@ async def voice_to_text(
         await modify_progress(text=msg, force_update=True)
         return
 
-    # 音频长度
-    duration = parse_media_info(path).get("duration", 0)
-    if duration > ASR_MAX_DURATION:
-        msg = f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒"
+    # Skip very long
+    if trigger_info["duration"] > ASR_MAX_DURATION:
+        msg = f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {trigger_info['duration']}秒"
         logger.error(msg)
         await modify_progress(text=msg, force_update=True, **kwargs)
         return
@@ -150,7 +152,6 @@ async def voice_to_text(
     recognizer = FlashRecognizer(TOKEN.TENCENT_ASR_APPID, credential_var)
     req = FlashRecognitionRequest(engine_type=asr_engine)
     req.set_voice_format(voice_format)
-
     final = ""
     try:
         with path.open("rb") as f:
@@ -163,14 +164,24 @@ async def voice_to_text(
             for cid, text in enumerate(texts):
                 final += f"通道{cid + 1}: {text}\n"
         if final:
-            final = f"🗣语音转文字:\n{final}"
-        logger.success(f"Recognized text: {final}")
-        await send2tg(client, trigger_message, texts=final.replace("。", "。\n"), **kwargs)
+            final = f"🗣语音转文字:\n{final}".replace("。", "。\n")
+        logger.success(f"{final!r}")
+
+        # send results
+        caption = smart_split(final, CAPTION_LENGTH)[0]
+        remaining_texts = final.removeprefix(caption)
+        reply_parameters = get_reply_to(this_info["mid"], trigger_info["mid"])
+        target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
+        await client.copy_message(chat_id=to_int(target_chat), from_chat_id=trigger_info["cid"], message_id=trigger_info["mid"], caption=caption, reply_parameters=reply_parameters)
+        await send_texts(client, target_chat, reply_parameters, texts=remaining_texts)
         await modify_progress(del_status=True, **kwargs)
     except Exception as e:
         logger.error(f"Failed to recognize audio: {e}")
     finally:
         path.unlink(missing_ok=True)
+    with contextlib.suppress(Exception):
+        if this_info["mtype"] == "text":
+            await message.delete()
 
 
 @cache.memoize(ttl=10)
src/messages/sender.py
@@ -11,8 +11,8 @@ from pyrogram.types import Message, ReplyParameters
 from config import CAPTION_LENGTH
 from messages.preprocess import preprocess_media, warp_media_group
 from messages.progress import modify_progress, telegram_uploading
-from messages.utils import get_reply_to, summay_media
-from utils import smart_split, to_int
+from messages.utils import get_reply_to, smart_split, summay_media
+from utils import to_int
 
 
 async def send2tg(
src/messages/utils.py
@@ -5,7 +5,7 @@ import re
 
 from pyrogram.types import ReplyParameters
 
-from config import cache
+from config import TEXT_LENGTH, cache
 from utils import readable_size, to_int
 
 
@@ -77,3 +77,28 @@ def sender_markdown_to_html(sender: str) -> str:
     if not sender:
         return ""
     return re.sub(r"^👤\[@(.*?)\]\(tg://user\?id=(\d+)\)", r'👤<a href="tg://user?id=\2">@\1</a>', sender)
+
+
+def smart_split(text: str, chars_per_string: int = TEXT_LENGTH) -> list[str]:
+    """Splits one string into multiple strings, with a maximum amount of `chars_per_string` characters per string."""
+
+    def _text_before_last(substr: str) -> str:
+        return substr.join(part.split(substr)[:-1]) + substr
+
+    parts = []
+    while True:
+        if len(text) < chars_per_string:
+            parts.append(text)
+            return parts
+
+        part = text[:chars_per_string]
+
+        if "\n" in part:
+            part = _text_before_last("\n")
+        elif ". " in part:
+            part = _text_before_last(". ")
+        elif " " in part:
+            part = _text_before_last(" ")
+
+        parts.append(part)
+        text = text[len(part) :]
src/utils.py
@@ -15,7 +15,7 @@ from loguru import logger
 from pyrogram.client import Client
 from yt_dlp.extractor import gen_extractors
 
-from config import DOWNLOAD_DIR, TEXT_LENGTH, TZ, cache
+from config import DOWNLOAD_DIR, TZ, cache
 
 # ruff: noqa: RUF001
 
@@ -24,31 +24,6 @@ def nowdt(tz: str = "UTC") -> datetime:
     return datetime.now(ZoneInfo(tz))
 
 
-def smart_split(text: str, chars_per_string: int = TEXT_LENGTH) -> list[str]:
-    """Splits one string into multiple strings, with a maximum amount of `chars_per_string` characters per string."""
-
-    def _text_before_last(substr: str) -> str:
-        return substr.join(part.split(substr)[:-1]) + substr
-
-    parts = []
-    while True:
-        if len(text) < chars_per_string:
-            parts.append(text)
-            return parts
-
-        part = text[:chars_per_string]
-
-        if "\n" in part:
-            part = _text_before_last("\n")
-        elif ". " in part:
-            part = _text_before_last(". ")
-        elif " " in part:
-            part = _text_before_last(" ")
-
-        parts.append(part)
-        text = text[len(part) :]
-
-
 def split_parts(first: int = 0, middle: int = 0, last: int = 0) -> dict:
     """Split a list of items into three parts: first, middle, and last.