Commit 87890e0

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-21 10:27:03
feat(asr): support all Tencent ASR methods
1 parent ffb7ae4
src/asr/tecent_asr.py
@@ -3,28 +3,34 @@
 import base64
 import hashlib
 import hmac
-import time
 from pathlib import Path
 
 import anyio
 
 from config import PROXY, TOKEN
 from networking import hx_req
+from utils import nowdt
 
 
-async def flash_asr(path: str | Path, engine: str, voice_format: str):
+def sign(key, msg):
+    return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
+
+
+async def flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
     """Tencent Flash ASR.
 
+    录音文件识别极速版
     https://cloud.tencent.com/document/product/1093/52097
     """
+    now = nowdt()
     params = {
         "secretid": TOKEN.TENCENT_ASR_SECRET_ID,
         "engine_type": engine,
         "voice_format": voice_format,
-        "timestamp": str(int(time.time())),
+        "timestamp": str(int(now.timestamp())),
     }
     signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{TOKEN.TENCENT_ASR_APPID}?"
-    for k, v in dict(sorted(params.items())).items():
+    for k, v in dict(sorted(params.items())).items():  # type: ignore
         signstr += f"{k}={v}&"
     signstr = signstr[:-1]  # strip last "&"
 
@@ -34,3 +40,112 @@ async def flash_asr(path: str | Path, engine: str, voice_format: str):
     url = f"https://{signstr.removeprefix('POST')}"
     async with await anyio.open_file(path, "rb") as f:
         return await hx_req(url, method="POST", headers=headers, post_content=await f.read(), timeout=60, proxy=PROXY.TENCENT, check_kv={"code": 0}, check_keys=["flash_result"])
+
+
+def generate_tencent_cloud_headers(
+    action: str,
+    payload: str,
+    service: str = "asr",
+    host: str = "asr.tencentcloudapi.com",
+    version: str = "2019-06-14",
+    secret_id: str = TOKEN.TENCENT_ASR_SECRET_ID,
+    secret_key: str = TOKEN.TENCENT_ASR_SECRET_KEY,
+) -> dict:
+    """Generate TencentCloudAPI Headers (TC3-HMAC-SHA256)."""
+    algorithm = "TC3-HMAC-SHA256"
+    now = nowdt()
+    timestamp = str(int(now.timestamp()))
+    date = f"{now:%Y-%m-%d}"
+
+    # ************* 步骤 1: 拼接规范请求串 *************
+    http_request_method = "POST"
+    canonical_uri = "/"
+    canonical_querystring = ""
+    canonical_headers = f"content-type:application/json; charset=utf-8\nhost:{host}\nx-tc-action:{action.lower()}\n"
+    signed_headers = "content-type;host;x-tc-action"
+    hashed_request_payload = hashlib.sha256(payload.encode("utf-8")).hexdigest()
+    canonical_request = f"{http_request_method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{hashed_request_payload}"
+
+    # ************* 步骤 2: 拼接待签名字符串 *************
+    credential_scope = f"{date}/{service}/tc3_request"
+    hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
+    string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}"
+
+    # ************* 步骤 3: 计算签名 *************
+    secret_date = sign(("TC3" + secret_key).encode("utf-8"), date)
+    secret_service = sign(secret_date, service)
+    secret_signing = sign(secret_service, "tc3_request")
+    signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
+
+    # ************* 步骤 4: 拼接 Authorization *************
+    authorization = f"{algorithm} Credential={secret_id}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
+
+    # ************* 步骤 5: 构造 Headers *************
+    return {
+        "Authorization": authorization,
+        "Content-Type": "application/json; charset=utf-8",
+        "Host": host,
+        "X-TC-Action": action,
+        "X-TC-Timestamp": timestamp,
+        "X-TC-Version": version,
+    }
+
+
+async def single_sentence_asr(path: str | Path, engine: str, voice_format: str) -> dict:
+    """Tencent Single Sentence ASR.
+
+    一句话识别
+    https://cloud.tencent.com/document/product/1093/52097
+    """
+    async with await anyio.open_file(path, "rb") as f:
+        content = await f.read()
+        data = base64.b64encode(content).decode("utf-8")
+    payload = f'{{"EngSerViceType":"{engine}","SourceType":1,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
+    headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
+    return await hx_req(
+        "https://asr.tencentcloudapi.com",
+        method="POST",
+        headers=headers,
+        post_content=payload.encode("utf-8"),
+        timeout=60,
+        proxy=PROXY.TENCENT,
+        check_keys=["Response.Result"],
+    )
+
+
+async def create_async_asr(url: str, engine: str) -> dict:
+    """Create Tencent ASR Task.
+
+    录音文件识别请求
+    https://cloud.tencent.com/document/api/1093/37823
+    """
+    payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":0,"SourceType":0,"Url":"{url}"}}'
+    headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
+    return await hx_req(
+        "https://asr.tencentcloudapi.com",
+        method="POST",
+        headers=headers,
+        post_content=payload.encode("utf-8"),
+        timeout=600,
+        proxy=PROXY.TENCENT,
+        check_keys=["Response.Data.TaskId"],
+    )
+
+
+async def query_async_asr(task_id: int) -> dict:
+    """Query Tencent ASR Task.
+
+    录音文件识别结果查询
+    https://cloud.tencent.com/document/api/1093/37822
+    """
+    payload = f'{{"TaskId":{task_id}}}'
+    headers = generate_tencent_cloud_headers(action="DescribeTaskStatus", payload=payload)
+    return await hx_req(
+        "https://asr.tencentcloudapi.com",
+        method="POST",
+        headers=headers,
+        post_content=payload.encode("utf-8"),
+        timeout=600,
+        proxy=PROXY.TENCENT,
+        check_keys=["Response.Data.StatusStr"],
+    )
src/asr/voice_recognition.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import asyncio
 import contextlib
 import re
 from pathlib import Path
