Commit f907c53

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-06 10:56:33
feat(price): support stock and index price query
1 parent 7f178dc
src/bridge/chartimg.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from bridge.utils import forward_bot_message
+from config import cache
+from messages.parser import parse_msg
+from utils import i_am_bot, to_int
+
+CHART_BOT = "chartImgOpnBot"
+
+
+@cache.memoize(ttl=10)
+async def send_to_chartimg_bridge(client: Client, message: Message, symbol: str, interval: str, target_chat: int | str | None = None, reply_msg_id: int = 0, **kwargs):  # noqa: ARG001
+    """See docs in `bridge/README.md` for details.
+
+    Args:
+        target_chat (int | str, optional): Send result to this telegram target chat. If not set, send to the trigger message's chat.
+        reply_msg_id (int, optional): If set to integer > 0, the result is sent as a reply message to this message_id.
+                                             If set to 0, reply to the trigger message itself.
+                                             If set to -1, do not send as a reply message.
+    """
+    if await i_am_bot(client):  # bot can't send message to other bots
+        return
+    target_cid = target_chat if target_chat else message.chat.id  # MSG-A's cid
+    # set MSG-A's mid
+    if to_int(reply_msg_id) == 0:
+        target_mid = message.id
+    elif to_int(reply_msg_id) == -1:
+        target_mid = None
+    else:
+        target_mid = to_int(reply_msg_id)
+    metadata = {"target_cid": target_cid, "target_mid": target_mid, "src": f"{symbol} {interval}"}
+    cache.set(f"bridge-{symbol} {interval}", metadata, ttl=15)  # save metadata to cache
+    logger.warning(f"Trying chartimg bridge (@{CHART_BOT}): {symbol} {interval}")
+    await client.send_message(chat_id=f"@{CHART_BOT}", text=f"/chart {symbol} {interval}")
+
+
+@cache.memoize(ttl=10)
+async def forward_chartimg_results(client: Client, message: Message):
+    """See docs in `bridge/README.md` for details."""
+    if message.from_user.username != CHART_BOT or not message.photo:
+        return
+
+    # got a photo message, format:
+    # [kline chart]\n{symbol} {interval}
+    info = parse_msg(message)
+    if metadata := cache.get(f"bridge-{info['text']}"):
+        await forward_bot_message(client, message, metadata)
src/bridge/social.py
@@ -1,13 +1,12 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
-import contextlib
-import re
 
 from loguru import logger
 from pyrogram.client import Client
-from pyrogram.types import Message, ReplyParameters
+from pyrogram.types import Message
 
+from bridge.utils import extract_forwarding_params, forward_bot_message, get_recent_msg_from_me
 from config import cache
 from messages.parser import parse_msg
 from utils import i_am_bot
