main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import io
  4from pathlib import Path
  5
  6from glom import glom
  7from loguru import logger
  8from pyrogram.client import Client
  9from pyrogram.enums import ParseMode
 10from pyrogram.types import Message
 11
 12from asr.ali import ali_asr
 13from asr.cloudflare import cloudflare_asr
 14from asr.corrector import asr_corrector
 15from asr.deepgram import deepgram_asr
 16from asr.gemini import gemini_asr
 17from asr.groq import groq_asr
 18from asr.tecent import tencent_asr
 19from asr.utils import audio_duration, auto_choose_asr_engine
 20from config import ASR, CAPTION_LENGTH, PREFIX, TEXT_LENGTH
 21from messages.parser import parse_msg
 22from messages.progress import modify_progress
 23from messages.sender import send2tg
 24from messages.utils import blockquote, count_without_entities, delete_message, equal_prefix, get_reply_to, startswith_prefix
 25from publish import publish_telegraph
 26from utils import readable_time, to_int
 27
 28# https://cloud.tencent.com/document/product/1093/52097
 29HELP = f"""🗣**语音转文字**
 30使用说明: 以 `{PREFIX.ASR}` 回复包含音频的消息 (如语音, 视频, 音乐)
 31默认可以识别普通话、粤语、英语三种语言。
 32识别其他语种可在`{PREFIX.ASR}`后加上语种代码, 如:
 33以`{PREFIX.ASR} ja`回复音频消息识别日语
 34以`{PREFIX.ASR} fr`回复音频消息识别法语
 35
 36**目前支持以下语种:**
 37fy: 多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话
 38ja: 日语
 39ko: 韩语
 40vi: 越南语
 41ms: 马来语
 42id: 印度尼西亚语
 43fil: 菲律宾语
 44th: 泰语
 45pt: 葡萄牙语
 46tr: 土耳其语
 47ar: 阿拉伯语
 48es: 西班牙语
 49hi: 印地语
 50fr: 法语
 51de: 德语
 52"""
 53
 54LANG_MAP = {
 55    "16k_zh-PY": "中英粤",
 56    "16k_fy": "多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话",
 57    "16k_ja": "日语",
 58    "16k_ko": "韩语",
 59    "16k_vi": "越南语",
 60    "16k_ms": "马来语",
 61    "16k_id": "印度尼西亚语",
 62    "16k_fil": "菲律宾语",
 63    "16k_th": "泰语",
 64    "16k_pt": "葡萄牙语",
 65    "16k_tr": "土耳其语",
 66    "16k_ar": "阿拉伯语",
 67    "16k_es": "西班牙语",
 68    "16k_hi": "印地语",
 69    "16k_fr": "法语",
 70    "16k_de": "德语",
 71}
 72
 73
 74def get_msg_to_asr(message: Message, *, asr_need_prefix: bool = True) -> Message | None:
 75    """Get the message to be recognized by ASR.
 76
 77    By default, "/asr" prefix is needed to trigger ASR function.
 78    """
 79    # skip no "/asr" prefix message if asr_need_prefix
 80    if asr_need_prefix and not startswith_prefix(message.content, prefix=PREFIX.ASR):
 81        return None
 82    # no need prefix or has "/asr" prefix
 83
 84    mtype = glom(message, "media.value", default="text") or "text"
 85    # has "/asr" prefix
 86    if startswith_prefix(message.content, prefix=PREFIX.ASR):
 87        if mtype in ["voice", "audio", "video"]:
 88            return message
 89        if reply_msg := message.reply_to_message:
 90            reply_mtype = glom(reply_msg, "media.value", default="text") or "text"
 91            if reply_mtype in ["voice", "audio", "video"]:
 92                return reply_msg
 93    elif mtype == "voice":  # no need "/asr" prefix
 94        return message
 95    return None
 96
 97
 98async def voice_to_text(
 99    client: Client,
100    message: Message,
101    asr_engine: str = ASR.DEFAULT_ENGINE,
102    *,
103    asr_need_prefix: bool = True,
104    to_telegraph: bool = True,
105    **kwargs,
106) -> None:
107    """Voice, audio, video message to text.
108
109    By default, "/asr" prefix is needed in in Group & Channel & Bot chats to trigger this function.
110    In private chat, no need to add "/asr" prefix for voice message, but the video & audio message still need it.
111
112    Args:
113        client (Client): The Pyrogram client.
114        message (Message): The trigger message object.
115        asr_need_prefix (bool, optional): If True, must prepend "/asr" prefix to call ASR function.
116        to_telegraph (bool, optional): If True, publish the result to Telegraph.
117
118    """
119    # send docs if message == "/asr", without reply
120    if equal_prefix(message.text, prefix=PREFIX.ASR) and not message.reply_to_message:
121        await send2tg(client, message, texts=HELP, **kwargs)
122        return
123
124    msg_to_asr = get_msg_to_asr(message, asr_need_prefix=asr_need_prefix)
125    if not msg_to_asr:
126        return
127    this_info = parse_msg(message, silent=True)
128    asr_msg_info = parse_msg(msg_to_asr, silent=True)
129
130    remain_text = this_info["text"].removeprefix(PREFIX.ASR).strip().lower()
131    tencent_language = "16k_zh-PY"  # default: 中英粤
132    if remain_text in ["fy", "ja", "ko", "vi", "ms", "id", "fil", "th", "pt", "tr", "ar", "es", "hi", "fr", "de"]:
133        # tencent asr
134        asr_engine = "tencent"
135        tencent_language = f"16k_{remain_text}".replace("fy", "zh_dialect")
136
137    elif remain_text:
138        asr_engine = remain_text
139    msg = f"[ASR] 收到消息: {asr_msg_info['mtype']}, 开始下载..."
140    logger.info(msg)
141    if kwargs.get("show_progress"):
142        res = await send2tg(client, msg_to_asr, texts=msg, **kwargs)
143        kwargs["progress"] = res[0]
144
145    path: str | Path = await msg_to_asr.download()  # type: ignore
146    path = Path(path).expanduser().resolve()
147    if not path.is_file():
148        msg = f"❌下载 {asr_msg_info['mtype']} 文件失败, 无法识别"
149        logger.error(msg)
150        await modify_progress(text=msg, force_update=True, **kwargs)
151        return
152
153    res = await asr_file(path, engine=asr_engine, tencent_language=tencent_language, message=msg_to_asr, **kwargs)
154    if error := res.get("error"):
155        await modify_progress(kwargs.get("progress"), text=error, force_update=True)
156        return
157    if texts := res.get("texts"):
158        final = blockquote(texts) if len(texts) > 300 else texts
159        # send results
160        target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
161        reply_parameters = get_reply_to(asr_msg_info["mid"], kwargs.get("reply_msg_id", 0))
162        length = await count_without_entities(final)
163        if length < CAPTION_LENGTH:  # short
164            await client.copy_message(chat_id=to_int(target_chat), from_chat_id=asr_msg_info["cid"], message_id=asr_msg_info["mid"], caption=final, reply_parameters=reply_parameters)
165        elif length < TEXT_LENGTH:  # middle
166            await client.send_message(to_int(target_chat), final, reply_parameters=reply_parameters)
167        else:  # long
168            caption = asr_msg_info["html"]
169            if to_telegraph:
170                html = "\n".join([f"<p>{s}</p>" for s in texts.split("\n")])
171                if telegraph_url := await publish_telegraph(title=asr_msg_info["text"] or "语音识别结果", html=html, author=asr_msg_info["full_name"], url=asr_msg_info["message_url"]):
172                    caption += f"\n<a href={telegraph_url}>⚡️即时预览</a>"
173            with io.BytesIO(texts.encode("utf-8")) as f:
174                await client.send_document(
175                    to_int(target_chat),
176                    f,
177                    parse_mode=ParseMode.HTML,
178                    file_name="语音识别结果.txt",
179                    caption=caption.strip(),
180                    reply_parameters=reply_parameters,
181                )
182        await modify_progress(del_status=True, **kwargs)
183
184    [await delete_message(msg) for msg in res.get("sent_messages", [])]
185    if this_info["mtype"] == "text":
186        await delete_message(message)
187
188
189async def asr_file(
190    path: str | Path,
191    engine: str = "",
192    prompt: str = "",
193    *,
194    tencent_language: str = "16k_zh-PY",
195    enable_corrector: bool = False,
196    corrector_model: str = "asr-corrector",
197    corrector_reference: str | None = None,
198    delete_local_file: bool = True,
199    delete_gemini_file: bool = True,
200    **kwargs,
201) -> dict:
202    """Get ASR results of an audio file."""
203    path = Path(path).expanduser().resolve()
204    if not path.is_file():
205        return {"error": f"{path} is not exist"}
206    duration = audio_duration(path)
207    engine = auto_choose_asr_engine(duration=duration, engine=engine)
208    log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)} {path.name}"
209    logger.debug(log)
210    await modify_progress(message=kwargs.get("progress"), text=log, force_update=True)
211    res = {}
212    try:
213        if engine == "tencent":
214            res = await tencent_asr(path, tencent_language, duration)
215        elif engine == "ali":
216            res = await ali_asr(path)
217        elif engine == "deepgram":
218            res = await deepgram_asr(path)
219        elif engine == "gemini":
220            res = await gemini_asr(path=path, prompt=prompt, delete_gemini_file=delete_gemini_file)
221        elif engine == "cloudflare":
222            res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"))
223        elif engine == "groq":
224            res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""))
225        else:
226            return {"error": "ASR method not supported"}
227        if res.get("texts"):
228            logger.success(f"{res['texts']!r}")
229    except Exception as e:
230        error = f"Failed to recognize audio: {e}"
231        logger.error(error)
232        res["error"] = res.get("error", error)
233    finally:
234        if delete_local_file:
235            path.unlink(missing_ok=True)
236        elif path.is_file():
237            res["audio_file"] = path
238    if enable_corrector or corrector_reference:
239        res["texts"] = await asr_corrector(res["texts"], corrector_reference, corrector_model)
240    return res