Commit 3f5da47

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-04-17 14:51:01
feat: add `google` search and `youtube` search
1 parent ea2111d
src/llm/tools.py
@@ -12,7 +12,7 @@ from llm.prompts import add_search_results_to_prompts, modify_prompts
 from llm.response import send_to_gpt
 from llm.tool_scheme import ONLINE_SEARCH
 from messages.progress import modify_progress
-from networking import hx_req
+from others.search_google import query_google
 from utils import nowdt
 
 
@@ -30,22 +30,12 @@ async def get_online_search_result(query: str) -> list[dict]:
 
 
 async def google_search(query: str) -> list[dict]:
-    if not (TOKEN.GOOGLE_SEARCH_API_KEY and TOKEN.GOOGLE_SEARCH_CX):
+    res = await query_google(query)
+    if not res:
         return []
-    try:
-        url = f"https://www.googleapis.com/customsearch/v1?key={TOKEN.GOOGLE_SEARCH_API_KEY}&cx={TOKEN.GOOGLE_SEARCH_CX}&q={query}"
-        response = await hx_req(url, proxy=PROXY.GOOGLE_SEARCH, check_keys=["items"])
-        results = glom(response, "items", default=[]) or []
-        for item in results:
-            keys = copy.copy(item).keys()
-            for key in keys:
-                if key not in ["title", "link", "snippet", "mime"]:
-                    item.pop(key, None)
-        if results:
-            return results[: int(GPT.SEARCH_NUM_RESULTS)]
-    except Exception as e:
-        logger.error(e)
-    return []
+    keep_keys = ["title", "link", "snippet", "mime"]
+    results = [{k: v for k, v in x.items() if k in keep_keys} for x in res]
+    return results[: int(GPT.SEARCH_NUM_RESULTS)]
 
 
 async def glm_search(query: str) -> list[dict]:
