Commit 4eadac6
Changed files (4)
src/asr/gemini_asr.py
@@ -13,12 +13,11 @@ from pydantic import BaseModel
from pyrogram.client import Client
from pyrogram.types import Message
-from config import GEMINI, TEXT_LENGTH
-from llm.gemini import parse_response
+from config import GEMINI
+from llm.gemini import gemini_stream
from llm.hooks import hook_gemini_httpoptions
-from llm.utils import beautify_llm_response
+from llm.utils import shuffle_keys
from messages.progress import modify_progress
-from messages.utils import blockquote, count_without_entities, smart_split
async def gemini_stream_asr(
@@ -29,9 +28,6 @@ async def gemini_stream_asr(
prompt: str = "请转录这段音频",
*,
slient: bool = False,
- retry: int = 0,
- max_retry: int = 2,
- last_error: str = "",
**kwargs,
) -> dict:
"""Gemini stream ASR.
@@ -71,64 +67,41 @@ Example-2:
Notes:
- Focus on accuracy in capturing both the timing and the spoken content.
- Maintain consistent formatting to ensure clarity and readability."""
- if retry > max_retry:
- logger.error(f"[GeminiASR] Failed after {retry} retries")
- return {"error": last_error}
path = Path(path)
- api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
- transcriptions = ""
- runtime_texts = ""
- sent_messages = []
- if status := kwargs.get("progress"):
- sent_messages.append(status)
- if slient:
- status = None
- try:
- logger.debug(f"ASR via {GEMINI.ASR_MODEL}: {path.as_posix()} , proxy={GEMINI.PROXY}")
- http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
- http_options = hook_gemini_httpoptions(http_options, message)
- app = genai.Client(api_key=random.choice(api_keys), http_options=http_options)
- uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
- logger.debug(uploaded_audio)
- genconfig = {}
- with contextlib.suppress(Exception):
- genconfig = json.loads(GEMINI.ASR_CONFIG)
- genconfig |= {"response_modalities": ["TEXT"]} # force text response
- genconfig |= {"system_instruction": system_instruction} # pin system instruction
- if GEMINI.ASR_THINKING_BUDGET is not None:
- thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
- genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
- contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
- params = {"model": GEMINI.ASR_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
- async for chunk in await app.aio.models.generate_content_stream(**params):
- resp = parse_response(chunk.model_dump())
- sentence = resp.get("texts", "")
- transcriptions += sentence
- runtime_texts += sentence
- runtime_texts = beautify_llm_response(runtime_texts)
- if await count_without_entities(runtime_texts) <= TEXT_LENGTH:
- if len(runtime_texts) > 5: # start response if sentence is not empty
- await modify_progress(message=status, text=runtime_texts, detail_progress=True, ttl=10)
- else: # transcriptions is too long, split it into multiple messages
- parts = await smart_split(runtime_texts)
- await modify_progress(message=status, text=blockquote(parts[0]), force_update=True) # force send the first part
- runtime_texts = parts[-1] # keep the last part
- if not slient:
- status = await client.send_message(message.chat.id, runtime_texts) # the new message
- sent_messages.append(status)
-
- # all chunks are processed
- await modify_progress(message=status, text=blockquote(beautify_llm_response(runtime_texts)), force_update=True)
- if uploaded_audio.name: # delete file once finished
- await app.aio.files.delete(name=uploaded_audio.name)
- except Exception as e:
- logger.error(e)
- with contextlib.suppress(Exception):
- [await modify_progress(msg, del_status=True) for msg in sent_messages]
- if "uploaded_audio" in locals() and uploaded_audio.name:
- await app.aio.files.delete(name=uploaded_audio.name)
- return await gemini_stream_asr(client, message, path, voice_format, prompt, slient=slient, retry=retry + 1, max_retry=max_retry, last_error=str(e))
- return {"texts": transcriptions, "sent_messages": sent_messages}
+ status = None if slient else kwargs.get("progress")
+ api_keys = shuffle_keys(GEMINI.API_KEYS)
+ for api_key in api_keys.split(","):
+ try:
+ logger.debug(f"ASR via {GEMINI.ASR_MODEL}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+ http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
+ http_options = hook_gemini_httpoptions(http_options, message)
+ app = genai.Client(api_key=api_key, http_options=http_options)
+ uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
+ genconfig = {}
+ with contextlib.suppress(Exception):
+ genconfig = json.loads(GEMINI.ASR_CONFIG)
+ genconfig |= {"response_modalities": ["TEXT"]} # force text response
+ genconfig |= {"system_instruction": system_instruction} # pin system instruction
+ if GEMINI.ASR_THINKING_BUDGET is not None:
+ thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
+ genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
+ contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
+ params = {"model": GEMINI.ASR_MODEL, "contents": contents, "config": GenerateContentConfig(**genconfig)}
+ res = await gemini_stream(client, message, model_name="ASR", params=params, prefix="", slient=slient, max_retry=0, gemini_api_key=api_key, **kwargs)
+ if res.get("error"):
+ continue
+ sent_messages = res.get("sent_messages", [])
+ break
+ except Exception as e:
+ logger.error(e)
+ with contextlib.suppress(Exception):
+ [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
+ finally:
+ with contextlib.suppress(Exception):
+ if "uploaded_audio" in locals() and uploaded_audio.name:
+ await app.aio.files.delete(name=uploaded_audio.name)
+ res["sent_messages"] = [status, *sent_messages]
+ return res
class Transcription(BaseModel):
src/asr/voice_recognition.py
@@ -206,23 +206,22 @@ async def asr_file(
logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
try:
if asr_method == "tencent_single_asr":
- texts = await tencent_single_asr(path, language, voice_format)
+ res["texts"] = await tencent_single_asr(path, language, voice_format)
elif asr_method == "tencent_flash_asr":
- texts = await tencent_flash_asr(path, language, voice_format)
+ res["texts"] = await tencent_flash_asr(path, language, voice_format)
elif asr_method == "tencent_async_asr":
- texts = await tencent_async_asr(path, language)
+ res["texts"] = await tencent_async_asr(path, language)
elif asr_method == "ali":
- texts = await ali_asr(path)
+ res["texts"] = await ali_asr(path)
elif asr_method == "deepgram":
- texts = await deepgram_asr(path)
+ res["texts"] = await deepgram_asr(path)
elif asr_method == "gemini":
- return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
- res["texts"] = texts
- logger.success(f"{texts!r}")
+ res |= await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
+ logger.success(f"{res['texts']!r}")
except Exception as e:
error = f"Failed to recognize audio: {e}"
logger.error(error)
- res["error"] = error
+ res["error"] = res.get("error", error)
finally:
path.unlink(missing_ok=True)
return res
src/llm/gemini.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-
import contextlib
import json
from io import BytesIO
@@ -17,7 +16,7 @@ from pyrogram.types import Message, ReplyParameters
from config import CAPTION_LENGTH, DOWNLOAD_DIR, GEMINI, GPT, PREFIX, TEXT_LENGTH
from llm.contexts import get_conversation_contexts
from llm.hooks import hook_gemini_httpoptions
-from llm.utils import BOT_TIPS, beautify_llm_response, clean_cmd_prefix, clean_gemini_sourcemarks, clean_source_marks
+from llm.utils import BOT_TIPS, beautify_llm_response, clean_cmd_prefix, clean_gemini_sourcemarks, clean_source_marks, shuffle_keys
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
@@ -85,23 +84,35 @@ async def gemini_stream(
message: Message,
model_name: str,
params: dict,
+ prefix: str | None = None,
retry: int = 0,
+ max_retry: int | None = None,
+ last_error: str = "",
+ *,
+ silent: bool = False,
**kwargs,
-):
- prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
- answers = ""
+) -> dict:
+ if prefix is None:
+ prefix = f"🤖**{model_name}**:{BOT_TIPS}\n"
+ answers = "" # all model responses
+ runtime_texts = "" # for a single telegram message
+ init_status_msg = None if silent else kwargs.get("progress")
+ status_msg = init_status_msg
+ status_mid = status_msg.id if isinstance(status_msg, Message) else message.id
+ if not kwargs.get("gemini_api_keys"):
+ kwargs["gemini_api_keys"] = shuffle_keys(GEMINI.API_KEYS)
+ api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
+ max_retry = len(api_keys) - 1 if max_retry is None else max_retry
try:
- status: Message = kwargs.get("progress") # type: ignore
- api_keys = [x.strip() for x in GEMINI.API_KEYS.split(",") if x.strip()]
- if kwargs.get("gemini_api_keys"):
- api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
- if retry > len(api_keys) - 1:
- return None
+ if retry > min(len(api_keys) - 1, max_retry):
+ logger.error(f"[Gemini] Failed after {retry} retries")
+ await modify_progress(message=init_status_msg, text=last_error, force_update=True)
+ return {"error": last_error}
api_key = kwargs.get("gemini_api_key", api_keys[retry])
http_options = HttpOptions(base_url=GEMINI.BASR_URL, async_client_args={"proxy": GEMINI.PROXY})
http_options = hook_gemini_httpoptions(http_options, message)
app = genai.Client(api_key=api_key, http_options=http_options)
- runtime_texts = ""
+ sent_messages = []
async for chunk in await app.aio.models.generate_content_stream(**params):
resp = parse_response(chunk.model_dump())
answer = resp.get("texts", "")
@@ -111,36 +122,42 @@ async def gemini_stream(
length = await count_without_entities(prefix + runtime_texts)
if length <= TEXT_LENGTH:
if len(runtime_texts.removeprefix(prefix)) > 10: # start response if answer is not empty
- await modify_progress(message=status, text=prefix + runtime_texts, detail_progress=True)
+ await modify_progress(message=status_msg, text=prefix + runtime_texts, detail_progress=True)
else: # answers is too long, split it into multiple messages
parts = await smart_split(prefix + runtime_texts)
if len(parts) == 1:
continue
- await modify_progress(message=status, text=blockquote(parts[0]), force_update=True) # force send the first part
+ await modify_progress(message=status_msg, text=blockquote(parts[0]), force_update=True) # force send the first part
runtime_texts = parts[-1] # keep the last part
- status = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status.id)) # the new message
+ if not silent:
+ status_msg = await client.send_message(message.chat.id, text=prefix + runtime_texts, reply_parameters=ReplyParameters(message_id=status_mid)) # the new message
+ sent_messages.append(status_msg)
+ status_mid = status_msg.id
# all chunks are processed
if not answers.strip(): # empty response
- return await gemini_stream(client, message, model_name, params, retry + 1, **kwargs) # type: ignore
+ return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=last_error, **kwargs)
if await count_without_entities(prefix + answers) <= TEXT_LENGTH: # short answer in single msg
if length > GPT.COLLAPSE_LENGTH: # collapse the response if the answer is too long
- await modify_progress(message=status, text=f"{prefix}{blockquote(runtime_texts)}", force_update=True)
+ await modify_progress(message=status_msg, text=f"{prefix}{blockquote(runtime_texts)}", force_update=True)
else:
- await modify_progress(message=status, text=f"{prefix}{runtime_texts}", force_update=True)
+ await modify_progress(message=status_msg, text=f"{prefix}{runtime_texts}", force_update=True)
elif length > GPT.COLLAPSE_LENGTH:
- await modify_progress(message=status, text=prefix + blockquote(runtime_texts), force_update=True)
+ await modify_progress(message=status_msg, text=prefix + blockquote(runtime_texts), force_update=True)
else:
- await modify_progress(message=status, text=prefix + runtime_texts, force_update=True)
+ await modify_progress(message=status_msg, text=prefix + runtime_texts, force_update=True)
except Exception as e:
- logger.error(e)
error = str(e)
if "resp" in locals():
error += f"\n{resp}"
- await modify_progress(text=error, force_update=True, **kwargs)
- return await gemini_stream(client, message, model_name, params, retry + 1, **kwargs) # type: ignore
+ logger.error(error)
+ with contextlib.suppress(Exception):
+ await modify_progress(message=init_status_msg, text=error, force_update=True)
+ [await modify_progress(msg, del_status=True) for msg in sent_messages]
+ return await gemini_stream(client, message, model_name, params, prefix=prefix, retry=retry + 1, last_error=error, **kwargs)
+ return {"texts": answers, "prefix": prefix, "model_name": model_name, "sent_messages": sent_messages}
async def gemini_nonstream(
src/llm/utils.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import random
import re
import tempfile
from pathlib import Path
@@ -242,3 +243,15 @@ def split_reasoning(text: str) -> tuple[str, str]:
if matched := re.search(rf"{REASONING_BEGIN}(.*?){REASONING_END}", text, flags=re.DOTALL):
reasoning = REASONING_BEGIN + matched.group(1) + REASONING_END
return reasoning.strip(), content.strip()
+
+
+def shuffle_keys(keys: str | list[str]) -> str:
+ """Shuffle comma speparated string."""
+ if isinstance(keys, str):
+ keys = [x.strip() for x in keys.split(",") if x.strip()]
+ elif isinstance(keys, list):
+ keys = [x.strip() for x in keys if x.strip()]
+ else:
+ return ""
+ random.shuffle(keys)
+ return ",".join(keys)