main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import io
  4from pathlib import Path
  5
  6from glom import glom
  7from loguru import logger
  8from pyrogram.client import Client
  9from pyrogram.types import Message
 10
 11from asr.ali import ali_asr
 12from asr.cloudflare import cloudflare_asr
 13from asr.corrector import asr_corrector
 14from asr.deepgram import deepgram_asr
 15from asr.gemini import gemini_asr
 16from asr.groq import groq_asr
 17from asr.tecent import tencent_asr
 18from asr.utils import audio_duration, auto_choose_asr_engine
 19from config import ASR, CAPTION_LENGTH, PREFIX, TEXT_LENGTH
 20from messages.parser import parse_msg
 21from messages.progress import modify_progress
 22from messages.sender import send2tg
 23from messages.utils import blockquote, count_without_entities, delete_message, equal_prefix, get_reply_to, startswith_prefix
 24from publish import publish_telegraph
 25from utils import readable_time, to_int
 26
 27# https://cloud.tencent.com/document/product/1093/52097
 28HELP = f"""🗣**语音转文字**
 29使用说明: 以 `{PREFIX.ASR}` 回复包含音频的消息 (如语音, 视频, 音乐)
 30默认可以识别普通话、粤语、英语三种语言。
 31识别其他语种可在`{PREFIX.ASR}`后加上语种代码, 如:
 32以`{PREFIX.ASR} ja`回复音频消息识别日语
 33以`{PREFIX.ASR} fr`回复音频消息识别法语
 34
 35**目前支持以下语种:**
 36fy: 多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话
 37ja: 日语
 38ko: 韩语
 39vi: 越南语
 40ms: 马来语
 41id: 印度尼西亚语
 42fil: 菲律宾语
 43th: 泰语
 44pt: 葡萄牙语
 45tr: 土耳其语
 46ar: 阿拉伯语
 47es: 西班牙语
 48hi: 印地语
 49fr: 法语
 50de: 德语
 51"""
 52
 53LANG_MAP = {
 54    "16k_zh-PY": "中英粤",
 55    "16k_fy": "多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话",
 56    "16k_ja": "日语",
 57    "16k_ko": "韩语",
 58    "16k_vi": "越南语",
 59    "16k_ms": "马来语",
 60    "16k_id": "印度尼西亚语",
 61    "16k_fil": "菲律宾语",
 62    "16k_th": "泰语",
 63    "16k_pt": "葡萄牙语",
 64    "16k_tr": "土耳其语",
 65    "16k_ar": "阿拉伯语",
 66    "16k_es": "西班牙语",
 67    "16k_hi": "印地语",
 68    "16k_fr": "法语",
 69    "16k_de": "德语",
 70}
 71
 72
 73def get_msg_to_asr(message: Message, *, asr_need_prefix: bool = True) -> Message | None:
 74    """Get the message to be recognized by ASR.
 75
 76    By default, "/asr" prefix is needed to trigger ASR function.
 77    """
 78    # skip no "/asr" prefix message if asr_need_prefix
 79    if asr_need_prefix and not startswith_prefix(message.content, prefix=PREFIX.ASR):
 80        return None
 81    # no need prefix or has "/asr" prefix
 82
 83    mtype = glom(message, "media.value", default="text") or "text"
 84    # has "/asr" prefix
 85    if startswith_prefix(message.content, prefix=PREFIX.ASR):
 86        if mtype in ["voice", "audio", "video"]:
 87            return message
 88        if reply_msg := message.reply_to_message:
 89            reply_mtype = glom(reply_msg, "media.value", default="text") or "text"
 90            if reply_mtype in ["voice", "audio", "video"]:
 91                return reply_msg
 92    elif mtype == "voice":  # no need "/asr" prefix
 93        return message
 94    return None
 95
 96
 97async def voice_to_text(
 98    client: Client,
 99    message: Message,
100    asr_engine: str = ASR.DEFAULT_ENGINE,
101    *,
102    asr_need_prefix: bool = True,
103    **kwargs,
104) -> None:
105    """Voice, audio, video message to text.
106
107    By default, "/asr" prefix is needed in in Group & Channel & Bot chats to trigger this function.
108    In private chat, no need to add "/asr" prefix for voice message, but the video & audio message still need it.
109
110    Args:
111        client (Client): The Pyrogram client.
112        message (Message): The trigger message object.
113        asr_need_prefix (bool, optional): If True, must prepend "/asr" prefix to call ASR function.
114        to_telegraph (bool, optional): If True, publish the result to Telegraph.
115
116    """
117    # send docs if message == "/asr", without reply
118    if equal_prefix(message.text, prefix=PREFIX.ASR) and not message.reply_to_message:
119        await send2tg(client, message, texts=HELP, **kwargs)
120        return
121
122    msg_to_asr = get_msg_to_asr(message, asr_need_prefix=asr_need_prefix)
123    if not msg_to_asr:
124        return
125    this_info = parse_msg(message, silent=True)
126    asr_msg_info = parse_msg(msg_to_asr, silent=True)
127
128    remain_text = this_info["text"].removeprefix(PREFIX.ASR).strip().lower()
129    tencent_language = "16k_zh-PY"  # default: 中英粤
130    if remain_text in ["fy", "ja", "ko", "vi", "ms", "id", "fil", "th", "pt", "tr", "ar", "es", "hi", "fr", "de"]:
131        # tencent asr
132        asr_engine = "tencent"
133        tencent_language = f"16k_{remain_text}".replace("fy", "zh_dialect")
134
135    elif remain_text:
136        asr_engine = remain_text
137    msg = f"[ASR] 收到消息: {asr_msg_info['mtype']}, 开始下载..."
138    logger.info(msg)
139    if kwargs.get("show_progress"):
140        res = await send2tg(client, msg_to_asr, texts=msg, **kwargs)
141        kwargs["progress"] = res[0]
142
143    path: str | Path = await msg_to_asr.download()  # type: ignore
144    path = Path(path).expanduser().resolve()
145    if not path.is_file():
146        msg = f"❌下载 {asr_msg_info['mtype']} 文件失败, 无法识别"
147        logger.error(msg)
148        await modify_progress(text=msg, force_update=True, **kwargs)
149        return
150
151    res = await asr_file(path, engine=asr_engine, tencent_language=tencent_language, message=msg_to_asr, **kwargs)
152    if error := res.get("error"):
153        await modify_progress(kwargs.get("progress"), text=error, force_update=True)
154        return
155    if texts := res.get("texts"):
156        final = blockquote(texts) if len(texts) > 300 else texts
157        # send results
158        target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
159        reply_parameters = get_reply_to(asr_msg_info["mid"], kwargs.get("reply_msg_id", 0))
160        length = await count_without_entities(final)
161        if length < CAPTION_LENGTH:  # short
162            await client.copy_message(chat_id=to_int(target_chat), from_chat_id=asr_msg_info["cid"], message_id=asr_msg_info["mid"], caption=final, reply_parameters=reply_parameters)
163        elif length < TEXT_LENGTH:  # middle
164            await client.send_message(to_int(target_chat), final, reply_parameters=reply_parameters)
165        else:  # long
166            caption = ""
167            html = "\n".join([f"<p>{s}</p>" for s in texts.split("\n")])
168            if telegraph_url := await publish_telegraph(title=asr_msg_info["text"], html=html, author=asr_msg_info["full_name"], url=asr_msg_info["message_url"]):
169                caption = f"[⚡️即时预览]({telegraph_url})"
170            with io.BytesIO(texts.encode("utf-8")) as f:
171                await client.send_document(to_int(target_chat), f, file_name="语音识别结果.txt", caption=caption, reply_parameters=reply_parameters)
172        await modify_progress(del_status=True, **kwargs)
173
174    [await delete_message(msg) for msg in res.get("sent_messages", [])]
175    if this_info["mtype"] == "text":
176        await delete_message(message)
177
178
179async def asr_file(
180    path: str | Path,
181    engine: str = "",
182    prompt: str = "",
183    *,
184    tencent_language: str = "16k_zh-PY",
185    enable_corrector: bool = False,
186    corrector_model: str = "asr-corrector",
187    corrector_reference: str | None = None,
188    delete_local_file: bool = True,
189    delete_gemini_file: bool = True,
190    **kwargs,
191) -> dict:
192    """Get ASR results of an audio file."""
193    path = Path(path).expanduser().resolve()
194    if not path.is_file():
195        return {"error": f"{path} is not exist"}
196    duration = audio_duration(path)
197    engine = auto_choose_asr_engine(duration=duration, engine=engine)
198    log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)} {path.name}"
199    logger.debug(log)
200    await modify_progress(message=kwargs.get("progress"), text=log, force_update=True)
201    res = {}
202    try:
203        if engine == "tencent":
204            res = await tencent_asr(path, tencent_language, duration)
205        elif engine == "ali":
206            res = await ali_asr(path)
207        elif engine == "deepgram":
208            res = await deepgram_asr(path)
209        elif engine == "gemini":
210            res = await gemini_asr(path=path, prompt=prompt, delete_gemini_file=delete_gemini_file)
211        elif engine == "cloudflare":
212            res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"))
213        elif engine == "groq":
214            res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""))
215        else:
216            return {"error": "ASR method not supported"}
217        if res.get("texts"):
218            logger.success(f"{res['texts']!r}")
219    except Exception as e:
220        error = f"Failed to recognize audio: {e}"
221        logger.error(error)
222        res["error"] = res.get("error", error)
223    finally:
224        if delete_local_file:
225            path.unlink(missing_ok=True)
226        elif path.is_file():
227            res["audio_file"] = path
228    if enable_corrector or corrector_reference:
229        res["texts"] = await asr_corrector(res["texts"], corrector_reference, corrector_model)
230    return res