main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import io
  5import re
  6from pathlib import Path
  7
  8import anyio
  9from dashscope.audio.asr import Recognition, RecognitionCallback
 10from glom import flatten, glom
 11from httpx import AsyncHTTPTransport
 12from loguru import logger
 13
 14from asr.utils import convert_single_channel, downsampe_audio
 15from config import ASR, DB, FILE_SERVER
 16from database.alist import delete_alist, upload_alist
 17from database.uguu import upload_uguu
 18from networking import hx_req
 19from utils import strings_list
 20
 21
 22async def ali_asr(path: str | Path) -> dict:
 23    """Create Aliyun ASR Task.
 24
 25    录音文件识别请求
 26
 27    Paraformer:
 28    https://help.aliyun.com/zh/model-studio/paraformer-recorded-speech-recognition-restful-api
 29
 30    SenseVoice:
 31    https://help.aliyun.com/zh/model-studio/developer-reference/sensevoice-recorded-speech-recognition-restful-applicant
 32    """
 33    path = Path(path).expanduser().resolve()
 34    if not path.is_file():
 35        return {"texts": "", "error": "File not found."}
 36    supported_ext = [".aac", ".amr", ".avi", ".flac", ".flv", ".m4a", ".mkv", ".mov", ".mp3", ".mp4", ".mpeg", ".oga", ".ogg", ".opus", ".wav", ".webm", ".wma", ".wmv"]
 37    audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
 38    audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
 39    api_keys = strings_list(ASR.ALI_API_KEY, shuffle=True)
 40    if not api_keys:
 41        return {"error": "请配置阿里云语音识别的API Key"}
 42    for api_key in api_keys:
 43        for model in strings_list(ASR.ALI_MODEL, shuffle=True):
 44            logger.debug(f"阿里云ASR {audio_path} via model: {model}")
 45            if model.startswith("paraformer-realtime-"):
 46                return await ali_realtime_asr(model, audio_path, api_key)
 47
 48            headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable"}
 49            if ASR.ALI_FS_ENGINE.lower() == "local":
 50                url = FILE_SERVER.removesuffix("/") + "/" + path.name
 51            elif ASR.ALI_FS_ENGINE.lower() == "uguu":
 52                if audio_path.stat().st_size > 100 * 1024 * 1024:  # 100 MB
 53                    audio_path = await downsampe_audio(audio_path, ext="wav", codec="pcm_s16le")
 54                url = await upload_uguu(audio_path)  # max 100 MB for Uguu
 55            elif ASR.ALI_FS_ENGINE.lower() == "alist":
 56                url = await upload_alist(audio_path)
 57            else:
 58                return {"error": f"Unsupported file server engine: {ASR.ALI_FS_ENGINE}"}
 59
 60            payload = {"model": model, "input": {"file_urls": [url]}}
 61            res = await hx_req(
 62                "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
 63                method="POST",
 64                headers=headers,
 65                json_data=payload,
 66                timeout=600,
 67                check_keys=["output.task_id"],
 68            )
 69            if res.get("hx_error"):
 70                return {"error": res["hx_error"]}
 71            logger.success(f"ASR任务提交成功, TaskID: {res['output']['task_id']}")
 72            return await query_ali_asr(task_id=res["output"]["task_id"], api_key=api_key)
 73    return {}
 74
 75
 76async def query_ali_asr(task_id: str, api_key: str, query_times: int = 0) -> dict:
 77    """Query Ali ASR Task.
 78
 79    录音文件识别结果查询
 80    Paraformer:
 81    https://help.aliyun.com/zh/model-studio/paraformer-recorded-speech-recognition-restful-api
 82
 83    SenseVoice:
 84    https://help.aliyun.com/zh/model-studio/developer-reference/sensevoice-recorded-speech-recognition-restful-applicant
 85    """
 86    payload = {"task_id": task_id}
 87    headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable"}
 88    result = await hx_req(
 89        f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}",
 90        method="POST",
 91        headers=headers,
 92        json_data=payload,
 93        check_keys=["output.task_status"],
 94    )
 95    if result.get("hx_error"):
 96        return {"error": result["hx_error"]}
 97    status = glom(result, "output.task_status")
 98    while status in ["RUNNING", "PENDING"] and query_times < 600:  # max 10 minutes
 99        await asyncio.sleep(1)
