Commit 1493f25

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-06 03:37:03
fix(gemini): convert audio files if not supported by Gemini
1 parent c576d27
Changed files (2)
src/asr/utils.py
@@ -7,6 +7,12 @@ from pathlib import Path
 from config import ASR, GEMINI
 from multimedia import convert_to_audio
 
+ALI_AUDIO_EXT = [".aac", ".amr", ".avi", ".flac", ".flv", ".m4a", ".mkv", ".mov", ".mp3", ".mp4", ".mpeg", ".ogg-opus", ".ogg", ".opus", ".wav", ".webm", ".wma", ".wmv"]
+GEMINI_AUDIO_EXT = [".aac", ".aiff", ".flac", ".mp3", ".oga", ".ogg", ".opus", ".wav"]
+DEEPGRAM_AUDIO_EXT = [".mp3", ".aac", ".flac", ".m4a", ".mp2", ".mp4", ".ogg", ".opus", ".ogg-opus", ".pcm", ".wav", ".webm"]
+TENCENT_AUDIO_EXT = [".aac", ".amr", ".m4a", ".mp3", ".oga", ".ogg-opus", ".ogg", ".opus", ".pcm", ".silk", ".speex", ".wav"]
+TENCENT_ASYNC_AUDIO_EXT = [".3gp", ".aac", ".amr", ".flac", ".flv", ".m4a", ".mp3", ".mp4", ".oga", ".ogg-opus", ".ogg", ".opus", ".wav", ".wma"]
+
 
 def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
     """Get ASR method and supported file types."""
@@ -21,7 +27,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
     if force_engine == "ali":
         return get_ali_asr_method()
     if force_engine == "deepgram":
-        return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
+        return "deepgram", [x.lstrip(".") for x in DEEPGRAM_AUDIO_EXT]
     if force_engine == "tencent":
         return get_tencent_asr_method(duration, file_size)
     if force_engine == "gemini":
@@ -30,7 +36,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
     if asr_engine == "ali":
         return get_ali_asr_method()
     if asr_engine == "deepgram":
-        return "deepgram", ["mp3", "aac", "flac", "m4a", "mp2", "mp4", "ogg", "opus", "ogg-opus", "pcm", "wav", "webm"]
+        return "deepgram", [x.lstrip(".") for x in DEEPGRAM_AUDIO_EXT]
     if asr_engine == "tencent":
         return get_tencent_asr_method(duration, file_size)
     if asr_engine.lower() == "gemini":
@@ -41,7 +47,7 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
 def get_ali_asr_method() -> tuple[str, list[str]]:
     if not all([ASR.ALI_MODEL, ASR.ALI_API_KEY]):
         return "请设置阿里云ASR相关环境变量", []
-    supported_ext = ["aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov", "mp3", "mp4", "mpeg", "ogg-opus", "ogg", "opus", "wav", "webm", "wma", "wmv"]
+    supported_ext = [x.lstrip(".") for x in ALI_AUDIO_EXT]
     return "ali", supported_ext
 
 
@@ -52,22 +58,22 @@ def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[s
     asr_method = ""
     if duration < 60 and file_size < 3 * 1024 * 1024:
         asr_method = "tencent_single_asr"  # 一句话识别
-        supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
+        supported_ext = [x.lstrip(".") for x in TENCENT_AUDIO_EXT]
     elif 60 <= duration <= 300 and file_size < 100 * 1024 * 1024:
         asr_method = "tencent_flash_asr"  # 录音文件识别极速版
-        supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
+        supported_ext = [x.lstrip(".") for x in TENCENT_AUDIO_EXT]
     else:
         asr_method = "tencent_async_asr"  # 录音文件识别 (异步请求)
-        supported_ext = ["3gp", "aac", "amr", "flac", "flv", "m4a", "mp3", "mp4", "oga", "ogg-opus", "ogg", "opus", "wav", "wma"]
+        supported_ext = [x.lstrip(".") for x in TENCENT_ASYNC_AUDIO_EXT]
     return asr_method, supported_ext
 
 
-def get_gemini_asr_method(duration: float) -> tuple[str, list[str]]:
-    if duration > GEMINI.ASR_MAX_DURATION:
+def get_gemini_asr_method(duration: float | None = None) -> tuple[str, list[str]]:
+    if duration is not None and duration > GEMINI.ASR_MAX_DURATION:
         return f"无法识别时长超过{GEMINI.ASR_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒", []
     if not GEMINI.API_KEY:
         return "请设置`GEMINI_API_KEY`环境变量", []
-    return "gemini", ["aac", "aiff", "flac", "mp3", "oga", "ogg", "opus", "wav"]
+    return "gemini", [x.lstrip(".") for x in GEMINI_AUDIO_EXT]
 
 
 async def downsampe_audio(path: str | Path, ext: str = "opus", codec: str = "libopus", sample_rate: int = 16000, **kwargs) -> Path:
src/llm/contexts.py
@@ -13,9 +13,11 @@ from openai import AsyncOpenAI
 from pyrogram.client import Client
 from pyrogram.types import Message
 
+from asr.utils import GEMINI_AUDIO_EXT
 from config import GPT
 from llm.utils import BOT_TIPS, clean_context, convert_md
 from messages.parser import parse_msg
+from multimedia import convert_to_audio
 
 if TYPE_CHECKING:
     from io import BytesIO
@@ -143,6 +145,9 @@ async def single_gemini_context(client: Client, message: Message, app: genai.Cli
         try:
             if info["mtype"] in ["video", "photo", "audio", "voice"] or info["mime_type"] in gemini_mime_types or any(info["file_name"].endswith(ext) for ext in gemini_extensions):
                 fpath: str = await client.download_media(msg, in_memory=False)  # type: ignore  # type: ignore
+                if info["mtype"] in ["audio", "voice"] and Path(fpath).suffix not in GEMINI_AUDIO_EXT:
+                    audio_path = await convert_to_audio(fpath, ext="opus", codec="libopus")
+                    fpath = audio_path.as_posix()
                 upload = await app.aio.files.upload(file=fpath, config=UploadFileConfig(display_name=info["file_name"] or f"send from {info['full_name']}"))
                 while upload.state == FileState.PROCESSING:
                     logger.trace("Waiting for upload to complete...")