Commit bff3b60

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-06-13 14:17:04
style(gpt): show reply text as prompt if available
1 parent 1b1cf63
Changed files (3)
src
src/llm/gemini.py
@@ -80,7 +80,8 @@ async def gemini_response(
         extra_config_str = GEMINI.IMG_CONFIG if modality == "image" else GEMINI.TEXT_CONFIG
         genconfig = json.loads(extra_config_str)
     try:
-        msg = f"πŸ€–**{model_name}**: 思考中...\nπŸ‘€**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: β€œ{clean_cmd_prefix(info['text'])}”"[:TEXT_LENGTH]
+        real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
+        msg = f"πŸ€–**{model_name}**: 思考中...\nπŸ‘€**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: β€œ{real_prompt}”"[:TEXT_LENGTH]
         status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
         kwargs["progress"] = status_msg
         genconfig |= {"response_modalities": response_modalities}
src/llm/gpt.py
@@ -42,7 +42,7 @@ HELP = f"""πŸ€–**GPT对话**
 """
 
 
-def is_gpt_conversation(minfo: dict, reply_text: str) -> bool:
+def is_gpt_conversation(minfo: dict) -> bool:
     # to avoid potential infinitely loop,
     # we do not respond to bot message & GPT responses.
     if minfo["is_bot"]:
@@ -84,7 +84,7 @@ def is_gpt_conversation(minfo: dict, reply_text: str) -> bool:
         GPT.GEMINI_MODEL_NAME,
         GEMINI.IMG_MODEL_NAME,
     ]
-    return startswith_prefix(reply_text, prefix=[f"πŸ€–{x}".lower() for x in model_names])
+    return startswith_prefix(minfo["reply_text"], prefix=[f"πŸ€–{x}".lower() for x in model_names])
 
 
 async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = True, **kwargs) -> dict:
@@ -100,20 +100,19 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     """
     # ruff: noqa: RET502, RET503
     info = parse_msg(message)
-    # send docs if message == "/ai", without reply
-    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=PREFIX.GPT) and not message.reply_to_message:
-        await send2tg(client, message, texts=HELP, **kwargs)
-        return {}
+    # send docs if message == "/ai"
+    if info["mtype"] == "text" and equal_prefix(info["text"], prefix=PREFIX.GPT):
+        if not message.reply_to_message:  # without reply
+            await send2tg(client, message, texts=HELP, **kwargs)
+            return {}
+        # with reply, change some information
+        info["uid"] = info["reply_uid"]
+        info["full_name"] = info["reply_full_name"]
     if info["mtype"] == "text" and equal_prefix(info["text"], prefix=PREFIX.GENIMG) and not message.reply_to_message:
         await send2tg(client, message, texts=AIGC_HELP, **kwargs)
         return {}
-    reply_text = ""
-    if message.reply_to_message:
-        reply_info = parse_msg(message.reply_to_message, silent=True)
-        reply_text = reply_info["text"]
-    if not is_gpt_conversation(info, reply_text):
+    if not is_gpt_conversation(info):
         return {}
-
     # cache media_group message, only process once
     if media_group_id := message.media_group_id:
         if cache.get(f"gpt-{info['cid']}-{media_group_id}"):
@@ -122,7 +121,7 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
     kwargs["message_info"] = info  # save trigger message info
     conversations = get_conversations(message)
     context_type = get_context_type(conversations)  # {"type": "text", "error": None}  # text, image
-    model_id, resp_modality, sdk = get_model_id(info["text"], reply_text, context_type)
+    model_id, resp_modality, sdk = get_model_id(info["text"], info["reply_text"], context_type)
     if "gemini" in model_id.lower() and sdk == "gemini":
         return await gemini_response(client, message, conversations, resp_modality, **kwargs)
 
@@ -135,7 +134,8 @@ async def gpt_response(client: Client, message: Message, *, gpt_stream: bool = T
         return {}
 
     config["completions"]["messages"] = await get_conversation_contexts(client, conversations, ctx_format="openai")
-    msg = f"πŸ€–**{config['friendly_name']}**: 思考中...\nπŸ‘€**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: β€œ{clean_cmd_prefix(info['text'])}”"[:TEXT_LENGTH]
+    real_prompt = clean_cmd_prefix(info["text"]) or clean_cmd_prefix(info["reply_text"])
+    msg = f"πŸ€–**{config['friendly_name']}**: 思考中...\nπŸ‘€**[{info['full_name'] or info['ctitle']}](tg://user?id={info['uid']})**: β€œ{real_prompt}”"[:TEXT_LENGTH]
     status_msg = (await send2tg(client, message, texts=msg, **kwargs))[0]
     kwargs["progress"] = status_msg
     config, response = await merge_tools_response(config, **kwargs)
src/messages/parser.py
@@ -5,61 +5,79 @@
 from datetime import datetime
 from zoneinfo import ZoneInfo
 
+from glom import Coalesce, glom
 from loguru import logger
 from pyrogram.enums import MessageEntityType
-from pyrogram.types import Audio, Message
+from pyrogram.types import Message
 
 from config import TZ, cache
 from utils import nowdt
 
 
-def parse_msg(message: Message, *, silent: bool = False, verbose: bool = False) -> dict:
+def parse_msg(message: Message, *, silent: bool = False, verbose: bool = False, use_cache: bool = True) -> dict:
     """Parse a message object and return a dictionary of its attributes.
 
     Abbreviations: c = chat, m = message, u = user
     """
-    if cached := cache.get(f"parse_msg-{message.chat.id}-{message.id}"):
+    # ruff: noqa: B009
+    if use_cache and (cached := cache.get(f"parse_msg-{message.chat.id}-{message.id}")):
         return cached
     if not silent and verbose:
         logger.trace(f"{message!r}")
-    mtype: str = message.media.value if message.media and message.media.value else "text"  # type: ignore
-    ctype = message.chat.type.name if message.chat and message.chat.type else ""
-    ctitle = message.chat.title if message.chat and message.chat.title else ""
-    uid = message.from_user.id if message.from_user else 1
-    cid = message.chat.id if message.chat else 0
-    mid = message.id if message.id else 0
-    media_group_id = message.media_group_id if message.media_group_id else 0
-    is_bot = bool(message.from_user and message.from_user.is_bot)
-    text = message.text or message.caption or ""
-    html = text.html if hasattr(text, "html") else ""  # type: ignore
+    mtype = glom(message, "media.value", default="") or "text"
+    ctype = glom(message, "chat.type.name", default="") or ""
+    ctitle = glom(message, "chat.title", default="") or ""
+    uid = glom(message, "from_user.id", default=1) or 1  # uid must > 0
+    cid = glom(message, "chat.id", default=0) or 0
+    mid = glom(message, "id", default=0) or 0
+    media_group_id = glom(message, "media_group_id", default=0) or 0
+    is_bot = glom(message, "from_user.is_bot", default=False)
+    text = message.content
     dt = message.date.replace(tzinfo=ZoneInfo(TZ)) if isinstance(message.date, datetime) else nowdt(TZ)
     time = f"{dt:%Y-%m-%d %H:%M:%S}"
 
     # parse user attributes
-    first_name = message.from_user.first_name if message.from_user and message.from_user.first_name else ""
-    last_name = message.from_user.last_name if message.from_user and message.from_user.last_name else ""
-    handle = message.from_user.username if message.from_user and message.from_user.username else ""
-    full_name = f"{first_name} {last_name}".strip() if message.from_user else ""
+    first_name = glom(message, "from_user.first_name", default="") or ""
+    last_name = glom(message, "from_user.last_name", default="") or ""
+    handle = glom(message, "from_user.username", default="") or ""
+    full_name = f"{first_name} {last_name}".strip()
+
+    # parse reply message
+    reply_uid = glom(message, "reply_to_message.from_user.id", default=1) or 1
+    reply_mid = glom(message, "reply_to_message.id", default=0) or 0
+    reply_text = glom(message, "reply_to_message.content", default="") or ""
+    reply_first_name = glom(message, "reply_to_message.from_user.first_name", default="") or ""
+    reply_last_name = glom(message, "reply_to_message.from_user.last_name", default="") or ""
+    reply_handle = glom(message, "reply_to_message.from_user.username", default="") or ""
+    reply_full_name = f"{reply_first_name} {reply_last_name}".strip()
+
+    # parse forward message
+    forward_origin = message.forward_origin
+    fwd_cid = glom(forward_origin, "chat.id", default=0) or 0
+    fwd_ctype = glom(forward_origin, "chat.type.name", default="") or ""
+    fwd_uid = glom(forward_origin, "sender_user.id", default=1) or 1
+    fwd_handle = glom(forward_origin, Coalesce("sender_user.username", "chat.username"), default="") or ""
+    fwd_first_name = glom(forward_origin, "sender_user.first_name", default="") or ""
+    fwd_last_name = glom(forward_origin, "sender_user.last_name", default="") or ""
+    fwd_full_name = f"{fwd_first_name} {fwd_last_name}".strip() or glom(forward_origin, Coalesce("sender_user_name", "chat.title"), default="") or ""
 
     # parse media attributes. for photo, we should use `sizes[-1]`. ref: TelegramPlayground/pyrogram @1ea5e797f920776bfeecf985a51dc03ff22906af
-    media = getattr(message, mtype) if hasattr(message, mtype) else Audio(file_id="", file_unique_id="", duration=0)  # placeholder
     if mtype == "photo":
-        file_id = message.photo.sizes[-1].file_id or ""
-        file_size = message.photo.sizes[-1].file_size or 0
+        file_id = glom(message, f"{mtype}.sizes")[-1].file_id
+        file_size = glom(message, f"{mtype}.sizes")[-1].file_size
     else:
-        file_id = media.file_id if hasattr(media, "file_id") and media.file_id else ""
-        file_size = media.file_size if hasattr(media, "file_size") and media.file_size else 0
+        file_id = glom(message, f"{mtype}.file_id", default=0) or 0
+        file_size = glom(message, f"{mtype}.file_size", default=0) or 0
 
-    file_name = media.file_name if hasattr(media, "file_name") and media.file_name else ""
-    mime_type = media.mime_type if hasattr(media, "mime_type") and media.mime_type else ""
-    duration = media.duration if hasattr(media, "duration") and media.duration else 0
+    file_name = glom(message, f"{mtype}.file_name", default="") or ""
+    mime_type = glom(message, f"{mtype}.mime_type", default="") or ""
+    duration = glom(message, f"{mtype}.duration", default=0) or 0
     # Parse URL from message entities
     entity_urls = []
     if message.entities:
         entity_urls.extend(entity.url for entity in message.entities if entity.type == MessageEntityType.TEXT_LINK)
     if message.caption_entities:
         entity_urls.extend(entity.url for entity in message.caption_entities if entity.type == MessageEntityType.TEXT_LINK)
-    message_url = f"https://t.me/c/{str(cid).removeprefix('-100')}/{mid}"
 
     ctype_emoji = {
         "BOT": "πŸ€–",
@@ -101,9 +119,7 @@ def parse_msg(message: Message, *, silent: bool = False, verbose: bool = False)
         "media_group_id": int(media_group_id),
         "is_bot": bool(is_bot),
         "text": str(text),
-        "html": str(html),
-        "first_name": str(first_name),
-        "last_name": str(last_name),
+        "html": getattr(text, "html", ""),  # type: ignore
         "full_name": str(full_name),
         "handle": str(handle),
         "datetime": dt,
@@ -114,8 +130,19 @@ def parse_msg(message: Message, *, silent: bool = False, verbose: bool = False)
         "file_size": int(file_size),
         "duration": int(duration),
         "summary": str(summary),
-        "message_url": str(message_url),
+        "message_url": str(message.link),
         "entity_urls": entity_urls,
+        "reply_mid": int(reply_mid),
+        "reply_text": str(reply_text),
+        "reply_uid": int(reply_uid),
+        "reply_handle": str(reply_handle),
+        "reply_full_name": str(reply_full_name),
+        "fwd_cid": int(fwd_cid),
+        "fwd_ctype": str(fwd_ctype),
+        "fwd_uid": int(fwd_uid),
+        "fwd_handle": str(fwd_handle),
+        "fwd_full_name": str(fwd_full_name),
     }
-    cache.set(f"parse_msg-{message.chat.id}-{message.id}", info, ttl=120)  # cache the same msg for 2 minutes
+    if use_cache:
+        cache.set(f"parse_msg-{message.chat.id}-{message.id}", info, ttl=120)  # cache the same msg for 2 minutes
     return info