100        query_times += 1
101        logger.trace(f"Status:[{status} ({query_times}/600)], Wating TaskID: {task_id}")
102        result = await query_ali_asr(task_id, api_key, query_times)
103        if result.get("texts") or result.get("error"):
104            return result
105        status = glom(result, "output.task_status")
106    if ASR.ALI_FS_ENGINE.lower() == "alist":
107        await clean_alist(glom(result, "output.results.0.file_url", default=""))
108    if status == "SUCCEEDED":
109        transcription_url = glom(result, "output.results.0.transcription_url")
110        trans_res = await hx_req(transcription_url, transport=AsyncHTTPTransport(), check_keys=["transcripts.0.sentences.0.text"])
111        if trans_res.get("hx_error"):
112            return {"error": trans_res["hx_error"]}
113        # DO NOT use AsyncCurlTransport
114        sentence_start_ms = glom(trans_res, "transcripts.0.sentences.*.begin_time")
115        sentences = glom(trans_res, "transcripts.0.sentences.*.text")
116        return generate_ali_transcription(sentence_start_ms, sentences)
117    return {"error": "" + glom(result, "output.message", default="语音识别失败")}
118
119
120async def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> dict:
121    # convert audio file
122    sample_rate = 8000 if "8k" in model else 16000
123    audio_path = await downsampe_audio(path, ext="wav", codec="pcm_s16le", sample_rate=sample_rate, channel=1)
124    recognition = Recognition(model=model, format="wav", sample_rate=sample_rate, callback=RecognitionCallback(), api_key=api_key)
125    result = recognition.call(Path(audio_path).as_posix())
126    if result.status_code != 200:
127        return {"error": f"❌语音识别失败: {result.message}"}
128    Path(audio_path).unlink(missing_ok=True)
129    data = result.get_sentence()
130    if not data:
131        return {"error": "⚠️该音频未识别到文字"}
132    start_times = flatten(glom(data, "*.words.*.begin_time"))
133    texts = flatten(glom(data, "*.words.*.text"))
134    punctuations = flatten(glom(data, "*.words.*.punctuation"))
135    sentences = [f"{text}{punc}" for text, punc in zip(texts, punctuations, strict=True)]
136    return generate_ali_transcription(start_times, sentences)
137
138
139def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str]) -> dict:
140    def clean_tags(text: str) -> str:
141        """Clean sensevoice tags.
142
143        Remove <|sense-1|>, <|sense-2|>, ..., etc.
144        """
145        if not text:
146            return text
147        return re.sub(r"<\|.*?\|>", "", text)
148
149    res = ""
150    try:
151        indexs = list(range(len(sentences)))
152        for idx, start_ms, sentence in zip(indexs, sentence_start_ms, sentences, strict=True):
153            text = clean_tags(sentence)
154            if not text:
155                continue
156            if idx == 0 or res.endswith((".", "", "?", "")):  # noqa: RUF001
157                start_seconds = float(start_ms) // 1000
158                minutes = int(start_seconds // 60)
159                seconds = int(start_seconds % 60)
160                res += f"\n[{minutes:02d}:{seconds:02d}] {text}"
161            else:
162                res += text
163    except Exception as e:
164        logger.error(e)
165        return {"error": str(e)}
166    return {"texts": res.strip()}
167
168
169async def clean_alist(url: str):
170    """Clean alist file after ASR is finished."""
171    if not url:
172        return
173    prefix = DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + "/"
174    if url.startswith(prefix):
175        fname = url.removeprefix(prefix)
176        await delete_alist(fname)
177
178
179async def upload_ali_oss(path: str | Path, api_key: str, model_name: str):
180    """Get OSS url of Aliyun.
181
182    https://help.aliyun.com/zh/model-studio/get-temporary-file-url
183    """
184    url = "https://dashscope.aliyuncs.com/api/v1/uploads"
185    headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
186    params = {"action": "getPolicy", "model": model_name}
187
188    response = await hx_req(url, headers=headers, params=params, check_keys=["data.upload_host"])
189    policy_data = response["data"]
190    path = Path(path)
191    key = f"{policy_data['upload_dir']}/{path.name}"
192    async with await anyio.open_file(path, "rb") as f:
193        content = await f.read()
194        files = {
195            "OSSAccessKeyId": (None, policy_data["oss_access_key_id"]),
196            "Signature": (None, policy_data["signature"]),
197            "policy": (None, policy_data["policy"]),
198            "x-oss-object-acl": (None, policy_data["x_oss_object_acl"]),
199            "x-oss-forbid-overwrite": (None, policy_data["x_oss_forbid_overwrite"]),
200            "key": (None, key),
201            "success_action_status": (None, "200"),
202            "file": (path.name, io.BytesIO(content)),
203        }
204        response = await hx_req(policy_data["upload_host"], method="POST", files=files, rformat="text")
205    # return f"oss://{key}"
206    return f"{policy_data['upload_host']}/{key}"