Commit ffb7ae4

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-21 04:44:38
style(asr): improve asr code style and refactor
1 parent 205c40d
src/asr/tecent_asr.py
@@ -1,160 +1,36 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-
-# From: https://github.com/TencentCloud/tencentcloud-speech-sdk-python
 import base64
 import hashlib
 import hmac
 import time
+from pathlib import Path
 
-from loguru import logger
+import anyio
 
-from config import PROXY
+from config import PROXY, TOKEN
 from networking import hx_req
 
 
-# 录音识别极速版
-class FlashRecognitionRequest:
-    def __init__(self, engine_type):
-        self.engine_type = engine_type
-        self.speaker_diarization = 0
-        self.hotword_id = ""
-        self.hotword_list = ""
-        self.input_sample_rate = 0
-        self.customization_id = ""
-        self.filter_dirty = 0
-        self.filter_modal = 0
-        self.filter_punc = 0
-        self.convert_num_mode = 1
-        self.word_info = 0
-        self.voice_format = ""
-        self.first_channel_only = 1
-        self.reinforce_hotword = 0
-        self.sentence_max_length = 0
-
-    def set_first_channel_only(self, first_channel_only):
-        self.first_channel_only = first_channel_only
-
-    def set_speaker_diarization(self, speaker_diarization):
-        self.speaker_diarization = speaker_diarization
-
-    def set_filter_dirty(self, filter_dirty):
-        self.filter_dirty = filter_dirty
-
-    def set_filter_modal(self, filter_modal):
-        self.filter_modal = filter_modal
-
-    def set_filter_punc(self, filter_punc):
-        self.filter_punc = filter_punc
-
-    def set_convert_num_mode(self, convert_num_mode):
-        self.convert_num_mode = convert_num_mode
-
-    def set_word_info(self, word_info):
-        self.word_info = word_info
-
-    def set_hotword_id(self, hotword_id):
-        self.hotword_id = hotword_id
-
-    def set_hotword_list(self, hotword_list):
-        self.hotword_list = hotword_list
-
-    def set_input_sample_rate(self, input_sample_rate):
-        self.input_sample_rate = input_sample_rate
-
-    def set_customization_id(self, customization_id):
-        self.customization_id = customization_id
-
-    def set_voice_format(self, voice_format):
-        self.voice_format = voice_format
-
-    def set_sentence_max_length(self, sentence_max_length):
-        self.sentence_max_length = sentence_max_length
-
-    def set_reinforce_hotword(self, reinforce_hotword):
-        self.reinforce_hotword = reinforce_hotword
-
-
-class FlashRecognizer:
-    def __init__(self, appid, credential):
-        self.credential = credential
-        self.appid = appid
-
-    def _format_sign_string(self, param):
-        signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/"
-        for t in param:
-            if "appid" in t:
-                signstr += str(t[1])
-                break
-        signstr += "?"
-        for x in param:
-            tmp = x
-            if "appid" in x:
-                continue
-            for t in tmp:
-                signstr += str(t)
-                signstr += "="
-            signstr = signstr[:-1]
-            signstr += "&"
-        return signstr[:-1]
-
-    def _build_header(self):
-        header = {}
-        header["Host"] = "asr.cloud.tencent.com"
-        return header
-
-    def _sign(self, signstr, secret_key):
-        hmacstr = hmac.new(secret_key.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
-        s = base64.b64encode(hmacstr)
-        return s.decode("utf-8")
-
-    def _build_req_with_signature(self, secret_key, params, header):
-        query = sorted(params.items(), key=lambda d: d[0])
-        signstr = self._format_sign_string(query)
-        signature = self._sign(signstr, secret_key)
-        header["authorization"] = signature
-        requrl = "https://"
-        requrl += signstr[4::]
-        return requrl
-
-    def _create_query_arr(self, req):
-        query_arr = {}
-        query_arr["appid"] = self.appid
-        query_arr["secretid"] = self.credential.secret_id
-        query_arr["timestamp"] = str(int(time.time()))
-        query_arr["engine_type"] = req.engine_type
-        query_arr["voice_format"] = req.voice_format
-        query_arr["speaker_diarization"] = req.speaker_diarization
-        if req.hotword_id != "":
-            query_arr["hotword_id"] = req.hotword_id
-        if req.hotword_list != "":
-            query_arr["hotword_list"] = req.hotword_list
-        if req.input_sample_rate != 0:
-            query_arr["input_sample_rate"] = req.input_sample_rate
-        query_arr["customization_id"] = req.customization_id
-        query_arr["filter_dirty"] = req.filter_dirty
-        query_arr["filter_modal"] = req.filter_modal
-        query_arr["filter_punc"] = req.filter_punc
-        query_arr["convert_num_mode"] = req.convert_num_mode
-        query_arr["word_info"] = req.word_info
-        query_arr["first_channel_only"] = req.first_channel_only
-        query_arr["reinforce_hotword"] = req.reinforce_hotword
-        query_arr["sentence_max_length"] = req.sentence_max_length
-        return query_arr
-
-    async def recognize(self, req, data) -> dict:
-        header = self._build_header()
-        query_arr = self._create_query_arr(req)
-        req_url = self._build_req_with_signature(self.credential.secret_key, query_arr, header)
-        resp = await hx_req(req_url, method="POST", headers=header, post_data=data, timeout=30, proxy=PROXY.TENCENT, check_kv={"code": 0})
-        if resp.get("hx_error"):
-            logger.error(f"ASR failed: {resp.get('hx_error')}")
-            return {}
-        return resp
-
-
-class Credential:
-    def __init__(self, secret_id, secret_key, token=""):
-        self.secret_id = secret_id
-        self.secret_key = secret_key
-        self.token = token
+async def flash_asr(path: str | Path, engine: str, voice_format: str):
+    """Tencent Flash ASR.
+
+    https://cloud.tencent.com/document/product/1093/52097
+    """
+    params = {
+        "secretid": TOKEN.TENCENT_ASR_SECRET_ID,
+        "engine_type": engine,
+        "voice_format": voice_format,
+        "timestamp": str(int(time.time())),
+    }
+    signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{TOKEN.TENCENT_ASR_APPID}?"
+    for k, v in dict(sorted(params.items())).items():
+        signstr += f"{k}={v}&"
+    signstr = signstr[:-1]  # strip last "&"
+
+    hmacstr = hmac.new(TOKEN.TENCENT_ASR_SECRET_KEY.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
+    signature = base64.b64encode(hmacstr).decode("utf-8")
+    headers = {"Host": "asr.cloud.tencent.com", "authorization": signature}
+    url = f"https://{signstr.removeprefix('POST')}"
+    async with await anyio.open_file(path, "rb") as f:
+        return await hx_req(url, method="POST", headers=headers, post_content=await f.read(), timeout=60, proxy=PROXY.TENCENT, check_kv={"code": 0}, check_keys=["flash_result"])
src/asr/voice_recognition.py
@@ -4,12 +4,13 @@ import contextlib
 import re
 from pathlib import Path
 
+from glom import glom
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
 
-from asr.tecent_asr import Credential, FlashRecognitionRequest, FlashRecognizer
-from config import ASR_MAX_DURATION, CAPTION_LENGTH, PREFIX, TOKEN
+from asr.tecent_asr import flash_asr
+from config import ASR_MAX_DURATION, CAPTION_LENGTH, PREFIX
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg, send_texts
@@ -149,23 +150,10 @@ async def voice_to_text(
         return
 
     logger.debug(f"Recognizing {voice_format} audio by {asr_engine}: {path.as_posix()}")
-    credential_var = Credential(TOKEN.TENCENT_ASR_SECRET_ID, TOKEN.TENCENT_ASR_SECRET_KEY)
-    recognizer = FlashRecognizer(TOKEN.TENCENT_ASR_APPID, credential_var)
-    req = FlashRecognitionRequest(engine_type=asr_engine)
-    req.set_voice_format(voice_format)
-    final = ""
     try:
-        with path.open("rb") as f:
-            resp = await recognizer.recognize(req, f.read())
-            logger.trace(resp)
-        texts = [channel.get("text", "") for channel in resp.get("flash_result", [])]
-        if len(set(texts)) == 1:  # single channel
-            final = texts[0]
-        else:
-            for cid, text in enumerate(texts):
-                final += f"通道{cid + 1}: {text}\n"
-        if final:
-            final = f"{BEGINNING}\n{final}".replace("。", "。\n")
+        resp = await flash_asr(path, asr_engine, voice_format)
+        texts = glom(resp, "flash_result.0.text") or "❌无法识别"
+        final = f"{BEGINNING}\n{texts}".replace("。", "。\n")
         logger.success(f"{final!r}")
 
         # send results
src/networking.py
@@ -10,6 +10,7 @@ from typing import Any
 from urllib.parse import parse_qs, quote_plus, urlparse
 
 import anyio
+import httpx
 from httpx import AsyncClient, HTTPStatusError, Request, RequestError, Response
 from httpx_curl_cffi import AsyncCurlTransport, CurlOpt
 from loguru import logger
@@ -39,7 +40,7 @@ async def hx_req(
     cookies: dict | None = None,
     params: dict | None = None,
     post_json: dict | None = None,
-    post_data: dict | None = None,
+    post_content: httpx._types.RequestContent | None = None,
     proxy: str | None = None,
     follow_redirects: bool = True,
     check_keys: list[str] | None = None,
@@ -60,7 +61,7 @@ async def hx_req(
         cookies (dict, optional): The cookies to use for the request.
         params (dict, optional): The parameters to use for the request.
         post_json (dict, optional): The JSON data to use for the request.
-        post_data (dict, optional): The form data to use for the request.
+        post_content (dict, optional): The form data to use for the request.
         proxy (str, optional): The proxy to use for the request.
         follow_redirects (bool, optional): Whether to follow redirects.
         check_keys (list[str], optional): The keys to check in the response.
@@ -94,7 +95,7 @@ async def hx_req(
             if method == "GET":
                 response = await client.get(url, cookies=cookies, headers=headers, params=params)
             else:
-                response = await client.post(url, cookies=cookies, headers=headers, json=post_json, data=post_data, params=params)
+                response = await client.post(url, cookies=cookies, headers=headers, json=post_json, content=post_content, params=params)
             response.raise_for_status()
             data = response.text
             check_data(data, check_keys=check_keys, check_kv=check_kv)