Commit 6b8dc05
Changed files (3)
src/asr/tecent_asr.py
@@ -1,11 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import asyncio
import base64
import hashlib
import hmac
from pathlib import Path
import anyio
+from glom import Coalesce, flatten, glom
+from loguru import logger
from config import ASR
from networking import hx_req
@@ -16,32 +19,6 @@ def sign(key, msg):
return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
-async def flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
- """Tencent Flash ASR.
-
- 录音文件识别极速版
- https://cloud.tencent.com/document/product/1093/52097
- """
- now = nowdt()
- params = {
- "secretid": ASR.TENCENT_SECRET_ID,
- "engine_type": engine,
- "voice_format": voice_format,
- "timestamp": str(int(now.timestamp())),
- }
- signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{ASR.TENCENT_APPID}?"
- for k, v in dict(sorted(params.items())).items(): # type: ignore
- signstr += f"{k}={v}&"
- signstr = signstr[:-1] # strip last "&"
-
- hmacstr = hmac.new(ASR.TENCENT_SECRET_KEY.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
- signature = base64.b64encode(hmacstr).decode("utf-8")
- headers = {"Host": "asr.cloud.tencent.com", "authorization": signature}
- url = f"https://{signstr.removeprefix('POST')}"
- async with await anyio.open_file(path, "rb") as f:
- return await hx_req(url, method="POST", headers=headers, post_content=await f.read(), timeout=60, proxy=ASR.TENCENT_PROXY, check_kv={"code": 0}, check_keys=["flash_result"])
-
-
def generate_tencent_cloud_headers(
action: str,
payload: str,
@@ -91,7 +68,7 @@ def generate_tencent_cloud_headers(
}
-async def single_sentence_asr(path: str | Path, engine: str, voice_format: str) -> dict:
+async def single_sentence_asr(path: str | Path, engine: str, voice_format: str) -> str:
"""Tencent Single Sentence ASR.
一句话识别
@@ -100,17 +77,57 @@ async def single_sentence_asr(path: str | Path, engine: str, voice_format: str)
async with await anyio.open_file(path, "rb") as f:
content = await f.read()
data = base64.b64encode(content).decode("utf-8")
- payload = f'{{"EngSerViceType":"{engine}","SourceType":1,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
+ payload = f'{{"EngSerViceType":"{engine}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
- return await hx_req(
+ res = await hx_req(
"https://asr.tencentcloudapi.com",
method="POST",
headers=headers,
post_content=payload.encode("utf-8"),
timeout=60,
proxy=ASR.TENCENT_PROXY,
- check_keys=["Response.Result"],
+ check_keys=["Response.WordList"],
)
+ return generate_tencent_transcription(sentence_start_ms=[0], words=[res["Response"]["WordList"]])
+
+
+async def flash_asr(path: str | Path, engine: str, voice_format: str) -> str:
+ """Tencent Flash ASR.
+
+ 录音文件识别极速版
+ https://cloud.tencent.com/document/product/1093/52097
+ """
+ now = nowdt()
+ params = {
+ "secretid": ASR.TENCENT_SECRET_ID,
+ "engine_type": engine,
+ "voice_format": voice_format,
+ "timestamp": str(int(now.timestamp())),
+ "word_info": 2,
+ }
+ signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{ASR.TENCENT_APPID}?"
+ for k, v in dict(sorted(params.items())).items(): # type: ignore
+ signstr += f"{k}={v}&"
+ signstr = signstr[:-1] # strip last "&"
+
+ hmacstr = hmac.new(ASR.TENCENT_SECRET_KEY.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest()
+ signature = base64.b64encode(hmacstr).decode("utf-8")
+ headers = {"Host": "asr.cloud.tencent.com", "authorization": signature}
+ url = f"https://{signstr.removeprefix('POST')}"
+ async with await anyio.open_file(path, "rb") as f:
+ res = await hx_req(
+ url,
+ method="POST",
+ headers=headers,
+ post_content=await f.read(),
+ timeout=60,
+ proxy=ASR.TENCENT_PROXY,
+ check_kv={"code": 0},
+ check_keys=["flash_result.0.sentence_list.0.word_list"],
+ )
+ sentence_start_ms = flatten(glom(res, "flash_result.*.sentence_list.*.start_time"), levels=1)
+ words = flatten(glom(res, "flash_result.*.sentence_list.*.word_list"), levels=1)
+ return generate_tencent_transcription(sentence_start_ms, words)
async def create_async_asr(url: str, engine: str) -> dict:
@@ -119,7 +136,7 @@ async def create_async_asr(url: str, engine: str) -> dict:
录音文件识别请求
https://cloud.tencent.com/document/api/1093/37823
"""
- payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":0,"SourceType":0,"Url":"{url}"}}'
+ payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":2,"SourceType":0,"Url":"{url}"}}'
headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
return await hx_req(
"https://asr.tencentcloudapi.com",
@@ -132,7 +149,7 @@ async def create_async_asr(url: str, engine: str) -> dict:
)
-async def query_async_asr(task_id: int) -> dict:
+async def query_async_asr(task_id: int) -> str:
"""Query Tencent ASR Task.
录音文件识别结果查询
@@ -140,7 +157,7 @@ async def query_async_asr(task_id: int) -> dict:
"""
payload = f'{{"TaskId":{task_id}}}'
headers = generate_tencent_cloud_headers(action="DescribeTaskStatus", payload=payload)
- return await hx_req(
+ result = await hx_req(
"https://asr.tencentcloudapi.com",
method="POST",
headers=headers,
@@ -149,3 +166,37 @@ async def query_async_asr(task_id: int) -> dict:
proxy=ASR.TENCENT_PROXY,
check_keys=["Response.Data.StatusStr"],
)
+ status = glom(result, "Response.Data.StatusStr")
+ while status in ["waiting", "doing"]:
+ await asyncio.sleep(1)
+ logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
+ result = await query_async_asr(task_id)
+ if isinstance(result, str):
+ return result
+ status = glom(result, "Response.Data.StatusStr")
+ if status == "success":
+ sentence_start_ms = glom(result, "Response.Data.ResultDetail.*.StartMs")
+ words = glom(result, "Response.Data.ResultDetail.*.Words")
+ return generate_tencent_transcription(sentence_start_ms, words)
+ return glom(result, "Response.Data.ErrorMsg")
+
+
+def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> str:
+ res = ""
+ show_timestamp = False
+ for start_offset, items in zip(sentence_start_ms, words, strict=True):
+ for idx, item in enumerate(items):
+ sentence = glom(item, Coalesce("Word", "word"), default="")
+ if not sentence:
+ continue
+ if idx == 0 or res.endswith((".", "。")):
+ show_timestamp = True
+ if show_timestamp:
+ start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
+ minutes = int(start_seconds // 60)
+ seconds = int(start_seconds % 60)
+ res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
+ show_timestamp = False
+ else:
+ res += sentence
+ return res.strip()
src/asr/utils.py
@@ -13,10 +13,11 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
asr_engine = ASR.MIDDLE_ENGINE
else:
asr_engine = ASR.LONG_ENGINE
-
- if asr_engine == "tencent" or force_engine == "tencent":
+ if force_engine:
+ return get_tencent_asr_method(duration, file_size) if force_engine == "tencent" else get_gemini_asr_method(duration)
+ if asr_engine == "tencent":
return get_tencent_asr_method(duration, file_size)
- if asr_engine.lower() == "gemini" or force_engine == "gemini":
+ if asr_engine.lower() == "gemini":
return get_gemini_asr_method(duration)
return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
src/asr/voice_recognition.py
@@ -1,12 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import asyncio
import contextlib
import io
import re
from pathlib import Path
-from glom import glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
@@ -136,7 +134,7 @@ async def voice_to_text(
await modify_progress(text=error, force_update=True, **kwargs)
return
if texts := res.get("texts"):
- final = f"{BEGINNING}\n{blockquote(texts)}"
+ final = blockquote(texts) if len(texts) > 300 else texts
logger.success(f"{final!r}")
# send results
target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
@@ -201,28 +199,14 @@ async def asr_file(
logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
try:
if asr_method == "single_sentence_asr":
- resp = await single_sentence_asr(path, language, voice_format)
- texts = glom(resp, "Response.Result").replace("。", "。\n")
+ texts = await single_sentence_asr(path, language, voice_format)
elif asr_method == "flash_asr":
- resp = await flash_asr(path, language, voice_format)
- texts = glom(resp, "flash_result.0.text").replace("。", "。\n")
+ texts = await flash_asr(path, language, voice_format)
elif asr_method == "async_asr":
resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", language)
task_id = resp["Response"]["Data"]["TaskId"]
logger.success(f"ASR任务提交成功, TaskID: {task_id}")
- result = await query_async_asr(task_id)
- status = result["Response"]["Data"]["StatusStr"]
- while status in ["waiting", "doing"]:
- await asyncio.sleep(1)
- logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
- result = await query_async_asr(task_id)
- status = result["Response"]["Data"]["StatusStr"]
- if status == "success":
- texts = glom(result, "Response.Data.Result")
- texts = re.sub(r"\[.*?\]\s*", "", texts)
- else:
- texts = glom(result, "Response.Data.ErrorMsg")
- res["error"] = texts
+ texts = await query_async_asr(task_id)
elif asr_method == "gemini":
return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
res["texts"] = texts