main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import base64
  5import hashlib
  6import hmac
  7from collections.abc import Coroutine
  8from decimal import Decimal
  9from pathlib import Path
 10from typing import Any
 11
 12import anyio
 13from glom import Coalesce, flatten, glom
 14from loguru import logger
 15
 16from asr.groq import merge_transcripts
 17from asr.utils import audio_chunk_to_bytes, audio_duration, convert_single_channel, downsampe_audio, get_file_bytes, is_english_word, load_audio
 18from config import ASR, FILE_SERVER, PROXY
 19from database.alist import delete_alist, upload_alist
 20from database.uguu import upload_uguu
 21from networking import hx_req
 22from utils import nowdt, seconds_to_time
 23
 24
 25def sign(key, msg):
 26    return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
 27
 28
 29def generate_tencent_cloud_headers(
 30    action: str,
 31    payload: str,
 32    service: str = "asr",
 33    host: str = "asr.tencentcloudapi.com",
 34    version: str = "2019-06-14",
 35    secret_id: str = ASR.TENCENT_SECRET_ID,
 36    secret_key: str = ASR.TENCENT_SECRET_KEY,
 37) -> dict:
 38    """Generate TencentCloudAPI Headers (TC3-HMAC-SHA256)."""
 39    algorithm = "TC3-HMAC-SHA256"
 40    now = nowdt()
 41    timestamp = str(int(now.timestamp()))
 42    date = f"{now:%Y-%m-%d}"
 43
 44    # ************* 步骤 1: 拼接规范请求串 *************
 45    http_request_method = "POST"
 46    canonical_uri = "/"
 47    canonical_querystring = ""
 48    canonical_headers = f"content-type:application/json; charset=utf-8\nhost:{host}\nx-tc-action:{action.lower()}\n"
 49    signed_headers = "content-type;host;x-tc-action"
 50    hashed_request_payload = hashlib.sha256(payload.encode("utf-8")).hexdigest()
 51    canonical_request = f"{http_request_method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{hashed_request_payload}"
 52
 53    # ************* 步骤 2: 拼接待签名字符串 *************
 54    credential_scope = f"{date}/{service}/tc3_request"
 55    hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
 56    string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}"
 57
 58    # ************* 步骤 3: 计算签名 *************
 59    secret_date = sign(("TC3" + secret_key).encode("utf-8"), date)
 60    secret_service = sign(secret_date, service)
 61    secret_signing = sign(secret_service, "tc3_request")
 62    signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
 63
 64    # ************* 步骤 4: 拼接 Authorization *************
 65    authorization = f"{algorithm} Credential={secret_id}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
 66
 67    # ************* 步骤 5: 构造 Headers *************
 68    return {
 69        "Authorization": authorization,
 70        "Content-Type": "application/json; charset=utf-8",
 71        "Host": host,
 72        "X-TC-Action": action,
 73        "X-TC-Timestamp": timestamp,
 74        "X-TC-Version": version,
 75    }
 76
 77
 78async def tencent_asr(path: str | Path, language: str, duration: float) -> dict:
 79    """Tencent ASR.
 80
 81    由于 `录音文件识别` 和 `录音文件识别极速版`免费额度太少
 82    所以现在我们只使用 `一句话识别` 来处理所有ASR请求
 83
 84    Returns:
 85        {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
 86    """
 87    path = Path(path).expanduser().resolve()
 88    if not path.is_file():
 89        return {"texts": "", "error": "File not found."}
 90    supported_ext = [".wav", ".pcm", ".ogg", ".opus", ".oga", ".speex", ".silk", ".mp3", ".m4a", ".aac", ".amr"]
 91    audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
 92    audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
 93    if duration < 1:  # some thing error in detecting duration
 94        audio_path = await downsampe_audio(path, ext="wav", codec="pcm_s16le")
 95        duration = audio_duration(audio_path)
 96
 97    # max allowed duration is 60s
 98    if duration < 60:
 99        return await tencent_single_asr(audio_path, language=language)
