Commit b918cf5
Changed files (6)
src/asr/ali_asr.py
@@ -18,7 +18,7 @@ from database import delete_alist, upload_alist, upload_uguu
from networking import hx_req
-async def ali_asr(path: str | Path) -> str:
+async def ali_asr(path: str | Path) -> dict:
"""Create Aliyun ASR Task.
录音文件识别请求
@@ -31,7 +31,7 @@ async def ali_asr(path: str | Path) -> str:
"""
api_keys = [x.strip() for x in ASR.ALI_API_KEY.split(",") if x.strip()]
if not api_keys:
- return "请配置阿里云语音识别的API Key"
+ return {"error": "请配置阿里云语音识别的API Key"}
models = [x.strip() for x in ASR.ALI_MODEL.split(",") if x.strip()]
model = random.choice(models)
logger.debug(f"阿里云ASR {path} via model: {model}")
@@ -47,8 +47,10 @@ async def ali_asr(path: str | Path) -> str:
if path.stat().st_size > 100 * 1024 * 1024: # 100 MB
path = downsampe_audio(path)
url = await upload_uguu(path) # max 100 MB for Uguu
- else:
+ elif ASR.ALI_FS_ENGINE.lower() == "alist":
url = await upload_alist(path)
+ else:
+ return {"error": f"Unsupported file server engine: {ASR.ALI_FS_ENGINE}"}
payload = {"model": model, "input": {"file_urls": [url]}}
res = await hx_req(
@@ -59,11 +61,13 @@ async def ali_asr(path: str | Path) -> str:
timeout=600,
check_keys=["output.task_id"],
)
+ if res.get("hx_error"):
+ return {"error": res["hx_error"]}
logger.success(f"ASR任务提交成功, TaskID: {res['output']['task_id']}")
return await query_ali_asr(task_id=res["output"]["task_id"], api_key=api_key)
-async def query_ali_asr(task_id: str, api_key: str) -> str:
+async def query_ali_asr(task_id: str, api_key: str, query_times: int = 0) -> dict:
"""Query Ali ASR Task.
录音文件识别结果查询
@@ -82,27 +86,32 @@ async def query_ali_asr(task_id: str, api_key: str) -> str:
post_json=payload,
check_keys=["output.task_status"],
)
+ if result.get("hx_error"):
+ return {"error": result["hx_error"]}
status = glom(result, "output.task_status")
- query_times = 0
while status in ["RUNNING", "PENDING"] and query_times < 600: # max 10 minutes
await asyncio.sleep(1)
query_times += 1
- logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
- result = await query_ali_asr(task_id, api_key)
- if isinstance(result, str):
+ logger.trace(f"Status:[{status} ({query_times}/600)], Wating TaskID: {task_id}")
+ result = await query_ali_asr(task_id, api_key, query_times)
+ if result.get("texts") or result.get("error"):
return result
status = glom(result, "output.task_status")
- await clean_alist(glom(result, "output.results.0.file_url", default=""))
+ if ASR.ALI_FS_ENGINE.lower() == "alist":
+ await clean_alist(glom(result, "output.results.0.file_url", default=""))
if status == "SUCCEEDED":
transcription_url = glom(result, "output.results.0.transcription_url")
- trans_res = await hx_req(transcription_url, transport=AsyncHTTPTransport(), check_keys=["transcripts.0.sentences.0.text"]) # DO NOT use AsyncCurlTransport
+ trans_res = await hx_req(transcription_url, transport=AsyncHTTPTransport(), check_keys=["transcripts.0.sentences.0.text"])
+ if trans_res.get("hx_error"):
+ return {"error": trans_res["hx_error"]}
+ # DO NOT use AsyncCurlTransport
sentence_start_ms = glom(trans_res, "transcripts.0.sentences.*.begin_time")
sentences = glom(trans_res, "transcripts.0.sentences.*.text")
return generate_ali_transcription(sentence_start_ms, sentences)
- return "❌" + glom(result, "output.message", default="语音识别失败")
+ return {"error": "❌" + glom(result, "output.message", default="语音识别失败")}
-def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
+def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> dict:
# convert audio file
sample_rate = 8000 if "8k" in model else 16000
ext = "opus"
@@ -110,9 +119,11 @@ def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
recognition = Recognition(model=model, format=ext, sample_rate=sample_rate, callback=RecognitionCallback(), api_key=api_key)
result = recognition.call(Path(path).as_posix())
if result.status_code != 200:
- return f"❌语音识别失败: {result.message}"
+ return {"error": f"❌语音识别失败: {result.message}"}
Path(path).unlink(missing_ok=True)
data = result.get_sentence()
+ if not data:
+ return {"error": "⚠️该音频未识别到文字"}
start_times = flatten(glom(data, "*.words.*.begin_time"))
texts = flatten(glom(data, "*.words.*.text"))
punctuations = flatten(glom(data, "*.words.*.punctuation"))
@@ -120,7 +131,7 @@ def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
return generate_ali_transcription(start_times, sentences)
-def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str]) -> str:
+def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str]) -> dict:
def clean_tags(text: str) -> str:
"""Clean sensevoice tags.
@@ -131,19 +142,23 @@ def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str
return re.sub(r"<\|.*?\|>", "", text)
res = ""
- indexs = list(range(len(sentences)))
- for idx, start_ms, sentence in zip(indexs, sentence_start_ms, sentences, strict=True):
- text = clean_tags(sentence)
- if not text:
- continue
- if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
- start_seconds = float(start_ms) // 1000
- minutes = int(start_seconds // 60)
- seconds = int(start_seconds % 60)
- res += f"\n[{minutes:02d}:{seconds:02d}] {text}"
- else:
- res += text
- return res.strip()
+ try:
+ indexs = list(range(len(sentences)))
+ for idx, start_ms, sentence in zip(indexs, sentence_start_ms, sentences, strict=True):
+ text = clean_tags(sentence)
+ if not text:
+ continue
+ if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
+ start_seconds = float(start_ms) // 1000
+ minutes = int(start_seconds // 60)
+ seconds = int(start_seconds % 60)
+ res += f"\n[{minutes:02d}:{seconds:02d}] {text}"
+ else:
+ res += text
+ except Exception as e:
+ logger.error(e)
+ return {"error": str(e)}
+ return {"texts": res.strip()}
async def clean_alist(url: str):
src/asr/deepgram.py
@@ -12,14 +12,14 @@ from networking import hx_req
from utils import zhcn
-async def deepgram_asr(path: str | Path) -> str:
+async def deepgram_asr(path: str | Path) -> dict:
"""Deepgram ASR.
https://developers.deepgram.com/docs/pre-recorded-audio
"""
api_keys = [x.strip() for x in ASR.DEEPGRAM_API.split(",") if x.strip()]
if not api_keys:
- return "请配置DeepGram语音识别的API Key"
+ return {"error": "请配置DeepGram语音识别的API Key"}
logger.debug(f"DeepGram ASR {path}")
headers = {"Authorization": f"Token {random.choice(api_keys)}"}
path = Path(path).expanduser().resolve()
@@ -35,18 +35,24 @@ async def deepgram_asr(path: str | Path) -> str:
timeout=600,
check_keys=["results.channels.0.alternatives.0.words"],
)
- start_seconds = flatten(glom(res, "results.channels.*.alternatives.0.words.*.start"))
- sentences = flatten(glom(res, "results.channels.*.alternatives.0.words.*.punctuated_word"))
- res = ""
- indexs = list(range(len(sentences)))
- for idx, start_time, sentence in zip(indexs, start_seconds, sentences, strict=True):
- if not sentence:
- continue
- if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
- start_seconds = float(start_time)
- minutes = int(start_seconds // 60)
- seconds = int(start_seconds % 60)
- res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
- else:
- res += sentence
- return zhcn(res.strip())
+ if res.get("hx_error"):
+ return {"error": res["hx_error"]}
+ try:
+ start_seconds = flatten(glom(res, "results.channels.*.alternatives.0.words.*.start"))
+ sentences = flatten(glom(res, "results.channels.*.alternatives.0.words.*.punctuated_word"))
+ res = ""
+ indexs = list(range(len(sentences)))
+ for idx, start_time, sentence in zip(indexs, start_seconds, sentences, strict=True):
+ if not sentence:
+ continue
+ if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
+ start_seconds = float(start_time)
+ minutes = int(start_seconds // 60)
+ seconds = int(start_seconds % 60)
+ res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
+ else:
+ res += sentence
+ except Exception as e:
+ logger.error(e)
+ return {"error": str(e)}
+ return {"texts": zhcn(res.strip())}
src/asr/gemini_asr.py
@@ -18,6 +18,7 @@ from llm.gemini import gemini_stream
from llm.hooks import hook_gemini_httpoptions
from llm.utils import shuffle_keys
from messages.progress import modify_progress
+from utils import count_subtitles
async def gemini_stream_asr(
@@ -69,6 +70,7 @@ Notes:
- Focus on accuracy in capturing both the timing and the spoken content.
- Maintain consistent formatting to ensure clarity and readability."""
path = Path(path)
+ sent_messages = []
status = None if silent else kwargs.get("progress")
api_keys = shuffle_keys(GEMINI.API_KEY)
if model_id is None:
@@ -104,7 +106,7 @@ Notes:
append_grounding=False,
**kwargs,
)
- if res.get("error"):
+ if res.get("error") or count_subtitles(res.get("texts", "")) == 0:
continue
sent_messages = res.get("sent_messages", [])
break
src/asr/tecent_asr.py
@@ -12,7 +12,7 @@ from loguru import logger
from asr.utils import downsampe_audio
from config import ASR, FILE_SERVER
-from database import upload_alist, upload_uguu
+from database import delete_alist, upload_alist, upload_uguu
from networking import hx_req
from utils import nowdt
@@ -70,7 +70,7 @@ def generate_tencent_cloud_headers(
}
-async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -> str:
+async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -> dict:
"""Tencent Single Sentence ASR.
一句话识别
@@ -91,11 +91,14 @@ async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -
check_keys=["Response.WordList"],
)
if res["Response"]["WordList"] is None:
- return "⚠️该音频未识别到文字"
+ return {"error": "⚠️该音频未识别到文字"}
+ if res.get("hx_error"):
+ return {"error": res["hx_error"]}
+
return generate_tencent_transcription(sentence_start_ms=[0], words=[res["Response"]["WordList"]])
-async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) -> str:
+async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
"""Tencent Flash ASR.
录音文件识别极速版
@@ -131,14 +134,14 @@ async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) ->
)
if error := res.get("hx_error", ""):
if "audio data empty" in error:
- return "⚠️该音频未识别到文字"
- return error
+ return {"error": "⚠️该音频未识别到文字"}
+ return {"error": error}
sentence_start_ms = flatten(glom(res, "flash_result.*.sentence_list.*.start_time"), levels=1)
words = flatten(glom(res, "flash_result.*.sentence_list.*.word_list"), levels=1)
return generate_tencent_transcription(sentence_start_ms, words)
-async def tencent_async_asr(path: str | Path, engine: str) -> str:
+async def tencent_async_asr(path: str | Path, engine: str) -> dict:
"""Create Tencent ASR Task.
录音文件识别请求
@@ -151,8 +154,10 @@ async def tencent_async_asr(path: str | Path, engine: str) -> str:
if path.stat().st_size > 100 * 1024 * 1024: # 100 MB
path = downsampe_audio(path)
url = await upload_uguu(path) # max 100 MB for Uguu
- else:
+ elif ASR.TENCENT_FS_ENGINE.lower() == "alist":
url = await upload_alist(path)
+ else:
+ return {"error": f"Unsupported file server engine: {ASR.TENCENT_FS_ENGINE}"}
payload = f'{{"EngineModelType":"{engine}","ChannelNum":1,"ResTextFormat":2,"SourceType":0,"Url":"{url}"}}'
headers = generate_tencent_cloud_headers(action="CreateRecTask", payload=payload)
@@ -165,12 +170,14 @@ async def tencent_async_asr(path: str | Path, engine: str) -> str:
proxy=ASR.TENCENT_PROXY,
check_keys=["Response.Data.TaskId"],
)
+ if resp.get("hx_error"):
+ return {"error": resp["hx_error"]}
task_id = resp["Response"]["Data"]["TaskId"]
logger.success(f"ASR任务提交成功, TaskID: {task_id}")
- return await tencent_query_asr(task_id)
+ return await tencent_query_asr(task_id, file_name=path.name)
-async def tencent_query_asr(task_id: int) -> str:
+async def tencent_query_asr(task_id: int, file_name: str, query_times: int = 0) -> dict:
"""Query Tencent ASR Task.
录音文件识别结果查询
@@ -187,35 +194,44 @@ async def tencent_query_asr(task_id: int) -> str:
proxy=ASR.TENCENT_PROXY,
check_keys=["Response.Data.StatusStr"],
)
+ if result.get("hx_error"):
+ return {"error": result["hx_error"]}
status = glom(result, "Response.Data.StatusStr")
- query_times = 0
while status in ["waiting", "doing"] and query_times < 600: # max 10 minutes
await asyncio.sleep(1)
query_times += 1
- logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
- result = await tencent_query_asr(task_id)
- if isinstance(result, str):
+ logger.trace(f"Status: [{status} ({query_times}/600)], Wating TaskID: {task_id}")
+ result = await tencent_query_asr(task_id, file_name, query_times)
+ if result.get("texts") or result.get("error"):
return result
status = glom(result, "Response.Data.StatusStr")
+ if ASR.TENCENT_FS_ENGINE.lower() == "alist":
+ await delete_alist(file_name)
if status == "success":
+ if glom(result, "Response.Data.ResultDetail") is None:
+ return {"error": "⚠️该音频未识别到文字"}
sentence_start_ms = glom(result, "Response.Data.ResultDetail.*.StartMs")
words = glom(result, "Response.Data.ResultDetail.*.Words")
return generate_tencent_transcription(sentence_start_ms, words)
- return "❌" + glom(result, "Response.Data.ErrorMsg", default="语音识别失败")
+ return {"error": "❌" + glom(result, "Response.Data.ErrorMsg", default="语音识别失败")}
-def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> str:
+def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> dict:
res = ""
- for start_offset, items in zip(sentence_start_ms, words, strict=True):
- for idx, item in enumerate(items):
- sentence = glom(item, Coalesce("Word", "word"), default="")
- if not sentence:
- continue
- if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
- start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
- minutes = int(start_seconds // 60)
- seconds = int(start_seconds % 60)
- res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
- else:
- res += sentence
- return res.strip()
+ try:
+ for start_offset, items in zip(sentence_start_ms, words, strict=True):
+ for idx, item in enumerate(items):
+ sentence = glom(item, Coalesce("Word", "word"), default="")
+ if not sentence:
+ continue
+ if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
+ start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
+ minutes = int(start_seconds // 60)
+ seconds = int(start_seconds % 60)
+ res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
+ else:
+ res += sentence
+ except Exception as e:
+ logger.error(e)
+ return {"error": str(e)}
+ return {"texts": res.strip()}
src/asr/voice_recognition.py
@@ -175,7 +175,6 @@ async def asr_file(
**kwargs,
) -> dict:
"""Get ASR results of an audio file."""
- res = {}
path = Path(path).expanduser().resolve()
if not path.is_file():
return {"error": f"{path} is not exist"}
@@ -205,19 +204,23 @@ async def asr_file(
logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
try:
+ res = {}
if asr_method == "tencent_single_asr":
- res["texts"] = await tencent_single_asr(path, language, voice_format)
+ res = await tencent_single_asr(path, language, voice_format)
elif asr_method == "tencent_flash_asr":
- res["texts"] = await tencent_flash_asr(path, language, voice_format)
+ res = await tencent_flash_asr(path, language, voice_format)
elif asr_method == "tencent_async_asr":
- res["texts"] = await tencent_async_asr(path, language)
+ res = await tencent_async_asr(path, language)
elif asr_method == "ali":
- res["texts"] = await ali_asr(path)
+ res = await ali_asr(path)
elif asr_method == "deepgram":
- res["texts"] = await deepgram_asr(path)
+ res = await deepgram_asr(path)
elif asr_method == "gemini":
- res |= await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
- logger.success(f"{res['texts']!r}")
+ res = await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
+ else:
+ return {"error": "ASR method not supported"}
+ if res.get("texts"):
+ logger.success(f"{res['texts']!r}")
except Exception as e:
error = f"Failed to recognize audio: {e}"
logger.error(error)
src/preview/ytdlp.py
@@ -263,6 +263,8 @@ async def preview_ytdlp(
ytdlp_transcription_engine = "gemini" if "youtube" in info["extractor"] else ytdlp_transcription_engine # use gemini to bypass censorship
res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, silent=True)
subtitles = res.get("texts", "")
+ if count_subtitles(subtitles) < 20:
+ subtitles = "" # ignore too short transcription
if subtitles:
if len(subtitles) > TEXT_LENGTH or transcription_force_file:
caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}"