@@ -36,7 +35,7 @@ async def send_to_social_media_bridge(client: Client, message: Message, url: str
         mid = None
     else:
         mid = kwargs["reply_msg_id"]
-    msg = f"#URL=( {url} ) \n#ID=({cid},{mid})".replace("None", "0")
+    msg = f"#SRC=( {url} ) \n#ID=({cid},{mid})".replace("None", "0")
 
     # add progress message
     if prog := kwargs.get("progress"):
@@ -54,73 +53,17 @@ async def forward_social_media_results(client: Client, message: Message):
     if message.from_user.username not in SOCIAL_BOTS or not message.media:
         return
     #  got a media message
-    parse_msg(message)
-
-    # Helper to extract forwarding parameters
-    def extract_forwarding_params(msg_text: str) -> dict:
-        """Extract target chat ID, message ID, and URL from message text."""
-        params = {}
-        id_match = re.search(r"#ID=\((-?\d+),(\d+)\)", msg_text)
-        url_match = re.search(r"#URL=\( (.*?) \)", msg_text)
-        if id_match and url_match:
-            params = {
-                "target_cid": id_match.group(1),
-                "target_mid": int(id_match.group(2)) if int(id_match.group(2)) != 0 else None,
-                "url": url_match.group(1),
-            }
-        if prog_match := re.search(r"#PROGRESS=\((-?\d+),(\d+)\)", msg_text):
-            params["prog_cid"] = int(prog_match.group(1))
-            params["prog_mid"] = int(prog_match.group(2)) if int(prog_match.group(2)) != 0 else None
-        return params
-
-    async def forward_message(client: Client, message: Message, params: dict):
-        """Forward the message to the target chat and delete the pending cache."""
-        logger.info(f"Forwarding chat=@{message.from_user.username}, id={message.id} -> chat={params['target_cid']}, id={params['target_mid']}")
-        if message.media_group_id and not cache.get(f"bridge-{params['url']}-{message.media_group_id}"):
-            # send media_group only once
-            cache.set(f"bridge-{params['url']}-{message.media_group_id}", "1", ttl=120)
-            await client.copy_media_group(
-                chat_id=params["target_cid"],
-                from_chat_id=message.chat.id,
-                message_id=message.id,
-                reply_parameters=ReplyParameters(message_id=params["target_mid"]),  # type: ignore
-            )
-        elif cache.get(f"bridge-{params['url']}"):
-            await client.copy_message(
-                chat_id=params["target_cid"],
-                from_chat_id=message.chat.id,
-                message_id=message.id,
-                reply_parameters=ReplyParameters(message_id=params["target_mid"]),  # type: ignore
-            )
-        cache.delete(f"bridge-{params['url']}")
-        with contextlib.suppress(Exception):
-            if params.get("prog_cid") and params.get("prog_mid"):
-                await client.delete_messages(chat_id=params["prog_cid"], message_ids=params["prog_mid"])
+    info = parse_msg(message)
 
     # Process reply-to messages
     if message.reply_to_message:
         params = extract_forwarding_params(str(message.reply_to_message.text))
-        if params and cache.get(f"bridge-{params['url']}"):
-            await forward_message(client, message, params)
+        if params and cache.get(f"bridge-{params['src']}"):
+            await forward_bot_message(client, message, params)
             return
 
     # Process messages not in reply context
-    my_msg = await get_last_message_from_myself(client, message.from_user.username, message.from_user.id)
+    my_msg = await get_recent_msg_from_me(client, info["handle"], info["uid"])
     params = extract_forwarding_params(my_msg)
-    if params and cache.get(f"bridge-{params['url']}"):
-        await forward_message(client, message, params)
-
-
-@cache.memoize(ttl=3)
-async def get_last_message_from_myself(client: Client, chat_id: int | str, opponent_id: int) -> str:
-    """Get the last message from me in the chat.
-
-    Args:
-        client (Client): The Pyrogram client.
-        chat_id (int | str): The chat id.
-        opponent_id (int): The opponent id.
-    """
-    async for message in client.get_chat_history(chat_id, limit=20):  # type: ignore
-        if message.from_user.id != opponent_id:
-            return message.text or message.caption or ""
-    return ""
+    if params and cache.get(f"bridge-{params['src']}"):
+        await forward_bot_message(client, message, params)
src/bridge/utils.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import contextlib
+import re
+
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message, ReplyParameters
+
+from config import cache
+from utils import to_int
+
+
+@cache.memoize(ttl=3)
+async def get_recent_msg_from_me(client: Client, chat_id: int | str, opponent_id: int, idx: int = 1) -> str:
+    """Get the last message from me in the chat.
+
+    Args:
+        client (Client): The Pyrogram client.
+        chat_id (int | str): The chat id.
+        opponent_id (int): The opponent id.
+        idx (int): The index of the message to get. 1 for the last message, 2 for the second last message, etc.
+    """
+    hit = 0
+    async for message in client.get_chat_history(to_int(chat_id), limit=20):  # type: ignore
+        if message.from_user.id != opponent_id:
+            hit += 1
+            if hit == idx:
+                return message.text or message.caption or ""
+    return ""
+
+
+def extract_forwarding_params(msg_text: str) -> dict:
+    """Extract target chat ID, message ID, and SRC from message text."""
+    params = {}
+    id_match = re.search(r"#ID=\((-?\d+),(\d+)\)", msg_text)
+    src_match = re.search(r"#SRC=\( (.*?) \)", msg_text)
+    if id_match and src_match:
+        params = {
+            "target_cid": id_match.group(1),
+            "target_mid": int(id_match.group(2)) if int(id_match.group(2)) != 0 else None,
+            "src": src_match.group(1),
+        }
+    if prog_match := re.search(r"#PROGRESS=\((-?\d+),(\d+)\)", msg_text):
+        params["prog_cid"] = int(prog_match.group(1))
+        params["prog_mid"] = int(prog_match.group(2)) if int(prog_match.group(2)) != 0 else None
+    return params
+
+
+async def forward_bot_message(client: Client, message: Message, params: dict):
+    """Forward the message to the target chat and delete the pending cache."""
+    logger.info(f"Forwarding chat=@{message.from_user.username}, id={message.id} -> chat={params['target_cid']}, id={params['target_mid']}")
+    if message.media_group_id and not cache.get(f"bridge-{params['src']}-{message.media_group_id}"):
+        # send media_group only once
+        cache.set(f"bridge-{params['src']}-{message.media_group_id}", "1", ttl=120)
+        await client.copy_media_group(
+            chat_id=params["target_cid"],
+            from_chat_id=message.chat.id,
+            message_id=message.id,
+            reply_parameters=ReplyParameters(message_id=params["target_mid"]),  # type: ignore
+        )
+    elif cache.get(f"bridge-{params['src']}"):
+        await client.copy_message(
+            chat_id=params["target_cid"],
+            from_chat_id=message.chat.id,
+            message_id=message.id,
+            reply_parameters=ReplyParameters(message_id=params["target_mid"]),  # type: ignore
+        )
+    cache.delete(f"bridge-{params['src']}")
+    with contextlib.suppress(Exception):
+        if params.get("prog_cid") and params.get("prog_mid"):
+            await client.delete_messages(chat_id=params["prog_cid"], message_ids=params["prog_mid"])
src/price/binance.py
@@ -29,8 +29,16 @@ async def get_binance_symbols() -> dict[str, str]:
 
 
 @cache.memoize(ttl=60)
-async def get_binance_price(coin: str, interval: str = "30m") -> dict:
+async def get_binance_price(coin: str, interval: str | None = None) -> dict:
     """Get the price of a crypto asset from Binance."""
+    if interval is None:
+        interval = "30m"
+    # Binance interval unit: m, h, d, w, M
+    if interval.endswith(("H", "D", "W")):
+        interval = interval.lower()
+    if interval not in ["1m", "3m", "5m", "15m", "30m", "1h", "2h", "6h", "8h", "12h", "1d", "3d", "1w", "1M"]:
+        interval = "30m"
+
     symbols = await get_binance_symbols()
     symbol = coin.upper()
     stablecoins = ["USDT", "USDC", "FDUSD", "TUSD"]
src/price/coinmarketcap.py
@@ -49,12 +49,18 @@ async def get_cmc_price(coin: str, fiat: str = "USD") -> str:
     If the market cap of the coin is less than 10M, we skip it.
     """
     cmc_coins = await get_cmc_coins()
-    if coin.upper() in cmc_coins:
-        params = {"symbol": coin.upper(), "convert": fiat}
-    elif coin.lower() in cmc_coins.values():
-        params = {"slug": coin.lower(), "convert": fiat}
-    else:
+    if coin.upper() not in cmc_coins and coin.lower() not in cmc_coins.values():
         return ""
+
+    # Some shitcoins use the common name as the symbol. (e.g. "bitcoin")
+    # So we use the highest rank symbol.
+    symbol_index = 10e8
+    slug_index = 10e8
+    if coin.upper() in cmc_coins:
+        symbol_index = list(cmc_coins.keys()).index(coin.upper())
+    if coin.lower() in cmc_coins.values():
+        slug_index = list(cmc_coins.values()).index(coin.lower())
+    params = {"symbol": coin.upper(), "convert": fiat} if symbol_index <= slug_index else {"slug": coin.lower(), "convert": fiat}
     url = "https://pro-api.coinmarketcap.com/v1/cryptocurrency/quotes/latest"
     response = await hx_req(url, params=params, headers=HEADERS, merge_headers=False, proxy=PROXY.CRYPTO, check_has_kv=["data"])
     data = next(iter(response.json()["data"].values()), {})
src/price/entrypoint.py
@@ -13,37 +13,74 @@ from messages.utils import equal_prefix, startswith_prefix
 from price.binance import get_binance_price
 from price.coinmarketcap import cmc_convert_price, get_cmc_price
 from price.okx import get_okx_price
+from price.tradingview import get_tradingview_price
 
 HELP = f"""
 💵**查询价格**
-示例: (Symbol大小写均可)
-1. `{PREFIX.PRICE} BTC` 查询比特币价格
-2. `{PREFIX.PRICE} BTC CNY` 计算1枚BTC的CNY价值
-3. `{PREFIX.PRICE} USD CNY` 计算USD与CNY的汇率
-4. `{PREFIX.PRICE} 1.5 BTC` 计算1.5枚BTC的USD价值
-5. `{PREFIX.PRICE} 1.5 BTC CNY` 计算1.5枚BTC的CNY价值
+使用说明: `{PREFIX.PRICE}` + Symbol + [@Interval]
+其中symbol(大小写不限)支持如下类别:
+1. 加密货币, 如 `BTC`
+2. 股票, 如 `AAPL` (A股, 港股, 美股)
+3. 指数, 如 `SPX`
+4. 汇率, 如 `USD CNY` (中间有空格)
+
+K线Interval (可选):
+- 加密货币(默认30m)
+1m,3m,5m,15m,30m,1h,2h,6h,8h,12h,1D,3D,1W,1M
+- 股票&指数(默认15m)
+1m,3m,5m,15m,30m,45m,1h,2h,3h,4h,1D,1W,1M,3M,6M,1Y
+
+说明:
+- 加密货币支持币种代码(BTC), 币种名称(bitcoin), 或交易对(BTCUSDC).
+- 对于符号冲突的Symbol, 可使用完整代码表示。例如灰度基金推出的Grayscale Ethereum Mini Trust也使用`ETH`作为Symbol, 和加密货币`ETH`冲突. 查询时可指定完整代码, 如 `{PREFIX.PRICE} AMEX:ETH`
+- 此外在Symbol前添加数字可以计算对应数量的价值。当前仅支持对加密货币和法币汇率进行计算。
+
+示例:
+1. 查询加密货币价格
+- 对于Binance和OKX支持的币种, 还会返回K线图(默认30m)
+- `{PREFIX.PRICE} BTC`
+- `{PREFIX.PRICE} ethereum`
+- `{PREFIX.PRICE} DOGEUSDT`
+- `{PREFIX.PRICE} BTC @4h`
+
+2. 查询股票&指数价格:
+- 默认返回Interval为15m的K线图
+- `{PREFIX.PRICE}` AAPL 或 NASDAQ:AAPL
+- `{PREFIX.PRICE}` SPX 或 SP:SPX
+- `{PREFIX.PRICE}` 000001 或 SSE:000001
+- `{PREFIX.PRICE} AAPL @1m`
+
+3. 查询汇率:
+- `{PREFIX.PRICE} USD CNY`
+- `{PREFIX.PRICE} BTC CNY`
+- `{PREFIX.PRICE} DOGE BTC`
+
+4. 计算价值:
+- `{PREFIX.PRICE} 1.5 BTC` (默认计算美元价值)
+- `{PREFIX.PRICE} 1.5 BTC CNY`
+- `{PREFIX.PRICE} 3000 JPY CNY`
 """
 
 
 @cache.memoize(ttl=60)
-async def get_asset_price(client: Client, message: Message, **kwargs) -> None:
+async def get_asset_price(client: Client, message: Message, **kwargs):
     """Get asset price."""
     if not ENABLE.PRICE:
-        return
+        return None
     info = parse_msg(message)
     # send docs if message == "/price"
     if equal_prefix(info["text"], prefix=[PREFIX.PRICE]):
         await send2tg(client, message, texts=HELP, **kwargs)
-        return
+        return None
 
     if not startswith_prefix(info["text"], prefix=[PREFIX.PRICE]):
-        return
+        return None
     text = info["text"].removeprefix(PREFIX.PRICE).strip()
-
     # these patterns should use CoinMarketCap API
-    pattern_1 = r"^([\d.]+)\s+([A-Za-z]+)\s+([A-Za-z]+)"  # match "1.5 BTC CNY"
-    pattern_2 = r"^([\d.]+)\s+([A-Za-z]+)"  # match "1.5 BTC"
-    pattern_3 = r"^([A-Za-z]+)\s+([A-Za-z]+)"  # match "BTC CNY"
+    # some coin has "$" in symbol, so we need to match it
+    pattern_1 = r"^([\d.]+)\s+([$\dA-Za-z]+)\s+([$\dA-Za-z]+)$"  # match "1.5 BTC CNY"
+    pattern_2 = r"^([\d.]+)\s+([$\dA-Za-z]+)$"  # match "1.5 BTC"
+    pattern_3 = r"^([$\dA-Za-z]+)\s+([$\dA-Za-z]+)$"  # match "BTC CNY"
     amount, base, quote = 0, "", ""
     if matched := re.search(pattern_1, text, re.IGNORECASE):
         amount = float(matched.group(1))
@@ -58,13 +95,26 @@ async def get_asset_price(client: Client, message: Message, **kwargs) -> None:
         base = matched.group(1)
         quote = matched.group(2)
     if amount > 0 and base and quote and (msg := await cmc_convert_price(amount, base, quote)):
-        await send2tg(client, message, texts=msg, **kwargs)
-        return
-
-    # match "BTC"
-    if (res := await get_binance_price(text)) or (res := await get_okx_price(text)):
-        await send2tg(client, message, **res, **kwargs)  # with klines chart
-    elif res := await get_cmc_price(text):
-        await send2tg(client, message, texts=res, **kwargs)
-    else:
-        await send2tg(client, message, texts=f"不支持此Symbol: {text}\n{HELP}", **kwargs)
+        return await send2tg(client, message, texts=msg, **kwargs)
+
+    # match interval: "BTC @1m" or "000001 @15m"
+    if matched := re.search(r"^([$\dA-Za-z]+)\s+@(\d+[A-Za-z])$", text, re.IGNORECASE):
+        symbol = matched.group(1)
+        interval = matched.group(2)
+    else:  # match single symbol: "BTC" / "AAPL" / "SPX" / "000001"
+        symbol = text
+        interval = None
+
+    # Binance & OKX will return klines chart
+    if (res := await get_binance_price(symbol, interval)) or (res := await get_okx_price(symbol, interval)):
+        return await send2tg(client, message, **res, **kwargs)
+
+    # other crypto assets supported by CoinMarketCap
+    if res := await get_cmc_price(text):
+        return await send2tg(client, message, texts=res, **kwargs)
+
+    # other assets to tradingview
+    if await get_tradingview_price(client, message, symbol, interval, **kwargs):
+        return None
+
+    return await send2tg(client, message, texts=f"不支持此Symbol: {text}\n{HELP}", **kwargs)
src/price/okx.py
@@ -20,8 +20,16 @@ async def get_okx_symbols() -> dict[str, str]:
 
 
 @cache.memoize(ttl=60)
-async def get_okx_price(coin: str, interval: str = "30m") -> dict:
+async def get_okx_price(coin: str, interval: str | None = None) -> dict:
     """Get the price of a crypto asset from OKX."""
+    if interval is None:
+        interval = "30m"
+    # OKX interval unit: m, H, D, W, M
+    if interval.endswith(("h", "d", "w")):
+        interval = interval.upper()
+    if interval not in ["1m", "3m", "5m", "15m", "30m", "1H", "2H", "4H", "6H", "12H", "1D", "2D", "3D", "1W", "1M", "3M"]:
+        interval = "30m"
+
     symbols = await get_okx_symbols()
     coin = coin.upper().replace("-", "")
     symbol = coin
src/price/tradingview.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from bridge.chartimg import send_to_chartimg_bridge
+from config import PROXY, cache
+from networking import hx_req
+
+
+@cache.memoize(ttl=7200)
+async def get_tradingview_symbols() -> dict[str, str]:
+    """Get all symbols from TradingView.
+
+    Returns: {
+        "AAPL": "NASDAQ:AAPL",  # (simple symbol)
+        "NASDAQ:AAPL": "NASDAQ:AAPL",  # (full symbol)
+        }
+    """
+    full = {}
+    for region in ["hongkong", "china", "america", "cfd"]:  # always put cfd at the last
+        url = f"https://scanner.tradingview.com/{region}/scan"
+        response = await hx_req(url, proxy=PROXY.CRYPTO, check_has_kv=["data"])
+        data = response.json()["data"]
+        full |= {coin["s"]: coin["s"] for coin in data}
+    simple = {k.split(":")[-1]: v for k, v in full.items()}
+    return simple | full
+
+
+async def get_tradingview_price(client: Client, message: Message, symbol: str, interval: str | None = None, **kwargs) -> bool:
+    """Get the price of a crypto asset from TradingView.
+
+    This function is currently supported by third-party bot: https://t.me/chartImgOpnBot
+    """
+    if interval is None:
+        interval = "15m"
+    # TradingView interval unit: m, h, D, W, M, Y
+    if interval.endswith("h"):
+        interval = interval.upper()
+    elif interval.endswith(("D", "W", "M", "Y")):
+        interval = interval.lower()
+    if interval not in ["1m", "3m", "5m", "15m", "30m", "45m", "1h", "2h", "3h", "4h", "1D", "1W", "1M", "3M", "6M", "1Y"]:
+        interval = "15m"
+    symbol = symbol.upper()
+    symbols = await get_tradingview_symbols()
+    if symbol not in symbols:
+        return False
+    ticker = symbols[symbol]
+    await send_to_chartimg_bridge(client, message, ticker, interval, **kwargs)
+    return True
src/main.py
@@ -18,6 +18,7 @@ from pyrogram.client import Client
 from pyrogram.sync import idle
 from pyrogram.types import LinkPreviewOptions, Message
 
+from bridge.chartimg import forward_chartimg_results
 from bridge.ocr import forward_ocr_results
 from bridge.social import forward_social_media_results
 from config import DAILY_MESSAGES, DEVICE_NAME, ENABLE, PROXY, TID, TOKEN, TZ, cache
@@ -89,6 +90,7 @@ async def main():
         parse_msg(message, verbose=True)
         await forward_social_media_results(client, message)
         await forward_ocr_results(client, message)
+        await forward_chartimg_results(client, message)
         await handle_utilities(client, message, detail_progress=True)
         await handle_social_media(client, message, detail_progress=True)