Commit d9b0f48
Changed files (4)
src
src/asr/gemini_asr.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-import io
import random
from pathlib import Path
@@ -10,7 +9,7 @@ from google.genai.types import GenerateContentConfig, HttpOptions, UploadFileCon
from loguru import logger
from pydantic import BaseModel
from pyrogram.client import Client
-from pyrogram.types import Message, ReplyParameters
+from pyrogram.types import Message
from config import ASR, TEXT_LENGTH
from llm.gemini import parse_response
@@ -120,16 +119,12 @@ async def gemini_stream_asr(client: Client, message: Message, path: str | Path,
parts = await smart_split(runtime_texts)
await modify_progress(message=status, text=blockquote(parts[0]), force_update=True) # force send the first part
runtime_texts = parts[-1] # keep the last part
- if not slient:
+ if not status:
status = await client.send_message(message.chat.id, runtime_texts) # the new message
sent_messages.append(status)
# all chunks are processed
await modify_progress(message=status, text=blockquote(beautify_llm_response(runtime_texts)), force_update=True)
- if len(sent_messages) > 1:
- with io.BytesIO(transcriptions.encode("utf-8")) as f:
- await client.send_document(message.chat.id, f, file_name="语音识别结果.txt", reply_parameters=ReplyParameters(message_id=message.id))
- [await modify_progress(msg, del_status=True) for msg in sent_messages]
except Exception as e:
logger.error(e)
- return {"texts": transcriptions} if slient else {}
+ return {"texts": transcriptions, "sent_messages": sent_messages}
src/asr/voice_recognition.py
@@ -9,7 +9,6 @@ from pathlib import Path
from glom import glom
from loguru import logger
from pyrogram.client import Client
-from pyrogram.parser.markdown import BLOCKQUOTE_EXPANDABLE_DELIM, BLOCKQUOTE_EXPANDABLE_END_DELIM
from pyrogram.types import Message
from asr.gemini_asr import gemini_asr, gemini_stream_asr
@@ -19,7 +18,7 @@ from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
from messages.parser import parse_msg
from messages.progress import modify_progress
from messages.sender import send2tg
-from messages.utils import count_without_entities, equal_prefix, get_reply_to, startswith_prefix
+from messages.utils import blockquote, count_without_entities, equal_prefix, get_reply_to, startswith_prefix
from multimedia import convert_to_audio, parse_media_info
from utils import rand_string, to_int
@@ -136,7 +135,7 @@ async def voice_to_text(
await modify_progress(text=error, force_update=True, **kwargs)
return
if texts := res.get("texts"):
- final = f"{BEGINNING}\n{BLOCKQUOTE_EXPANDABLE_DELIM}{texts}{BLOCKQUOTE_EXPANDABLE_END_DELIM}"
+ final = f"{BEGINNING}\n{blockquote(texts)}"
logger.success(f"{final!r}")
# send results
target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
@@ -152,6 +151,7 @@ async def voice_to_text(
await modify_progress(del_status=True, **kwargs)
with contextlib.suppress(Exception):
+ [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
if this_info["mtype"] == "text":
await message.delete()
src/others/subtitle.py
@@ -10,6 +10,7 @@ from pyrogram.client import Client
from pyrogram.types import Message
from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
+from asr.voice_recognition import asr_file
from config import API, PREFIX, PROVIDER, PROXY, READING_SPEED, TOKEN, TZ
from database import cache
from messages.parser import parse_msg
@@ -43,19 +44,36 @@ async def get_subtitle(client: Client, message: Message, youtube_subtitle_provid
if kwargs.get("show_progress"):
res = await send2tg(client, message, texts=msg, **kwargs)
kwargs["progress"] = res[0]
+
# cache media_group message
if media_group_id := message.media_group_id:
if cache.get(f"subtitle-{message.chat.id}-{media_group_id}"):
return
cache.set(f"subtitle-{message.chat.id}-{media_group_id}", "1", ttl=120)
+ this_info = parse_msg(message, silent=True)
+ reply_info = parse_msg(message.reply_to_message, silent=True) if message.reply_to_message else {}
+
res = await fetch_subtitle(vid, youtube_subtitle_provider)
if error := res.get("error", ""):
- await modify_progress(text=error, force_update=True, **kwargs)
- return
- if not res.get("subtitle", ""):
- return
+ if "Subtitles are disabled for this video" in error:
+ error = "❌该视频没有提供字幕选项"
+ if this_info["mtype"] in ["audio", "video"] or reply_info.get("mtype", "") in ["audio", "video"]:
+ error += "\n🔄尝试使用语音转文字获取字幕"
+ await modify_progress(text=error, force_update=True, **kwargs)
+ msg = message if this_info["mtype"] in ["audio", "video"] else message.reply_to_message
+ fpath: str = await msg.download() # type: ignore
+ asr_res = await asr_file(fpath, engine="gemini", client=client, message=message, **kwargs)
+ if asr_res.get("error"):
+ await modify_progress(text=asr_res["error"], force_update=True, **kwargs)
+ return
+ res = {"subtitle": asr_res["texts"], "num_chars": len(asr_res["texts"]), "reading_minutes": len(asr_res["texts"]) / READING_SPEED}
+ else:
+ await modify_progress(text=error, force_update=True, **kwargs)
+ return
subtitles = res.get("subtitle", "")
+ if not subtitles:
+ return
logger.success(subtitles)
if vinfo := await fetch_youtube_video_info(vid):
caption = f"🔴[{vinfo['author']}]({vinfo['channel']})\n🕒{vinfo['date']:%Y-%m-%d %H:%M:%S}\n"
@@ -67,6 +85,7 @@ async def get_subtitle(client: Client, message: Message, youtube_subtitle_provid
with io.BytesIO(subtitles.encode("utf-8")) as f:
await client.send_document(to_int(target_chat), f, file_name=f"{vid}字幕.txt", caption=caption)
+ [await modify_progress(msg, del_status=True) for msg in res.get("sent_messages", [])]
await modify_progress(del_status=True, **kwargs)
src/preview/ytdlp.py
@@ -236,7 +236,7 @@ async def preview_ytdlp(
if subtitles := res.get("subtitle"):
caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {res['num_chars']}\n阅读时长: {res['reading_minutes']:.1f}分钟"
with io.BytesIO(subtitles.encode("utf-8")) as f:
- await client.send_document(to_int(target_chat), f, file_name="字幕文件.txt", caption=caption)
+ await client.send_document(to_int(target_chat), f, file_name=f"{info['title']}.txt", caption=caption)
append_transcription = False # disable asr transcription
if any(x in info["extractor"] for x in ["youtube", "bilibili"]) and append_transcription and audio_path.is_file():
@@ -244,7 +244,8 @@ async def preview_ytdlp(
if texts := asr_res.get("texts"):
caption = f"{emoji}[{info['author']}]({info['author_url']})\n🕒{create_time}\n📝[{info['title']}]({url})\n字符数: {len(texts)}\n阅读时长: {len(texts) / READING_SPEED:.1f}分钟"
with io.BytesIO(texts.encode("utf-8")) as f:
- await client.send_document(to_int(target_chat), f, file_name="字幕文件.txt", caption=caption)
+ await client.send_document(to_int(target_chat), f, file_name=f"{info['title']}.txt", caption=caption)
+ [await modify_progress(msg, del_status=True) for msg in asr_res.get("sent_messages", [])]
Path(json_file).unlink(missing_ok=True)
cleanup_ytdlp(info["id"])