main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import io
4from pathlib import Path
5
6from glom import glom
7from loguru import logger
8from pyrogram.client import Client
9from pyrogram.enums import ParseMode
10from pyrogram.types import Message
11
12from asr.ali import ali_asr
13from asr.cloudflare import cloudflare_asr
14from asr.corrector import asr_corrector
15from asr.deepgram import deepgram_asr
16from asr.gemini import gemini_asr
17from asr.groq import groq_asr
18from asr.tecent import tencent_asr
19from asr.utils import audio_duration, auto_choose_asr_engine
20from config import ASR, CAPTION_LENGTH, PREFIX, TEXT_LENGTH
21from messages.parser import parse_msg
22from messages.progress import modify_progress
23from messages.sender import send2tg
24from messages.utils import blockquote, count_without_entities, delete_message, equal_prefix, get_reply_to, startswith_prefix
25from publish import publish_telegraph
26from utils import readable_time, to_int
27
28# https://cloud.tencent.com/document/product/1093/52097
29HELP = f"""🗣**语音转文字**
30使用说明: 以 `{PREFIX.ASR}` 回复包含音频的消息 (如语音, 视频, 音乐)
31默认可以识别普通话、粤语、英语三种语言。
32识别其他语种可在`{PREFIX.ASR}`后加上语种代码, 如:
33以`{PREFIX.ASR} ja`回复音频消息识别日语
34以`{PREFIX.ASR} fr`回复音频消息识别法语
35
36**目前支持以下语种:**
37fy: 多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话
38ja: 日语
39ko: 韩语
40vi: 越南语
41ms: 马来语
42id: 印度尼西亚语
43fil: 菲律宾语
44th: 泰语
45pt: 葡萄牙语
46tr: 土耳其语
47ar: 阿拉伯语
48es: 西班牙语
49hi: 印地语
50fr: 法语
51de: 德语
52"""
53
54LANG_MAP = {
55 "16k_zh-PY": "中英粤",
56 "16k_fy": "多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话",
57 "16k_ja": "日语",
58 "16k_ko": "韩语",
59 "16k_vi": "越南语",
60 "16k_ms": "马来语",
61 "16k_id": "印度尼西亚语",
62 "16k_fil": "菲律宾语",
63 "16k_th": "泰语",
64 "16k_pt": "葡萄牙语",
65 "16k_tr": "土耳其语",
66 "16k_ar": "阿拉伯语",
67 "16k_es": "西班牙语",
68 "16k_hi": "印地语",
69 "16k_fr": "法语",
70 "16k_de": "德语",
71}
72
73
74def get_msg_to_asr(message: Message, *, asr_need_prefix: bool = True) -> Message | None:
75 """Get the message to be recognized by ASR.
76
77 By default, "/asr" prefix is needed to trigger ASR function.
78 """
79 # skip no "/asr" prefix message if asr_need_prefix
80 if asr_need_prefix and not startswith_prefix(message.content, prefix=PREFIX.ASR):
81 return None
82 # no need prefix or has "/asr" prefix
83
84 mtype = glom(message, "media.value", default="text") or "text"
85 # has "/asr" prefix
86 if startswith_prefix(message.content, prefix=PREFIX.ASR):
87 if mtype in ["voice", "audio", "video"]:
88 return message
89 if reply_msg := message.reply_to_message:
90 reply_mtype = glom(reply_msg, "media.value", default="text") or "text"
91 if reply_mtype in ["voice", "audio", "video"]:
92 return reply_msg
93 elif mtype == "voice": # no need "/asr" prefix
94 return message
95 return None
96
97
98async def voice_to_text(
99 client: Client,
100 message: Message,
101 asr_engine: str = ASR.DEFAULT_ENGINE,
102 *,
103 asr_need_prefix: bool = True,
104 to_telegraph: bool = True,
105 **kwargs,
106) -> None:
107 """Voice, audio, video message to text.
108
109 By default, "/asr" prefix is needed in in Group & Channel & Bot chats to trigger this function.
110 In private chat, no need to add "/asr" prefix for voice message, but the video & audio message still need it.
111
112 Args:
113 client (Client): The Pyrogram client.
114 message (Message): The trigger message object.
115 asr_need_prefix (bool, optional): If True, must prepend "/asr" prefix to call ASR function.
116 to_telegraph (bool, optional): If True, publish the result to Telegraph.
117
118 """
119 # send docs if message == "/asr", without reply
120 if equal_prefix(message.text, prefix=PREFIX.ASR) and not message.reply_to_message:
121 await send2tg(client, message, texts=HELP, **kwargs)
122 return
123
124 msg_to_asr = get_msg_to_asr(message, asr_need_prefix=asr_need_prefix)
125 if not msg_to_asr:
126 return
127 this_info = parse_msg(message, silent=True)
128 asr_msg_info = parse_msg(msg_to_asr, silent=True)
129
130 remain_text = this_info["text"].removeprefix(PREFIX.ASR).strip().lower()
131 tencent_language = "16k_zh-PY" # default: 中英粤
132 if remain_text in ["fy", "ja", "ko", "vi", "ms", "id", "fil", "th", "pt", "tr", "ar", "es", "hi", "fr", "de"]:
133 # tencent asr
134 asr_engine = "tencent"
135 tencent_language = f"16k_{remain_text}".replace("fy", "zh_dialect")
136
137 elif remain_text:
138 asr_engine = remain_text
139 msg = f"[ASR] 收到消息: {asr_msg_info['mtype']}, 开始下载..."
140 logger.info(msg)
141 if kwargs.get("show_progress"):
142 res = await send2tg(client, msg_to_asr, texts=msg, **kwargs)
143 kwargs["progress"] = res[0]
144
145 path: str | Path = await msg_to_asr.download() # type: ignore
146 path = Path(path).expanduser().resolve()
147 if not path.is_file():
148 msg = f"❌下载 {asr_msg_info['mtype']} 文件失败, 无法识别"
149 logger.error(msg)
150 await modify_progress(text=msg, force_update=True, **kwargs)
151 return
152
153 res = await asr_file(path, engine=asr_engine, tencent_language=tencent_language, message=msg_to_asr, **kwargs)
154 if error := res.get("error"):
155 await modify_progress(kwargs.get("progress"), text=error, force_update=True)
156 return
157 if texts := res.get("texts"):
158 final = blockquote(texts) if len(texts) > 300 else texts
159 # send results
160 target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
161 reply_parameters = get_reply_to(asr_msg_info["mid"], kwargs.get("reply_msg_id", 0))
162 length = await count_without_entities(final)
163 if length < CAPTION_LENGTH: # short
164 await client.copy_message(chat_id=to_int(target_chat), from_chat_id=asr_msg_info["cid"], message_id=asr_msg_info["mid"], caption=final, reply_parameters=reply_parameters)
165 elif length < TEXT_LENGTH: # middle
166 await client.send_message(to_int(target_chat), final, reply_parameters=reply_parameters)
167 else: # long
168 caption = asr_msg_info["html"]
169 if to_telegraph:
170 html = "\n".join([f"<p>{s}</p>" for s in texts.split("\n")])
171 if telegraph_url := await publish_telegraph(title=asr_msg_info["text"] or "语音识别结果", html=html, author=asr_msg_info["full_name"], url=asr_msg_info["message_url"]):
172 caption += f"\n<a href={telegraph_url}>⚡️即时预览</a>"
173 with io.BytesIO(texts.encode("utf-8")) as f:
174 await client.send_document(
175 to_int(target_chat),
176 f,
177 parse_mode=ParseMode.HTML,
178 file_name="语音识别结果.txt",
179 caption=caption.strip(),
180 reply_parameters=reply_parameters,
181 )
182 await modify_progress(del_status=True, **kwargs)
183
184 [await delete_message(msg) for msg in res.get("sent_messages", [])]
185 if this_info["mtype"] == "text":
186 await delete_message(message)
187
188
189async def asr_file(
190 path: str | Path,
191 engine: str = "",
192 prompt: str = "",
193 *,
194 tencent_language: str = "16k_zh-PY",
195 enable_corrector: bool = False,
196 corrector_model: str = "asr-corrector",
197 corrector_reference: str | None = None,
198 delete_local_file: bool = True,
199 delete_gemini_file: bool = True,
200 **kwargs,
201) -> dict:
202 """Get ASR results of an audio file."""
203 path = Path(path).expanduser().resolve()
204 if not path.is_file():
205 return {"error": f"{path} is not exist"}
206 duration = audio_duration(path)
207 engine = auto_choose_asr_engine(duration=duration, engine=engine)
208 log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)} {path.name}"
209 logger.debug(log)
210 await modify_progress(message=kwargs.get("progress"), text=log, force_update=True)
211 res = {}
212 try:
213 if engine == "tencent":
214 res = await tencent_asr(path, tencent_language, duration)
215 elif engine == "ali":
216 res = await ali_asr(path)
217 elif engine == "deepgram":
218 res = await deepgram_asr(path)
219 elif engine == "gemini":
220 res = await gemini_asr(path=path, prompt=prompt, delete_gemini_file=delete_gemini_file)
221 elif engine == "cloudflare":
222 res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"))
223 elif engine == "groq":
224 res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""))
225 else:
226 return {"error": "ASR method not supported"}
227 if res.get("texts"):
228 logger.success(f"{res['texts']!r}")
229 except Exception as e:
230 error = f"Failed to recognize audio: {e}"
231 logger.error(error)
232 res["error"] = res.get("error", error)
233 finally:
234 if delete_local_file:
235 path.unlink(missing_ok=True)
236 elif path.is_file():
237 res["audio_file"] = path
238 if enable_corrector or corrector_reference:
239 res["texts"] = await asr_corrector(res["texts"], corrector_reference, corrector_model)
240 return res