Commit a46a7bc
Changed files (4)
src
src/asr/gemini_asr.py
@@ -24,15 +24,12 @@ async def gemini_asr(path: str | Path, voice_format: str) -> str:
https://ai.google.dev/gemini-api/docs/audio
"""
path = Path(path)
- if voice_format == "ogg-opus":
- voice_format = "ogg"
-
api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
random.shuffle(api_keys)
res = ""
for key in api_keys:
try:
- logger.debug(f"ASR {path.as_posix()} via {ASR.GEMINI_BASR_URL}, proxy={ASR.GEMINI_PROXY}")
+ logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={ASR.GEMINI_PROXY}")
client = genai.Client(api_key=key, http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
logger.debug(uploaded_audio)
src/asr/utils.py
@@ -5,7 +5,7 @@
from config import ASR, FILE_SERVER
-def get_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
+def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
"""Get ASR method and supported file types."""
if duration < 60:
asr_engine = ASR.SHORT_ENGINE
@@ -14,9 +14,9 @@ def get_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
else:
asr_engine = ASR.LONG_ENGINE
- if asr_engine.lower() == "tencent":
+ if asr_engine == "tencent" or force_engine == "tencent":
return get_tencent_asr_method(duration, file_size)
- if asr_engine.lower() == "gemini":
+ if asr_engine.lower() == "gemini" or force_engine == "gemini":
return get_gemini_asr_method(duration)
return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
src/asr/voice_recognition.py
@@ -9,17 +9,18 @@ from pathlib import Path
from glom import glom
from loguru import logger
from pyrogram.client import Client
+from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
from pyrogram.types import Message
from asr.gemini_asr import gemini_asr
from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
from asr.utils import get_asr_method
-from config import CAPTION_LENGTH, FILE_SERVER, PREFIX
+from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import count_without_entities, equal_prefix, get_reply_to, startswith_prefix
-from multimedia import convert_to_audio
+from multimedia import convert_to_audio, parse_media_info
from utils import rand_string, to_int
# ruff: noqa: RUF001
@@ -104,25 +105,23 @@ async def voice_to_text(
this_info = parse_msg(message, silent=True)
trigger_info = parse_msg(trigger_message, silent=True)
- asr_engine = "16k_zh-PY" # default: 中英粤
+ asr_language = "16k_zh-PY" # default: 中英粤
+ force_engine = "" # gemini or tencent
if matched := re.match(r"/asr\s+([^.。,,/\s]+)", this_info["text"]): # /asr yue
- asr_engine = f"16k_{matched.group(1)}"
- asr_engine = asr_engine.replace("16k_fy", "16k_zh_dialect") # fix dialect engine code
+ custom_code = matched.group(1)
+ if custom_code == "fy": # re-map dialect
+ custom_code = "zh_dialect"
+ custom_code = custom_code.replace("fy", "zh_dialect")
+ if f"16k_{custom_code}" in ENGINE_MAP:
+ asr_language = f"16k_{custom_code}"
+ elif custom_code in ["gemini", "tencent"]:
+ force_engine = custom_code
- duration = trigger_info["duration"]
- asr_method, supported_ext = get_asr_method(duration, trigger_info["file_size"])
- if asr_method not in ["single_sentence_asr", "flash_asr", "async_asr", "gemini"]:
- await modify_progress(text=asr_method, force_update=True, **kwargs)
- return
-
- msg = f"Recieved {trigger_info['mtype']} message, start recognizing by {ENGINE_MAP.get(asr_engine, 'Unknown')}..."
+ msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
logger.info(msg)
if kwargs.get("show_progress"):
res = await send2tg(client, message, texts=msg, **kwargs)
kwargs["progress"] = res[0]
- if asr_method != "gemini" and asr_engine not in ENGINE_MAP:
- await modify_progress(text=f"Unsupported ASR engine: {asr_engine}", force_update=True, **kwargs)
- return
path: str | Path = await trigger_message.download() # type: ignore
path = Path(path).expanduser().resolve()
@@ -131,23 +130,13 @@ async def voice_to_text(
logger.error(msg)
await modify_progress(text=msg, force_update=True, **kwargs)
return
- voice_format = path.suffix.lstrip(".")
- if voice_format not in supported_ext:
- path = convert_to_audio(path, ext="aac", codec="aac")
- voice_format = "aac"
- asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size) # match again based on converted file
- path = path.rename(path.with_stem(rand_string())) # sanitize filename. (for Tencent Signature v3)
-
- if voice_format in ["oga", "ogg", "opus"]: # rename format
- voice_format = "ogg-opus"
- logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {asr_engine}: {path.as_posix()}")
- res = await asr_file(path, asr_method, voice_format, asr_engine)
+ res = await asr_file(path, engine=force_engine, duration=trigger_info["duration"], language=asr_language)
if error := res.get("error"):
await modify_progress(text=error, force_update=True, **kwargs)
return
if texts := res.get("texts"):
- final = f"{BEGINNING}\n{texts}"
+ final = f"{BEGINNING}\n{BLOCKQUOTE_EXPANDABLE_DELIM}{texts}{BLOCKQUOTE_EXPANDABLE_END_DELIM}"
logger.success(f"{final!r}")
# send results
target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
@@ -155,9 +144,10 @@ async def voice_to_text(
length = await count_without_entities(final)
if length < CAPTION_LENGTH: # short
await client.copy_message(chat_id=to_int(target_chat), from_chat_id=trigger_info["cid"], message_id=trigger_info["mid"], caption=final, reply_parameters=reply_parameters)
+ elif length < TEXT_LENGTH: # middle
+ await client.send_message(to_int(target_chat), final, reply_parameters=reply_parameters)
else: # long
- final = final.removeprefix(f"{BEGINNING}\n")
- with io.BytesIO(final.encode("utf-8")) as f:
+ with io.BytesIO(texts.encode("utf-8")) as f:
await client.send_document(to_int(target_chat), f, file_name="语音识别结果.txt", reply_parameters=reply_parameters)
await modify_progress(del_status=True, **kwargs)
@@ -166,21 +156,51 @@ async def voice_to_text(
await message.delete()
-async def asr_file(path: str | Path, method: str, voice_format: str, asr_engine: str = "16k_zh-PY") -> dict:
- """Get ASR texts of an audio file."""
+async def asr_file(
+ path: str | Path,
+ engine: str = "",
+ duration: int = 0,
+ language: str = "16k_zh-PY",
+) -> dict:
+ """Get ASR results of an audio file."""
res = {}
path = Path(path).expanduser().resolve()
if not path.is_file():
return {"error": f"{path} is not exist"}
+ info = parse_media_info(path)
+ if duration == 0:
+ duration = info["duration"]
+ asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
+ if asr_method not in ["single_sentence_asr", "flash_asr", "async_asr", "gemini"]:
+ return {"error": asr_method}
+
+ voice_format = path.suffix.lstrip(".")
+ if voice_format not in supported_ext:
+ if info["audio_codec"].split("/")[-1] in supported_ext and not info["video_codec"]:
+ voice_format = info["audio_codec"].split("/")[-1]
+ else:
+ path = convert_to_audio(path, ext="aac", codec="aac")
+ voice_format = "aac"
+ # match again based on converted file
+ asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
+
+ ogg_names = ["oga", "ogg-opus", "ogg", "opus"] # unify format name
+ if asr_method in ["single_sentence_asr", "flash_asr", "async_asr"] and voice_format in ogg_names:
+ voice_format = "ogg-opus"
+ path = path.rename(path.with_stem(rand_string())) # sanitize filename. (for Tencent Signature v3)
+ if asr_method == "gemini" and voice_format in ogg_names:
+ voice_format = "ogg"
+
+ logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
try:
- if method == "single_sentence_asr":
- resp = await single_sentence_asr(path, asr_engine, voice_format)
+ if asr_method == "single_sentence_asr":
+ resp = await single_sentence_asr(path, language, voice_format)
texts = glom(resp, "Response.Result").replace("。", "。\n")
- elif method == "flash_asr":
- resp = await flash_asr(path, asr_engine, voice_format)
+ elif asr_method == "flash_asr":
+ resp = await flash_asr(path, language, voice_format)
texts = glom(resp, "flash_result.0.text").replace("。", "。\n")
- elif method == "async_asr":
- resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", asr_engine)
+ elif asr_method == "async_asr":
+ resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", language)
task_id = resp["Response"]["Data"]["TaskId"]
logger.success(f"ASR任务提交成功, TaskID: {task_id}")
result = await query_async_asr(task_id)
@@ -196,7 +216,7 @@ async def asr_file(path: str | Path, method: str, voice_format: str, asr_engine:
else:
texts = glom(result, "Response.Data.ErrorMsg")
res["error"] = texts
- elif method == "gemini":
+ elif asr_method == "gemini":
texts = await gemini_asr(path, voice_format)
res["texts"] = texts
logger.success(f"{texts!r}")
src/preview/ytdlp.py
@@ -19,7 +19,8 @@ from pyrogram.types import Message
from yt_dlp import YoutubeDL
from yt_dlp.utils import DownloadError, ExtractorError, YoutubeDLError
-from config import API, CAPTION_LENGTH, DB, DOWNLOAD_DIR, MAX_FILE_BYTES, PROVIDER, PROXY, TID, TOKEN, YTDLP_DOWNLOAD_MAX_FILE_BYTES, YTDLP_RE_ENCODING_MAX_FILE_BYTES, cache
+from asr.voice_recognition import asr_file
+from config import API, CAPTION_LENGTH, DB, DOWNLOAD_DIR, MAX_FILE_BYTES, PROVIDER, PROXY, READING_SPEED, TID, TOKEN, YTDLP_DOWNLOAD_MAX_FILE_BYTES, YTDLP_RE_ENCODING_MAX_FILE_BYTES, cache
from cookies import cookie_cloud_bilibili
from database import get_db
from messages.database import copy_messages_from_db, save_messages
@@ -51,6 +52,8 @@ async def preview_ytdlp(
youtube_comments_provider: str = PROVIDER.YOUTUBE_COMMENTS,
proxy: str | None = None,
append_youtube_subtitle: bool = True,
+ append_transcription: bool = True,
+ ytdlp_transcription_engine: str = "gemini",
**kwargs,
):
"""Preview ytdlp link in the message.
@@ -66,6 +69,8 @@ async def preview_ytdlp(
youtube_comments_provider (str, optional): The youtube comments extractor: "free" or "false".
proxy (str, optional): Proxy to use. Defaults to None.
append_youtube_subtitle (bool, optional): Also send youtube subtitle.
+ append_transcription (bool, optional): Also append transcription.
+ ytdlp_transcription_method (str, optional): Method to get transcription.
"""
logger.trace(f"{url=} {kwargs=}")
if kwargs.get("show_progress") and "progress" not in kwargs:
@@ -232,6 +237,15 @@ async def preview_ytdlp(
caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {res['num_chars']}\n阅读时长: {res['reading_minutes']:.1f}分钟"
with io.BytesIO(subtitles.encode("utf-8")) as f:
await client.send_document(to_int(target_chat), f, file_name="字幕文件.txt", caption=caption)
+ append_transcription = False # disable asr transcription
+
+ if any(x in info["extractor"] for x in ["youtube", "bilibili"]) and append_transcription and audio_path.is_file():
+ asr_res = await asr_file(audio_path, ytdlp_transcription_engine, duration)
+ if texts := asr_res.get("texts"):
+ caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {len(texts)}\n阅读时长: {len(texts) / READING_SPEED:.1f}分钟"
+ with io.BytesIO(texts.encode("utf-8")) as f:
+ await client.send_document(to_int(target_chat), f, file_name="字幕文件.txt", caption=caption)
+
Path(json_file).unlink(missing_ok=True)
cleanup_ytdlp(info["id"])