Commit d3f6834

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-18 17:34:12
refactor: move permission checks to a separate file
1 parent 0932f29
Changed files (3)
src/config.py
@@ -185,26 +185,11 @@ class GPT:  # see `llm/README.md`
     DEEPSEEK_BASE_URL = os.getenv("GPT_DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
 
 
-class TID:
-    ADMIN = os.getenv("TID_ADMIN", "me")
-    ADMIN_GROUP = os.getenv("TID_ADMIN_GROUP", "me")
-    CHANNEL_YTDLP_BACKUP = os.getenv("TID_CHANNEL_YTDLP_BACKUP", "me")
+class TID:  # see more TID usecase in `src/permission.py`
+    # comma separated chat ids of 67373
     GROUP67373 = os.getenv("TID_GROUP67373", "")
-    # ONLY: Only process messages from these chats
-    ONLY_BOTS = os.getenv("TID_ONLY_BOTS", "")
-    ONLY_CHANNELS = os.getenv("TID_ONLY_CHANNELS", "")
-    ONLY_GROUPS = os.getenv("TID_ONLY_GROUPS", "")
-    ONLY_USERS = os.getenv("TID_ONLY_USERS", "")
-    # SKIP: Ignore messages from these chats
-    SKIP_BOTS = os.getenv("TID_SKIP_BOTS", "")
-    SKIP_CHANNELS = os.getenv("TID_SKIP_CHANNELS", "")
-    SKIP_GROUPS = os.getenv("TID_SKIP_GROUPS", "")
-    SKIP_USERS = os.getenv("TID_SKIP_USERS", "")
-    # MUTE: Mark all messages from these chats as read
-    MUTE_BOTS = os.getenv("TID_MUTE_BOTS", "")
-    MUTE_CHANNELS = os.getenv("TID_MUTE_CHANNELS", "")
-    MUTE_GROUPS = os.getenv("TID_MUTE_GROUPS", "")
-    MUTE_USERS = os.getenv("TID_MUTE_USERS", "")
+    # back up ytdlp audio if the user does not request it
+    CHANNEL_YTDLP_BACKUP = os.getenv("TID_CHANNEL_YTDLP_BACKUP", "me")
 
 
 class DB:
src/main.py
@@ -24,6 +24,7 @@ from bridge.social import forward_social_media_results
 from config import DAILY_MESSAGES, DEVICE_NAME, ENABLE, PROXY, TID, TOKEN, TZ, cache
 from handler import handle_social_media, handle_utilities
 from messages.parser import parse_msg
+from permission import check_permission
 from price.entrypoint import match_symbol_category
 from utils import cleanup_old_files, nowdt, to_int
 
@@ -48,13 +49,7 @@ async def main():
 
     @app.on_message(filters.group)
     async def groups(client: Client, message: Message):
-        if not ENABLE.GROUPS:
-            return
-        if TID.MUTE_GROUPS and message.chat.id in [int(x.strip()) for x in TID.MUTE_GROUPS.split(",")]:
-            await message.read()
-        if TID.SKIP_GROUPS and message.chat.id in [int(x.strip()) for x in TID.SKIP_GROUPS.split(",")]:
-            return
-        if TID.ONLY_GROUPS and message.chat.id not in [int(x.strip()) for x in TID.ONLY_GROUPS.split(",")]:
+        if not await check_permission(client, message, "GROUP"):
             return
         parse_msg(message)
         if TID.GROUP67373 and message.chat.id in [int(x.strip()) for x in TID.GROUP67373.split(",")]:
@@ -66,13 +61,7 @@ async def main():
 
     @app.on_message(filters.channel)
     async def channels(client: Client, message: Message):
-        if not ENABLE.CHANNELS:
-            return
-        if TID.MUTE_CHANNELS and message.chat.id in [int(x.strip()) for x in TID.MUTE_CHANNELS.split(",")]:
-            await message.read()
-        if TID.SKIP_CHANNELS and message.chat.id in [int(x.strip()) for x in TID.SKIP_CHANNELS.split(",")]:
-            return
-        if TID.ONLY_CHANNELS and message.chat.id not in [int(x.strip()) for x in TID.ONLY_CHANNELS.split(",")]:
+        if not await check_permission(client, message, "CHANNEL"):
             return
         parse_msg(message)
         await handle_utilities(client, message, detail_progress=True)
@@ -80,13 +69,7 @@ async def main():
 
     @app.on_message(filters.bot)
     async def bots(client: Client, message: Message):
-        if not ENABLE.BOTS:
-            return
-        if TID.MUTE_BOTS and message.chat.id in [int(x.strip()) for x in TID.MUTE_BOTS.split(",")]:
-            await message.read()
-        if TID.SKIP_BOTS and message.chat.id in [int(x.strip()) for x in TID.SKIP_BOTS.split(",")]:
-            return
-        if TID.ONLY_BOTS and message.chat.id not in [int(x.strip()) for x in TID.ONLY_BOTS.split(",")]:
+        if not await check_permission(client, message, "BOT"):
             return
         parse_msg(message, verbose=True)
         await forward_social_media_results(client, message)
@@ -99,13 +82,8 @@ async def main():
     # so the private handler should be placed after the bot handler
     @app.on_message(filters.private)
     async def private(client: Client, message: Message):
-        if not ENABLE.USERS or message.chat.type.name != "PRIVATE":
-            return
-        if TID.MUTE_USERS and message.chat.id in [int(x.strip()) for x in TID.MUTE_USERS.split(",")]:
-            await message.read()
-        if TID.SKIP_USERS and message.chat.id in [int(x.strip()) for x in TID.SKIP_USERS.split(",")]:
-            return
-        if TID.ONLY_USERS and message.chat.id not in [int(x.strip()) for x in TID.ONLY_USERS.split(",")]:
+        ctype = message.chat.type.name if message.chat and message.chat.type else ""
+        if not await check_permission(client, message, "PRIVATE") or ctype != "PRIVATE":
             return
         parse_msg(message, verbose=True)
         await handle_utilities(client, message, raw_img=True, detail_progress=True)
src/permission.py
@@ -0,0 +1,143 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import contextlib
+import os
+
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from config import ENABLE, cache
+from utils import to_int, true
+
+
+# ruff: noqa: SIM103
+async def check_permission(client: Client, message: Message, category: str) -> bool:
+    """Check if the user has permission to use the bot."""
+    if cached := cache.get(f"permission-{category}-{message.chat.id}"):
+        return cached
+    category = category.upper()
+    if category == "GROUP":
+        permission = await check_group(message)
+    elif category == "CHANNEL":
+        permission = await check_channel(message)
+    elif category == "BOT":
+        permission = await check_bot(message)
+    elif category == "PRIVATE":
+        permission = await check_user(client, message)
+    else:
+        permission = False
+    cache.set(f"permission-{category}-{message.chat.id}", permission)
+    return permission
+
+
+async def check_user(client: Client, message: Message) -> bool:
+    if not ENABLE.USERS:
+        return False
+
+    cid = message.chat.id
+
+    """
+    mark as read for these user chats
+    TID_MUTE_USERS=111111,234567
+    TID_MUTE_USER_111111=true
+    """
+    if cid in [to_int(x.strip()) for x in os.getenv("TID_MUTE_USERS", "").split(",")] or true(os.getenv(f"TID_MUTE_USER_{cid}")):
+        await message.read()
+
+    """
+    do not process these chats
+    TID_SKIP_USERS=111111,234567
+    TID_SKIP_USER_111111=true
+    """
+    if cid in [to_int(x.strip()) for x in os.getenv("TID_SKIP_USERS", "").split(",")] or true(os.getenv(f"TID_SKIP_USER_{cid}")):
+        return False
+
+    """
+    whitelist mode, only allow these users
+    TID_USERS_WHITELIST_MODE=true  # enable whitelist mode for users
+    TID_ALLOW_USERS=111111,234567  # these are allowed users
+    TID_ALLOW_USER_111111=true  # also allow this user
+    TID_ALLOW_USER_IN_CHATS=111111,234567,-100234567  # also allow users in these chats
+    """
+    if true(os.getenv("TID_USERS_WHITELIST_MODE")):
+        if cid in [to_int(x.strip()) for x in os.getenv("TID_ALLOW_USERS", "").split(",")]:
+            return True
+        if true(os.getenv(f"TID_ALLOW_USER_{cid}")):
+            return True
+        # check if user is a member of these chats
+        with contextlib.suppress(Exception):
+            for chat_id in [int(x.strip()) for x in os.getenv("TID_ALLOW_USER_IN_CHATS", "").split(",") if x.strip()]:
+                if not str(chat_id).startswith("-100"):
+                    chat_id = to_int(f"-100{chat_id}")  # noqa: PLW2901
+                if await client.get_chat_member(chat_id, cid):
+                    return True
+        return False
+    return True
+
+
+async def check_group(message: Message) -> bool:
+    if not ENABLE.GROUPS:
+        return False
+    cid = to_int(f"{message.chat.id}".removeprefix("-100"))  # strip `-100` prefix
+
+    """
+    mark as read for these group chats
+    TID_MUTE_GROUPS=111111,234567,-100234567
+    TID_MUTE_GROUP_111111=true  # no `-100` prefix
+    """
+    if cid in [to_int(x.strip().removeprefix("-100")) for x in os.getenv("TID_MUTE_GROUPS", "").split(",")] or true(os.getenv(f"TID_MUTE_GROUP_{cid}")):
+        await message.read()
+
+    """
+    do not process these chats
+    TID_SKIP_GROUPS=111111,234567,-100234567
+    TID_SKIP_GROUP_111111=true  # no `-100` prefix
+    """
+    if cid in [to_int(x.strip().removeprefix("-100")) for x in os.getenv("TID_SKIP_GROUPS", "").split(",")] or true(os.getenv(f"TID_SKIP_GROUP_{cid}")):
+        return False
+
+    """
+    only process these chats
+    TID_ONLY_GROUPS=111111,234567,-100234567
+    """
+    if os.getenv("TID_ONLY_GROUPS") and cid not in [to_int(x.strip().removeprefix("-100")) for x in os.getenv("TID_ONLY_GROUPS", "").split(",")]:
+        return False
+    return True
+
+
+async def check_channel(message: Message) -> bool:
+    if not ENABLE.CHANNELS:
+        return False
+    cid = to_int(f"{message.chat.id}".removeprefix("-100"))  # strip `-100` prefix
+
+    # mark as read
+    if cid in [to_int(x.strip().removeprefix("-100")) for x in os.getenv("TID_MUTE_CHANNELS", "").split(",")] or true(os.getenv(f"TID_MUTE_CHANNEL_{cid}")):
+        await message.read()
+
+    # do not process
+    if cid in [to_int(x.strip().removeprefix("-100")) for x in os.getenv("TID_SKIP_CHANNELS", "").split(",")] or true(os.getenv(f"TID_SKIP_CHANNEL_{cid}")):
+        return False
+
+    # only process
+    if os.getenv("TID_ONLY_CHANNELS") and cid not in [to_int(x.strip().removeprefix("-100")) for x in os.getenv("TID_ONLY_CHANNELS", "").split(",")]:
+        return False
+    return True
+
+
+async def check_bot(message: Message) -> bool:
+    if not ENABLE.BOTS:
+        return False
+    cid = message.chat.id
+
+    # mark as read
+    if cid in [to_int(x.strip()) for x in os.getenv("TID_MUTE_BOTS", "").split(",")] or true(os.getenv(f"TID_MUTE_BOT_{cid}")):
+        await message.read()
+
+    # do not process
+    if cid in [to_int(x.strip()) for x in os.getenv("TID_SKIP_BOTS", "").split(",")] or true(os.getenv(f"TID_SKIP_BOT_{cid}")):
+        return False
+
+    # only process
+    if os.getenv("TID_ONLY_BOTS") and cid not in [to_int(x.strip()) for x in os.getenv("TID_ONLY_BOTS", "").split(",")]:
+        return False
+    return True