src/others/search_google.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from glom import glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import GOOGLE_SEARCH_GL, NUM_GOOGLE_SEARCH_RESULTS, PREFIX, PROXY, TOKEN
+from messages.parser import parse_msg
+from messages.sender import send2tg
+from messages.utils import equal_prefix, startswith_prefix
+from networking import hx_req
+from utils import number_to_emoji
+
+HELP = f"""🔍**搜索Google**
+`{PREFIX.SEARCH_GOOGLE}` + 关键词
+"""
+
+
+async def search_google(client: Client, message: Message, **kwargs):
+    """Search Google."""
+    # send docs if message == "/ytb", without reply
+    if equal_prefix(message.text, prefix=[PREFIX.SEARCH_GOOGLE]) and not message.reply_to_message:
+        await send2tg(client, message, texts=HELP, **kwargs)
+        return
+
+    info = parse_msg(message, silent=True)
+    if not startswith_prefix(info["text"], prefix=[PREFIX.SEARCH_GOOGLE]):
+        return
+    query = info["text"].removeprefix(PREFIX.SEARCH_GOOGLE).strip()
+    if not query:
+        return
+
+    res = await query_google(query)
+    if not res:
+        await send2tg(client, message, texts="❌查询Google失败", **kwargs)
+        return
+    msg = ""
+    for idx, item in enumerate(res):
+        msg += f"{number_to_emoji(idx + 1)}[{item['title']}]({item['link']})\n"
+        msg += f"{item['snippet']}\n"
+    await send2tg(client, message, texts=msg, **kwargs)
+
+
+async def query_google(query: str) -> list[dict]:
+    if not (TOKEN.GOOGLE_SEARCH_API_KEY and TOKEN.GOOGLE_SEARCH_CX):
+        return []
+    try:
+        api = "https://www.googleapis.com/customsearch/v1"
+        params = {
+            "key": TOKEN.GOOGLE_SEARCH_API_KEY,
+            "cx": TOKEN.GOOGLE_SEARCH_CX,
+            "q": query,
+            "num": min(NUM_GOOGLE_SEARCH_RESULTS, 10),
+            "safe": "off",
+            "gl": GOOGLE_SEARCH_GL,
+        }
+        response = await hx_req(api, proxy=PROXY.YOUTUBE_SEARCH, params=params, check_keys=["items"], max_retry=0)
+        return glom(response, "items", default=[]) or []
+    except Exception as e:
+        logger.error(e)
+    return []
src/others/search_ytb.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from datetime import UTC, datetime
+from zoneinfo import ZoneInfo
+
+from glom import glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import NUM_YOUTUBE_SEARCH_RESULTS, PREFIX, PROXY, TOKEN, TZ
+from messages.parser import parse_msg
+from messages.sender import send2tg
+from messages.utils import equal_prefix, startswith_prefix
+from networking import hx_req
+from utils import number_to_emoji
+
+HELP = f"""🔍**搜索YouTube**
+`{PREFIX.SEARCH_YOUTUBE}` + 关键词
+"""
+
+
+async def search_youtube(client: Client, message: Message, **kwargs):
+    """Search YouTube."""
+    # send docs if message == "/ytb", without reply
+    if equal_prefix(message.text, prefix=[PREFIX.SEARCH_YOUTUBE]) and not message.reply_to_message:
+        await send2tg(client, message, texts=HELP, **kwargs)
+        return
+    info = parse_msg(message, silent=True)
+    if not startswith_prefix(info["text"], prefix=[PREFIX.SEARCH_YOUTUBE]):
+        return
+    query = info["text"].removeprefix(PREFIX.SEARCH_YOUTUBE).strip()
+    if not query:
+        return
+
+    res = await query_youtube(query)
+    if not res.get("data"):
+        await send2tg(client, message, texts="❌查询YouTube失败", **kwargs)
+        return
+    if error := res.get("error", ""):
+        await send2tg(client, message, texts=error, **kwargs)
+        return
+
+    msg = ""
+    for idx, item in enumerate(res["data"]):
+        video_url = f"https://www.youtube.com/watch?v={item['vid']}"
+        msg += f"👤[{item['author']}]({item['channel']}) 🕒{item['date']:%Y-%m-%d}\n"
+        msg += f"{number_to_emoji(idx + 1)}[{item['title']}]({video_url})\n"
+    await send2tg(client, message, texts=msg, **kwargs)
+
+
+async def query_youtube(query: str) -> dict:
+    results = []
+    try:
+        logger.info(f"Query YouTube info for {query=}, proxy={PROXY.YOUTUBE_SEARCH}")
+        api = "https://www.googleapis.com/youtube/v3/search"
+        params = {
+            "key": TOKEN.YOUTUBE_API_KEY,
+            "part": "snippet",
+            "q": query,
+            "maxResults": min(NUM_YOUTUBE_SEARCH_RESULTS, 50),
+            "safeSearch": "none",
+            "type": "video",
+        }
+        resp = await hx_req(api, proxy=PROXY.YOUTUBE_SEARCH, params=params, check_keys=["items"], max_retry=0)
+        if resp.get("hx_error"):
+            logger.warning(f"Search YouTube API failed: {resp['hx_error']}")
+            return {"error": {resp["hx_error"]}}
+
+        for x in resp["items"]:
+            if glom(x, "id.kind", default="") != "youtube#video":
+                continue
+            results.append(
+                {
+                    "vid": x["id"]["videoId"],
+                    "title": x["snippet"]["title"],
+                    "author": x["snippet"]["channelTitle"],
+                    "channel": f"https://www.youtube.com/channel/{x['snippet']['channelId']}",
+                    "date": datetime.strptime(x["snippet"]["publishedAt"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=UTC).astimezone(ZoneInfo(TZ)),
+                }
+            )
+    except Exception as e:
+        logger.error(f"Failed to get video info: {e}")
+        return {"error": str(e)}
+    return {"data": results}
src/config.py
@@ -28,6 +28,9 @@ YTDLP_RE_ENCODING_MAX_FILE_BYTES = int(os.getenv("YTDLP_RE_ENCODING_MAX_FILE_BYT
 # ytdlp max allowed file bytes. Default: 1PB (Set this if the VPS disk space is limited)
 YTDLP_DOWNLOAD_MAX_FILE_BYTES = int(os.getenv("YTDLP_DOWNLOAD_MAX_FILE_BYTES", "1125899906842624"))
 TELEGRAM_UA = os.getenv("TELEGRAM_UA", "TelegramBot (like TwitterBot)")
+NUM_YOUTUBE_SEARCH_RESULTS = int(os.getenv("NUM_YOUTUBE_SEARCH_RESULTS", "10"))  # Number of youtube search results
+NUM_GOOGLE_SEARCH_RESULTS = int(os.getenv("NUM_GOOGLE_SEARCH_RESULTS", "10"))  # Number of google search results
+GOOGLE_SEARCH_GL = os.getenv("GOOGLE_SEARCH_GL", "cn")  # "gl" parameter (Geolocation)
 
 
 class ENABLE:  # see fine-grained permission in `src/permission.py`
@@ -40,6 +43,8 @@ class ENABLE:  # see fine-grained permission in `src/permission.py`
     INSTAGRAM = os.getenv("ENABLE_INSTAGRAM", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     OCR = os.getenv("ENABLE_OCR", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     PRICE = os.getenv("ENABLE_PRICE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+    SEARCH_YOUTUBE = os.getenv("ENABLE_SEARCH_YOUTUBE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+    SEARCH_GOOGLE = os.getenv("ENABLE_SEARCH_GOOGLE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     SUBTITLE = os.getenv("ENABLE_SUBTITLE", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     TIKTOK = os.getenv("ENABLE_TIKTOK", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
     TWITTER = os.getenv("ENABLE_TWITTER", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
@@ -71,6 +76,8 @@ class PREFIX:
     STOCK = os.getenv("PREFIX_STOCK", "/stock").lower()  # stock only
     COMBINATION = os.getenv("PREFIX_COMBINATION", "/combine").lower()
     VOICE = os.getenv("PREFIX_VOICE", "/voice").lower()
+    SEARCH_YOUTUBE = os.getenv("PREFIX_SEARCH_YOUTUBE", "/ytb").lower()
+    SEARCH_GOOGLE = os.getenv("PREFIX_SEARCH_GOOGLE", "/google").lower()
 
 
 class API:
@@ -121,6 +128,7 @@ class PROXY:  # format: socks5://127.0.0.1:7890
     INSTAGRAM = os.getenv("INSTAGRAM_PROXY", None)
     TWITTER = os.getenv("TWITTER_PROXY", None)
     SUBTITLE = os.getenv("SUBTITLE_PROXY", None)
+    YOUTUBE_SEARCH = os.getenv("YOUTUBE_SEARCH_PROXY", None)
     CRYPTO = os.getenv("CRYPTO_PROXY", None)
     GOOGLE_SEARCH = os.getenv("GOOGLE_SEARCH_PROXY", None)
     DOWNLOAD = os.getenv("DOWNLOAD_PROXY", None)
src/handler.py
@@ -1,6 +1,5 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
-
 import re
 
 from loguru import logger
@@ -20,6 +19,8 @@ from networking import match_social_media_link
 from others.download_external import download_url_in_message
 from others.extract_audio import extract_audio_file
 from others.raw_img_file import convert_raw_img_file
+from others.search_google import search_google
+from others.search_ytb import search_youtube
 from others.subtitle import get_subtitle
 from permission import check_service
 from preview.bilibili import preview_bilibili
@@ -42,6 +43,8 @@ async def handle_utilities(
     ai: bool = True,
     asr: bool = True,
     audio: bool = True,
+    ytb: bool = True,
+    google: bool = True,
     subtitle: bool = True,
     wget: bool = True,
     ocr: bool = True,
@@ -64,6 +67,8 @@ async def handle_utilities(
         ai (bool, optional): Enable GPT. Defaults to True.
         asr (bool, optional): Enable ASR. Defaults to True.
         audio (bool, optional): Enable Video -> Audio. Defaults to True.
+        google (bool, optional): Enable Google Search. Defaults to True.
+        ytb (bool, optional): Enable YouTube Search. Defaults to True.
         subtitle (bool, optional): Enable YouTube subtitle. Defaults to True.
         wget (bool, optional): Enable WGET. Defaults to True.
         ocr (bool, optional): Enable OCR. Defaults to True.
@@ -84,6 +89,10 @@ async def handle_utilities(
         await get_subtitle(client, message, **kwargs)  # /subtitle
     if wget:
         await download_url_in_message(client, message, **kwargs)  # /wget
+    if google:
+        await search_google(client, message, **kwargs)  # /google
+    if ytb:
+        await search_youtube(client, message, **kwargs)  # /ytb
     if ocr:
         await send_to_ocr_bridge(client, message, **kwargs)  # /ocr
     if price:
@@ -285,6 +294,10 @@ def get_social_media_help(chat_id: int | str, ctype: str, prefixes: list[str] |
         msg += f"\n🤖**总结历史**: `{PREFIX.AI_SUMMARY}` AI总结历史聊天记录"
     if permission["wget"]:
         msg += f"\n⏬**下载文件**: `{PREFIX.WGET}` + URL"
+    if permission["ytb"]:
+        msg += f"\n🔍**搜索YouTube**: `{PREFIX.SEARCH_YOUTUBE}` + 关键词"
+    if permission["google"]:
+        msg += f"\n🔍**搜索Google**: `{PREFIX.SEARCH_GOOGLE}` + 关键词"
 
     msg += "\n\n单独发送每个命令前缀本身可查看该命令详细使用说明"
     return msg
src/permission.py
@@ -103,6 +103,8 @@ def check_service(cid: int | str, ctype: str) -> dict:
         "price": True,
         "raw_img": True,
         "summary": True,
+        "ytb": True,
+        "google": True,
         "show_progress": True,
         "detail_progress": True,
         "douyin": True,
@@ -141,6 +143,10 @@ def check_service(cid: int | str, ctype: str) -> dict:
         permission["audio"] = False
     if not ENABLE.SUBTITLE:
         permission["subtitle"] = False
+    if not ENABLE.SEARCH_YOUTUBE:
+        permission["ytb"] = False
+    if not ENABLE.SEARCH_GOOGLE:
+        permission["google"] = False
     if not ENABLE.WGET:
         permission["wget"] = False
     if not ENABLE.OCR:
src/utils.py
@@ -136,10 +136,10 @@ def soup_to_text(soup: PageElement) -> str:
     return text
 
 
-def number_to_emoji(num: int | str) -> str:
+def number_to_emoji(num: int | str, default: str = "") -> str:
     """Convert a number to an emoji."""
     num = str(num)
-    return {"1": "1️⃣", "2": "2️⃣", "3": "3️⃣", "4": "4️⃣", "5": "5️⃣", "6": "6️⃣", "7": "7️⃣", "8": "8️⃣", "9": "9️⃣", "10": "🔟"}.get(num, "🔢")
+    return {"0": "0️⃣", "1": "1️⃣", "2": "2️⃣", "3": "3️⃣", "4": "4️⃣", "5": "5️⃣", "6": "6️⃣", "7": "7️⃣", "8": "8️⃣", "9": "9️⃣", "10": "🔟"}.get(num, default)
 
 
 def stringfy(d: dict) -> dict: