Commit 7ed8330
Changed files (6)
src
src/asr/gemini_asr.py
@@ -18,60 +18,6 @@ from messages.progress import modify_progress
from messages.utils import blockquote, count_without_entities, smart_split
-class Transcription(BaseModel):
- start_minute: int
- start_second: int
- sentence_with_punctuation: str
-
-
-async def gemini_asr(path: str | Path, voice_format: str) -> str:
- """Gemini ASR.
-
- https://ai.google.dev/gemini-api/docs/audio
- """
- path = Path(path)
- api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
- random.shuffle(api_keys)
- res = ""
- for key in api_keys:
- try:
- logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={ASR.GEMINI_PROXY}")
- client = genai.Client(api_key=key, http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
- uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
- logger.debug(uploaded_audio)
- response = await client.aio.models.generate_content(
- model=ASR.GEMINI_MODEL,
- contents=["请转录这段音频", uploaded_audio],
- config=GenerateContentConfig(
- response_mime_type="application/json",
- response_schema=list[Transcription],
- ),
- )
- if parsed := glom(response.model_dump(), "parsed"):
- return generate_transcription(parsed)
- except Exception as e:
- logger.error(e)
- res = str(e)
- return res
-
-
-def generate_transcription(items: list[dict]) -> str:
- res = ""
- show_timestamp = False
- for idx, item in enumerate(items):
- sentence: str = item["sentence_with_punctuation"]
- if not sentence:
- continue
-
- if idx == 0 or res.endswith((".", "。")):
- show_timestamp = True
- if show_timestamp:
- res += f"\n[{item['start_minute']}:{item['start_second']:02d}] {sentence}"
- else:
- res += sentence
- return res.strip()
-
-
async def gemini_stream_asr(client: Client, message: Message, path: str | Path, voice_format: str, *, slient: bool = False, **kwargs) -> dict:
"""Gemini stream ASR.
@@ -119,7 +65,7 @@ async def gemini_stream_asr(client: Client, message: Message, path: str | Path,
parts = await smart_split(runtime_texts)
await modify_progress(message=status, text=blockquote(parts[0]), force_update=True) # force send the first part
runtime_texts = parts[-1] # keep the last part
- if not status:
+ if not slient:
status = await client.send_message(message.chat.id, runtime_texts) # the new message
sent_messages.append(status)
@@ -128,3 +74,60 @@ async def gemini_stream_asr(client: Client, message: Message, path: str | Path,
except Exception as e:
logger.error(e)
return {"texts": transcriptions, "sent_messages": sent_messages}
+
+
+class Transcription(BaseModel):
+ start_minute: int
+ start_second: int
+ sentence_with_punctuation: str
+
+
+def generate_transcription(items: list[dict]) -> str:
+ res = ""
+ show_timestamp = False
+ for idx, item in enumerate(items):
+ sentence: str = item["sentence_with_punctuation"]
+ if not sentence:
+ continue
+
+ if idx == 0 or res.endswith((".", "。")):
+ show_timestamp = True
+ if show_timestamp:
+ res += f"\n[{item['start_minute']}:{item['start_second']:02d}] {sentence}"
+ else:
+ res += sentence
+ return res.strip()
+
+
+async def gemini_nonstream_asr(path: str | Path, voice_format: str) -> str:
+ """(Deprecated) Gemini ASR.
+
+ This function is deprecated and will be removed in the future.
+ Use `gemini_stream_asr` instead.
+
+ https://ai.google.dev/gemini-api/docs/audio
+ """
+ path = Path(path)
+ api_keys = [x.strip() for x in ASR.GEMINI_API_KEY.split(",") if x.strip()]
+ random.shuffle(api_keys)
+ res = ""
+ for key in api_keys:
+ try:
+ logger.debug(f"ASR via {ASR.GEMINI_MODEL}: {path.as_posix()} , proxy={ASR.GEMINI_PROXY}")
+ client = genai.Client(api_key=key, http_options=HttpOptions(base_url=ASR.GEMINI_BASR_URL, async_client_args={"proxy": ASR.GEMINI_PROXY}))
+ uploaded_audio = await client.aio.files.upload(file=path, config=UploadFileConfig(mime_type=f"audio/{voice_format}"))
+ logger.debug(uploaded_audio)
+ response = await client.aio.models.generate_content(
+ model=ASR.GEMINI_MODEL,
+ contents=["请转录这段音频", uploaded_audio],
+ config=GenerateContentConfig(
+ response_mime_type="application/json",
+ response_schema=list[Transcription],
+ ),
+ )
+ if parsed := glom(response.model_dump(), "parsed"):
+ return generate_transcription(parsed)
+ except Exception as e:
+ logger.error(e)
+ res = str(e)
+ return res
src/asr/voice_recognition.py
@@ -11,7 +11,7 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
-from asr.gemini_asr import gemini_asr, gemini_stream_asr
+from asr.gemini_asr import gemini_stream_asr
from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
from asr.utils import get_asr_method
from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
@@ -20,7 +20,7 @@ from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import blockquote, count_without_entities, equal_prefix, get_reply_to, startswith_prefix
from multimedia import convert_to_audio, parse_media_info
-from utils import rand_string, to_int
+from utils import publish_telegraph, rand_string, to_int
# ruff: noqa: RUF001
@@ -80,6 +80,7 @@ async def voice_to_text(
asr_skip_voice: bool | None = None,
asr_skip_audio: bool | None = None,
asr_skip_video: bool | None = None,
+ to_telegraph: bool = True,
**kwargs,
) -> None:
"""Voice, audio, video message to text.
@@ -146,8 +147,12 @@ async def voice_to_text(
elif length < TEXT_LENGTH: # middle
await client.send_message(to_int(target_chat), final, reply_parameters=reply_parameters)
else: # long
+ if to_telegraph:
+ html = "\n".join([f"<p>{s}</p>" for s in texts.split("\n")])
+ if telegraph_url := await publish_telegraph(title=trigger_info["text"] or "语音识别结果", html=html, author=trigger_info["full_name"], url=trigger_info["message_url"]):
+ caption = f"\n**⚡️[Telegraph即时预览]({telegraph_url})**"
with io.BytesIO(texts.encode("utf-8")) as f:
- await client.send_document(to_int(target_chat), f, file_name="语音识别结果.txt", reply_parameters=reply_parameters)
+ await client.send_document(to_int(target_chat), f, file_name="语音识别结果.txt", caption=caption, reply_parameters=reply_parameters)
await modify_progress(del_status=True, **kwargs)
with contextlib.suppress(Exception):
@@ -161,8 +166,6 @@ async def asr_file(
engine: str = "",
duration: int = 0,
language: str = "16k_zh-PY",
- *,
- gemini_stream_mode: bool = True,
**kwargs,
) -> dict:
"""Get ASR results of an audio file."""
@@ -219,10 +222,8 @@ async def asr_file(
else:
texts = glom(result, "Response.Data.ErrorMsg")
res["error"] = texts
- elif asr_method == "gemini" and gemini_stream_mode:
- return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
elif asr_method == "gemini":
- texts = await gemini_asr(path, voice_format)
+ return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
res["texts"] = texts
logger.success(f"{texts!r}")
except Exception as e:
src/others/subtitle.py
@@ -18,7 +18,7 @@ from messages.progress import modify_progress
from messages.sender import send2tg
from messages.utils import equal_prefix, startswith_prefix
from networking import hx_req, match_social_media_link
-from utils import to_int
+from utils import publish_telegraph, to_int
HELP = f"""📃**提取字幕**
使用说明:
@@ -29,7 +29,7 @@ HELP = f"""📃**提取字幕**
"""
-async def get_subtitle(client: Client, message: Message, youtube_subtitle_provider: str = PROVIDER.YOUTUBE_SUBTITLE, **kwargs):
+async def get_subtitle(client: Client, message: Message, youtube_subtitle_provider: str = PROVIDER.YOUTUBE_SUBTITLE, *, to_telegraph: bool = True, **kwargs):
"""Get YouTube Subtitle."""
target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else message.chat.id
# send docs if message == "/subtitle", without reply
@@ -67,21 +67,31 @@ async def get_subtitle(client: Client, message: Message, youtube_subtitle_provid
if asr_res.get("error"):
await modify_progress(text=asr_res["error"], force_update=True, **kwargs)
return
- res = {"subtitle": asr_res["texts"], "num_chars": len(asr_res["texts"]), "reading_minutes": len(asr_res["texts"]) / READING_SPEED}
+ res = {"subtitles": asr_res["texts"], "num_chars": len(asr_res["texts"]), "reading_minutes": len(asr_res["texts"]) / READING_SPEED}
+ if asr_res.get("telegraph"):
+ res["telegraph"] = asr_res["telegraph"]
else:
await modify_progress(text=error, force_update=True, **kwargs)
return
- subtitles = res.get("subtitle", "")
+ subtitles = res.get("subtitles", "")
if not subtitles:
return
logger.success(subtitles)
if vinfo := await fetch_youtube_video_info(vid):
caption = f"🔴[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['date']:%Y-%m-%d %H:%M:%S}\n"
caption += f"📝[{vinfo['title']}]({yt_url})\n字符数: {res['num_chars']}\n阅读时长: {res['reading_minutes']:.1f}分钟"
+ if to_telegraph:
+ html = "\n".join([f"<p>{s}</p>" for s in subtitles.split("\n")])
+ if telegraph_url := await publish_telegraph(title=vinfo["title"], html=html, author=vinfo["author"], url=yt_url):
+ caption += f"\n**⚡️[Telegraph即时预览]({telegraph_url})**"
with io.BytesIO(subtitles.encode("utf-8")) as f:
await client.send_document(to_int(target_chat), f, file_name=f"{vinfo['title']}.txt", caption=caption)
else:
caption = f"原视频: [{vid}]({yt_url})\n字符数: {res['num_chars']}\n阅读时长: {res['reading_minutes']:.1f}分钟"
+ if to_telegraph:
+ html = "\n".join([f"<p>{s}</p>" for s in subtitles.split("\n")])
+ if telegraph_url := await publish_telegraph(title=f"{vid}字幕", html=html, url=yt_url):
+ caption += f"\n**⚡️[Telegraph即时预览]({telegraph_url})**"
with io.BytesIO(subtitles.encode("utf-8")) as f:
await client.send_document(to_int(target_chat), f, file_name=f"{vid}字幕.txt", caption=caption)
@@ -143,10 +153,10 @@ async def fetch_subtitle(video_id: str, provider: str) -> dict:
except Exception as e:
logger.error(f"Failed to get subtitle: {e}")
return {"error": str(e)}
- return to_transcription(subtitles)
+ return await to_transcription(subtitles)
-def to_transcription(subtitles: list[dict]) -> dict:
+async def to_transcription(subtitles: list[dict]) -> dict:
"""Converts subtitles to "[minute:second] transcription" format.
sample subtitles = [
@@ -156,7 +166,7 @@ def to_transcription(subtitles: list[dict]) -> dict:
Returns:
dict: {
- "subtitle": "[minute:second] transcription",
+ "subtitles": "[minute:second] transcription",
"num_chars": 11,
"num_tokens": 2,
}
@@ -164,17 +174,16 @@ def to_transcription(subtitles: list[dict]) -> dict:
if not subtitles:
return {}
- res = []
+ sentences = []
num_chars = 0
for subtitle in subtitles:
minutes = int(float(subtitle["start"]) // 60)
seconds = int(float(subtitle["start"]) % 60)
- res.append(f"[{minutes}:{seconds:02d}] {subtitle['text']}")
+ sentences.append(f"[{minutes}:{seconds:02d}] {subtitle['text']}")
num_chars += len(subtitle["text"])
-
return {
- "subtitle": "\n".join(res),
+ "subtitles": "\n".join(sentences),
"num_chars": num_chars,
"reading_minutes": num_chars / READING_SPEED,
}
@@ -190,7 +199,7 @@ def to_webvtt(subtitles: list[dict]) -> dict:
Returns:
dict: {
- "subtitle": "strings of subtitles in WebVTT format",
+ "subtitles": "strings of subtitles in WebVTT format",
"num_chars": 11,
"num_tokens": 2,
}
@@ -220,7 +229,7 @@ def to_webvtt(subtitles: list[dict]) -> dict:
vtt_output.append("") # Add blank line between subtitles
# num_tokens = count_tokens("\n".join(vtt_output))
reading_minutes = num_chars / READING_SPEED # minutes
- return {"subtitle": "\n".join(vtt_output), "num_chars": num_chars, "reading_minutes": reading_minutes}
+ return {"subtitles": "\n".join(vtt_output), "num_chars": num_chars, "reading_minutes": reading_minutes}
except Exception as e:
logger.error(f"Failed to convert subtitles to WebVTT: {e}")
return {"error": str(e)}
src/preview/wechat.py
@@ -52,7 +52,7 @@ async def preview_wechat(client: Client, message: Message, url: str = "", db_key
texts = f"{post_info['header']}"
telegraph_url = await publish_telegraph(title=post_info["title"], html=post_info["html"], author=post_info["author"], url=url)
if telegraph_url:
- texts += f"\n[⚡️点击此处即时预览]({telegraph_url})"
+ texts += f"\n[⚡️Telegraph即时预览]({telegraph_url})"
sent_messages.extend(await send2tg(client, message, texts=texts, media=[{"document": post_info["html_path"]}], **kwargs))
elif length < CAPTION_LENGTH - 8: # 有图片短文
texts = f"{post_info['header']}\n{BLOCKQUOTE_EXPANDABLE_DELIM}{post_info['markdown']}\n{BLOCKQUOTE_EXPANDABLE_END_DELIM}"
@@ -61,7 +61,7 @@ async def preview_wechat(client: Client, message: Message, url: str = "", db_key
texts = f"{post_info['header']}"
telegraph_url = await publish_telegraph(title=post_info["title"], html=post_info["html"], author=post_info["author"], url=url)
if telegraph_url:
- texts += f"\n**⚡️[点击此处即时预览]({telegraph_url})**"
+ texts += f"\n**⚡️[Telegraph即时预览]({telegraph_url})**"
sent_messages.extend(await send2tg(client, message, texts=texts, media=[{"document": post_info["path"]}], **kwargs))
kwargs["reply_msg_id"] = -1 # do not send as reply
sent_messages.extend(await send2tg(client, message, texts=texts, media=post_info["media"], **kwargs))
src/preview/ytdlp.py
@@ -33,7 +33,7 @@ from networking import hx_req
from others.emoji import emojify
from others.subtitle import fetch_subtitle
from preview.utils import bv2av, make_bvid_clickable
-from utils import readable_size, readable_time, remove_none_values, soup_to_text, to_int, true, ts_to_dt, unicode_to_ascii
+from utils import publish_telegraph, readable_size, readable_time, remove_none_values, soup_to_text, to_int, true, ts_to_dt, unicode_to_ascii
class ProxyError(Exception):
@@ -54,6 +54,7 @@ async def preview_ytdlp(
append_youtube_subtitle: bool = True,
append_transcription: bool = True,
ytdlp_transcription_engine: str = "gemini",
+ to_telegraph: bool = True,
**kwargs,
):
"""Preview ytdlp link in the message.
@@ -71,6 +72,7 @@ async def preview_ytdlp(
append_youtube_subtitle (bool, optional): Also send youtube subtitle.
append_transcription (bool, optional): Also append transcription.
ytdlp_transcription_method (str, optional): Method to get transcription.
+ to_telegraph (bool, optional): Whether to publish the subtitle or transcription to telegraph.
"""
logger.trace(f"{url=} {kwargs=}")
if kwargs.get("show_progress") and "progress" not in kwargs:
@@ -235,6 +237,10 @@ async def preview_ytdlp(
res = await fetch_subtitle(video_id=info["id"], provider="free")
if subtitles := res.get("subtitle"):
caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {res['num_chars']}\n阅读时长: {res['reading_minutes']:.1f}分钟"
+ if to_telegraph:
+ html = "\n".join([f"<p>{s}</p>" for s in subtitles.split("\n")])
+ if telegraph_url := await publish_telegraph(title=info["title"], html=html, author=info["author"], url=url):
+ caption += f"\n**⚡️[Telegraph即时预览]({telegraph_url})**"
with io.BytesIO(subtitles.encode("utf-8")) as f:
await client.send_document(to_int(target_chat), f, file_name=f"{info['title']}.txt", caption=caption)
append_transcription = False # disable asr transcription
@@ -243,6 +249,10 @@ async def preview_ytdlp(
asr_res = await asr_file(audio_path, ytdlp_transcription_engine, duration, client=client, message=message, slient=True)
if texts := asr_res.get("texts"):
caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {len(texts)}\n阅读时长: {len(texts) / READING_SPEED:.1f}分钟"
+ if to_telegraph:
+ html = "\n".join([f"<p>{s}</p>" for s in texts.split("\n")])
+ if telegraph_url := await publish_telegraph(title=info["title"], html=html, author=info["author"], url=url):
+ caption += f"\n**⚡️[Telegraph即时预览]({telegraph_url})**"
with io.BytesIO(texts.encode("utf-8")) as f:
await client.send_document(to_int(target_chat), f, file_name=f"{info['title']}.txt", caption=caption)
[await modify_progress(msg, del_status=True) for msg in asr_res.get("sent_messages", [])]
src/utils.py
@@ -381,11 +381,20 @@ async def publish_telegraph(title: str, texts: str | None = None, html: str | No
account_info = {}
if not (author and url):
with contextlib.suppress(Exception):
- if account_info := await telegraph.get_account_info():
- author = glom(account_info, Coalesce("result.short_name", "result.author_name"), default=None)
- url = glom(account_info, "result.author_url", default=None)
+ account_info = await telegraph.get_account_info()
+ if not author:
+ author = glom(account_info, Coalesce("result.short_name", "result.author_name"), default=None)
+ if not url:
+ url = glom(account_info, "result.author_url", default=None)
+ # sanitize
+ title = title[:256]
+ if isinstance(author, str):
+ author = author[:128]
+ if isinstance(url, str):
+ url = url[:512]
try:
- page = await telegraph.create_page(title=title, author_name=author, author_url=url, html_content=html)
+ page = await telegraph.create_page(title=title[:256], author_name=author, author_url=url, html_content=html)
+ logger.info(f"⚡️Telegraph即时预览: {page['url']}")
return page["url"]
except Exception as e:
logger.error(f"Telegraph publish error: {e}")