Commit d76d8a6

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-01-22 14:19:13
style: improve the prefix check
1 parent 557e163
src/asr/voice_recognition.py
@@ -9,7 +9,7 @@ from pyrogram.types import Message
 
 from asr.tecent_asr import Credential, FlashRecognitionRequest, FlashRecognizer
 from config import ASR_MAX_DURATION, ENABLE, PREFIX, TOKEN, cache
-from message_utils import modify_progress, send2tg
+from message_utils import equal_prefix, modify_progress, send2tg, startswith_prefix
 from multimedia import convert_to_audio, parse_media_info
 
 # ruff: noqa: RUF001
@@ -85,7 +85,7 @@ async def voice_to_text(
         asr_skip_video (bool, optional): If True, skip video message.
     """
     # send docs if message == "/asr", without reply
-    if str(message.text).lower().strip() == PREFIX.ASR and not message.reply_to_message:
+    if equal_prefix(message.text, prefix=[PREFIX.ASR]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
 
@@ -197,11 +197,11 @@ def get_trigger_message(
         asr_skip_video = asr_skip_video or False
 
     # only trigger if msg has "/asr" prefix
-    if asr_need_prefix and not str(message.text).lower().strip().startswith(PREFIX.ASR):
+    if asr_need_prefix and not startswith_prefix(message.text, prefix=[PREFIX.ASR]):
         return None
 
     # treat the reply_to_message as the real message need to be recognized
-    trigger_msg = message.reply_to_message if asr_need_prefix or str(message.text).lower().strip().startswith(PREFIX.ASR) else message
+    trigger_msg = message.reply_to_message if asr_need_prefix or startswith_prefix(message.text, prefix=[PREFIX.ASR]) else message
 
     if not trigger_msg:
         return None
src/bridge/ocr.py
@@ -8,7 +8,7 @@ from pyrogram.client import Client
 from pyrogram.types import Message, ReplyParameters
 
 from config import ENABLE, PREFIX, cache
-from message_utils import send2tg
+from message_utils import equal_prefix, send2tg, startswith_prefix
 from utils import i_am_bot
 
 OCR_BOT = "GLBetabot"
@@ -23,11 +23,11 @@ async def send_to_ocr_bridge(client: Client, message: Message, **kwargs):
     if not ENABLE.WGET:
         return
     # send docs if message == "/ocr", without reply
-    if str(message.text).lower().strip() == PREFIX.OCR and not message.reply_to_message:
+    if equal_prefix(message.text, prefix=[PREFIX.OCR]) and not message.reply_to_message:
         await send2tg(client, message, texts=f"**图片转文字**: 以`{PREFIX.OCR}`回复图片消息即可提取文字", **kwargs)
         return
     msg = message.text or message.caption or ""  # /ocr args
-    if not msg.startswith(PREFIX.OCR):
+    if not startswith_prefix(message.text or message.caption, prefix=[PREFIX.OCR]):
         return
     if await i_am_bot(client):  # bot can't send message to other bots
         return
src/others/download_external.py
@@ -11,18 +11,34 @@ from pyrogram.client import Client
 from pyrogram.types import Message, ReplyParameters
 
 from config import ENABLE, MAX_FILE_BYTES, PREFIX
-from message_utils import modify_progress, send2tg
+from message_utils import equal_prefix, modify_progress, send2tg, startswith_prefix
 from multimedia import is_valid_video, validate_img
 from networking import download_file
 from utils import https_url, readable_size
 
+HELP = f"""
+⏬**下载文件**
+使用说明:
+以 `{PREFIX.WGET}` 回复视频消息提取出音频
+2. 发送视频时, 添加`{PREFIX.AUDIO}`文本描述会同时提取视频音频
+"""
+
 
 async def download_url_in_message(client: Client, message: Message, **kwargs):
     """Download the url from the message."""
     if not ENABLE.WGET:
         return
-    if not str(message.text).strip().lower().startswith(PREFIX.WGET):
+    if not startswith_prefix(message.text or message.caption, prefix=[PREFIX.WGET]):
+        return
+    # send docs if message == "/wget", without reply
+    if equal_prefix(message.text, prefix=[PREFIX.WGET]) and not message.reply_to_message:
+        await send2tg(client, message, texts=HELP, **kwargs)
         return
+
+    # reply a message with /wget
+    if message.reply_to_message:
+        message = message.reply_to_message
+
     target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else message.chat.id
     reply_msg_id = kwargs.get("reply_msg_id", 0)
     if reply_msg_id == 0:
src/others/extract_audio.py
@@ -7,7 +7,7 @@ from pyrogram.client import Client
 from pyrogram.types import Message, ReplyParameters
 
 from config import PREFIX, cache
-from message_utils import modify_progress, parse_msg, send2tg
+from message_utils import equal_prefix, modify_progress, parse_msg, send2tg, startswith_prefix
 from multimedia import convert_to_audio, parse_media_info
 
 # ruff: noqa: RUF001
@@ -24,17 +24,14 @@ HELP = f"""
 async def extract_audio_file(client: Client, message: Message, **kwargs) -> None:
     """Extract audio from video message."""
     # send docs if message == "/audio", without reply
-    if str(message.text).lower().strip() == PREFIX.AUDIO and not message.reply_to_message:
+    if equal_prefix(message.text, prefix=[PREFIX.AUDIO]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
 
-    msg_text = message.text or message.caption or ""
-    msg_text = msg_text.lower().strip()
-
-    if not msg_text.startswith(PREFIX.AUDIO):
+    if not startswith_prefix(message.text or message.caption, prefix=[PREFIX.AUDIO]):
         return
 
-    # 以/audio命令回复一条消息
+    # reply a message with /audio
     if message.reply_to_message:
         message = message.reply_to_message
 
src/others/gpt.py
@@ -14,7 +14,7 @@ from pyrogram.client import Client
 from pyrogram.types import Message
 
 from config import DOWNLOAD_DIR, ENABLE, GPT, PREFIX, PROXY, cache
-from message_utils import modify_progress, send2tg
+from message_utils import equal_prefix, modify_progress, send2tg, startswith_prefix
 from multimedia import convert_to_audio
 from networking import hx_req
 
@@ -49,8 +49,7 @@ async def gpt_response(client: Client, message: Message, **kwargs):
         return
 
     # send docs if message == "/ai", without reply
-    texts = message.text or message.caption or ""
-    if str(texts).lower().strip() == PREFIX.GPT and not message.reply_to_message:
+    if equal_prefix(message.text or message.caption, prefix=[PREFIX.GPT]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
 
@@ -287,10 +286,7 @@ def fix_doubao(contexts: list[dict]) -> list[dict]:
 
 
 def is_valid_conversation(message: Message) -> bool:
-    # match commands: /ai
-    if str(message.text).strip()[:3].lower() == PREFIX.GPT:
-        return True
-    if str(message.caption).strip()[:3].lower() == PREFIX.GPT:
+    if startswith_prefix(message.text or message.caption, prefix=[PREFIX.GPT]):
         return True
     # is replying to gpt-bot response message?
     if not message.reply_to_message:
src/others/subtitle.py
@@ -14,7 +14,7 @@ from youtube_transcript_api import YouTubeTranscriptApi
 
 from config import API, ENABLE, PREFIX, PROXY, TOKEN
 from database import cache
-from message_utils import modify_progress, send2tg
+from message_utils import equal_prefix, modify_progress, send2tg, startswith_prefix
 from networking import hx_req, match_social_media_link
 
 HELP = f"""📃**提取字幕**
@@ -32,9 +32,8 @@ async def get_subtitle(client: Client, message: Message, **kwargs):
     if not ENABLE.SUBTITLE:
         return
     target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else message.chat.id
-    # send docs if message == "/ai", without reply
-    texts = message.text or message.caption or ""
-    if str(texts).lower().strip() == PREFIX.SUBTITLE and not message.reply_to_message:
+    # send docs if message == "/subtitle", without reply
+    if equal_prefix(message.text, prefix=[PREFIX.SUBTITLE]) and not message.reply_to_message:
         await send2tg(client, message, texts=HELP, **kwargs)
         return
 
@@ -67,8 +66,7 @@ async def get_subtitle(client: Client, message: Message, **kwargs):
 
 
 async def find_yt_vid(client: Client, message: Message) -> str:
-    msg_text = message.text or message.caption or ""
-    if not msg_text.strip().lower().startswith(PREFIX.SUBTITLE):
+    if not startswith_prefix(message.text or message.caption, prefix=[PREFIX.SUBTITLE]):
         return ""
     url = find_url_in_message(message)
     # /subtitle "link"
src/message_utils.py
@@ -107,7 +107,7 @@ def parse_msg(message: Message, *, verbose: bool = False) -> dict:
 
 
 @cache.memoize(ttl=60)
-def startswith_prefix(text: str, prefix: list[str] | None = None, ignore_prefix: list[str] | None = None) -> bool:
+def startswith_prefix(text: str | None = None, prefix: list[str] | None = None, ignore_prefix: list[str] | None = None) -> bool:
     """Check if the message text starts with the given command prefixes.
 
     Args:
@@ -115,13 +115,15 @@ def startswith_prefix(text: str, prefix: list[str] | None = None, ignore_prefix:
         prefix (list[str], optional): Command prefixes that are effective.
         ignore_prefix (list[str], optional): Ignore these command prefixes.
     """
+    if not text:
+        return False
     if ignore_prefix and any(text.strip().lower().startswith(prefix) for prefix in ignore_prefix):
         return False
     return bool(prefix and any(text.strip().lower().startswith(prefix) for prefix in prefix))
 
 
 @cache.memoize(ttl=60)
-def equal_prefix(text: str, prefix: list[str] | None = None, ignore_prefix: list[str] | None = None) -> bool:
+def equal_prefix(text: str | None = None, prefix: list[str] | None = None, ignore_prefix: list[str] | None = None) -> bool:
     """Check if the message text equal with the given command prefixes.
 
     Args:
@@ -129,6 +131,8 @@ def equal_prefix(text: str, prefix: list[str] | None = None, ignore_prefix: list
         prefix (list[str], optional): Extra command prefixes that are effective.
         ignore_prefix (list[str], optional): Ignore these command prefixes.
     """
+    if not text:
+        return False
     if ignore_prefix and text.strip().lower() in ignore_prefix:
         return False
     return bool(prefix and text.strip().lower() in prefix)