main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import io
5import re
6from pathlib import Path
7
8import anyio
9from dashscope.audio.asr import Recognition, RecognitionCallback
10from glom import flatten, glom
11from httpx import AsyncHTTPTransport
12from loguru import logger
13
14from asr.utils import convert_single_channel, downsampe_audio
15from config import ASR, DB, FILE_SERVER
16from database.alist import delete_alist, upload_alist
17from database.uguu import upload_uguu
18from networking import hx_req
19from utils import strings_list
20
21
22async def ali_asr(path: str | Path) -> dict:
23 """Create Aliyun ASR Task.
24
25 录音文件识别请求
26
27 Paraformer:
28 https://help.aliyun.com/zh/model-studio/paraformer-recorded-speech-recognition-restful-api
29
30 SenseVoice:
31 https://help.aliyun.com/zh/model-studio/developer-reference/sensevoice-recorded-speech-recognition-restful-applicant
32 """
33 path = Path(path).expanduser().resolve()
34 if not path.is_file():
35 return {"texts": "", "error": "File not found."}
36 supported_ext = [".aac", ".amr", ".avi", ".flac", ".flv", ".m4a", ".mkv", ".mov", ".mp3", ".mp4", ".mpeg", ".oga", ".ogg", ".opus", ".wav", ".webm", ".wma", ".wmv"]
37 audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
38 audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
39 api_keys = strings_list(ASR.ALI_API_KEY, shuffle=True)
40 if not api_keys:
41 return {"error": "请配置阿里云语音识别的API Key"}
42 for api_key in api_keys:
43 for model in strings_list(ASR.ALI_MODEL, shuffle=True):
44 logger.debug(f"阿里云ASR {audio_path} via model: {model}")
45 if model.startswith("paraformer-realtime-"):
46 return await ali_realtime_asr(model, audio_path, api_key)
47
48 headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable"}
49 if ASR.ALI_FS_ENGINE.lower() == "local":
50 url = FILE_SERVER.removesuffix("/") + "/" + path.name
51 elif ASR.ALI_FS_ENGINE.lower() == "uguu":
52 if audio_path.stat().st_size > 100 * 1024 * 1024: # 100 MB
53 audio_path = await downsampe_audio(audio_path, ext="wav", codec="pcm_s16le")
54 url = await upload_uguu(audio_path) # max 100 MB for Uguu
55 elif ASR.ALI_FS_ENGINE.lower() == "alist":
56 url = await upload_alist(audio_path)
57 else:
58 return {"error": f"Unsupported file server engine: {ASR.ALI_FS_ENGINE}"}
59
60 payload = {"model": model, "input": {"file_urls": [url]}}
61 res = await hx_req(
62 "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
63 method="POST",
64 headers=headers,
65 json_data=payload,
66 timeout=600,
67 check_keys=["output.task_id"],
68 )
69 if res.get("hx_error"):
70 return {"error": res["hx_error"]}
71 logger.success(f"ASR任务提交成功, TaskID: {res['output']['task_id']}")
72 return await query_ali_asr(task_id=res["output"]["task_id"], api_key=api_key)
73 return {}
74
75
76async def query_ali_asr(task_id: str, api_key: str, query_times: int = 0) -> dict:
77 """Query Ali ASR Task.
78
79 录音文件识别结果查询
80 Paraformer:
81 https://help.aliyun.com/zh/model-studio/paraformer-recorded-speech-recognition-restful-api
82
83 SenseVoice:
84 https://help.aliyun.com/zh/model-studio/developer-reference/sensevoice-recorded-speech-recognition-restful-applicant
85 """
86 payload = {"task_id": task_id}
87 headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable"}
88 result = await hx_req(
89 f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}",
90 method="POST",
91 headers=headers,
92 json_data=payload,
93 check_keys=["output.task_status"],
94 )
95 if result.get("hx_error"):
96 return {"error": result["hx_error"]}
97 status = glom(result, "output.task_status")
98 while status in ["RUNNING", "PENDING"] and query_times < 600: # max 10 minutes
99 await asyncio.sleep(1)
100 query_times += 1
101 logger.trace(f"Status:[{status} ({query_times}/600)], Wating TaskID: {task_id}")
102 result = await query_ali_asr(task_id, api_key, query_times)
103 if result.get("texts") or result.get("error"):
104 return result
105 status = glom(result, "output.task_status")
106 if ASR.ALI_FS_ENGINE.lower() == "alist":
107 await clean_alist(glom(result, "output.results.0.file_url", default=""))
108 if status == "SUCCEEDED":
109 transcription_url = glom(result, "output.results.0.transcription_url")
110 trans_res = await hx_req(transcription_url, transport=AsyncHTTPTransport(), check_keys=["transcripts.0.sentences.0.text"])
111 if trans_res.get("hx_error"):
112 return {"error": trans_res["hx_error"]}
113 # DO NOT use AsyncCurlTransport
114 sentence_start_ms = glom(trans_res, "transcripts.0.sentences.*.begin_time")
115 sentences = glom(trans_res, "transcripts.0.sentences.*.text")
116 return generate_ali_transcription(sentence_start_ms, sentences)
117 return {"error": "❌" + glom(result, "output.message", default="语音识别失败")}
118
119
120async def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> dict:
121 # convert audio file
122 sample_rate = 8000 if "8k" in model else 16000
123 audio_path = await downsampe_audio(path, ext="wav", codec="pcm_s16le", sample_rate=sample_rate, channel=1)
124 recognition = Recognition(model=model, format="wav", sample_rate=sample_rate, callback=RecognitionCallback(), api_key=api_key)
125 result = recognition.call(Path(audio_path).as_posix())
126 if result.status_code != 200:
127 return {"error": f"❌语音识别失败: {result.message}"}
128 Path(audio_path).unlink(missing_ok=True)
129 data = result.get_sentence()
130 if not data:
131 return {"error": "⚠️该音频未识别到文字"}
132 start_times = flatten(glom(data, "*.words.*.begin_time"))
133 texts = flatten(glom(data, "*.words.*.text"))
134 punctuations = flatten(glom(data, "*.words.*.punctuation"))
135 sentences = [f"{text}{punc}" for text, punc in zip(texts, punctuations, strict=True)]
136 return generate_ali_transcription(start_times, sentences)
137
138
139def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str]) -> dict:
140 def clean_tags(text: str) -> str:
141 """Clean sensevoice tags.
142
143 Remove <|sense-1|>, <|sense-2|>, ..., etc.
144 """
145 if not text:
146 return text
147 return re.sub(r"<\|.*?\|>", "", text)
148
149 res = ""
150 try:
151 indexs = list(range(len(sentences)))
152 for idx, start_ms, sentence in zip(indexs, sentence_start_ms, sentences, strict=True):
153 text = clean_tags(sentence)
154 if not text:
155 continue
156 if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
157 start_seconds = float(start_ms) // 1000
158 minutes = int(start_seconds // 60)
159 seconds = int(start_seconds % 60)
160 res += f"\n[{minutes:02d}:{seconds:02d}] {text}"
161 else:
162 res += text
163 except Exception as e:
164 logger.error(e)
165 return {"error": str(e)}
166 return {"texts": res.strip()}
167
168
169async def clean_alist(url: str):
170 """Clean alist file after ASR is finished."""
171 if not url:
172 return
173 prefix = DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + "/"
174 if url.startswith(prefix):
175 fname = url.removeprefix(prefix)
176 await delete_alist(fname)
177
178
179async def upload_ali_oss(path: str | Path, api_key: str, model_name: str):
180 """Get OSS url of Aliyun.
181
182 https://help.aliyun.com/zh/model-studio/get-temporary-file-url
183 """
184 url = "https://dashscope.aliyuncs.com/api/v1/uploads"
185 headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
186 params = {"action": "getPolicy", "model": model_name}
187
188 response = await hx_req(url, headers=headers, params=params, check_keys=["data.upload_host"])
189 policy_data = response["data"]
190 path = Path(path)
191 key = f"{policy_data['upload_dir']}/{path.name}"
192 async with await anyio.open_file(path, "rb") as f:
193 content = await f.read()
194 files = {
195 "OSSAccessKeyId": (None, policy_data["oss_access_key_id"]),
196 "Signature": (None, policy_data["signature"]),
197 "policy": (None, policy_data["policy"]),
198 "x-oss-object-acl": (None, policy_data["x_oss_object_acl"]),
199 "x-oss-forbid-overwrite": (None, policy_data["x_oss_forbid_overwrite"]),
200 "key": (None, key),
201 "success_action_status": (None, "200"),
202 "file": (path.name, io.BytesIO(content)),
203 }
204 response = await hx_req(policy_data["upload_host"], method="POST", files=files, rformat="text")
205 # return f"oss://{key}"
206 return f"{policy_data['upload_host']}/{key}"