Commit 65293e8
Changed files (12)
src
src/asr/ali.py
@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
import asyncio
import io
-import random
import re
from pathlib import Path
@@ -12,11 +11,12 @@ from glom import flatten, glom
from httpx import AsyncHTTPTransport
from loguru import logger
-from asr.utils import downsampe_audio
+from asr.utils import convert_single_channel, downsampe_audio
from config import ASR, DB, FILE_SERVER
from database.alist import delete_alist, upload_alist
from database.uguu import upload_uguu
from networking import hx_req
+from utils import strings_list
async def ali_asr(path: str | Path) -> dict:
@@ -30,42 +30,47 @@ async def ali_asr(path: str | Path) -> dict:
SenseVoice:
https://help.aliyun.com/zh/model-studio/developer-reference/sensevoice-recorded-speech-recognition-restful-applicant
"""
- api_keys = [x.strip() for x in ASR.ALI_API_KEY.split(",") if x.strip()]
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return {"texts": "", "error": "File not found."}
+ supported_ext = [".aac", ".amr", ".avi", ".flac", ".flv", ".m4a", ".mkv", ".mov", ".mp3", ".mp4", ".mpeg", ".oga", ".ogg", ".opus", ".wav", ".webm", ".wma", ".wmv"]
+ audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="opus", codec="libopus")
+ audio_path = await convert_single_channel(audio_path)
+ api_keys = strings_list(ASR.ALI_API_KEY, shuffle=True)
if not api_keys:
return {"error": "请配置阿里云语音识别的API Key"}
- models = [x.strip() for x in ASR.ALI_MODEL.split(",") if x.strip()]
- model = random.choice(models)
- logger.debug(f"阿里云ASR {path} via model: {model}")
- api_key = random.choice(api_keys)
- if model.startswith("paraformer-realtime-"):
- return await ali_realtime_asr(model, path, api_key)
-
- headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable"}
- path = Path(path).expanduser().resolve()
- if ASR.ALI_FS_ENGINE.lower() == "local":
- url = FILE_SERVER.removesuffix("/") + "/" + path.name
- elif ASR.ALI_FS_ENGINE.lower() == "uguu":
- if path.stat().st_size > 100 * 1024 * 1024: # 100 MB
- path = await downsampe_audio(path)
- url = await upload_uguu(path) # max 100 MB for Uguu
- elif ASR.ALI_FS_ENGINE.lower() == "alist":
- url = await upload_alist(path)
- else:
- return {"error": f"Unsupported file server engine: {ASR.ALI_FS_ENGINE}"}
-
- payload = {"model": model, "input": {"file_urls": [url]}}
- res = await hx_req(
- "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
- method="POST",
- headers=headers,
- json_data=payload,
- timeout=600,
- check_keys=["output.task_id"],
- )
- if res.get("hx_error"):
- return {"error": res["hx_error"]}
- logger.success(f"ASR任务提交成功, TaskID: {res['output']['task_id']}")
- return await query_ali_asr(task_id=res["output"]["task_id"], api_key=api_key)
+ for api_key in api_keys:
+ for model in strings_list(ASR.ALI_MODEL, shuffle=True):
+ logger.debug(f"阿里云ASR {audio_path} via model: {model}")
+ if model.startswith("paraformer-realtime-"):
+ return await ali_realtime_asr(model, audio_path, api_key)
+
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable"}
+ if ASR.ALI_FS_ENGINE.lower() == "local":
+ url = FILE_SERVER.removesuffix("/") + "/" + path.name
+ elif ASR.ALI_FS_ENGINE.lower() == "uguu":
+ if audio_path.stat().st_size > 100 * 1024 * 1024: # 100 MB
+ audio_path = await downsampe_audio(audio_path)
+ url = await upload_uguu(audio_path) # max 100 MB for Uguu
+ elif ASR.ALI_FS_ENGINE.lower() == "alist":
+ url = await upload_alist(audio_path)
+ else:
+ return {"error": f"Unsupported file server engine: {ASR.ALI_FS_ENGINE}"}
+
+ payload = {"model": model, "input": {"file_urls": [url]}}
+ res = await hx_req(
+ "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
+ method="POST",
+ headers=headers,
+ json_data=payload,
+ timeout=600,
+ check_keys=["output.task_id"],
+ )
+ if res.get("hx_error"):
+ return {"error": res["hx_error"]}
+ logger.success(f"ASR任务提交成功, TaskID: {res['output']['task_id']}")
+ return await query_ali_asr(task_id=res["output"]["task_id"], api_key=api_key)
+ return {}
async def query_ali_asr(task_id: str, api_key: str, query_times: int = 0) -> dict:
src/asr/cloudflare.py
@@ -1,21 +1,29 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import asyncio
import base64
+from collections.abc import Coroutine
+from decimal import Decimal
from pathlib import Path
+from typing import Any
import anyio
+import soundfile as sf
from glom import glom
from loguru import logger
-from config import ASR
+from asr.groq import merge_transcripts
+from asr.utils import convert_single_channel, downsampe_audio
+from config import ASR, DOWNLOAD_DIR
from networking import hx_req
-from utils import seconds_to_time, strings_list
+from utils import rand_string, seconds_to_time, strings_list, zhcn
async def cloudflare_asr(
path: str | Path,
- model: str = "",
- prompt: str = "",
+ duration: float,
+ model: str | None = "",
+ prompt: str | None = "",
) -> dict:
"""Cloudflare ASR.
@@ -27,8 +35,32 @@ async def cloudflare_asr(
Returns:
{"texts": str, "error": str}
"""
- path = Path(path)
- res = {}
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return {"texts": "", "error": "File not found."}
+ supported_ext = [".mp3", ".opus", ".ogg", ".oga", ".wav", ".flac", ".aac"]
+ audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="opus", codec="libopus")
+ audio_path = await convert_single_channel(audio_path)
+ # max allowed file size is 25MB
+ if audio_path.stat().st_size < ASR.CLOUDFLARE_MAX_BYTES:
+ return await cloudflare_single_file(audio_path, model=model, prompt=prompt)
+ return await cloudflare_file_chunks(audio_path, duration, model=model, prompt=prompt)
+
+
+async def cloudflare_single_file(
+ path: Path,
+ model: str | None = "",
+ prompt: str | None = "",
+ *,
+ offset_seconds: int = 0,
+ delete_local_file: bool = False,
+) -> dict:
+ """Transcribe a single audio chunk with Groq API.
+
+ Returns:
+ {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
+ """
+ resp = {"texts": "", "raw_texts": "", "segments": []}
if not ASR.CLOUDFLARE_KEYS:
return {"error": "未配置Cloudflare相关API"}
if not model:
@@ -44,20 +76,92 @@ async def cloudflare_asr(
payload = {"audio": audio_base64, "task": "transcribe", "vad_filter": True}
if prompt:
payload["initial_prompt"] = prompt
- resp = await hx_req(url, "POST", headers=headers, json_data=payload, check_kv={"success": True}, check_keys=["result"])
- if texts := generate_transcription(glom(resp, "result.segments", default=[])):
- return {"texts": texts}
+ resp = await hx_req(
+ url,
+ "POST",
+ headers=headers,
+ json_data=payload,
+ proxy=ASR.CLOUDFLARE_PROXY,
+ check_kv={"success": True},
+ check_keys=["result"],
+ )
+ offset = Decimal(offset_seconds).quantize(Decimal(".01"))
+ resp["segments"] = [
+ {
+ "start": offset + Decimal(str(seg["start"])),
+ "end": offset + Decimal(str(seg["end"])),
+ "text": zhcn(seg.get("text", "")),
+ }
+ for seg in glom(resp, "result.segments", default=[])
+ ]
+ resp["raw_texts"] = " ".join(x["text"] for x in resp["segments"])
+ resp["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in resp["segments"]) # with timestamp
+ if resp.get("hx_error"):
+ resp["error"] = resp.pop("hx_error")
+ if delete_local_file:
+ path.unlink(missing_ok=True)
except Exception as e:
logger.error(e)
- return res
-
-
-def generate_transcription(items: list[dict]) -> str:
- res = ""
- for item in items:
- sentence: str = item["text"]
- if not sentence:
- continue
- start = seconds_to_time(item["start"])
- res += f"\n[{start}] {sentence}"
- return res.strip()
+ return resp
+ return resp
+
+
+async def cloudflare_file_chunks(
+ path: Path,
+ duration: float,
+ model: str | None = "",
+ prompt: str | None = "",
+ chunk_seconds: float = ASR.CLOUDFLARE_CHUNK_SECONDS,
+ overlap_seconds: float = ASR.CLOUDFLARE_OVERLAP_SECONDS,
+) -> dict:
+ """Transcribe audio in chunks with overlap.
+
+ Most of this code is copied from `gemini_file_chunks` in `asr/gemini.py`
+
+ Args:
+ chunk_seconds: Length of each chunk in seconds
+ overlap_seconds: Overlap between chunks in seconds
+
+ Returns:
+ dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
+ """
+ # only support opus file
+ ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
+ with sf.SoundFile(ogg_path, "r") as f:
+ sr = f.samplerate
+ audio = f.read(dtype="float32")
+ logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
+
+ transcription = {}
+ semaphore = asyncio.Semaphore(30) # max concurrent requests
+
+ async def run_with_semaphore(task: Coroutine[Any, Any, dict]) -> dict:
+ async with semaphore:
+ return await task
+
+ try:
+ # Calculate # of chunks
+ total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
+ total_chunks = int(total_chunks)
+ chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
+ tasks = []
+ # Loop through each chunk, extract current chunk from audio, transcribe
+ for i in range(total_chunks):
+ start = int(i * (chunk_seconds - overlap_seconds) * sr)
+ end = int(min(start + chunk_seconds * sr, duration * sr))
+ logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
+ chunk = audio[start:end]
+ if chunk.shape[0] == 0: # empty chunk
+ continue
+ chunk_path = chunk_paths[i]
+ await asyncio.to_thread(sf.write, chunk_path.as_posix(), chunk, sr, format="ogg", subtype="OPUS")
+ task = cloudflare_single_file(chunk_path, model, prompt, offset_seconds=int(start / sr), delete_local_file=False)
+ tasks.append(run_with_semaphore(task))
+ results = await asyncio.gather(*tasks)
+ results = [r for r in results if r.get("segments")]
+ transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
+ [p.unlink(missing_ok=True) for p in chunk_paths]
+ except Exception as e:
+ logger.error(e)
+ return {"error": str(e)}
+ return transcription
src/asr/deepgram.py
@@ -7,9 +7,10 @@ import anyio
from glom import flatten, glom
from loguru import logger
+from asr.utils import convert_single_channel, downsampe_audio
from config import ASR
from networking import hx_req
-from utils import zhcn
+from utils import strings_list, zhcn
async def deepgram_asr(path: str | Path) -> dict:
@@ -17,12 +18,16 @@ async def deepgram_asr(path: str | Path) -> dict:
https://developers.deepgram.com/docs/pre-recorded-audio
"""
- api_keys = [x.strip() for x in ASR.DEEPGRAM_API.split(",") if x.strip()]
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return {"texts": "", "error": "File not found."}
+ supported_ext = [".mp3", ".aac", ".flac", ".m4a", ".mp2", ".mp4", ".ogg", ".opus", ".oga", ".pcm", ".wav", ".webm"]
+ audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="opus", codec="libopus")
+ audio_path = await convert_single_channel(audio_path)
+ api_keys = strings_list(ASR.DEEPGRAM_API, shuffle=True)
if not api_keys:
return {"error": "请配置DeepGram语音识别的API Key"}
- logger.debug(f"DeepGram ASR {path}")
headers = {"Authorization": f"Token {random.choice(api_keys)}"}
- path = Path(path).expanduser().resolve()
url = "https://api.deepgram.com/v1/listen"
params = {"model": "nova-3-general", "detect_language": True, "punctuate": True, "smart_format": True}
async with await anyio.open_file(path, "rb") as f:
src/asr/gemini.py
@@ -44,7 +44,7 @@ async def gemini_asr(
path = Path(path).expanduser().resolve()
if not path.is_file():
return {"texts": "", "error": "File not found."}
- audio_path = path if path.suffix.lower() in GEMINI_AUDIO_EXT else await downsampe_audio(path, ext="ogg", codec="libopus")
+ audio_path = path if path.suffix.lower() in GEMINI_AUDIO_EXT else await downsampe_audio(path, ext="opus", codec="libopus")
audio_path = await convert_single_channel(audio_path)
duration = audio_duration(audio_path)
if duration < ASR.GEMINI_CHUNK_SECONDS:
src/asr/groq.py
@@ -12,10 +12,10 @@ import soundfile as sf
from glom import glom
from loguru import logger
-from asr.utils import COMMON_AUDIO_EXT, convert_single_channel, downsampe_audio
+from asr.utils import convert_single_channel, downsampe_audio
from config import ASR
from networking import hx_req
-from utils import seconds_to_time, strings_list
+from utils import seconds_to_time, strings_list, zhcn
async def groq_asr(path: str | Path, model: str = "", prompt: str = "", temperature: float = 0, language: str = "") -> dict:
@@ -29,7 +29,8 @@ async def groq_asr(path: str | Path, model: str = "", prompt: str = "", temperat
path = Path(path).expanduser().resolve()
if not path.is_file():
return {"texts": "", "error": "File not found."}
- audio_path = path if path.suffix.lower() in COMMON_AUDIO_EXT else await downsampe_audio(path, ext="ogg", codec="libopus")
+ supported_ext = [".aac", ".flac", ".m4a", ".mp3", ".mpeg", ".mpga", ".oga", ".ogg", ".opus", ".wav", ".webm"]
+ audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="opus", codec="libopus")
audio_path = await convert_single_channel(audio_path)
# max allowed file size is 25MB
if audio_path.stat().st_size < ASR.GROQ_MAX_BYTES:
@@ -85,7 +86,7 @@ async def groq_single_file(
{
"start": start + Decimal(str(seg["start"])),
"end": start + Decimal(str(seg["end"])),
- "text": seg["text"],
+ "text": zhcn(seg["text"]),
}
for seg in resp.get("segments", [])
]
@@ -278,12 +279,12 @@ async def groq_file_chunks(
path = Path(path).expanduser().resolve()
if not path.is_file():
return {"texts": "", "error": "File not found."}
- if path.suffix.lower() not in [".opus", ".ogg"]:
- path = await downsampe_audio(path, ext="ogg", codec="libopus")
+ if path.suffix.lower() not in [".opus", ".ogg", ".oga"]:
+ path = await downsampe_audio(path, ext="opus", codec="libopus")
try:
with sf.SoundFile(path, "r") as f:
if f.channels != 1:
- path = await downsampe_audio(path, ext="ogg", codec="libopus")
+ path = await downsampe_audio(path, ext="opus", codec="libopus")
return await groq_file_chunks(path, chunk_seconds, overlap_seconds, model, temperature, prompt, language)
sr = f.samplerate
audio = f.read(dtype="float32")
src/asr/tecent.py
@@ -4,18 +4,23 @@ import asyncio
import base64
import hashlib
import hmac
+from collections.abc import Coroutine
+from decimal import Decimal
from pathlib import Path
+from typing import Any
import anyio
+import soundfile as sf
from glom import Coalesce, flatten, glom
from loguru import logger
-from asr.utils import downsampe_audio, is_english_word
-from config import ASR, FILE_SERVER
+from asr.groq import merge_transcripts
+from asr.utils import audio_duration, convert_single_channel, downsampe_audio, is_english_word
+from config import ASR, DOWNLOAD_DIR, FILE_SERVER
from database.alist import delete_alist, upload_alist
from database.uguu import upload_uguu
from networking import hx_req
-from utils import nowdt
+from utils import nowdt, rand_string, seconds_to_time
def sign(key, msg):
@@ -71,16 +76,58 @@ def generate_tencent_cloud_headers(
}
-async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -> dict:
+async def tencent_asr(path: str | Path, language: str, duration: float) -> dict:
+ """Tencent ASR.
+
+ 由于 `录音文件识别` 和 `录音文件识别极速版`免费额度太少
+ 所以现在我们只使用 `一句话识别` 来处理所有ASR请求
+
+ Returns:
+ {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
+ """
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return {"texts": "", "error": "File not found."}
+ supported_ext = [".wav", ".pcm", ".ogg", ".opus", ".oga", ".speex", ".silk", ".mp3", ".m4a", ".aac", ".amr"]
+ audio_path = path if path.suffix.lower() in supported_ext else await downsampe_audio(path, ext="opus", codec="libopus")
+ audio_path = await convert_single_channel(audio_path)
+ if duration < 1: # some thing error in detecting duration
+ audio_path = await downsampe_audio(path, ext="opus", codec="libopus")
+ duration = audio_duration(audio_path)
+
+ # max allowed duration is 60s
+ if duration < 60:
+ return await tencent_single_asr(audio_path, language=language)
+ return await tencent_file_chunks(audio_path, language=language, duration=duration)
+
+
+async def tencent_single_asr(path: Path, language: str, *, offset_seconds: int = 0, delete_local_file: bool = False) -> dict:
"""Tencent Single Sentence ASR.
- 一句话识别
+ 一句话识别 (每月免费额度: 5000次)
https://cloud.tencent.com/document/product/1093/35646
+
+ Returns:
+ {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
+
+ example item of segments:
+ {
+ "start": Decimal,
+ "end": Decimal,
+ "text": str,
+ }
"""
- async with await anyio.open_file(path, "rb") as f:
+ final = {"texts": "", "raw_texts": "", "segments": []}
+ file_bytes = path.stat().st_size
+ # max 3 MB
+ audio_path = path if file_bytes < 3 * 1024 * 1024 else await downsampe_audio(path, ext="opus", codec="libopus")
+ voice_format = Path(audio_path).suffix.lower().lstrip(".")
+ if voice_format in ["ogg", "opus", "oga"]: # tencnet only supports ogg-opus
+ voice_format = "ogg-opus"
+ async with await anyio.open_file(audio_path, "rb") as f:
content = await f.read()
data = base64.b64encode(content).decode("utf-8")
- payload = f'{{"EngSerViceType":"{engine}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
+ payload = f'{{"EngSerViceType":"{language}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
res = await hx_req(
"https://asr.tencentcloudapi.com",
@@ -91,18 +138,109 @@ async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -
proxy=ASR.TENCENT_PROXY,
check_keys=["Response.WordList"],
)
- if res["Response"]["WordList"] is None:
- return {"error": "⚠️该音频未识别到文字"}
if res.get("hx_error"):
- return {"error": res["hx_error"]}
+ return final | {"error": res["hx_error"]}
+ words = glom(res, "Response.WordList", default=None)
+ if words is None:
+ return final | {"error": "⚠️该音频未识别到文字"}
+ final["raw_texts"] = glom(res, "Response.Result", default="") or ""
+
+ sentences = [] # list of sentence
+ sentence = [] # list of dict
+ for item in words:
+ word = item.get("Word", "")
+ if is_english_word(word) or word.endswith((",", ".", "?", "!")):
+ item["Word"] = word + " "
+ if word.endswith((".", "。", "?", "?", "!", "!")): # noqa: RUF001
+ sentence.append(item)
+ sentences.append(sentence)
+ sentence = []
+ continue
+ sentence.append(item)
+ if sentence:
+ sentences.append(sentence)
+
+ segments = []
+ offset = Decimal(offset_seconds).quantize(Decimal(".01"))
+ for sentence in sentences:
+ start = offset + Decimal(sentence[0].get("StartTime", 0)) / 1000
+ end = offset + Decimal(sentence[-1].get("EndTime", 0)) / 1000
+ text = "".join(x.get("Word", "") for x in sentence)
+ text = text.replace(" ,", ",").replace(" ,", ",") # noqa: RUF001
+ segments.append({"start": start, "end": end, "text": text})
+ final["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in segments) # with timestamp
+ final["segments"] = segments
+ if delete_local_file:
+ path.unlink(missing_ok=True)
+ audio_path.unlink(missing_ok=True)
+ return final
+
- return generate_tencent_transcription(sentence_start_ms=[0], words=[res["Response"]["WordList"]])
+async def tencent_file_chunks(
+ path: Path,
+ language: str,
+ duration: float,
+ chunk_seconds: float = 60,
+ overlap_seconds: float = 5,
+) -> dict:
+ """Transcribe audio in chunks with overlap.
+
+ Most of this code is copied from `gemini_file_chunks` in `asr/gemini.py`
+
+ Args:
+ chunk_seconds: Length of each chunk in seconds
+ overlap_seconds: Overlap between chunks in seconds
+
+ Returns:
+ dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
+ """
+ # only support opus file
+ ogg_path = path if path.suffix in [".oga", ".ogg", ".opus"] else await downsampe_audio(path, ext="opus", codec="libopus")
+ with sf.SoundFile(ogg_path, "r") as f:
+ sr = f.samplerate
+ audio = f.read(dtype="float32")
+ logger.trace(f"音频时长: {duration:.2f}s, 采样率: {sr} Hz")
+
+ transcription = {}
+ semaphore = asyncio.Semaphore(30) # max concurrent requests
+
+ async def run_with_semaphore(task: Coroutine[Any, Any, dict]) -> dict:
+ async with semaphore:
+ return await task
+
+ try:
+ # Calculate # of chunks
+ total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
+ total_chunks = int(total_chunks)
+ chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
+ tasks = []
+ # Loop through each chunk, extract current chunk from audio, transcribe
+ for i in range(total_chunks):
+ start = int(i * (chunk_seconds - overlap_seconds) * sr)
+ end = int(min(start + chunk_seconds * sr, duration * sr))
+ logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
+ chunk = audio[start:end]
+ if chunk.shape[0] == 0: # empty chunk
+ continue
+ chunk_path = chunk_paths[i]
+ await asyncio.to_thread(sf.write, chunk_path.as_posix(), chunk, sr, format="ogg", subtype="OPUS")
+ task = tencent_single_asr(chunk_path, language=language, offset_seconds=int(start / sr), delete_local_file=False)
+ tasks.append(run_with_semaphore(task))
+ results = await asyncio.gather(*tasks)
+ results = [r for r in results if r.get("segments")]
+ transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
+ [p.unlink(missing_ok=True) for p in chunk_paths]
+ except Exception as e:
+ logger.error(e)
+ return {"error": str(e)}
+ return transcription
async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) -> dict:
- """Tencent Flash ASR.
+ """(Deprecated) Tencent Flash ASR.
- 录音文件识别极速版
+ 已弃用, 请使用 `tencent_single_asr`
+ 录音文件识别极速版 (每月免费额度: 5小时)
https://cloud.tencent.com/document/product/1093/52097
"""
now = nowdt()
@@ -143,9 +281,11 @@ async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) ->
async def tencent_async_asr(path: str | Path, engine: str) -> dict:
- """Create Tencent ASR Task.
+ """(Deprecated) Create Tencent ASR Task.
- 录音文件识别请求
+ 已弃用, 请使用 `tencent_single_asr`
+ 注意: 此接口不支持中文文件名
+ 录音文件识别请求 (每月免费额度: 10小时)
https://cloud.tencent.com/document/api/1093/37823
"""
path = Path(path).expanduser().resolve()
src/asr/utils.py
@@ -12,79 +12,61 @@ from soundfile import LibsndfileError
from config import ASR, GEMINI
from multimedia import convert_to_audio
+from utils import strings_list
-ALI_AUDIO_EXT = [".aac", ".amr", ".avi", ".flac", ".flv", ".m4a", ".mkv", ".mov", ".mp3", ".mp4", ".mpeg", ".ogg-opus", ".ogg", ".opus", ".wav", ".webm", ".wma", ".wmv"]
GEMINI_AUDIO_EXT = [".aac", ".aiff", ".flac", ".mp3", ".oga", ".ogg", ".opus", ".wav"]
-DEEPGRAM_AUDIO_EXT = [".mp3", ".aac", ".flac", ".m4a", ".mp2", ".mp4", ".ogg", ".opus", ".ogg-opus", ".pcm", ".wav", ".webm"]
-TENCENT_AUDIO_EXT = [".aac", ".amr", ".m4a", ".mp3", ".oga", ".ogg-opus", ".ogg", ".opus", ".pcm", ".silk", ".speex", ".wav"]
-TENCENT_ASYNC_AUDIO_EXT = [".3gp", ".aac", ".amr", ".flac", ".flv", ".m4a", ".mp3", ".mp4", ".oga", ".ogg-opus", ".ogg", ".opus", ".wav", ".wma"]
-COMMON_AUDIO_EXT = [".mp3", ".opus", ".ogg", ".wav", ".flac", ".aac"]
-
-
-def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
- """Get ASR method and supported file types."""
- if duration < ASR.SHORT_DURATION:
- asr_engine = random.choice([x.strip() for x in ASR.SHORT_ENGINE.split(",") if x.strip()])
- elif ASR.SHORT_DURATION <= duration <= ASR.MIDDLE_DURATION:
- asr_engine = random.choice([x.strip() for x in ASR.MIDDLE_ENGINE.split(",") if x.strip()])
- else:
- asr_engine = random.choice([x.strip() for x in ASR.LONG_ENGINE.split(",") if x.strip()])
-
- # respect force_engine
- if force_engine == "ali":
- return get_ali_asr_method()
- if force_engine == "deepgram":
- return "deepgram", [x.lstrip(".") for x in DEEPGRAM_AUDIO_EXT]
- if force_engine == "tencent":
- return get_tencent_asr_method(duration, file_size)
- if force_engine == "gemini":
- return get_gemini_asr_method(duration)
- if force_engine in ["cloudflare", "groq"]:
- return force_engine.lower(), [x.lstrip(".") for x in COMMON_AUDIO_EXT]
-
- if asr_engine == "ali":
- return get_ali_asr_method()
- if asr_engine == "deepgram":
- return "deepgram", [x.lstrip(".") for x in DEEPGRAM_AUDIO_EXT]
- if asr_engine == "tencent":
- return get_tencent_asr_method(duration, file_size)
- if asr_engine.lower() == "gemini":
- return get_gemini_asr_method(duration)
- if asr_engine.lower() in ["cloudflare", "groq"]:
- return asr_engine.lower(), [x.lstrip(".") for x in COMMON_AUDIO_EXT]
- return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
-
-
-def get_ali_asr_method() -> tuple[str, list[str]]:
- if not all([ASR.ALI_MODEL, ASR.ALI_API_KEY]):
- return "请设置阿里云ASR相关环境变量", []
- supported_ext = [x.lstrip(".") for x in ALI_AUDIO_EXT]
- return "ali", supported_ext
-
-
-def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
- if not all([ASR.TENCENT_APPID, ASR.TENCENT_SECRET_ID, ASR.TENCENT_SECRET_KEY]):
- return "请设置Tencent ASR相关环境变量", []
-
- asr_method = ""
- if duration < 60 and file_size < 3 * 1024 * 1024:
- asr_method = "tencent_single_asr" # 一句话识别
- supported_ext = [x.lstrip(".") for x in TENCENT_AUDIO_EXT]
- elif 60 <= duration <= 300 and file_size < 100 * 1024 * 1024:
- asr_method = "tencent_flash_asr" # 录音文件识别极速版
- supported_ext = [x.lstrip(".") for x in TENCENT_AUDIO_EXT]
- else:
- asr_method = "tencent_async_asr" # 录音文件识别 (异步请求)
- supported_ext = [x.lstrip(".") for x in TENCENT_ASYNC_AUDIO_EXT]
- return asr_method, supported_ext
-
-
-def get_gemini_asr_method(duration: float | None = None) -> tuple[str, list[str]]:
- if duration is not None and duration > GEMINI.ASR_MAX_DURATION:
- return f"无法识别时长超过{GEMINI.ASR_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒", []
- if not GEMINI.API_KEY:
- return "请设置`GEMINI_API_KEY`环境变量", []
- return "gemini", [x.lstrip(".") for x in GEMINI_AUDIO_EXT]
+
+
+def auto_choose_asr_engine(duration: float, engine: str) -> str:
+ """Get ASR engine based on duration or category."""
+ all_engines = ["ali", "tencent", "cloudflare", "groq", "gemini", "deepgram"]
+ categries = {
+ "whisper": ["cloudflare", "groq"],
+ "china": ["ali", "tencent"],
+ "uncensored": ["cloudflare", "groq", "gemini"],
+ }
+
+ def get_enabled_engines() -> list[str]:
+ enabled_engines = []
+ if all([ASR.ALI_API_KEY, ASR.ALI_MODEL, ASR.ALI_FS_ENGINE]):
+ enabled_engines.append("ali")
+ if all([ASR.TENCENT_APPID, ASR.TENCENT_SECRET_ID, ASR.TENCENT_SECRET_KEY, ASR.TENCENT_FS_ENGINE]):
+ enabled_engines.append("tencent")
+ if all([ASR.CLOUDFLARE_MODEL, ASR.CLOUDFLARE_KEYS, ASR.CLOUDFLARE_MAX_BYTES, ASR.CLOUDFLARE_CHUNK_SECONDS]):
+ enabled_engines.append("cloudflare")
+ if all([GEMINI.ASR_MODEL, GEMINI.API_KEY, GEMINI.BASE_URL, ASR.GEMINI_CHUNK_SECONDS]):
+ enabled_engines.append("gemini")
+ if all([ASR.GROQ_MODELS, ASR.GROQ_KEYS, ASR.GROQ_MAX_BYTES, ASR.GROQ_CHUNK_SECONDS]):
+ enabled_engines.append("groq")
+ if ASR.DEEPGRAM_API:
+ enabled_engines.append("deepgram")
+ return enabled_engines
+
+ def parse_engines(eng: str) -> list[str]:
+ res = []
+ for x in strings_list(eng.lower()):
+ if x in all_engines:
+ res.append(x)
+ elif x in categries:
+ res.extend(categries[x])
+ enabled_engines = get_enabled_engines()
+ return [x for x in res if x in enabled_engines]
+
+ fallback_engine = "gemini" # fallback if no match
+ if not engine:
+ return fallback_engine
+
+ if engine.lower() == "auto":
+ if duration < ASR.SHORT_DURATION:
+ engines = parse_engines(ASR.SHORT_ENGINE)
+ elif ASR.SHORT_DURATION <= duration <= ASR.MIDDLE_DURATION:
+ engines = parse_engines(ASR.MIDDLE_ENGINE)
+ else:
+ engines = parse_engines(ASR.LONG_ENGINE)
+ return random.choice(engines) if engines else fallback_engine
+
+ engines = parse_engines(engine)
+ return random.choice(engines) if engines else fallback_engine
async def downsampe_audio(path: str | Path, ext: str = "opus", codec: str = "libopus", sample_rate: int = 16000, channel: int = 1, **kwargs) -> Path:
@@ -98,14 +80,22 @@ def is_english_word(text: str) -> bool:
return bool(re.match(r"^[a-zA-Z]+$", text))
+async def get_audio_channel(path: str | Path) -> int:
+ with contextlib.suppress(Exception), sf.SoundFile(path, "r") as f:
+ return f.channels
+ with contextlib.suppress(Exception):
+ ffprobe = FFmpeg(executable="ffprobe").input(Path(path).as_posix(), print_format="json", show_streams=None)
+ metadata = json.loads(ffprobe.execute())
+ streams = metadata.get("streams", [])
+ return len(streams)
+ return -1
+
+
async def convert_single_channel(path: str | Path) -> Path:
path = Path(path).expanduser().resolve()
- try:
- with sf.SoundFile(path, "r") as f:
- if f.channels != 1:
- return await downsampe_audio(path, ext="ogg", codec="libopus", channel=1)
- except LibsndfileError:
- return await downsampe_audio(path, ext="ogg", codec="libopus", channel=1)
+ num_channel = await get_audio_channel(path)
+ if num_channel != 1:
+ return await downsampe_audio(path, ext="opus", codec="libopus", channel=1)
return path
src/asr/voice_recognition.py
@@ -1,9 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import io
-import re
from pathlib import Path
+from glom import glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.enums import ParseMode
@@ -14,18 +14,15 @@ from asr.cloudflare import cloudflare_asr
from asr.deepgram import deepgram_asr
from asr.gemini import gemini_asr
from asr.groq import groq_asr
-from asr.tecent import tencent_async_asr, tencent_flash_asr, tencent_single_asr
-from asr.utils import get_asr_method
-from config import CAPTION_LENGTH, PREFIX, TEXT_LENGTH
+from asr.tecent import tencent_asr
+from asr.utils import audio_duration, auto_choose_asr_engine
+from config import ASR, CAPTION_LENGTH, PREFIX, TEXT_LENGTH
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import blockquote, count_without_entities, delete_message, equal_prefix, get_reply_to, startswith_prefix
-from multimedia import convert_to_audio, parse_media_info
from publish import publish_telegraph
-from utils import rand_string, to_int
-
-# ruff: noqa: RUF001
+from utils import readable_time, to_int
# https://cloud.tencent.com/document/product/1093/52097
HELP = f"""🗣**语音转文字**
@@ -53,7 +50,7 @@ fr: 法语
de: 德语
"""
-ENGINE_MAP = {
+LANG_MAP = {
"16k_zh-PY": "中英粤",
"16k_fy": "多种方言, 上海话、四川话、武汉话、贵阳话、昆明话、西安话、郑州话、太原话、兰州话、银川话、西宁话、南京话、合肥话、南昌话、长沙话、苏州话、杭州话、济南话、天津话、石家庄话、黑龙江话、吉林话、辽宁话",
"16k_ja": "日语",
@@ -73,14 +70,36 @@ ENGINE_MAP = {
}
+def get_msg_to_asr(message: Message, *, asr_need_prefix: bool = True) -> Message | None:
+ """Get the message to be recognized by ASR.
+
+ By default, "/asr" prefix is needed to trigger ASR function.
+ """
+ # skip no "/asr" prefix message if asr_need_prefix
+ if asr_need_prefix and not startswith_prefix(message.content, prefix=PREFIX.ASR):
+ return None
+ # no need prefix or has "/asr" prefix
+
+ mtype = glom(message, "media.value", default="text") or "text"
+ # has "/asr" prefix
+ if startswith_prefix(message.content, prefix=PREFIX.ASR):
+ if mtype in ["voice", "audio", "video"]:
+ return message
+ if reply_msg := message.reply_to_message:
+ reply_mtype = glom(reply_msg, "media.value", default="text") or "text"
+ if reply_mtype in ["voice", "audio", "video"]:
+ return reply_msg
+ elif mtype == "voice": # no need "/asr" prefix
+ return message
+ return None
+
+
async def voice_to_text(
client: Client,
message: Message,
+ asr_engine: str = ASR.DEFAULT_ENGINE,
*,
- asr_need_prefix: bool | None = None,
- asr_skip_voice: bool | None = None,
- asr_skip_audio: bool | None = None,
- asr_skip_video: bool | None = None,
+ asr_need_prefix: bool = True,
to_telegraph: bool = True,
**kwargs,
) -> None:
@@ -92,72 +111,63 @@ async def voice_to_text(
Args:
client (Client): The Pyrogram client.
message (Message): The trigger message object.
- asr_need_prefix (bool, optional): If True, must use "/asr" prefix to reply a audio message.
- asr_skip_voice (bool, optional): If True, skip voice message.
- asr_skip_audio (bool, optional): If True, skip audio message.
- asr_skip_video (bool, optional): If True, skip video message.
+ asr_need_prefix (bool, optional): If True, must prepend "/asr" prefix to call ASR function.
+ to_telegraph (bool, optional): If True, publish the result to Telegraph.
+
"""
# send docs if message == "/asr", without reply
- if equal_prefix(message.text, prefix=[PREFIX.ASR]) and not message.reply_to_message:
+ if equal_prefix(message.text, prefix=PREFIX.ASR) and not message.reply_to_message:
await send2tg(client, message, texts=HELP, **kwargs)
return
- trigger_message = get_trigger_message(
- message,
- asr_need_prefix=asr_need_prefix,
- asr_skip_voice=asr_skip_voice,
- asr_skip_audio=asr_skip_audio,
- asr_skip_video=asr_skip_video,
- )
- if not trigger_message:
+
+ msg_to_asr = get_msg_to_asr(message, asr_need_prefix=asr_need_prefix)
+ if not msg_to_asr:
return
this_info = parse_msg(message, silent=True)
- trigger_info = parse_msg(trigger_message, silent=True)
-
- asr_language = "16k_zh-PY" # default: 中英粤
- force_engine = "" # gemini or tencent
- if matched := re.match(r"/asr\s+([^.。,,/\s]+)", this_info["text"]): # /asr yue
- custom_code = matched.group(1)
- if custom_code == "fy": # re-map dialect
- custom_code = "zh_dialect"
- custom_code = custom_code.replace("fy", "zh_dialect")
- if f"16k_{custom_code}" in ENGINE_MAP:
- asr_language = f"16k_{custom_code}"
- elif custom_code in ["gemini", "tencent", "ali", "deepgram", "cloudflare", "groq"]:
- force_engine = custom_code
-
- msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
+ asr_msg_info = parse_msg(msg_to_asr, silent=True)
+
+ remain_text = this_info["text"].removeprefix(PREFIX.ASR).strip().lower()
+ tencent_language = "16k_zh-PY" # default: 中英粤
+ if remain_text in ["fy", "ja", "ko", "vi", "ms", "id", "fil", "th", "pt", "tr", "ar", "es", "hi", "fr", "de"]:
+ # tencent asr
+ asr_engine = "tencent"
+ tencent_language = f"16k_{remain_text}".replace("fy", "zh_dialect")
+
+ elif remain_text:
+ asr_engine = remain_text
+ msg = f"[ASR] 收到消息: {asr_msg_info['mtype']}, 开始下载..."
logger.info(msg)
if kwargs.get("show_progress"):
- res = await send2tg(client, trigger_message, texts=msg, **kwargs)
+ res = await send2tg(client, msg_to_asr, texts=msg, **kwargs)
kwargs["progress"] = res[0]
- path: str | Path = await trigger_message.download() # type: ignore
+ path: str | Path = await msg_to_asr.download() # type: ignore
path = Path(path).expanduser().resolve()
if not path.is_file():
- msg = "Failed to download audio, please try again later."
+ msg = f"❌下载 {asr_msg_info['mtype']} 文件失败, 无法识别"
logger.error(msg)
await modify_progress(text=msg, force_update=True, **kwargs)
return
- res = await asr_file(path, engine=force_engine, duration=trigger_info["duration"], language=asr_language, client=client, message=trigger_message, **kwargs)
+ res = await asr_file(path, engine=asr_engine, duration=asr_msg_info["duration"], tencent_language=tencent_language, client=client, message=msg_to_asr, **kwargs)
if error := res.get("error"):
- await modify_progress(text=error, force_update=True, **kwargs)
+ await modify_progress(kwargs.get("progress"), text=error, force_update=True)
return
if texts := res.get("texts"):
final = blockquote(texts) if len(texts) > 300 else texts
# send results
target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
- reply_parameters = get_reply_to(trigger_info["mid"], kwargs.get("reply_msg_id", 0))
+ reply_parameters = get_reply_to(asr_msg_info["mid"], kwargs.get("reply_msg_id", 0))
length = await count_without_entities(final)
if length < CAPTION_LENGTH: # short
- await client.copy_message(chat_id=to_int(target_chat), from_chat_id=trigger_info["cid"], message_id=trigger_info["mid"], caption=final, reply_parameters=reply_parameters)
+ await client.copy_message(chat_id=to_int(target_chat), from_chat_id=asr_msg_info["cid"], message_id=asr_msg_info["mid"], caption=final, reply_parameters=reply_parameters)
elif length < TEXT_LENGTH: # middle
await client.send_message(to_int(target_chat), final, reply_parameters=reply_parameters)
else: # long
- caption = trigger_info["html"]
+ caption = asr_msg_info["html"]
if to_telegraph:
html = "\n".join([f"<p>{s}</p>" for s in texts.split("\n")])
- if telegraph_url := await publish_telegraph(title=trigger_info["text"] or "语音识别结果", html=html, author=trigger_info["full_name"], url=trigger_info["message_url"]):
+ if telegraph_url := await publish_telegraph(title=asr_msg_info["text"] or "语音识别结果", html=html, author=asr_msg_info["full_name"], url=asr_msg_info["message_url"]):
caption += f"\n<a href={telegraph_url}>⚡️即时预览</a>"
with io.BytesIO(texts.encode("utf-8")) as f:
await client.send_document(
@@ -178,9 +188,9 @@ async def voice_to_text(
async def asr_file(
path: str | Path,
engine: str = "",
- duration: int = 0,
- language: str = "16k_zh-PY",
+ duration: float = 0,
*,
+ tencent_language: str = "16k_zh-PY",
delete_local_file: bool = True,
delete_gemini_file: bool = True,
**kwargs,
@@ -189,42 +199,21 @@ async def asr_file(
path = Path(path).expanduser().resolve()
if not path.is_file():
return {"error": f"{path} is not exist"}
- info = await parse_media_info(path)
- if duration == 0:
- duration = info["duration"]
- asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
- if asr_method not in ["ali", "deepgram", "cloudflare", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini", "groq"]:
- return {"error": asr_method}
-
- voice_format = path.suffix.lstrip(".")
- if voice_format not in supported_ext:
- if info["audio_codec"].split("/")[-1] in supported_ext and not info["video_codec"]:
- voice_format = info["audio_codec"].split("/")[-1]
- else:
- path = await convert_to_audio(path, ext="opus", codec="libopus")
- voice_format = "opus"
- # match again based on converted file
- asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
+ duration = audio_duration(path)
+ engine = auto_choose_asr_engine(duration=duration, engine=engine)
- ogg_names = ["oga", "ogg-opus", "ogg", "opus"] # unify format name
- if asr_method.startswith("tencent") and voice_format in ogg_names:
- voice_format = "ogg-opus"
- path = path.rename(path.with_stem(rand_string())) # sanitize filename. (for Tencent Signature v3)
-
- logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
+ log = f"{engine.capitalize()} ASR, 时长: {readable_time(duration)}\n{path.name}"
+ logger.debug(log)
+ await modify_progress(message=kwargs.get("progress"), text=log, force_update=True)
res = {}
try:
- if asr_method == "tencent_single_asr":
- res = await tencent_single_asr(path, language, voice_format)
- elif asr_method == "tencent_flash_asr":
- res = await tencent_flash_asr(path, language, voice_format)
- elif asr_method == "tencent_async_asr":
- res = await tencent_async_asr(path, language)
- elif asr_method == "ali":
+ if engine == "tencent":
+ res = await tencent_asr(path, tencent_language, duration)
+ elif engine == "ali":
res = await ali_asr(path)
- elif asr_method == "deepgram":
+ elif engine == "deepgram":
res = await deepgram_asr(path)
- elif asr_method == "gemini":
+ elif engine == "gemini":
res = await gemini_asr(
message=kwargs["message"],
path=path,
@@ -232,9 +221,9 @@ async def asr_file(
prompt=kwargs.get("gemini_asr_prompt", ""),
delete_gemini_file=delete_gemini_file,
)
- elif asr_method == "cloudflare":
- res = await cloudflare_asr(path=path, model=kwargs.get("cf_asr_model", ""), prompt=kwargs.get("cf_asr_prompt", ""))
- elif asr_method == "groq":
+ elif engine == "cloudflare":
+ res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"), prompt=kwargs.get("cf_asr_prompt"))
+ elif engine == "groq":
res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""), prompt=kwargs.get("groq_asr_prompt", ""))
else:
return {"error": "ASR method not supported"}
@@ -250,57 +239,3 @@ async def asr_file(
elif path.is_file():
res["audio_file"] = path
return res
-
-
-def get_trigger_message(
- message: Message,
- *,
- asr_need_prefix: bool | None = None,
- asr_skip_voice: bool | None = None,
- asr_skip_audio: bool | None = None,
- asr_skip_video: bool | None = None,
-) -> Message | None:
- """Check if the message is triggerable for voice recognition.
-
- By default, "/asr" prefix is needed in in Group & Channel & Bot chats to trigger this function.
- In private chat, no need to add "/asr" prefix for voice message, but the video & audio message still need it.
- """
- info = parse_msg(message)
- this_text = info["text"] # this message
- if info["ctype"].lower() in ["group", "supergroup", "channel", "bot"]:
- asr_need_prefix = asr_need_prefix if asr_need_prefix is not None else True
- asr_skip_voice = asr_skip_voice or False
- asr_skip_audio = asr_skip_audio or False
- asr_skip_audio = asr_skip_audio or False
- else: # private chat
- asr_need_prefix = asr_need_prefix or False
- asr_skip_voice = asr_skip_voice or False
- asr_skip_audio = asr_skip_audio if asr_skip_audio is not None else True
- asr_skip_video = asr_skip_video if asr_skip_video is not None else True
-
- # skip no "/asr" prefix message if asr_need_prefix
- if asr_need_prefix and not startswith_prefix(this_text, prefix=[PREFIX.ASR]):
- return None
-
- # treat the reply_to_message as the real message need to be recognized
- trigger_msg = message.reply_to_message if message.reply_to_message and startswith_prefix(this_text, prefix=[PREFIX.ASR]) else message
- trigger_info = parse_msg(trigger_msg, silent=True)
-
- # skip non voice/audio/video message
- if not trigger_msg:
- return None
- if trigger_info["mtype"] not in ["voice", "audio", "video"]:
- return None
-
- # always trigger if the message has "/asr" prefix
- if startswith_prefix(this_text, prefix=[PREFIX.ASR]):
- return trigger_msg
-
- # match the asr_skip_* settings
- if asr_skip_voice and trigger_info["mtype"] == "voice":
- return None
- if asr_skip_audio and trigger_info["mtype"] == "audio":
- return None
- if asr_skip_video and trigger_info["mtype"] == "video":
- return None
- return trigger_msg
src/others/podcast.py
@@ -27,7 +27,7 @@ from llm.utils import convert_html, convert_md, remove_consecutive_newlines
from messages.sender import send2tg
from networking import download_file, hx_req
from publish import publish_telegraph
-from utils import bare_url, count_subtitles, https_url, nowdt, rand_number, rand_string, readable_time
+from utils import bare_url, count_subtitles, https_url, nowdt, rand_number, rand_string, readable_time, strings_list
HEADERS = {
"User-Agent": "feedparser/6.0.11 +https://github.com/kurtmckee/feedparser/",
@@ -179,10 +179,27 @@ async def get_all_pods() -> dict[str, str]:
def get_pod_asr_engine(feed_title: str, feed_url: str) -> str:
- if feed_title in [x.strip() for x in PODCAST.ASR_FORCE_GEMINI_TITLES.split(",") if x.strip()]:
+ if feed_title in strings_list(PODCAST.ASR_FORCE_GEMINI_TITLES):
return "gemini"
- if urlparse(feed_url.strip()).netloc in [x.strip() for x in PODCAST.ASR_FORCE_GEMINI_DOMAINS.split(",") if x.strip()]:
+ if feed_title in strings_list(PODCAST.ASR_FORCE_GROQ_TITLES):
+ return "groq"
+ if feed_title in strings_list(PODCAST.ASR_FORCE_CLOUDFLARE_TITLES):
+ return "cloudflare"
+ if feed_title in strings_list(PODCAST.ASR_FORCE_WHISPER_TITLES):
+ return "whisper"
+ if feed_title in strings_list(PODCAST.ASR_FORCE_UNCENSORED_TITLES):
+ return "uncensored"
+
+ if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_GEMINI_DOMAINS):
return "gemini"
+ if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_GROQ_DOMAINS):
+ return "groq"
+ if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_CLOUDFLARE_DOMAINS):
+ return "cloudflare"
+ if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_WHISPER_DOMAINS):
+ return "whisper"
+ if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_UNCENSORED_DOMAINS):
+ return "uncensored"
return PODCAST.ASR_ENGINE
src/preview/ytdlp.py
@@ -21,6 +21,7 @@ from yt_dlp.utils import DownloadError, ExtractorError, YoutubeDLError
from asr.voice_recognition import asr_file
from config import (
+ ASR,
CAPTION_LENGTH,
COOKIE,
DB,
@@ -69,7 +70,7 @@ async def preview_ytdlp(
youtube_comments_provider: str = PROVIDER.YOUTUBE_COMMENTS,
proxy: str | None = None,
append_transcription: bool = True,
- ytdlp_transcription_engine: str = "",
+ ytdlp_asr_engine: str = "",
transcription_only: bool = False,
transcription_force_file: bool = False,
to_telegraph: bool = True,
@@ -89,7 +90,7 @@ async def preview_ytdlp(
youtube_comments_provider (str, optional): The youtube comments extractor: "free" or "false".
proxy (str, optional): Proxy to use. Defaults to None.
append_transcription (bool, optional): Also append transcription.
- ytdlp_transcription_engine (str, optional): Method to get transcription.
+ ytdlp_asr_engine (str, optional): Method to get transcription.
transcription_only (str, optional): If True, skip send video and audio file.
transcription_force_file (str, optional): If True, force to send transcription as file.
to_telegraph (bool, optional): Whether to publish the subtitle or transcription to telegraph.
@@ -268,8 +269,10 @@ async def preview_ytdlp(
res = await fetch_subtitle(url=url, provider="free")
subtitles = glom(res, Coalesce("full", "subtitles"), default="")
if not subtitles:
- ytdlp_transcription_engine = "gemini" if "youtube" in info["extractor"] else ytdlp_transcription_engine # use gemini to bypass censorship
- res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, silent=True)
+ if not ytdlp_asr_engine:
+ # bypass censorship
+ ytdlp_asr_engine = kwargs.get("asr_engine", "uncensored") if "youtube" in info["extractor"] else ASR.DEFAULT_ENGINE
+ res = await asr_file(audio_path, ytdlp_asr_engine, duration, client=client, message=message, silent=True)
subtitles = res.get("texts", "")
if count_subtitles(subtitles) < 20:
subtitles = "" # ignore too short transcription
src/subtitles/subtitle.py
@@ -11,7 +11,7 @@ from pyrogram.types import Message
from pyrogram.types.messages_and_media.message import Str
from asr.voice_recognition import asr_file
-from config import PREFIX, PROVIDER, READING_SPEED, TEXT_LENGTH, cache
+from config import ASR, PREFIX, PROVIDER, READING_SPEED, TEXT_LENGTH, cache
from llm.gpt import gpt_response
from messages.parser import parse_msg
from messages.progress import modify_progress
@@ -83,13 +83,15 @@ async def get_subtitle(
subtitles = ""
# API failed
if error := res.get("error", ""):
+ asr_engine = ASR.DEFAULT_ENGINE
+ if platform == "youtube": # bypass censorship
+ asr_engine = kwargs.get("asr_engine", "uncensored")
if this_info["mtype"] in ["audio", "video"] or reply_info.get("mtype", "") in ["audio", "video"]:
await modify_progress(text=error + "\n正在通过ASR识别字幕", force_update=True, **kwargs)
msg = message if this_info["mtype"] in ["audio", "video"] else message.reply_to_message
fpath: str = await client.download_media(msg) # type: ignore
- engine = "gemini" if platform == "youtube" else "" # use gemini to bypass censorship
prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
- res = await asr_file(fpath, engine=engine, prompt=prompt, client=client, message=message, silent=True, **kwargs)
+ res = await asr_file(fpath, engine=asr_engine, prompt=prompt, client=client, message=message, silent=True, **kwargs)
if res.get("error"):
await modify_progress(text=res["error"], force_update=True, **kwargs)
return
@@ -106,6 +108,7 @@ async def get_subtitle(
"bilibili_comments": False,
"proxy": None,
"use_db": False,
+ "ytdlp_asr_engine": asr_engine,
}
# Download and send subtitle file via ytdlp
subtitle_msg = (await preview_ytdlp(client=client, message=message, **kwargs))[0]
src/config.py
@@ -249,12 +249,13 @@ class HISTORY:
class ASR:
# use different engines based on duration
- # support ali, tencent, gemini engines
- SHORT_ENGINE = os.getenv("ASR_SHORT_ENGINE", "tencent")
+ # support ali, tencent, gemini, deepgram, cloudflare, groq
+ DEFAULT_ENGINE = os.getenv("ASR_DEFAULT_ENGINE", "auto")
+ SHORT_ENGINE = os.getenv("ASR_SHORT_ENGINE", "tencent") # comma separated engine names
SHORT_DURATION = int(os.getenv("ASR_SHORT_DURATION", "60"))
- MIDDLE_ENGINE = os.getenv("ASR_MIDDLE_ENGINE", "tencent,ali")
+ MIDDLE_ENGINE = os.getenv("ASR_MIDDLE_ENGINE", "tencent,ali") # comma separated engine names
MIDDLE_DURATION = int(os.getenv("ASR_MIDDLE_DURATION", "600"))
- LONG_ENGINE = os.getenv("ASR_LONG_ENGINE", "gemini")
+ LONG_ENGINE = os.getenv("ASR_LONG_ENGINE", "gemini") # comma separated engine names
TENCENT_APPID = os.getenv("ASR_TENCENT_APPID", "")
TENCENT_PROXY = os.getenv("ASR_TENCENT_PROXY", None) # Banned oversea IP, need a back to China proxy
@@ -269,7 +270,12 @@ class ASR:
ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local") # local, uguu or alist.
DEEPGRAM_API = os.getenv("ASR_DEEPGRAM_API", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
CLOUDFLARE_MODEL = os.getenv("ASR_CLOUDFLARE_MODEL", "@cf/openai/whisper-large-v3-turbo")
+ CLOUDFLARE_MAX_BYTES = int(os.getenv("ASR_CLOUDFLARE_MAX_BYTES", "26214400")) # 25MB (max file bytes for single file)
+ CLOUDFLARE_CHUNK_SECONDS = float(os.getenv("ASR_CLOUDFLARE_CHUNK_SECONDS", "1800")) # split long audio file into chunks
+ CLOUDFLARE_OVERLAP_SECONDS = float(os.getenv("ASR_CLOUDFLARE_OVERLAP_SECONDS", "5")) # overlap seconds between chunks
CLOUDFLARE_KEYS = os.getenv("ASR_CLOUDFLARE_KEYS", "") # comma separated keys for load balance. e.g. "AccountID:API_TOKEN, AccountID:API_TOKEN, ..."
+ CLOUDFLARE_PROXY = os.getenv("ASR_CLOUDFLARE_PROXY", None)
+
GEMINI_CHUNK_SECONDS = float(os.getenv("ASR_GEMINI_CHUNK_SECONDS", "900")) # split long audio file into chunks
GEMINI_OVERLAP_SECONDS = float(os.getenv("ASR_GEMINI_OVERLAP_SECONDS", "5")) # overlap seconds between chunks
GROQ_PROXY = os.getenv("ASR_GROQ_PROXY", None) # Ban CN & HK IP
@@ -286,11 +292,20 @@ class PODCAST:
OPML_URLS = os.getenv("PODCAST_OPML_URLS", "") # comma separated opml urls
TID = int(os.getenv("PODCAST_TID", "0")) # send to this chat id
FS_ENGINE = os.getenv("PODCAST_FS_ENGINE", "CF-R2") # file storage engine for hosting podcast feeds
- ASR_ENGINE = os.getenv("PODCAST_ASR_ENGINE", "gemini") # default ASR engine
- ASR_FORCE_GEMINI_TITLES = os.getenv("PODCAST_ASR_FORCE_GEMINI_TITLES", "") # comma separated titles for force Gemini ASR. (Bypass censorship)
- ASR_FORCE_GEMINI_DOMAINS = os.getenv("PODCAST_ASR_FORCE_GEMINI_DOMAINS", "anchor.fm,feeds.acast.com") # comma separated domains for force Gemini ASR. (Bypass censorship)
+ ASR_ENGINE = os.getenv("PODCAST_ASR_ENGINE", "auto") # default ASR engine
IGNORE_OLD_THAN_SECONDS = int(os.getenv("PODCAST_IGNORE_OLD_THAN_SECONDS", "14400")) # in seconds
KEEP_LATEST_ENTRIES = int(os.getenv("PODCAST_KEEP_LATEST_ENTRIES", "99999999")) # keep latest entries
+ # To bypass censorship, set asr engines here (comma separated titles or domains)
+ ASR_FORCE_GEMINI_TITLES = os.getenv("PODCAST_ASR_FORCE_GEMINI_TITLES", "")
+ ASR_FORCE_GEMINI_DOMAINS = os.getenv("PODCAST_ASR_FORCE_GEMINI_DOMAINS", "")
+ ASR_FORCE_GROQ_TITLES = os.getenv("PODCAST_ASR_FORCE_GROQ_TITLES", "")
+ ASR_FORCE_GROQ_DOMAINS = os.getenv("PODCAST_ASR_FORCE_GROQ_DOMAINS", "")
+ ASR_FORCE_CLOUDFLARE_TITLES = os.getenv("PODCAST_ASR_FORCE_CLOUDFLARE_TITLES", "")
+ ASR_FORCE_CLOUDFLARE_DOMAINS = os.getenv("PODCAST_ASR_FORCE_CLOUDFLARE_DOMAINS", "")
+ ASR_FORCE_WHISPER_TITLES = os.getenv("PODCAST_ASR_FORCE_WHISPER_TITLES", "")
+ ASR_FORCE_WHISPER_DOMAINS = os.getenv("PODCAST_ASR_FORCE_WHISPER_DOMAINS", "")
+ ASR_FORCE_UNCENSORED_TITLES = os.getenv("PODCAST_ASR_FORCE_UNCENSORED_TITLES", "")
+ ASR_FORCE_UNCENSORED_DOMAINS = os.getenv("PODCAST_ASR_FORCE_UNCENSORED_DOMAINS", "anchor.fm,feeds.acast.com")
class FAVORITE: