Commit 6b8dc05

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-04 08:58:12
feat(asr): add timestamp for all ASR method
1 parent c34d936
src/asr/tecent_asr.py
@@ -1,11 +1,14 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import asyncio
 import base64
 import hashlib
 import hmac
 from pathlib import Path
 
 import anyio
+from glom import Coalesce, flatten, glom
+from loguru import logger
 
 from config import ASR
 from networking import hx_req
@@ -16,32 +19,6 @@ def sign(key, msg):
     return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
 
 
-async def flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
-    """Tencent Flash ASR.
-
-    录音文件识别极速版
-    https://cloud.tencent.com/document/product/1093/52097
-    """
-    now = nowdt()
-    params = {
-        "secretid": ASR.TENCENT_SECRET_ID,
-        "engine_type": engine,
-        "voice_format": voice_format,
-        "timestamp": str(int(now.timestamp())),
-    }
-    signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{ASR.TENCENT_APPID}?"
-    for k, v in dict(sorted(params.items())).items():  # type: ignore
-        signstr += f"{k}={v}&"
-    signstr = signstr[:-1]  # strip last "&"
-
-    hmacstr = hmac.new(ASR.TENCENT_SECRET_KEY.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
-    signature = base64.b64encode(hmacstr).decode("utf-8")
-    headers = {"Host": "asr.cloud.tencent.com", "authorization": signature}
-    url = f"https://{signstr.removeprefix('POST')}"
-    async with await anyio.open_file(path, "rb") as f:
-        return await hx_req(url, method="POST", headers=headers, post_content=await f.read(), timeout=60, proxy=ASR.TENCENT_PROXY, check_kv={"code": 0}, check_keys=["flash_result"])
-
-
 def generate_tencent_cloud_headers(
     action: str,
     payload: str,
@@ -91,7 +68,7 @@ def generate_tencent_cloud_headers(
     }
 
 
-async def single_sentence_asr(path: str | Path, engine: str, voice_format: str) -> dict:
+async def single_sentence_asr(path: str | Path, engine: str, voice_format: str) -> str:
     """Tencent Single Sentence ASR.
 
     一句话识别
@@ -100,17 +77,57 @@ async def single_sentence_asr(path: str | Path, engine: str, voice_format: str)
     async with await anyio.open_file(path, "rb") as f:
         content = await f.read()
         data = base64.b64encode(content).decode("utf-8")
-    payload = f'{{"EngSerViceType":"{engine}","SourceType":1,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
+    payload = f'{{"EngSerViceType":"{engine}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
     headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
-    return await hx_req(
+    res = await hx_req(
         "https://asr.tencentcloudapi.com",
         method="POST",
         headers=headers,
         post_content=payload.encode("utf-8"),
         timeout=60,
         proxy=ASR.TENCENT_PROXY,
-        check_keys=["Response.Result"],
+        check_keys=["Response.WordList"],
     )
+    return generate_tencent_transcription(sentence_start_ms=[0], words=[res["Response"]["WordList"]])
+
+
+async def flash_asr(path: str | Path, engine: str, voice_format: str) -> str:
+    """Tencent Flash ASR.
+
+    录音文件识别极速版
+    https://cloud.tencent.com/document/product/1093/52097
+    """
+    now = nowdt()
+    params = {
+        "secretid": ASR.TENCENT_SECRET_ID,
+        "engine_type": engine,
+        "voice_format": voice_format,
+        "timestamp": str(int(now.timestamp())),
+        "word_info": 2,
+    }
+    signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{ASR.TENCENT_APPID}?"
+    for k, v in dict(sorted(params.items())).items():  # type: ignore
+        signstr += f"{k}={v}&"
+    signstr = signstr[:-1]  # strip last "&"
+
+    hmacstr = hmac.new(ASR.TENCENT_SECRET_KEY.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
+    signature = base64.b64encode(hmacstr).decode("utf-8")
+    headers = {"Host": "asr.cloud.tencent.com", "authorization": signature}
+    url = f"https://{signstr.removeprefix('POST')}"
+    async with await anyio.open_file(path, "rb") as f:
+        res = await hx_req(
+            url,
+            method="POST",
+            headers=headers,
+            post_content=await f.read(),
+            timeout=60,
+            proxy=ASR.TENCENT_PROXY,
+            check_kv={"code": 0},
+            check_keys=["flash_result.0.sentence_list.0.word_list"],
+        )
+        sentence_start_ms = flatten(glom(res, "flash_result.*.sentence_list.*.start_time"), levels=1)
+        words = flatten(glom(res, "flash_result.*.sentence_list.*.word_list"), levels=1)
+        return generate_tencent_transcription(sentence_start_ms, words)
 
 
 async def create_async_asr(url: str, engine: str) -> dict:
@@ -119,7 +136,7 @@ async def create_async_asr(url: str, engine: str) -> dict:
     录音文件识别请求
     https://cloud.tencent.com/document/api/1093/37823
     """
-    payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":0,"SourceType":0,"Url":"{url}"}}'
+    payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":2,"SourceType":0,"Url":"{url}"}}'
     headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
     return await hx_req(
         "https://asr.tencentcloudapi.com",
@@ -132,7 +149,7 @@ async def create_async_asr(url: str, engine: str) -> dict:
     )
 
 
-async def query_async_asr(task_id: int) -> dict:
+async def query_async_asr(task_id: int) -> str:
     """Query Tencent ASR Task.
 
     录音文件识别结果查询
@@ -140,7 +157,7 @@ async def query_async_asr(task_id: int) -> dict:
     """
     payload = f'{{"TaskId":{task_id}}}'
     headers = generate_tencent_cloud_headers(action="DescribeTaskStatus", payload=payload)
-    return await hx_req(
+    result = await hx_req(
         "https://asr.tencentcloudapi.com",
         method="POST",
         headers=headers,
@@ -149,3 +166,37 @@ async def query_async_asr(task_id: int) -> dict:
         proxy=ASR.TENCENT_PROXY,
         check_keys=["Response.Data.StatusStr"],
     )
+    status = glom(result, "Response.Data.StatusStr")
+    while status in ["waiting", "doing"]:
+        await asyncio.sleep(1)
+        logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
+        result = await query_async_asr(task_id)
+        if isinstance(result, str):
+            return result
+        status = glom(result, "Response.Data.StatusStr")
+    if status == "success":
+        sentence_start_ms = glom(result, "Response.Data.ResultDetail.*.StartMs")
+        words = glom(result, "Response.Data.ResultDetail.*.Words")
+        return generate_tencent_transcription(sentence_start_ms, words)
+    return glom(result, "Response.Data.ErrorMsg")
+
+
+def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> str:
+    res = ""
+    show_timestamp = False
+    for start_offset, items in zip(sentence_start_ms, words, strict=True):
+        for idx, item in enumerate(items):
+            sentence = glom(item, Coalesce("Word", "word"), default="")
+            if not sentence:
+                continue
+            if idx == 0 or res.endswith((".", "。")):
+                show_timestamp = True
+            if show_timestamp:
+                start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
+                minutes = int(start_seconds // 60)
+                seconds = int(start_seconds % 60)
+                res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
+                show_timestamp = False
+            else:
+                res += sentence
+    return res.strip()
src/asr/utils.py
@@ -13,10 +13,11 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
         asr_engine = ASR.MIDDLE_ENGINE
     else:
         asr_engine = ASR.LONG_ENGINE
-
-    if asr_engine == "tencent" or force_engine == "tencent":
+    if force_engine:
+        return get_tencent_asr_method(duration, file_size) if force_engine == "tencent" else get_gemini_asr_method(duration)
+    if asr_engine == "tencent":
         return get_tencent_asr_method(duration, file_size)
-    if asr_engine.lower() == "gemini" or force_engine == "gemini":
+    if asr_engine.lower() == "gemini":
         return get_gemini_asr_method(duration)
     return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
 
src/asr/voice_recognition.py
@@ -1,12 +1,10 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-import asyncio
 import contextlib
 import io
 import re
 from pathlib import Path
 
-from glom import glom
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
@@ -136,7 +134,7 @@ async def voice_to_text(
         await modify_progress(text=error, force_update=True, **kwargs)
         return
     if texts := res.get("texts"):
-        final = f"{BEGINNING}\n{blockquote(texts)}"
+        final = blockquote(texts) if len(texts) > 300 else texts
         logger.success(f"{final!r}")
         # send results
         target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
@@ -201,28 +199,14 @@ async def asr_file(
     logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
     try:
         if asr_method == "single_sentence_asr":
-            resp = await single_sentence_asr(path, language, voice_format)
-            texts = glom(resp, "Response.Result").replace("。", "。\n")
+            texts = await single_sentence_asr(path, language, voice_format)
         elif asr_method == "flash_asr":
-            resp = await flash_asr(path, language, voice_format)
-            texts = glom(resp, "flash_result.0.text").replace("。", "。\n")
+            texts = await flash_asr(path, language, voice_format)
         elif asr_method == "async_asr":
             resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", language)
             task_id = resp["Response"]["Data"]["TaskId"]
             logger.success(f"ASR任务提交成功, TaskID: {task_id}")
-            result = await query_async_asr(task_id)
-            status = result["Response"]["Data"]["StatusStr"]
-            while status in ["waiting", "doing"]:
-                await asyncio.sleep(1)
-                logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
-                result = await query_async_asr(task_id)
-                status = result["Response"]["Data"]["StatusStr"]
-            if status == "success":
-                texts = glom(result, "Response.Data.Result")
-                texts = re.sub(r"\[.*?\]\s*", "", texts)
-            else:
-                texts = glom(result, "Response.Data.ErrorMsg")
-                res["error"] = texts
+            texts = await query_async_asr(task_id)
         elif asr_method == "gemini":
             return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
         res["texts"] = texts