100    return await tencent_file_chunks(audio_path, language=language, duration=duration)
101
102
103async def tencent_single_asr(path_or_bytes: Path | bytes, language: str, *, offset_seconds: int = 0) -> dict:
104    """Tencent Single Sentence ASR.
105
106    一句话识别 (每月免费额度: 5000次)
107    https://cloud.tencent.com/document/product/1093/35646
108
109    Returns:
110        {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
111
112        example item of segments:
113        {
114            "start": Decimal,
115            "end": Decimal,
116            "text": str,
117        }
118    """
119    final = {"texts": "", "raw_texts": "", "segments": []}
120    if isinstance(path_or_bytes, Path):
121        # max 3 MB
122        file_size = path_or_bytes.stat().st_size
123        audio_path = path_or_bytes if file_size < 3 * 1024 * 1024 else await downsampe_audio(path_or_bytes)
124        voice_format = Path(audio_path).suffix.lower().lstrip(".")
125        if voice_format in ["ogg", "opus", "oga"]:  # tencnet only supports ogg-opus
126            voice_format = "ogg-opus"
127        audio_bytes = await get_file_bytes(audio_path)
128    elif isinstance(path_or_bytes, bytes):
129        voice_format = "wav"
130        audio_bytes = path_or_bytes
131    audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
132    payload = f'{{"EngSerViceType":"{language}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{audio_base64}"}}'
133    headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
134    res = await hx_req(
135        "https://asr.tencentcloudapi.com",
136        method="POST",
137        headers=headers,
138        content_data=payload.encode("utf-8"),
139        timeout=60,
140        proxy=PROXY.TENCENT,
141        check_keys=["Response.WordList"],
142    )
143    if res.get("hx_error"):
144        return final | {"error": res["hx_error"]}
145    words = glom(res, "Response.WordList", default=None)
146    if words is None:
147        return final | {"error": "⚠️该音频未识别到文字"}
148    final["raw_texts"] = glom(res, "Response.Result", default="") or ""
149
150    sentences = []  # list of sentence
151    sentence = []  # list of dict
152    for item in words:
153        word = item.get("Word", "")
154        if is_english_word(word) or word.endswith((",", ".", "?", "!")):
155            item["Word"] = word + " "
156        if word.endswith((".", "", "?", "", "!", "")):  # noqa: RUF001
157            sentence.append(item)
158            sentences.append(sentence)
159            sentence = []
160            continue
161        sentence.append(item)
162    if sentence:
163        sentences.append(sentence)
164
165    segments = []
166    offset = Decimal(offset_seconds).quantize(Decimal(".01"))
167    for sentence in sentences:
168        start = offset + Decimal(sentence[0].get("StartTime", 0)) / 1000
169        end = offset + Decimal(sentence[-1].get("EndTime", 0)) / 1000
170        text = "".join(x.get("Word", "") for x in sentence)
171        text = text.replace(" ,", ",").replace("", "")  # noqa: RUF001
172        segments.append({"start": start, "end": end, "text": text})
173    final["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in segments)  # with timestamp
174    final["segments"] = segments
175    return final
176
177
178async def tencent_file_chunks(
179    path: Path,
180    language: str,
181    duration: float,
182    chunk_seconds: float = 60,
183    overlap_seconds: float = 5,
184) -> dict:
185    """Transcribe audio in chunks with overlap.
186
187    Most of this code is copied from `gemini_file_chunks` in `asr/gemini.py`
188
189    Args:
190        chunk_seconds: Length of each chunk in seconds
191        overlap_seconds: Overlap between chunks in seconds
192
193    Returns:
194        dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
195    """
196    # only support wav file
197    aac_path = path if path.suffix == ".wav" else await downsampe_audio(path)
198    audio, _, sr = load_audio(aac_path)
199    if sr == 0:
200        return {"error": "Failed to load audio."}
201
202    transcription = {}
203    semaphore = asyncio.Semaphore(30)  # max concurrent requests
204
205    async def run_with_semaphore(task: Coroutine[Any, Any, dict]) -> dict:
206        async with semaphore:
207            return await task
208
209    try:
210        # Calculate # of chunks
211        total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
212        tasks = []
213        offset_list = []
214
215        # Loop through each chunk, extract current chunk from audio
216        for i in range(total_chunks):
217            start = int(i * (chunk_seconds - overlap_seconds) * sr)
218            end = int(min(start + chunk_seconds * sr, duration * sr))
219            logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
220            chunk = audio[start:end]
221            if chunk.shape[0] == 0:  # empty chunk
222                continue
223            tasks.append(audio_chunk_to_bytes(chunk, sr))
224            offset_list.append(int(start / sr))
225        bytes_list = await asyncio.gather(*tasks)  # convert chunks to bytes
226        tasks = []
227        for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
228            task = tencent_single_asr(audio_bytes, language=language, offset_seconds=offset_seconds)
229            tasks.append(run_with_semaphore(task))
230        results = await asyncio.gather(*tasks)
231        results = [r for r in results if r.get("segments")]
232        transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
233    except Exception as e:
234        logger.error(e)
235        return {"error": str(e)}
236    return transcription
237
238
239async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
240    """(Deprecated) Tencent Flash ASR.
241
242    已弃用, 请使用 `tencent_single_asr`
243    录音文件识别极速版 (每月免费额度: 5小时)
244    https://cloud.tencent.com/document/product/1093/52097
245    """
246    now = nowdt()
247    params = {
248        "secretid": ASR.TENCENT_SECRET_ID,
249        "engine_type": engine,
250        "voice_format": voice_format,
251        "timestamp": str(int(now.timestamp())),
252        "word_info": 2,
253    }
254    signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{ASR.TENCENT_APPID}?"
255    for k, v in dict(sorted(params.items())).items():  # type: ignore
256        signstr += f"{k}={v}&"
257    signstr = signstr[:-1]  # strip last "&"
258
259    hmacstr = hmac.new(ASR.TENCENT_SECRET_KEY.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
260    signature = base64.b64encode(hmacstr).decode("utf-8")
261    headers = {"Host": "asr.cloud.tencent.com", "authorization": signature}
262    url = f"https://{signstr.removeprefix('POST')}"
263    async with await anyio.open_file(path, "rb") as f:
264        res = await hx_req(
265            url,
266            method="POST",
267            headers=headers,
268            content_data=await f.read(),
269            timeout=60,
270            proxy=PROXY.TENCENT,
271            check_kv={"code": 0},
272            check_keys=["flash_result.0.sentence_list.0.word_list"],
273        )
274        if error := res.get("hx_error", ""):
275            if "audio data empty" in error:
276                return {"error": "⚠️该音频未识别到文字"}
277            return {"error": error}
278        sentence_start_ms = flatten(glom(res, "flash_result.*.sentence_list.*.start_time"), levels=1)
279        words = flatten(glom(res, "flash_result.*.sentence_list.*.word_list"), levels=1)
280        return generate_tencent_transcription(sentence_start_ms, words)
281
282
283async def tencent_async_asr(path: str | Path, engine: str) -> dict:
284    """(Deprecated)  Create Tencent ASR Task.
285
286    已弃用, 请使用 `tencent_single_asr`
287    注意: 此接口不支持中文文件名
288    录音文件识别请求 (每月免费额度: 10小时)
289    https://cloud.tencent.com/document/api/1093/37823
290    """
291    path = Path(path).expanduser().resolve()
292    if ASR.TENCENT_FS_ENGINE.lower() == "local":
293        url = FILE_SERVER.removesuffix("/") + "/" + path.name
294    elif ASR.TENCENT_FS_ENGINE.lower() == "uguu":
295        if path.stat().st_size > 100 * 1024 * 1024:  # 100 MB
296            path = await downsampe_audio(path)
297        url = await upload_uguu(path)  # max 100 MB for Uguu
298    elif ASR.TENCENT_FS_ENGINE.lower() == "alist":
299        url = await upload_alist(path)
300    else:
301        return {"error": f"Unsupported file server engine: {ASR.TENCENT_FS_ENGINE}"}
302
303    payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":2,"SourceType":0,"Url":"{url}"}}'
304    headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
305    resp = await hx_req(
306        "https://asr.tencentcloudapi.com",
307        method="POST",
308        headers=headers,
309        content_data=payload.encode("utf-8"),
310        timeout=600,
311        proxy=PROXY.TENCENT,
312        check_keys=["Response.Data.TaskId"],
313    )
314    if resp.get("hx_error"):
315        return {"error": resp["hx_error"]}
316    task_id = resp["Response"]["Data"]["TaskId"]
317    logger.success(f"ASR任务提交成功, TaskID: {task_id}")
318    return await tencent_query_asr(task_id, file_name=path.name)
319
320
321async def tencent_query_asr(task_id: int, file_name: str, query_times: int = 0) -> dict:
322    """Query Tencent ASR Task.
323
324    录音文件识别结果查询
325    https://cloud.tencent.com/document/api/1093/37822
326    """
327    payload = f'{{"TaskId":{task_id}}}'
328    headers = generate_tencent_cloud_headers(action="DescribeTaskStatus", payload=payload)
329    result = await hx_req(
330        "https://asr.tencentcloudapi.com",
331        method="POST",
332        headers=headers,
333        content_data=payload.encode("utf-8"),
334        timeout=600,
335        proxy=PROXY.TENCENT,
336        check_keys=["Response.Data.StatusStr"],
337    )
338    if result.get("hx_error"):
339        return {"error": result["hx_error"]}
340    status = glom(result, "Response.Data.StatusStr")
341    while status in ["waiting", "doing"] and query_times < 600:  # max 10 minutes
342        await asyncio.sleep(1)
343        query_times += 1
344        logger.trace(f"Status: [{status} ({query_times}/600)], Wating TaskID: {task_id}")
345        result = await tencent_query_asr(task_id, file_name, query_times)
346        if result.get("texts") or result.get("error"):
347            return result
348        status = glom(result, "Response.Data.StatusStr")
349    if ASR.TENCENT_FS_ENGINE.lower() == "alist":
350        await delete_alist(file_name)
351    if status == "success":
352        if glom(result, "Response.Data.ResultDetail") is None:
353            return {"error": "⚠️该音频未识别到文字"}
354        sentence_start_ms = glom(result, "Response.Data.ResultDetail.*.StartMs")
355        words = glom(result, "Response.Data.ResultDetail.*.Words")
356        return generate_tencent_transcription(sentence_start_ms, words)
357    return {"error": "" + glom(result, "Response.Data.ErrorMsg", default="语音识别失败")}
358
359
360def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> dict:
361    res = ""
362    try:
363        for start_offset, items in zip(sentence_start_ms, words, strict=True):
364            for idx, item in enumerate(items):
365                sentence = glom(item, Coalesce("Word", "word"), default="")
366                if not sentence:
367                    continue
368                if is_english_word(sentence):
369                    sentence = sentence + " "
370                if idx == 0 or res.endswith((".", "", "?", "")):  # noqa: RUF001
371                    start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
372                    minutes = int(start_seconds // 60)
373                    seconds = int(start_seconds % 60)
374                    res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
375                else:
376                    res += sentence
377    except Exception as e:
378        logger.error(e)
379        return {"error": str(e)}
380    return {"texts": res.strip()}