Commit 87890e0
Changed files (4)
src/asr/tecent_asr.py
@@ -3,28 +3,34 @@
import base64
import hashlib
import hmac
-import time
from pathlib import Path
import anyio
from config import PROXY, TOKEN
from networking import hx_req
+from utils import nowdt
-async def flash_asr(path: str | Path, engine: str, voice_format: str):
+def sign(key, msg):
+ return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
+
+
+async def flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
"""Tencent Flash ASR.
+ 录音文件识别极速版
https://cloud.tencent.com/document/product/1093/52097
"""
+ now = nowdt()
params = {
"secretid": TOKEN.TENCENT_ASR_SECRET_ID,
"engine_type": engine,
"voice_format": voice_format,
- "timestamp": str(int(time.time())),
+ "timestamp": str(int(now.timestamp())),
}
signstr = f"POSTasr.cloud.tencent.com/asr/flash/v1/{TOKEN.TENCENT_ASR_APPID}?"
- for k, v in dict(sorted(params.items())).items():
+ for k, v in dict(sorted(params.items())).items(): # type: ignore
signstr += f"{k}={v}&"
signstr = signstr[:-1] # strip last "&"
@@ -34,3 +40,112 @@ async def flash_asr(path: str | Path, engine: str, voice_format: str):
url = f"https://{signstr.removeprefix('POST')}"
async with await anyio.open_file(path, "rb") as f:
return await hx_req(url, method="POST", headers=headers, post_content=await f.read(), timeout=60, proxy=PROXY.TENCENT, check_kv={"code": 0}, check_keys=["flash_result"])
+
+
+def generate_tencent_cloud_headers(
+ action: str,
+ payload: str,
+ service: str = "asr",
+ host: str = "asr.tencentcloudapi.com",
+ version: str = "2019-06-14",
+ secret_id: str = TOKEN.TENCENT_ASR_SECRET_ID,
+ secret_key: str = TOKEN.TENCENT_ASR_SECRET_KEY,
+) -> dict:
+ """Generate TencentCloudAPI Headers (TC3-HMAC-SHA256)."""
+ algorithm = "TC3-HMAC-SHA256"
+ now = nowdt()
+ timestamp = str(int(now.timestamp()))
+ date = f"{now:%Y-%m-%d}"
+
+ # ************* 步骤 1: 拼接规范请求串 *************
+ http_request_method = "POST"
+ canonical_uri = "/"
+ canonical_querystring = ""
+ canonical_headers = f"content-type:application/json; charset=utf-8\nhost:{host}\nx-tc-action:{action.lower()}\n"
+ signed_headers = "content-type;host;x-tc-action"
+ hashed_request_payload = hashlib.sha256(payload.encode("utf-8")).hexdigest()
+ canonical_request = f"{http_request_method}\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{hashed_request_payload}"
+
+ # ************* 步骤 2: 拼接待签名字符串 *************
+ credential_scope = f"{date}/{service}/tc3_request"
+ hashed_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()
+ string_to_sign = f"{algorithm}\n{timestamp}\n{credential_scope}\n{hashed_canonical_request}"
+
+ # ************* 步骤 3: 计算签名 *************
+ secret_date = sign(("TC3" + secret_key).encode("utf-8"), date)
+ secret_service = sign(secret_date, service)
+ secret_signing = sign(secret_service, "tc3_request")
+ signature = hmac.new(secret_signing, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest()
+
+ # ************* 步骤 4: 拼接 Authorization *************
+ authorization = f"{algorithm} Credential={secret_id}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
+
+ # ************* 步骤 5: 构造 Headers *************
+ return {
+ "Authorization": authorization,
+ "Content-Type": "application/json; charset=utf-8",
+ "Host": host,
+ "X-TC-Action": action,
+ "X-TC-Timestamp": timestamp,
+ "X-TC-Version": version,
+ }
+
+
+async def single_sentence_asr(path: str | Path, engine: str, voice_format: str) -> dict:
+ """Tencent Single Sentence ASR.
+
+ 一句话识别
+ https://cloud.tencent.com/document/product/1093/52097
+ """
+ async with await anyio.open_file(path, "rb") as f:
+ content = await f.read()
+ data = base64.b64encode(content).decode("utf-8")
+ payload = f'{{"EngSerViceType":"{engine}","SourceType":1,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
+ headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
+ return await hx_req(
+ "https://asr.tencentcloudapi.com",
+ method="POST",
+ headers=headers,
+ post_content=payload.encode("utf-8"),
+ timeout=60,
+ proxy=PROXY.TENCENT,
+ check_keys=["Response.Result"],
+ )
+
+
+async def create_async_asr(url: str, engine: str) -> dict:
+ """Create Tencent ASR Task.
+
+ 录音文件识别请求
+ https://cloud.tencent.com/document/api/1093/37823
+ """
+ payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":0,"SourceType":0,"Url":"{url}"}}'
+ headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
+ return await hx_req(
+ "https://asr.tencentcloudapi.com",
+ method="POST",
+ headers=headers,
+ post_content=payload.encode("utf-8"),
+ timeout=600,
+ proxy=PROXY.TENCENT,
+ check_keys=["Response.Data.TaskId"],
+ )
+
+
+async def query_async_asr(task_id: int) -> dict:
+ """Query Tencent ASR Task.
+
+ 录音文件识别结果查询
+ https://cloud.tencent.com/document/api/1093/37822
+ """
+ payload = f'{{"TaskId":{task_id}}}'
+ headers = generate_tencent_cloud_headers(action="DescribeTaskStatus", payload=payload)
+ return await hx_req(
+ "https://asr.tencentcloudapi.com",
+ method="POST",
+ headers=headers,
+ post_content=payload.encode("utf-8"),
+ timeout=600,
+ proxy=PROXY.TENCENT,
+ check_keys=["Response.Data.StatusStr"],
+ )
src/asr/voice_recognition.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import asyncio
import contextlib
import re
from pathlib import Path
@@ -9,8 +10,8 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from asr.tecent_asr import flash_asr
-from config import ASR_MAX_DURATION, CAPTION_LENGTH, PREFIX
+from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
+from config import ASR_MAX_DURATION, CAPTION_LENGTH, FILE_SERVER, PREFIX
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg, send_texts
@@ -143,17 +144,42 @@ async def voice_to_text(
return
# Skip very long
- if trigger_info["duration"] > ASR_MAX_DURATION:
+ duration = trigger_info["duration"]
+ if duration > ASR_MAX_DURATION:
msg = f"无法识别时长超过{ASR_MAX_DURATION}秒的音频, 当前音频时长: {trigger_info['duration']}秒"
logger.error(msg)
await modify_progress(text=msg, force_update=True, **kwargs)
return
- logger.debug(f"Recognizing {voice_format} audio by {asr_engine}: {path.as_posix()}")
+ logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {asr_engine}: {path.as_posix()}")
try:
- resp = await flash_asr(path, asr_engine, voice_format)
- texts = glom(resp, "flash_result.0.text") or "❌无法识别"
- final = f"{BEGINNING}\n{texts}".replace("。", "。\n")
+ if duration < 60 and trigger_info["file_size"] < 3 * 1024 * 1024: # 时长不能超过60s, 文件大小不能超过3MB
+ resp = await single_sentence_asr(path, asr_engine, voice_format)
+ texts = glom(resp, "Response.Result", default="❌无法识别").replace("。", "。\n")
+ elif 60 <= duration <= 300:
+ resp = await flash_asr(path, asr_engine, voice_format)
+ texts = glom(resp, "flash_result.0.text", default="❌无法识别").replace("。", "。\n")
+ elif FILE_SERVER:
+ resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", asr_engine)
+ task_id = resp["Response"]["Data"]["TaskId"]
+ logger.success(task_id)
+ # task_id = 11982077075
+ result = await query_async_asr(task_id)
+ status = result["Response"]["Data"]["StatusStr"]
+ while status in ["waiting", "doing"]:
+ await asyncio.sleep(1)
+ logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
+ result = await query_async_asr(task_id)
+ status = result["Response"]["Data"]["StatusStr"]
+ if status == "success":
+ texts = glom(result, "Response.Data.Result", default="❌无法识别")
+ texts = re.sub(r"\[.*?\]\s*", "", texts)
+ else:
+ texts = glom(result, "Response.Data.ErrorMsg", default="❌无法识别")
+ else:
+ texts = "音频过长,请联系管理员配置`FILE_SERVER`变量"
+
+ final = f"{BEGINNING}\n{texts}"
logger.success(f"{final!r}")
# send results
src/llm/contexts.py
@@ -10,7 +10,7 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from config import GPT, PREFIX
+from config import FILE_SERVER, GPT, PREFIX
from llm.utils import BOT_TIPS
from messages.parser import parse_msg
@@ -106,9 +106,9 @@ async def single_context(client: Client, message: Message) -> dict:
path: str = await client.download_media(msg) # type: ignore
logger.debug(f"Downloaded GPT media: {path}")
if info["mtype"] == "photo":
- contexts.append({"type": "image_url", "image_url": {"url": f"{GPT.MEDIA_SERVER}/{Path(path).name}"}})
+ contexts.append({"type": "image_url", "image_url": {"url": f"{FILE_SERVER}/{Path(path).name}"}})
# elif info["mtype"] == "video":
- # media.append({"type": "video_url", "video_url": {"url": f"{GPT.MEDIA_SERVER}/{Path(path).name}"}})
+ # media.append({"type": "video_url", "video_url": {"url": f"{FILE_SERVER}/{Path(path).name}"}})
elif info["mtype"] == "document" and info["mime_type"] in ["text/plain", "text/markdown"]:
contexts.append(
{
src/config.py
@@ -12,13 +12,14 @@ cache = Cache(ttl=0, maxsize=2048)
semaphore = asyncio.Semaphore(8) # max 8 concurrent downloads
DOWNLOAD_DIR = os.getenv("DOWNLOAD_DIR", Path(__file__).parent.joinpath("downloads").as_posix())
+FILE_SERVER = os.getenv("FILE_SERVER", "") # expose the download dir to internet (optional). for example: https://server.com/dir
TZ = os.getenv("TZ", "Asia/Shanghai")
DEVICE_NAME = os.getenv("DEVICE_NAME", "BennyBot")
REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "60")) # seconds
TEXT_LENGTH = int(os.getenv("TEXT_LENGTH", "4096")) # Maximum length of text message
CAPTION_LENGTH = int(os.getenv("CAPTION_LENGTH", "1024")) # 4096 for Premium user
MAX_FILE_BYTES = int(os.getenv("MAX_FILE_BYTES", "2000")) * 1024 * 1024 # 4000 MB for Premium user
-ASR_MAX_DURATION = int(os.getenv("ASR_MAX_DURATION", "600"))
+ASR_MAX_DURATION = int(os.getenv("ASR_MAX_DURATION", "3600"))
MAX_MESSAGE_RETRIEVED = int(os.getenv("MAX_MESSAGE_RETRIEVED", "1000000")) # Maximum number of messages to retrieve
MAX_MESSAGE_SUMMARY = int(os.getenv("MAX_MESSAGE_SUMMARY", "9999")) # Maximum number of messages to summay
READING_SPEED = int(os.getenv("READING_SPEED", "600")) # words per minute
@@ -164,7 +165,6 @@ class GPT: # see `llm/README.md`
TEMPERATURE = os.getenv("GPT_TEMPERATURE", "1.0")
HISTORY_CONTEXT = os.getenv("GPT_HISTORY_CONTEXT", "20") # 最多携带多少条历史消息
MEDIA_FORMAT = os.getenv("GPT_MEDIA_FORMAT", "base64") # base64 or http
- MEDIA_SERVER = os.getenv("GPT_MEDIA_SERVER", "https://server.com/dir") # only when MEDIA_FORMAT is http
TEXT_API_KEY = os.getenv("GPT_TEXT_API_KEY", "")
TEXT_BASE_URL = os.getenv("GPT_TEXT_BASE_URL", "https://api.openai.com/v1")
IMAGE_API_KEY = os.getenv("GPT_IMAGE_API_KEY", "")