Commit 250d1f6
Changed files (19)
src
database
history
preview
price
subtitles
src/asr/gemini.py
@@ -8,20 +8,16 @@ from pathlib import Path
import soundfile as sf
from glom import glom
from google import genai
-from google.genai.types import GenerateContentConfig, GoogleSearch, HttpOptions, ThinkingConfig, Tool, UploadFileConfig, UrlContext
+from google.genai.types import File, GenerateContentConfig, HttpOptions, ThinkingConfig, UploadFileConfig
from loguru import logger
from pydantic import BaseModel, Field
-from pyrogram.client import Client
from pyrogram.types import Message
from asr.groq import merge_transcripts
from asr.utils import GEMINI_AUDIO_EXT, audio_duration, convert_single_channel, downsampe_audio
from config import ASR, DOWNLOAD_DIR, GEMINI
-from llm.gemini import gemini_stream
from llm.hooks import hook_gemini_httpoptions
-from llm.utils import shuffle_keys
-from messages.progress import modify_progress
-from utils import count_subtitles, guess_mime, rand_string, seconds_to_time, strings_list, zhcn
+from utils import guess_mime, rand_string, seconds_to_time, strings_list, zhcn
class Transcription(BaseModel):
@@ -80,11 +76,12 @@ async def gemini_single_file(
if not model_id:
model_id = GEMINI.ASR_MODEL
for api_key in strings_list(GEMINI.API_KEY, shuffle=True):
+ logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
+ http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
+ http_options = hook_gemini_httpoptions(http_options, message)
+ app = genai.Client(api_key=api_key, http_options=http_options)
+ uploaded_audio = File()
try:
- logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
- http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
- http_options = hook_gemini_httpoptions(http_options, message)
- app = genai.Client(api_key=api_key, http_options=http_options)
uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=guess_mime(path)))
genconfig = {}
with contextlib.suppress(Exception):
@@ -122,7 +119,7 @@ async def gemini_single_file(
if delete_local_file:
path.unlink(missing_ok=True)
with contextlib.suppress(Exception):
- if "uploaded_audio" in locals() and uploaded_audio.name:
+ if uploaded_audio.name:
if delete_gemini_file:
await app.aio.files.delete(name=uploaded_audio.name)
else:
@@ -193,109 +190,3 @@ async def gemini_file_chunks(
logger.error(e)
return {"error": str(e)}
return transcription
-
-
-async def gemini_stream_asr(
- client: Client,
- message: Message,
- path: str | Path,
- voice_format: str,
- model_id: str | None = None,
- prompt: str = "请转录这段音频",
- *,
- silent: bool = False,
- delete_gemini_file: bool = True,
- **kwargs,
-) -> dict:
- """(Deprecated) Gemini stream ASR.
-
- https://ai.google.dev/gemini-api/docs/audio
-
- Args:
- silent (bool, optional): If Ture, do not update the status, return all results in the end.
- """
- system_instruction = """You are a transcription assistant tasked with converting audio files into text.
-
-Your output must follow these requirements:
-- Format each sentence as `[hh:mm:ss] sentence` with punctuation included, where `hh:mm:ss` is the start time of the sentence in the audio.
-- Omit the hour (`hh`) if it is zero, displaying only `mm:ss`.
-- Directly transcribe the audio content without any greetings or content unrelated to the audio itself.
-
-Steps:
-1. Listen to the audio file carefully and identify the start time of each sentence.
-2. Transcribe the audio content word-for-word, including punctuation, according to the specified format.
-3. Ensure that all time codes (hh:mm:ss or mm:ss) are precise.
-
-Output Format:
-- Each sentence should be formatted in a line as `[hh:mm:ss] sentence`.
-- Exclude any hour segment that equals zero, converting `[00:mm:ss]` to `[mm:ss]`.
-- Do not include any additional commentary or greetings.
-
-Example-1:
-- Input: Audio with content starting at 2 seconds.
-- Output: [00:02] 大家好, 我是小明, 欢迎来到我的频道。
-
-Example-2:
-- Input: Audio with content at 8 seconds and 1 hour, 12 minutes, and 32 seconds.
-- Output: [00:08] 今天要和大家聊一个一直以来都很有争议的话题。
-[01:12:32] 谢谢大家收听。
-
-
-Notes:
-- Focus on accuracy in capturing both the timing and the spoken content.
-- Maintain consistent formatting to ensure clarity and readability."""
- path = Path(path)
- res = {}
- sent_messages = []
- status = None if silent else kwargs.get("progress")
- api_keys = shuffle_keys(GEMINI.API_KEY)
- if model_id is None:
- model_id = GEMINI.ASR_MODEL
- for api_key in api_keys.split(","):
- try:
- logger.debug(f"ASR via {model_id}: {path.as_posix()} , proxy={GEMINI.PROXY}")
- http_options = HttpOptions(base_url=GEMINI.BASE_URL, async_client_args={"proxy": GEMINI.PROXY})
- http_options = hook_gemini_httpoptions(http_options, message)
- app = genai.Client(api_key=api_key, http_options=http_options)
- uploaded_audio = await app.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
- genconfig = {}
- with contextlib.suppress(Exception):
- genconfig = json.loads(GEMINI.ASR_CONFIG)
- genconfig |= {"response_modalities": ["TEXT"]} # force text response
- genconfig |= {"system_instruction": system_instruction} # pin system instruction
- if GEMINI.ASR_THINKING_BUDGET is not None:
- thinking_budget = min(round(float(GEMINI.ASR_THINKING_BUDGET)), GEMINI.MAX_THINKING_BUDGET)
- genconfig |= {"thinking_config": ThinkingConfig(include_thoughts=False, thinking_budget=thinking_budget)}
- if GEMINI.ASR_USE_GROUNDING:
- genconfig |= {"tools": [Tool(url_context=UrlContext()), Tool(google_search=GoogleSearch())]}
- contents = [prompt, uploaded_audio] if prompt else [uploaded_audio]
- params = {"model": model_id, "contents": contents, "config": GenerateContentConfig(**genconfig)}
- res = await gemini_stream(
- client,
- message,
- model_name="ASR",
- params=params,
- prefix="",
- silent=silent,
- max_retry=0,
- gemini_api_key=api_key,
- append_grounding=False,
- **kwargs,
- )
- if res.get("error") or count_subtitles(res.get("texts", "")) == 0:
- continue
- sent_messages = res.get("sent_messages", [])
- break
- except Exception as e:
- logger.error(e)
- with contextlib.suppress(Exception):
- [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
- finally:
- with contextlib.suppress(Exception):
- if "uploaded_audio" in locals() and uploaded_audio.name:
- if delete_gemini_file:
- await app.aio.files.delete(name=uploaded_audio.name)
- else:
- res["gemini_file"] = uploaded_audio
- res["sent_messages"] = [status, *sent_messages]
- return res
src/asr/voice_recognition.py
@@ -212,8 +212,8 @@ async def asr_file(
path = path.rename(path.with_stem(rand_string())) # sanitize filename. (for Tencent Signature v3)
logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
+ res = {}
try:
- res = {}
if asr_method == "tencent_single_asr":
res = await tencent_single_asr(path, language, voice_format)
elif asr_method == "tencent_flash_asr":
src/danmu/turso.py
@@ -6,12 +6,12 @@ from decimal import Decimal
import anyio
from loguru import logger
-from config import DOWNLOAD_DIR, cache, cutter
+from config import DOWNLOAD_DIR, TZ, cache, cutter
from danmu.utils import TURSO_KWARGS
from database.turso import turso_exec, turso_parse_resp
from messages.progress import modify_progress
from others.emoji import CURRENCY
-from utils import number
+from utils import nowstr, number
async def query_turso(match_time: str, user: str, keyword: str, caption: str, super_chats: defaultdict, qtype: str, **kwargs) -> dict:
@@ -20,6 +20,8 @@ async def query_turso(match_time: str, user: str, keyword: str, caption: str, su
Returns:
{"paths": list[str], "count": int}
"""
+ begin = "1970-01-01 00:00:00"
+ end = nowstr(TZ)
if match_time:
if len(match_time) == 4: # 2025
begin = f"{match_time}-01-01 00:00:00"
@@ -30,6 +32,7 @@ async def query_turso(match_time: str, user: str, keyword: str, caption: str, su
elif len(match_time) == 10: # 2025-01-01
begin = f"{match_time} 00:00:00"
end = f"{match_time} 23:59:59"
+ texts_to_match = ""
if keyword:
segmented = " ".join(cutter.cutword(keyword))
texts_to_match = keyword if segmented == keyword else f'"{keyword}" OR "{segmented}"' # must use double quotes for inner part
src/danmu/utils.py
@@ -151,6 +151,8 @@ def merge_txt_files(paths: list[str], dates: list[str], user: str, keyword: str,
date_name = dates[0][:4]
elif all(len(x) == 4 for x in dates): # all years
date_name = dates[0] if len(set(dates)) == 1 else ""
+ else:
+ date_name = ""
keyword = f"“{keyword}”" if keyword else ""
src/database/alist.py
@@ -50,6 +50,7 @@ async def upload_alist(path: str | Path) -> str:
return ""
api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/form"
# Headers DO NOT support Unicode characters
+ new_path = Path("/non-exist")
if any(ord(c) > 127 for c in path.name): # has Non-ASCII
new_name = base64.urlsafe_b64encode(path.name.encode("utf-8")).decode("ascii").rstrip("=")
new_path = path.with_stem(new_name)
src/history/utils.py
@@ -87,6 +87,10 @@ def fine_grained_check(info: dict) -> bool:
keywords = [x.strip() for x in os.environ[f"HISTORY_{cid}_SKIP_KEYWORDS"].split(",") if x.strip()]
if any(x in info["text"] for x in keywords):
return False
+ if os.getenv("HISTORY_D1_ONLY_SYNC_CHATS") and cid not in strings_list(env_key="HISTORY_D1_ONLY_SYNC_CHATS"):
+ return False
+ if os.getenv("HISTORY_TURSO_ONLY_SYNC_CHATS") and cid not in strings_list(env_key="HISTORY_TURSO_ONLY_SYNC_CHATS"):
+ return False
return True
src/llm/gemini.py
@@ -150,6 +150,8 @@ async def gemini_stream(
kwargs["gemini_api_keys"] = shuffle_keys(GEMINI.API_KEY)
api_keys = [x.strip() for x in kwargs["gemini_api_keys"].split(",") if x.strip()]
max_retry = len(api_keys) - 1 if max_retry is None else max_retry
+ resp = {}
+ sent_messages = []
try:
if retry > min(len(api_keys) - 1, max_retry):
logger.error(f"[Gemini] Failed after {retry} retries")
@@ -169,10 +171,10 @@ async def gemini_stream(
logger.warning(f"[Gemini] Content is too long: {num_tokens} tokens, fallback to {GEMINI.TEXT_TOKENS_FALLBACK_MODEL}")
params["model"] = GEMINI.TEXT_TOKENS_FALLBACK_MODEL
params["config"].thinking_config = None
- sent_messages = []
is_reasoning = False
is_reasoning_conversation = None # to indicate whether it is a reasoning conversation
genai_params = {"model": params["model"], "contents": params["contents"], "config": params["config"]}
+ length = 0
async for chunk in await app.aio.models.generate_content_stream(**genai_params):
resp = parse_response(chunk.model_dump())
answer = resp.get("texts", "")
@@ -337,9 +339,9 @@ async def gemini_nonstream(
logger.error(e)
error = str(e)
if "res" in locals():
- error += f"\n{res}"
+ error += f"\n{res}" # type: ignore
if "response" in locals():
- error += f"\n{response}"
+ error += f"\n{response}" # type: ignore
await modify_progress(text=error, force_update=True, **kwargs)
return await gemini_nonstream(client, message, model_name, params, retry + 1, clean_marks=clean_marks, append_grounding=append_grounding, **kwargs) # type: ignore
return results
src/llm/response_stream.py
@@ -50,14 +50,15 @@ async def send_to_gpt_stream(
status_msg = None
status_cid = status_msg.chat.id if isinstance(status_msg, Message) else 0
status_mid = status_msg.id if isinstance(status_msg, Message) else 0
+ sent_messages = []
try:
pre_hooks(config["client"], config["completions"], message_info=kwargs.get("message_info"), system_prompt=system_prompt)
openai = AsyncOpenAI(**config["client"])
logger.trace(config)
- sent_messages = []
is_reasoning = False
is_reasoning_conversation = None # 用于指示是否是推理对话
gen = await openai.chat.completions.create(**config["completions"], stream=True)
+ length = 0
async for chunk in gen:
resp = chunk.model_dump()
logger.trace(resp)
@@ -162,7 +163,7 @@ async def send_to_gpt_stream(
except Exception as e:
error = f"🤖{config['friendly_name']}请求失败, 重试次数: {retry + 1}/{GPT.MAX_RETRY + 1}\n{e}"
if "resp" in locals():
- error += f"\n{resp}"
+ error += f"\n{resp}" # type: ignore
logger.error(error)
with contextlib.suppress(Exception):
await modify_progress(text=error, force_update=True, **kwargs)
src/llm/summary.py
@@ -205,6 +205,8 @@ async def parse_history_list(info_list: list[dict]) -> dict:
Currently, we only summarize text contents.
"""
+ begin_time = datetime.fromtimestamp(0, tz=ZoneInfo(TZ))
+ end_time = nowdt(tz=TZ)
messages: list[dict] = [] # hold user messages
for info in info_list:
if info["file_name"] == CONTEXT_FILENAME:
src/others/download_external.py
@@ -55,6 +55,7 @@ async def download_url_in_message(client: Client, message: Message, extra_prefix
else:
await modify_progress(text=msg, force_update=True, **kwargs)
success = False
+ path = Path("non-exist")
try:
path = await download_file(url, workers_proxy=True, **kwargs)
path = Path(path)
src/others/emoji.py
@@ -350,6 +350,8 @@ def emojify(text: str, platform: str = "all") -> str:
EMOJI_MAP = COMMON | DOUYIN
elif platform == "bilibili":
EMOJI_MAP = COMMON | BILIBILI
+ else:
+ EMOJI_MAP = COMMON
pattern = re.compile("|".join(re.escape(k) for k in EMOJI_MAP))
return pattern.sub(lambda match: EMOJI_MAP[match.group(0)], text)
src/preview/bilibili.py
@@ -48,6 +48,9 @@ async def preview_bilibili(
await modify_progress(text=f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...", **kwargs)
if platform == "bilibili-opus":
post_info = await parse_bilibili_opus(post_id, **kwargs)
+ else:
+ msg = f"Unsupported platform: {platform}"
+ raise RuntimeError(msg)
if error_msg := post_info.get("error_msg"):
await modify_progress(text=f"❌B站解析失败: {error_msg}", force_update=True, **kwargs)
msg = ""
src/preview/instagram.py
@@ -53,6 +53,7 @@ async def preview_instagram(
return
await modify_progress(text=f"❌从{DB.ENGINE}缓存中转发失败, 尝试重新解析...", **kwargs)
succ = False
+ resp = {}
if "tikhub" in instagram_provider: # try tikhub
api_url = API.TIKHUB_INSTAGRAM + url
logger.info(f"Preview Instagram TikHub for {api_url}")
src/price/entrypoint.py
@@ -168,7 +168,7 @@ async def match_symbol_category(symbol: str = "", *, crypto_only: bool = False,
category["tradingview"] = ", ".join(tv_symbols)
# skip some crypto ETF (e.g. Grayscale Bitcoin Mini Trust use symbol "AMEX:BTC")
if category.get("crypto") and category.get("tradingview"):
- exchange, coin = tradingview[0][0].split(":")
+ exchange, coin = tradingview[0][0].split(":") # type: ignore
if exchange == "AMEX" and category.get("crypto", "").startswith(coin): # hit crypto ETF
tradingview = []
del category["tradingview"]
src/subtitles/subtitle.py
@@ -78,7 +78,9 @@ async def get_subtitle(
# Fetch subtitle via API
res = await fetch_subtitle(url, youtube_subtitle_provider)
subtitle_file_sent = False
-
+ subtitle_msg = None
+ status_msg = None
+ subtitles = ""
# API failed
if error := res.get("error", ""):
if this_info["mtype"] in ["audio", "video"] or reply_info.get("mtype", "") in ["audio", "video"]:
src/tts/qwen.py
@@ -49,6 +49,9 @@ async def qwen_tts_real(texts: str, model: str, voice_name: str, *, convert_ogg:
Returns:
{"url": str, "duration": int, "caption": str}
"""
+ save_path = Path("/non-exist")
+ duration = 0
+ caption = ""
for api_key in strings_list(TTS.ALI_API_KEY, shuffle=True):
try:
logger.debug(f"TTS via {model}, voice: {voice_name}, texts: {texts}")
src/tts/sambert.py
@@ -53,6 +53,9 @@ async def sambert_tts_real(texts: str, model: str, voice_name: str, *, convert_o
Returns:
{"url": str, "duration": int, "caption": str}
"""
+ save_path = Path("/non-exist")
+ duration = 0
+ caption = ""
for api_key in strings_list(TTS.ALI_API_KEY, shuffle=True):
try:
logger.debug(f"TTS via {model}, voice: {voice_name}, texts: {texts}")
src/tts/tts.py
@@ -63,6 +63,9 @@ async def text_to_speech(client: Client, message: Message, **kwargs):
resp = await qwen_tts(texts, model, voice_name)
elif engine == "sambert":
resp = await sambert_tts(texts, model, voice_name)
+ else:
+ msg = f"Unknown engine: {engine}"
+ raise ValueError(msg)
path = Path(resp.get("voice", ""))
if path.is_file():
src/networking.py
@@ -115,9 +115,9 @@ async def hx_req(
except Exception as e:
error = f"{type(e).__name__}[{retry + 1}/{max_retry + 1}]: Failed to request {url}, {e}"
with contextlib.suppress(Exception):
- error += f"\n{response.json()}"
+ error += f"\n{response.json()}" # type: ignore
if "res" in locals():
- error += f"\n{res}"
+ error += f"\n{res}" # type: ignore
elif "data" in locals():
error += f"\n{data}"
logger.error(error)