main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import base64
5import hashlib
6import hmac
7from collections.abc import Coroutine
8from decimal import Decimal
9from pathlib import Path
10from typing import Any
11
12import anyio
13from glom import Coalesce, flatten, glom
14from loguru import logger
15
16from asr.groq import merge_transcripts
17from asr.utils import audio_chunk_to_bytes, audio_duration, convert_single_channel, downsampe_audio, get_file_bytes, is_english_word, load_audio
18from config import ASR, FILE_SERVER, PROXY
19from database.alist import delete_alist, upload_alist
20from database.uguu import upload_uguu
21from networking import hx_req
22from utils import nowdt, seconds_to_time
23
24
25def sign(key, msg):
26 return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
27
28
29def generate_tencent_cloud_headers(
30 action: str,
31 payload: str,
32 service: str = "asr",
33 host: str = "asr.tencentcloudapi.com",
34 version: str = "2019-06-14",
35 secret_id: str = ASR.TENCENT_SECRET_ID,
36 secret_key: str = ASR.TENCENT_SECRET_KEY,
37) -> dict:
38 """Generate TencentCloudAPI Headers (TC3-HMAC-SHA256)."""
39 algorithm = "TC3-HMAC-SHA256"
40 now = nowdt()
41 timestamp = str(int(now.timestamp()))
42 date = f"{now:%Y-%m-%d}"
43
44 # ************* 步骤 1: 拼接规范请求串 *************
45 http_request_method = "POST"
46 canonical_uri = "/"
47 canonical_querystring = ""
48 canonical_headers = f"content-type:application/json; charset=utf-8\nhost:{host}\nx-tc-action:{action.lower()}\n"
49 signed_headers = "content-type;host;x-tc-action"
50 hashed_request_payload = hashlib.sha256(payload.encode("utf-8")).hexdigest()
51 canonical_request = f"{http_request_method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{hashed_request_payload}"
52
53 # ************* 步骤 2: 拼接待签名字符串 *************
54 credential_scope = f"{date}/{service}/tc3_request"
55 hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
56 string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}"
57
58 # ************* 步骤 3: 计算签名 *************
59 secret_date = sign(("TC3" + secret_key).encode("utf-8"), date)
60 secret_service = sign(secret_date, service)
61 secret_signing = sign(secret_service, "tc3_request")
62 signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
63
64 # ************* 步骤 4: 拼接 Authorization *************
65 authorization = f"{algorithm} Credential={secret_id}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
66
67 # ************* 步骤 5: 构造 Headers *************
68 return {
69 "Authorization": authorization,
70 "Content-Type": "application/json; charset=utf-8",
71 "Host": host,
72 "X-TC-Action": action,
73 "X-TC-Timestamp": timestamp,
74 "X-TC-Version": version,
75 }
76
77
78async def tencent_asr(path: str | Path, language: str, duration: float) -> dict:
79 """Tencent ASR.
80
81 由于 `录音文件识别` 和 `录音文件识别极速版`免费额度太少
82 所以现在我们只使用 `一句话识别` 来处理所有ASR请求
83
84 Returns:
85 {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
86 """
87 path = Path(path).expanduser().resolve()
88 if not path.is_file():
89 return {"texts": "", "error": "File not found."}
90 supported_ext = [".wav", ".pcm", ".ogg", ".opus", ".oga", ".speex", ".silk", ".mp3", ".m4a", ".aac", ".amr"]
91 audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
92 audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
93 if duration < 1: # some thing error in detecting duration
94 audio_path = await downsampe_audio(path, ext="wav", codec="pcm_s16le")
95 duration = audio_duration(audio_path)
96
97 # max allowed duration is 60s
98 if duration < 60:
99 return await tencent_single_asr(audio_path, language=language)
100 return await tencent_file_chunks(audio_path, language=language, duration=duration)
101
102
103async def tencent_single_asr(path_or_bytes: Path | bytes, language: str, *, offset_seconds: int = 0) -> dict:
104 """Tencent Single Sentence ASR.
105
106 一句话识别 (每月免费额度: 5000次)
107 https://cloud.tencent.com/document/product/1093/35646
108
109 Returns:
110 {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
111
112 example item of segments:
113 {
114 "start": Decimal,
115 "end": Decimal,
116 "text": str,
117 }
118 """
119 final = {"texts": "", "raw_texts": "", "segments": []}
120 if isinstance(path_or_bytes, Path):
121 # max 3 MB
122 file_size = path_or_bytes.stat().st_size
123 audio_path = path_or_bytes if file_size < 3 * 1024 * 1024 else await downsampe_audio(path_or_bytes)
124 voice_format = Path(audio_path).suffix.lower().lstrip(".")
125 if voice_format in ["ogg", "opus", "oga"]: # tencnet only supports ogg-opus
126 voice_format = "ogg-opus"
127 audio_bytes = await get_file_bytes(audio_path)
128 elif isinstance(path_or_bytes, bytes):
129 voice_format = "wav"
130 audio_bytes = path_or_bytes
131 audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
132 payload = f'{{"EngSerViceType":"{language}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{audio_base64}"}}'
133 headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
134 res = await hx_req(
135 "https://asr.tencentcloudapi.com",
136 method="POST",
137 headers=headers,
138 content_data=payload.encode("utf-8"),
139 timeout=60,
140 proxy=PROXY.TENCENT,
141 check_keys=["Response.WordList"],
142 )
143 if res.get("hx_error"):
144 return final | {"error": res["hx_error"]}
145 words = glom(res, "Response.WordList", default=None)
146 if words is None:
147 return final | {"error": "⚠️该音频未识别到文字"}
148 final["raw_texts"] = glom(res, "Response.Result", default="") or ""
149
150 sentences = [] # list of sentence
151 sentence = [] # list of dict
152 for item in words:
153 word = item.get("Word", "")
154 if is_english_word(word) or word.endswith((",", ".", "?", "!")):
155 item["Word"] = word + " "
156 if word.endswith((".", "。", "?", "?", "!", "!")): # noqa: RUF001
157 sentence.append(item)
158 sentences.append(sentence)
159 sentence = []
160 continue
161 sentence.append(item)
162 if sentence:
163 sentences.append(sentence)
164
165 segments = []
166 offset = Decimal(offset_seconds).quantize(Decimal(".01"))
167 for sentence in sentences:
168 start = offset + Decimal(sentence[0].get("StartTime", 0)) / 1000
169 end = offset + Decimal(sentence[-1].get("EndTime", 0)) / 1000
170 text = "".join(x.get("Word", "") for x in sentence)
171 text = text.replace(" ,", ",").replace(" ,", ",") # noqa: RUF001
172 segments.append({"start": start, "end": end, "text": text})
173 final["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in segments) # with timestamp
174 final["segments"] = segments
175 return final
176
177
178async def tencent_file_chunks(
179 path: Path,
180 language: str,
181 duration: float,
182 chunk_seconds: float = 60,
183 overlap_seconds: float = 5,
184) -> dict:
185 """Transcribe audio in chunks with overlap.
186
187 Most of this code is copied from `gemini_file_chunks` in `asr/gemini.py`
188
189 Args:
190 chunk_seconds: Length of each chunk in seconds
191 overlap_seconds: Overlap between chunks in seconds
192
193 Returns:
194 dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
195 """
196 # only support wav file
197 aac_path = path if path.suffix == ".wav" else await downsampe_audio(path)
198 audio, _, sr = load_audio(aac_path)
199 if sr == 0:
200 return {"error": "Failed to load audio."}
201
202 transcription = {}
203 semaphore = asyncio.Semaphore(30) # max concurrent requests
204
205 async def run_with_semaphore(task: Coroutine[Any, Any, dict]) -> dict:
206 async with semaphore:
207 return await task
208
209 try:
210 # Calculate # of chunks
211 total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
212 tasks = []
213 offset_list = []
214
215 # Loop through each chunk, extract current chunk from audio
216 for i in range(total_chunks):
217 start = int(i * (chunk_seconds - overlap_seconds) * sr)
218 end = int(min(start + chunk_seconds * sr, duration * sr))
219 logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
220 chunk = audio[start:end]
221 if chunk.shape[0] == 0: # empty chunk
222 continue
223 tasks.append(audio_chunk_to_bytes(chunk, sr))
224 offset_list.append(int(start / sr))
225 bytes_list = await asyncio.gather(*tasks) # convert chunks to bytes
226 tasks = []
227 for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
228 task = tencent_single_asr(audio_bytes, language=language, offset_seconds=offset_seconds)
229 tasks.append(run_with_semaphore(task))
230 results = await asyncio.gather(*tasks)
231 results = [r for r in results if r.get("segments")]
232 transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
233 except Exception as e:
234 logger.error(e)
235 return {"error": str(e)}
236 return transcription
237
238
239async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
240 """(Deprecated) Tencent Flash ASR.
241
242 已弃用, 请使用 `tencent_single_asr`
243 录音文件识别极速版 (每月免费额度: 5小时)
244 https://cloud.tencent.com/document/product/1093/52097
245 """
246 now = nowdt()
247 params = {
248 "secretid": ASR.TENCENT_SECRET_ID,
249 "engine_type": engine,
250 "voice_format": voice_format,
251 "timestamp": str(int(now.timestamp())),
252 "word_info": 2,
253 }
254 signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{ASR.TENCENT_APPID}?"
255 for k, v in dict(sorted(params.items())).items(): # type: ignore
256 signstr += f"{k}={v}&"
257 signstr = signstr[:-1] # strip last "&"
258
259 hmacstr = hmac.new(ASR.TENCENT_SECRET_KEY.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
260 signature = base64.b64encode(hmacstr).decode("utf-8")
261 headers = {"Host": "asr.cloud.tencent.com", "authorization": signature}
262 url = f"https://{signstr.removeprefix('POST')}"
263 async with await anyio.open_file(path, "rb") as f:
264 res = await hx_req(
265 url,
266 method="POST",
267 headers=headers,
268 content_data=await f.read(),
269 timeout=60,
270 proxy=PROXY.TENCENT,
271 check_kv={"code": 0},
272 check_keys=["flash_result.0.sentence_list.0.word_list"],
273 )
274 if error := res.get("hx_error", ""):
275 if "audio data empty" in error:
276 return {"error": "⚠️该音频未识别到文字"}
277 return {"error": error}
278 sentence_start_ms = flatten(glom(res, "flash_result.*.sentence_list.*.start_time"), levels=1)
279 words = flatten(glom(res, "flash_result.*.sentence_list.*.word_list"), levels=1)
280 return generate_tencent_transcription(sentence_start_ms, words)
281
282
283async def tencent_async_asr(path: str | Path, engine: str) -> dict:
284 """(Deprecated) Create Tencent ASR Task.
285
286 已弃用, 请使用 `tencent_single_asr`
287 注意: 此接口不支持中文文件名
288 录音文件识别请求 (每月免费额度: 10小时)
289 https://cloud.tencent.com/document/api/1093/37823
290 """
291 path = Path(path).expanduser().resolve()
292 if ASR.TENCENT_FS_ENGINE.lower() == "local":
293 url = FILE_SERVER.removesuffix("/") + "/" + path.name
294 elif ASR.TENCENT_FS_ENGINE.lower() == "uguu":
295 if path.stat().st_size > 100 * 1024 * 1024: # 100 MB
296 path = await downsampe_audio(path)
297 url = await upload_uguu(path) # max 100 MB for Uguu
298 elif ASR.TENCENT_FS_ENGINE.lower() == "alist":
299 url = await upload_alist(path)
300 else:
301 return {"error": f"Unsupported file server engine: {ASR.TENCENT_FS_ENGINE}"}
302
303 payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":2,"SourceType":0,"Url":"{url}"}}'
304 headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
305 resp = await hx_req(
306 "https://asr.tencentcloudapi.com",
307 method="POST",
308 headers=headers,
309 content_data=payload.encode("utf-8"),
310 timeout=600,
311 proxy=PROXY.TENCENT,
312 check_keys=["Response.Data.TaskId"],
313 )
314 if resp.get("hx_error"):
315 return {"error": resp["hx_error"]}
316 task_id = resp["Response"]["Data"]["TaskId"]
317 logger.success(f"ASR任务提交成功, TaskID: {task_id}")
318 return await tencent_query_asr(task_id, file_name=path.name)
319
320
321async def tencent_query_asr(task_id: int, file_name: str, query_times: int = 0) -> dict:
322 """Query Tencent ASR Task.
323
324 录音文件识别结果查询
325 https://cloud.tencent.com/document/api/1093/37822
326 """
327 payload = f'{{"TaskId":{task_id}}}'
328 headers = generate_tencent_cloud_headers(action="DescribeTaskStatus", payload=payload)
329 result = await hx_req(
330 "https://asr.tencentcloudapi.com",
331 method="POST",
332 headers=headers,
333 content_data=payload.encode("utf-8"),
334 timeout=600,
335 proxy=PROXY.TENCENT,
336 check_keys=["Response.Data.StatusStr"],
337 )
338 if result.get("hx_error"):
339 return {"error": result["hx_error"]}
340 status = glom(result, "Response.Data.StatusStr")
341 while status in ["waiting", "doing"] and query_times < 600: # max 10 minutes
342 await asyncio.sleep(1)
343 query_times += 1
344 logger.trace(f"Status: [{status} ({query_times}/600)], Wating TaskID: {task_id}")
345 result = await tencent_query_asr(task_id, file_name, query_times)
346 if result.get("texts") or result.get("error"):
347 return result
348 status = glom(result, "Response.Data.StatusStr")
349 if ASR.TENCENT_FS_ENGINE.lower() == "alist":
350 await delete_alist(file_name)
351 if status == "success":
352 if glom(result, "Response.Data.ResultDetail") is None:
353 return {"error": "⚠️该音频未识别到文字"}
354 sentence_start_ms = glom(result, "Response.Data.ResultDetail.*.StartMs")
355 words = glom(result, "Response.Data.ResultDetail.*.Words")
356 return generate_tencent_transcription(sentence_start_ms, words)
357 return {"error": "❌" + glom(result, "Response.Data.ErrorMsg", default="语音识别失败")}
358
359
360def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> dict:
361 res = ""
362 try:
363 for start_offset, items in zip(sentence_start_ms, words, strict=True):
364 for idx, item in enumerate(items):
365 sentence = glom(item, Coalesce("Word", "word"), default="")
366 if not sentence:
367 continue
368 if is_english_word(sentence):
369 sentence = sentence + " "
370 if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
371 start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
372 minutes = int(start_seconds // 60)
373 seconds = int(start_seconds % 60)
374 res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
375 else:
376 res += sentence
377 except Exception as e:
378 logger.error(e)
379 return {"error": str(e)}
380 return {"texts": res.strip()}