Commit e902978

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-07 08:06:30
fix(gpt): simplify text contents format
1 parent 6f20cad
Changed files (3)
src/llm/models.py
@@ -5,7 +5,7 @@ from loguru import logger
 from pyrogram.types import Message
 
 from config import GPT
-from llm.utils import BOT_TIPS, fix_doubao
+from llm.utils import BOT_TIPS, simplify_text_contents
 from messages.parser import parse_msg
 
 
@@ -13,6 +13,7 @@ def get_model_type(conversations: list[Message]) -> str:
     """Get model type based on conversation messages."""
     has_image = False
     has_video = False
+    model_type = "text"
     for message in conversations:
         info = parse_msg(message, silent=True)
         if info["mtype"] == "photo":
@@ -21,9 +22,7 @@ def get_model_type(conversations: list[Message]) -> str:
         if info["mtype"] == "video":
             model_type = "video"
             has_video = True
-    if not has_image and not has_video:
-        model_type = "text"
-    elif has_image and has_video:
+    if has_image and has_video:
         model_type = "ERROR: this conversation have both image and video."
     return model_type
 
@@ -58,7 +57,6 @@ def get_model_with_contexts(model_type: str, contexts: list[dict]) -> tuple[dict
         "temperature": float(GPT.TEMPERATURE),
         "bot_msg_prefix": f"🤖**{model_names[model_type]}**: ({BOT_TIPS})",
     }
-    if model.startswith("豆包"):
-        contexts = fix_doubao(contexts)
+    contexts = simplify_text_contents(contexts)
     logger.trace(config)
     return config, contexts
src/llm/summary.py
@@ -9,7 +9,7 @@ from pyrogram.types import Message
 
 from config import ENABLE, GPT, MAX_MESSAGE_SUMMARY, PREFIX, cache
 from llm.response import get_gpt_response
-from llm.utils import fix_doubao
+from llm.utils import simplify_text_contents
 from messages.chat_history import get_parsed_chat_history
 from messages.parser import parse_msg
 from messages.progress import modify_progress
@@ -85,7 +85,7 @@ async def ai_summary(client: Client, message: Message, **kwargs):
     model_conf = get_summay_model(history)
     contexts = await get_contexts(client, history)
     if model_conf["friendly_name"].startswith("豆包"):
-        contexts = fix_doubao(contexts)
+        contexts = simplify_text_contents(contexts)
     msg = f"🤖{model_conf['friendly_name']}: 总结中..."
     if kwargs.get("show_progress"):
         res = await send2tg(client, message, texts=msg, **kwargs)
src/llm/utils.py
@@ -10,12 +10,13 @@ from config import DOWNLOAD_DIR, GPT
 BOT_TIPS = "回复此消息以继续对话"
 
 
-def fix_doubao(contexts: list[dict]) -> list[dict]:
-    """Fix doubao context format.
+def simplify_text_contents(contexts: list[dict]) -> list[dict]:
+    """Simplify the plain text content format.
 
-    Doubao do not support this content for:
+    Some models do not support this format:
         [{'text': 'hi', 'type': 'text'}], 'role': 'user'}]
-    It support:
+
+    It only supports:
         [{'content': 'hi', 'role': 'user'}]
 
     Args:
@@ -24,16 +25,28 @@ def fix_doubao(contexts: list[dict]) -> list[dict]:
                 "role": "user or assistant",
                 "content": [
                     {'type': 'text', 'text': 'caption this img'},
-                    {'type': 'image_url', 'image_url': {'url': 'data:image/jpeg;base64,base64_image'}},
-                    {'type': 'image_url', 'image_url': {'url': 'https://server.com/dir/image.jpg'}},
                 ]
             }
         ]
+
+    Returns:
+        list[dict]: [
+            {
+                "role": "user or assistant",
+                "content": "caption this img"
+            }
+        ]
     """
     fixed_contexts = []
     for msg in contexts:
-        if (lst := msg.get("content", [])) and all(x.get("type") == "text" for x in lst):
-            msg["content"] = "\n".join([x.get("text") for x in lst])
+        if not msg.get("content") or not isinstance(msg.get("content"), list):
+            fixed_contexts.append(msg)
+            continue
+        contents = msg.get("content", [])
+        if all(x.get("type") == "text" for x in contents):
+            msg["content"] = "\n".join([x.get("text") for x in contents])
+            fixed_contexts.append(msg)
+        else:
             fixed_contexts.append(msg)
     return fixed_contexts