Commit fc3479d

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-05-23 08:30:54
feat(asr): add `delete_gemini_file` and `delete_local_file` options
1 parent afdb353
src/asr/ali_asr.py
@@ -115,12 +115,12 @@ def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> dict:
     # convert audio file
     sample_rate = 8000 if "8k" in model else 16000
     ext = "opus"
-    path = downsampe_audio(path, ext=ext, sample_rate=sample_rate, ac=1)
+    audio_path = downsampe_audio(path, ext=ext, sample_rate=sample_rate, ac=1)
     recognition = Recognition(model=model, format=ext, sample_rate=sample_rate, callback=RecognitionCallback(), api_key=api_key)
-    result = recognition.call(Path(path).as_posix())
+    result = recognition.call(Path(audio_path).as_posix())
     if result.status_code != 200:
         return {"error": f"❌语音识别失败: {result.message}"}
-    Path(path).unlink(missing_ok=True)
+    Path(audio_path).unlink(missing_ok=True)
     data = result.get_sentence()
     if not data:
         return {"error": "⚠️该音频未识别到文字"}
src/asr/gemini_asr.py
@@ -30,6 +30,7 @@ async def gemini_stream_asr(
     prompt: str = "请转录这段音频",
     *,
     silent: bool = False,
+    delete_gemini_file: bool = True,
     **kwargs,
 ) -> dict:
     """Gemini stream ASR.
@@ -70,6 +71,7 @@ Notes:
 - Focus on accuracy in capturing both the timing and the spoken content.
 - Maintain consistent formatting to ensure clarity and readability."""
     path = Path(path)
+    res = {}
     sent_messages = []
     status = None if silent else kwargs.get("progress")
     api_keys = shuffle_keys(GEMINI.API_KEY)
@@ -117,7 +119,10 @@ Notes:
         finally:
             with contextlib.suppress(Exception):
                 if "uploaded_audio" in locals() and uploaded_audio.name:
-                    await app.aio.files.delete(name=uploaded_audio.name)
+                    if delete_gemini_file:
+                        await app.aio.files.delete(name=uploaded_audio.name)
+                    else:
+                        res["gemini_file"] = uploaded_audio
     res["sent_messages"] = [status, *sent_messages]
     return res
 
src/asr/voice_recognition.py
@@ -173,6 +173,9 @@ async def asr_file(
     engine: str = "",
     duration: int = 0,
     language: str = "16k_zh-PY",
+    *,
+    delete_local_file: bool = True,
+    delete_gemini_file: bool = True,
     **kwargs,
 ) -> dict:
     """Get ASR results of an audio file."""
@@ -217,7 +220,7 @@ async def asr_file(
         elif asr_method == "deepgram":
             res = await deepgram_asr(path)
         elif asr_method == "gemini":
-            res = await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
+            res = await gemini_stream_asr(path=path, voice_format=voice_format, delete_gemini_file=delete_gemini_file, **kwargs)
         else:
             return {"error": "ASR method not supported"}
         if res.get("texts"):
@@ -227,7 +230,10 @@ async def asr_file(
         logger.error(error)
         res["error"] = res.get("error", error)
     finally:
-        path.unlink(missing_ok=True)
+        if delete_local_file:
+            path.unlink(missing_ok=True)
+        elif path.is_file():
+            res["audio_file"] = path
     return res