Commit e4da047
Changed files (2)
src
subtitles
src/asr/voice_recognition.py
@@ -7,7 +7,7 @@ from glom import glom
from loguru import logger
from pyrogram.client import Client
from pyrogram.enums import ParseMode
-from pyrogram.types import Message
+from pyrogram.types import Chat, Message
from asr.ali import ali_asr
from asr.cloudflare import cloudflare_asr
@@ -149,7 +149,7 @@ async def voice_to_text(
await modify_progress(text=msg, force_update=True, **kwargs)
return
- res = await asr_file(path, engine=asr_engine, tencent_language=tencent_language, client=client, message=msg_to_asr, **kwargs)
+ res = await asr_file(path, engine=asr_engine, tencent_language=tencent_language, message=msg_to_asr, **kwargs)
if error := res.get("error"):
await modify_progress(kwargs.get("progress"), text=error, force_update=True)
return
@@ -214,7 +214,7 @@ async def asr_file(
res = await deepgram_asr(path)
elif engine == "gemini":
res = await gemini_asr(
- message=kwargs["message"],
+ message=kwargs.get("message", Message(id=0, chat=Chat(id=0))),
path=path,
model_id=kwargs.get("gemini_asr_model_id", ""),
prompt=kwargs.get("gemini_asr_prompt", ""),
src/subtitles/subtitle.py
@@ -94,7 +94,7 @@ async def get_subtitle(
media_path = f"{DOWNLOAD_DIR}/{this_info['file_name'] or reply_info.get('file_name', '')}"
fpath: str = await client.download_media(msg, media_path) # type: ignore
prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
- res = await asr_file(fpath, engine=asr_engine, prompt=prompt, client=client, message=message, silent=True, **kwargs)
+ res = await asr_file(fpath, engine=asr_engine, prompt=prompt, message=message, silent=True, **kwargs)
if res.get("error"):
await modify_progress(text=res["error"], force_update=True, **kwargs)
return
@@ -106,7 +106,7 @@ async def get_subtitle(
await modify_progress(text="❌下载音频失败", force_update=True, **kwargs)
return
prompt = f"请转录{matched['platform'].title()}视频作者【{vinfo['author']}】的一期节目的音频。\n该期节目标题: {vinfo['title']}\n节目简介: {description}"
- res = await asr_file(downloaded["audio_path"], engine=asr_engine, prompt=prompt, client=client, message=message, silent=True, **kwargs)
+ res = await asr_file(downloaded["audio_path"], engine=asr_engine, prompt=prompt, message=message, silent=True, **kwargs)
if res.get("error"):
await modify_progress(text=res["error"], force_update=True, **kwargs)
return