main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import io
4from pathlib import Path
5
6from glom import glom
7from loguru import logger
8from pyrogram.client import Client
9from pyrogram.types import Message
10
11from asr.ali import ali_asr
12from asr.cloudflare import cloudflare_asr
13from asr.corrector import asr_corrector
14from asr.deepgram import deepgram_asr
15from asr.gemini import gemini_asr
16from asr.groq import groq_asr
17from asr.tecent import tencent_asr
18from asr.utils import audio_duration, auto_choose_asr_engine
19from config import ASR, CAPTION_LENGTH, PREFIX, TEXT_LENGTH
20from messages.parser import parse_msg
21from messages.progress import modify_progress
22from messages.sender import send2tg
23from messages.utils import blockquote, count_without_entities, delete_message, equal_prefix, get_reply_to, startswith_prefix
24from publish import publish_telegraph
25from utils import readable_time, to_int
26
27# https://cloud.tencent.com/document/product/1093/52097
28HELP = f"""🗣**语音转文字**
29使用说明: 以 `{PREFIX.ASR}` 回复包含音频的消息 (如语音, 视频, 音乐)
30默认可以识别普通话、粤语、英语三种语言。
31识别其他语种可在`{PREFIX.ASR}`后加上语种代码, 如:
32以`{PREFIX.ASR} ja`回复音频消息识别日语
33以`{PREFIX.ASR} fr`回复音频消息识别法语
34
35**目前支持以下语种:**
36fy: 多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话
37ja: 日语
38ko: 韩语
39vi: 越南语
40ms: 马来语
41id: 印度尼西亚语
42fil: 菲律宾语
43th: 泰语
44pt: 葡萄牙语
45tr: 土耳其语
46ar: 阿拉伯语
47es: 西班牙语
48hi: 印地语
49fr: 法语
50de: 德语
51"""
52
53LANG_MAP = {
54 "16k_zh-PY": "中英粤",
55 "16k_fy": "多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话",
56 "16k_ja": "日语",
57 "16k_ko": "韩语",
58 "16k_vi": "越南语",
59 "16k_ms": "马来语",
60 "16k_id": "印度尼西亚语",
61 "16k_fil": "菲律宾语",
62 "16k_th": "泰语",
63 "16k_pt": "葡萄牙语",
64 "16k_tr": "土耳其语",
65 "16k_ar": "阿拉伯语",
66 "16k_es": "西班牙语",
67 "16k_hi": "印地语",
68 "16k_fr": "法语",
69 "16k_de": "德语",
70}
71
72
73def get_msg_to_asr(message: Message, *, asr_need_prefix: bool = True) -> Message | None:
74 """Get the message to be recognized by ASR.
75
76 By default, "/asr" prefix is needed to trigger ASR function.
77 """
78 # skip no "/asr" prefix message if asr_need_prefix
79 if asr_need_prefix and not startswith_prefix(message.content, prefix=PREFIX.ASR):
80 return None
81 # no need prefix or has "/asr" prefix
82
83 mtype = glom(message, "media.value", default="text") or "text"
84 # has "/asr" prefix
85 if startswith_prefix(message.content, prefix=PREFIX.ASR):
86 if mtype in ["voice", "audio", "video"]:
87 return message
88 if reply_msg := message.reply_to_message:
89 reply_mtype = glom(reply_msg, "media.value", default="text") or "text"
90 if reply_mtype in ["voice", "audio", "video"]:
91 return reply_msg
92 elif mtype == "voice": # no need "/asr" prefix
93 return message
94 return None
95
96
97async def voice_to_text(
98 client: Client,
99 message: Message,
100 asr_engine: str = ASR.DEFAULT_ENGINE,
101 *,
102 asr_need_prefix: bool = True,
103 **kwargs,
104) -> None:
105 """Voice, audio, video message to text.
106
107 By default, "/asr" prefix is needed in in Group & Channel & Bot chats to trigger this function.
108 In private chat, no need to add "/asr" prefix for voice message, but the video & audio message still need it.
109
110 Args:
111 client (Client): The Pyrogram client.
112 message (Message): The trigger message object.
113 asr_need_prefix (bool, optional): If True, must prepend "/asr" prefix to call ASR function.
114 to_telegraph (bool, optional): If True, publish the result to Telegraph.
115
116 """
117 # send docs if message == "/asr", without reply
118 if equal_prefix(message.text, prefix=PREFIX.ASR) and not message.reply_to_message:
119 await send2tg(client, message, texts=HELP, **kwargs)
120 return
121
122 msg_to_asr = get_msg_to_asr(message, asr_need_prefix=asr_need_prefix)
123 if not msg_to_asr:
124 return
125 this_info = parse_msg(message, silent=True)
126 asr_msg_info = parse_msg(msg_to_asr, silent=True)
127
128 remain_text = this_info["text"].removeprefix(PREFIX.ASR).strip().lower()
129 tencent_language = "16k_zh-PY" # default: 中英粤
130 if remain_text in ["fy", "ja", "ko", "vi", "ms", "id", "fil", "th", "pt", "tr", "ar", "es", "hi", "fr", "de"]:
131 # tencent asr
132 asr_engine = "tencent"
133 tencent_language = f"16k_{remain_text}".replace("fy", "zh_dialect")
134
135 elif remain_text:
136 asr_engine = remain_text
137 msg = f"[ASR] 收到消息: {asr_msg_info['mtype']}, 开始下载..."
138 logger.info(msg)
139 if kwargs.get("show_progress"):
140 res = await send2tg(client, msg_to_asr, texts=msg, **kwargs)
141 kwargs["progress"] = res[0]
142
143 path: str | Path = await msg_to_asr.download() # type: ignore
144 path = Path(path).expanduser().resolve()
145 if not path.is_file():
146 msg = f"❌下载 {asr_msg_info['mtype']} 文件失败, 无法识别"
147 logger.error(msg)
148 await modify_progress(text=msg, force_update=True, **kwargs)
149 return
150
151 res = await asr_file(path, engine=asr_engine, tencent_language=tencent_language, message=msg_to_asr, **kwargs)
152 if error := res.get("error"):
153 await modify_progress(kwargs.get("progress"), text=error, force_update=True)
154 return
155 if texts := res.get("texts"):
156 final = blockquote(texts) if len(texts) > 300 else texts
157 # send results
158 target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
159 reply_parameters = get_reply_to(asr_msg_info["mid"], kwargs.get("reply_msg_id", 0))
160 length = await count_without_entities(final)
161 if length < CAPTION_LENGTH: # short
162 await client.copy_message(chat_id=to_int(target_chat), from_chat_id=asr_msg_info["cid"], message_id=asr_msg_info["mid"], caption=final, reply_parameters=reply_parameters)
163 elif length < TEXT_LENGTH: # middle
164 await client.send_message(to_int(target_chat), final, reply_parameters=reply_parameters)
165 else: # long
166 caption = ""
167 html = "\n".join([f"<p>{s}</p>" for s in texts.split("\n")])
168 if telegraph_url := await publish_telegraph(title=asr_msg_info["text"], html=html, author=asr_msg_info["full_name"], url=asr_msg_info["message_url"]):
169 caption = f"[⚡️即时预览]({telegraph_url})"
170 with io.BytesIO(texts.encode("utf-8")) as f:
171 await client.send_document(to_int(target_chat), f, file_name="语音识别结果.txt", caption=caption, reply_parameters=reply_parameters)
172 await modify_progress(del_status=True, **kwargs)
173
174 [await delete_message(msg) for msg in res.get("sent_messages", [])]
175 if this_info["mtype"] == "text":
176 await delete_message(message)
177
178
179async def asr_file(
180 path: str | Path,
181 engine: str = "",
182 prompt: str = "",
183 *,
184 tencent_language: str = "16k_zh-PY",
185 enable_corrector: bool = False,
186 corrector_model: str = "asr-corrector",
187 corrector_reference: str | None = None,
188 delete_local_file: bool = True,
189 delete_gemini_file: bool = True,
190 **kwargs,
191) -> dict:
192 """Get ASR results of an audio file."""
193 path = Path(path).expanduser().resolve()
194 if not path.is_file():
195 return {"error": f"{path} is not exist"}
196 duration = audio_duration(path)
197 engine = auto_choose_asr_engine(duration=duration, engine=engine)
198 log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)} {path.name}"
199 logger.debug(log)
200 await modify_progress(message=kwargs.get("progress"), text=log, force_update=True)
201 res = {}
202 try:
203 if engine == "tencent":
204 res = await tencent_asr(path, tencent_language, duration)
205 elif engine == "ali":
206 res = await ali_asr(path)
207 elif engine == "deepgram":
208 res = await deepgram_asr(path)
209 elif engine == "gemini":
210 res = await gemini_asr(path=path, prompt=prompt, delete_gemini_file=delete_gemini_file)
211 elif engine == "cloudflare":
212 res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"))
213 elif engine == "groq":
214 res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""))
215 else:
216 return {"error": "ASR method not supported"}
217 if res.get("texts"):
218 logger.success(f"{res['texts']!r}")
219 except Exception as e:
220 error = f"Failed to recognize audio: {e}"
221 logger.error(error)
222 res["error"] = res.get("error", error)
223 finally:
224 if delete_local_file:
225 path.unlink(missing_ok=True)
226 elif path.is_file():
227 res["audio_file"] = path
228 if enable_corrector or corrector_reference:
229 res["texts"] = await asr_corrector(res["texts"], corrector_reference, corrector_model)
230 return res