Commit ffb7ae4
Changed files (3)
src/asr/tecent_asr.py
@@ -1,160 +1,36 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-
-# From: https://github.com/TencentCloud/tencentcloud-speech-sdk-python
import base64
import hashlib
import hmac
import time
+from pathlib import Path
-from loguru import logger
+import anyio
-from config import PROXY
+from config import PROXY, TOKEN
from networking import hx_req
-# 录音识别极速版
-class FlashRecognitionRequest:
- def __init__(self, engine_type):
- self.engine_type = engine_type
- self.speaker_diarization = 0
- self.hotword_id = ""
- self.hotword_list = ""
- self.input_sample_rate = 0
- self.customization_id = ""
- self.filter_dirty = 0
- self.filter_modal = 0
- self.filter_punc = 0
- self.convert_num_mode = 1
- self.word_info = 0
- self.voice_format = ""
- self.first_channel_only = 1
- self.reinforce_hotword = 0
- self.sentence_max_length = 0
-
- def set_first_channel_only(self, first_channel_only):
- self.first_channel_only = first_channel_only
-
- def set_speaker_diarization(self, speaker_diarization):
- self.speaker_diarization = speaker_diarization
-
- def set_filter_dirty(self, filter_dirty):
- self.filter_dirty = filter_dirty
-
- def set_filter_modal(self, filter_modal):
- self.filter_modal = filter_modal
-
- def set_filter_punc(self, filter_punc):
- self.filter_punc = filter_punc
-
- def set_convert_num_mode(self, convert_num_mode):
- self.convert_num_mode = convert_num_mode
-
- def set_word_info(self, word_info):
- self.word_info = word_info
-
- def set_hotword_id(self, hotword_id):
- self.hotword_id = hotword_id
-
- def set_hotword_list(self, hotword_list):
- self.hotword_list = hotword_list
-
- def set_input_sample_rate(self, input_sample_rate):
- self.input_sample_rate = input_sample_rate
-
- def set_customization_id(self, customization_id):
- self.customization_id = customization_id
-
- def set_voice_format(self, voice_format):
- self.voice_format = voice_format
-
- def set_sentence_max_length(self, sentence_max_length):
- self.sentence_max_length = sentence_max_length
-
- def set_reinforce_hotword(self, reinforce_hotword):
- self.reinforce_hotword = reinforce_hotword
-
-
-class FlashRecognizer:
- def __init__(self, appid, credential):
- self.credential = credential
- self.appid = appid
-
- def _format_sign_string(self, param):
- signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/"
- for t in param:
- if "appid" in t:
- signstr += str(t[1])
- break
- signstr += "?"
- for x in param:
- tmp = x
- if "appid" in x:
- continue
- for t in tmp:
- signstr += str(t)
- signstr += "="
- signstr = signstr[:-1]
- signstr += "&"
- return signstr[:-1]
-
- def _build_header(self):
- header = {}
- header["Host"] = "asr.cloud.tencent.com"
- return header
-
- def _sign(self, signstr, secret_key):
- hmacstr = hmac.new(secret_key.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
- s = base64.b64encode(hmacstr)
- return s.decode("utf-8")
-
- def _build_req_with_signature(self, secret_key, params, header):
- query = sorted(params.items(), key=lambda d: d[0])
- signstr = self._format_sign_string(query)
- signature = self._sign(signstr, secret_key)
- header["authorization"] = signature
- requrl = "https://"
- requrl += signstr[4::]
- return requrl
-
- def _create_query_arr(self, req):
- query_arr = {}
- query_arr["appid"] = self.appid
- query_arr["secretid"] = self.credential.secret_id
- query_arr["timestamp"] = str(int(time.time()))
- query_arr["engine_type"] = req.engine_type
- query_arr["voice_format"] = req.voice_format
- query_arr["speaker_diarization"] = req.speaker_diarization
- if req.hotword_id != "":
- query_arr["hotword_id"] = req.hotword_id
- if req.hotword_list != "":
- query_arr["hotword_list"] = req.hotword_list
- if req.input_sample_rate != 0:
- query_arr["input_sample_rate"] = req.input_sample_rate
- query_arr["customization_id"] = req.customization_id
- query_arr["filter_dirty"] = req.filter_dirty
- query_arr["filter_modal"] = req.filter_modal
- query_arr["filter_punc"] = req.filter_punc
- query_arr["convert_num_mode"] = req.convert_num_mode
- query_arr["word_info"] = req.word_info
- query_arr["first_channel_only"] = req.first_channel_only
- query_arr["reinforce_hotword"] = req.reinforce_hotword
- query_arr["sentence_max_length"] = req.sentence_max_length
- return query_arr
-
- async def recognize(self, req, data) -> dict:
- header = self._build_header()
- query_arr = self._create_query_arr(req)
- req_url = self._build_req_with_signature(self.credential.secret_key, query_arr, header)
- resp = await hx_req(req_url, method="POST", headers=header, post_data=data, timeout=30, proxy=PROXY.TENCENT, check_kv={"code": 0})
- if resp.get("hx_error"):
- logger.error(f"ASR failed: {resp.get('hx_error')}")
- return {}
- return resp
-
-
-class Credential:
- def __init__(self, secret_id, secret_key, token=""):
- self.secret_id = secret_id
- self.secret_key = secret_key
- self.token = token
+async def flash_asr(path: str | Path, engine: str, voice_format: str):
+ """Tencent Flash ASR.
+
+ https://cloud.tencent.com/document/product/1093/52097
+ """
+ params = {
+ "secretid": TOKEN.TENCENT_ASR_SECRET_ID,
+ "engine_type": engine,
+ "voice_format": voice_format,
+ "timestamp": str(int(time.time())),
+ }
+ signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{TOKEN.TENCENT_ASR_APPID}?"
+ for k, v in dict(sorted(params.items())).items():
+ signstr += f"{k}={v}&"
+ signstr = signstr[:-1] # strip last "&"
+
+ hmacstr = hmac.new(TOKEN.TENCENT_ASR_SECRET_KEY.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
+ signature = base64.b64encode(hmacstr).decode("utf-8")
+ headers = {"Host": "asr.cloud.tencent.com", "authorization": signature}
+ url = f"https://{signstr.removeprefix('POST')}"
+ async with await anyio.open_file(path, "rb") as f:
+ return await hx_req(url, method="POST", headers=headers, post_content=await f.read(), timeout=60, proxy=PROXY.TENCENT, check_kv={"code": 0}, check_keys=["flash_result"])
src/asr/voice_recognition.py
@@ -4,12 +4,13 @@ import contextlib
import re
from pathlib import Path
+from glom import glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from asr.tecent_asr import Credential, FlashRecognitionRequest, FlashRecognizer
-from config import ASR_MAX_DURATION, CAPTION_LENGTH, PREFIX, TOKEN
+from asr.tecent_asr import flash_asr
+from config import ASR_MAX_DURATION, CAPTION_LENGTH, PREFIX
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg, send_texts
@@ -149,23 +150,10 @@ async def voice_to_text(
return
logger.debug(f"Recognizing {voice_format} audio by {asr_engine}: {path.as_posix()}")
- credential_var = Credential(TOKEN.TENCENT_ASR_SECRET_ID, TOKEN.TENCENT_ASR_SECRET_KEY)
- recognizer = FlashRecognizer(TOKEN.TENCENT_ASR_APPID, credential_var)
- req = FlashRecognitionRequest(engine_type=asr_engine)
- req.set_voice_format(voice_format)
- final = ""
try:
- with path.open("rb") as f:
- resp = await recognizer.recognize(req, f.read())
- logger.trace(resp)
- texts = [channel.get("text", "") for channel in resp.get("flash_result", [])]
- if len(set(texts)) == 1: # single channel
- final = texts[0]
- else:
- for cid, text in enumerate(texts):
- final += f"通道{cid + 1}: {text}\n"
- if final:
- final = f"{BEGINNING}\n{final}".replace("。", "。\n")
+ resp = await flash_asr(path, asr_engine, voice_format)
+ texts = glom(resp, "flash_result.0.text") or "❌无法识别"
+ final = f"{BEGINNING}\n{texts}".replace("。", "。\n")
logger.success(f"{final!r}")
# send results
src/networking.py
@@ -10,6 +10,7 @@ from typing import Any
from urllib.parse import parse_qs, quote_plus, urlparse
import anyio
+import httpx
from httpx import AsyncClient, HTTPStatusError, Request, RequestError, Response
from httpx_curl_cffi import AsyncCurlTransport, CurlOpt
from loguru import logger
@@ -39,7 +40,7 @@ async def hx_req(
cookies: dict | None = None,
params: dict | None = None,
post_json: dict | None = None,
- post_data: dict | None = None,
+ post_content: httpx._types.RequestContent | None = None,
proxy: str | None = None,
follow_redirects: bool = True,
check_keys: list[str] | None = None,
@@ -60,7 +61,7 @@ async def hx_req(
cookies (dict, optional): The cookies to use for the request.
params (dict, optional): The parameters to use for the request.
post_json (dict, optional): The JSON data to use for the request.
- post_data (dict, optional): The form data to use for the request.
+ post_content (dict, optional): The form data to use for the request.
proxy (str, optional): The proxy to use for the request.
follow_redirects (bool, optional): Whether to follow redirects.
check_keys (list[str], optional): The keys to check in the response.
@@ -94,7 +95,7 @@ async def hx_req(
if method == "GET":
response = await client.get(url, cookies=cookies, headers=headers, params=params)
else:
- response = await client.post(url, cookies=cookies, headers=headers, json=post_json, data=post_data, params=params)
+ response = await client.post(url, cookies=cookies, headers=headers, json=post_json, content=post_content, params=params)
response.raise_for_status()
data = response.text
check_data(data, check_keys=check_keys, check_kv=check_kv)