@@ -9,8 +10,8 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from asr.tecent_asr import flash_asr
-from config import ASR_MAX_DURATION, CAPTION_LENGTH, PREFIX
+from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
+from config import ASR_MAX_DURATION, CAPTION_LENGTH, FILE_SERVER, PREFIX
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg, send_texts
@@ -143,17 +144,42 @@ async def voice_to_text(
         return
 
     # Skip very long
-    if trigger_info["duration"] > ASR_MAX_DURATION:
+    duration = trigger_info["duration"]
+    if duration > ASR_MAX_DURATION:
         msg = f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {trigger_info['duration']}秒"
         logger.error(msg)
         await modify_progress(text=msg, force_update=True, **kwargs)
         return
 
-    logger.debug(f"Recognizing {voice_format} audio by {asr_engine}: {path.as_posix()}")
+    logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {asr_engine}: {path.as_posix()}")
     try:
-        resp = await flash_asr(path, asr_engine, voice_format)
-        texts = glom(resp, "flash_result.0.text") or "❌无法识别"
-        final = f"{BEGINNING}\n{texts}".replace("。", "。\n")
+        if duration < 60 and trigger_info["file_size"] < 3 * 1024 * 1024:  # 时长不能超过60s, 文件大小不能超过3MB
+            resp = await single_sentence_asr(path, asr_engine, voice_format)
+            texts = glom(resp, "Response.Result", default="❌无法识别").replace("。", "。\n")
+        elif 60 <= duration <= 300:
+            resp = await flash_asr(path, asr_engine, voice_format)
+            texts = glom(resp, "flash_result.0.text", default="❌无法识别").replace("。", "。\n")
+        elif FILE_SERVER:
+            resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", asr_engine)
+            task_id = resp["Response"]["Data"]["TaskId"]
+            logger.success(task_id)
+            # task_id = 11982077075
+            result = await query_async_asr(task_id)
+            status = result["Response"]["Data"]["StatusStr"]
+            while status in ["waiting", "doing"]:
+                await asyncio.sleep(1)
+                logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
+                result = await query_async_asr(task_id)
+                status = result["Response"]["Data"]["StatusStr"]
+            if status == "success":
+                texts = glom(result, "Response.Data.Result", default="❌无法识别")
+                texts = re.sub(r"\[.*?\]\s*", "", texts)
+            else:
+                texts = glom(result, "Response.Data.ErrorMsg", default="❌无法识别")
+        else:
+            texts = "音频过长,请联系管理员配置`FILE_SERVER`变量"
+
+        final = f"{BEGINNING}\n{texts}"
         logger.success(f"{final!r}")
 
         # send results
src/llm/contexts.py
@@ -10,7 +10,7 @@ from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from config import GPT, PREFIX
+from config import FILE_SERVER, GPT, PREFIX
 from llm.utils import BOT_TIPS
 from messages.parser import parse_msg
 
@@ -106,9 +106,9 @@ async def single_context(client: Client, message: Message) -> dict:
                 path: str = await client.download_media(msg)  # type: ignore
                 logger.debug(f"Downloaded GPT media: {path}")
                 if info["mtype"] == "photo":
-                    contexts.append({"type": "image_url", "image_url": {"url": f"{GPT.MEDIA_SERVER}/{Path(path).name}"}})
+                    contexts.append({"type": "image_url", "image_url": {"url": f"{FILE_SERVER}/{Path(path).name}"}})
                 # elif info["mtype"] == "video":
-                #     media.append({"type": "video_url", "video_url": {"url": f"{GPT.MEDIA_SERVER}/{Path(path).name}"}})
+                #     media.append({"type": "video_url", "video_url": {"url": f"{FILE_SERVER}/{Path(path).name}"}})
                 elif info["mtype"] == "document" and info["mime_type"] in ["text/plain", "text/markdown"]:
                     contexts.append(
                         {
src/config.py
@@ -12,13 +12,14 @@ cache = Cache(ttl=0, maxsize=2048)
 semaphore = asyncio.Semaphore(8)  # max 8 concurrent downloads
 
 DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", Path(__file__).parent.joinpath("downloads").as_posix())
+FILE_SERVER = os.getenv("FILE_SERVER", "")  # expose the download dir to internet (optional). for example: https://server.com/dir
 TZ = os.getenv("TZ", "Asia/Shanghai")
 DEVICE_NAME = os.getenv("DEVICE_NAME", "BennyBot")
 REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "60"))  # seconds
 TEXT_LENGTH = int(os.getenv("TEXT_LENGTH", "4096"))  # Maximum length of text message
 CAPTION_LENGTH = int(os.getenv("CAPTION_LENGTH", "1024"))  # 4096 for Premium user
 MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", "2000")) * 1024 * 1024  # 4000 MB for Premium user
-ASR_MAX_DURATION = int(os.getenv("ASR_MAX_DURATION", "600"))
+ASR_MAX_DURATION = int(os.getenv("ASR_MAX_DURATION", "3600"))
 MAX_MESSAGE_RETRIEVED = int(os.getenv("MAX_MESSAGE_RETRIEVED", "1000000"))  # Maximum number of messages to retrieve
 MAX_MESSAGE_SUMMARY = int(os.getenv("MAX_MESSAGE_SUMMARY", "9999"))  # Maximum number of messages to summay
 READING_SPEED = int(os.getenv("READING_SPEED", "600"))  # words per minute
@@ -164,7 +165,6 @@ class GPT:  # see `llm/README.md`
     TEMPERATURE = os.getenv("GPT_TEMPERATURE", "1.0")
     HISTORY_CONTEXT = os.getenv("GPT_HISTORY_CONTEXT", "20")  # 最多携带多少条历史消息
     MEDIA_FORMAT = os.getenv("GPT_MEDIA_FORMAT", "base64")  # base64 or http
-    MEDIA_SERVER = os.getenv("GPT_MEDIA_SERVER", "https://server.com/dir")  # only when MEDIA_FORMAT is http
     TEXT_API_KEY = os.getenv("GPT_TEXT_API_KEY", "")
     TEXT_BASE_URL = os.getenv("GPT_TEXT_BASE_URL", "https://api.openai.com/v1")
     IMAGE_API_KEY = os.getenv("GPT_IMAGE_API_KEY", "")