Commit 437de33

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-06 18:17:30
feat(price): warn if symbol matches multiple tickers
1 parent 798261e
src/messages/progress.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import asyncio
 from pathlib import Path
 
 from loguru import logger
@@ -10,7 +11,16 @@ from pyrogram.types import Message
 from config import TEXT_LENGTH, cache
 
 
-async def modify_progress(message: Message | None = None, text: str = "", *, detail_progress: bool = False, del_status: bool = False, force_update: bool = False, **kwargs):
+async def modify_progress(
+    message: Message | None = None,
+    text: str = "",
+    *,
+    detail_progress: bool = False,
+    del_status: bool = False,
+    del_delay: int = 0,
+    force_update: bool = False,
+    **kwargs,
+):
     """Modify the progress message.
 
     Args:
@@ -18,6 +28,7 @@ async def modify_progress(message: Message | None = None, text: str = "", *, det
         text (str): The new text to update.
         detail_progress(bool): Whether to show the detail progress.
         del_status (bool): Whether the progress is done.
+        del_delay (int): Delay seconds to delete the message.
         force_update (bool): Force update the message.
     """
     if message is None:
@@ -27,6 +38,7 @@ async def modify_progress(message: Message | None = None, text: str = "", *, det
     try:
         if del_status:
             logger.info("Deleting progress message")
+            await asyncio.sleep(del_delay)
             await message.delete()
             return
         if not text:
src/price/binance.py
@@ -29,6 +29,21 @@ async def get_binance_symbols() -> dict[str, str]:
     return res
 
 
+async def binance_supported(coin: str) -> tuple[str, str]:
+    """Check if the coin is supported by Binance.
+
+    If supported, return the supported symbol format and the market.
+
+    e.g. "BTC" -> ("BTCUSDT", "SPOT")
+    """
+    symbols = await get_binance_symbols()
+    symbol = coin.upper()
+    stablecoins = ["USDT", "USDC", "FDUSD", "TUSD"]
+    while symbol not in symbols and stablecoins:
+        symbol = f"{coin}{stablecoins.pop(0)}".upper()
+    return (symbol, symbols[symbol]) if symbol in symbols else ("", "")
+
+
 @cache.memoize(ttl=60)
 async def get_binance_price(coin: str, interval: str | None = None) -> dict:
     """Get the price of a crypto asset from Binance."""
@@ -39,15 +54,10 @@ async def get_binance_price(coin: str, interval: str | None = None) -> dict:
         interval = interval.lower()
     if interval not in ["1m", "3m", "5m", "15m", "30m", "1h", "2h", "4h", "6h", "8h", "12h", "1d", "3d", "1w", "1M"]:
         interval = "30m"
-    symbols = await get_binance_symbols()
-    symbol = coin.upper()
-    stablecoins = ["USDT", "USDC", "FDUSD", "TUSD"]
-    while symbol not in symbols and stablecoins:
-        symbol = f"{coin}{stablecoins.pop(0)}".upper()
-    if symbol not in symbols:
+    symbol, market = await binance_supported(coin)
+    if not symbol:
         return {}
 
-    market = symbols[symbol]
     if market == "SPOT":
         url = f"{API.BINANCE_SPOT}/api/v3/klines?symbol={symbol}&interval={interval}&limit=49"
         klines = await hx_req(url, proxy=PROXY.CRYPTO, silent=True)
src/price/coinmarketcap.py
@@ -47,25 +47,34 @@ async def get_cmc_fiat() -> list:
     return [coin["symbol"] for coin in data]
 
 
-@cache.memoize(ttl=60)
-async def get_cmc_price(coin: str, fiat: str = "USD") -> str:
-    """Get the price of a crypto asset from CoinMarketCap.
+async def cmc_supported(coin: str, fiat: str = "USD") -> dict:
+    """Check if the coin is supported by CoinMarketCap.
 
-    If the market cap of the coin is less than 10M, we skip it.
+    If supported, returns a dict which is needed to pass to CMC API.
+    Some shitcoins use the common name as the symbol. (e.g. "bitcoin")
+    So we determine wether we should use "symbol" or "slug" base on the coin rank
     """
     cmc_coins = await get_cmc_coins()
     if coin.upper() not in cmc_coins and coin.lower() not in cmc_coins.values():
-        return ""
-
-    # Some shitcoins use the common name as the symbol. (e.g. "bitcoin")
-    # So we use the highest rank symbol.
+        return {}
     symbol_index = 10e8
     slug_index = 10e8
     if coin.upper() in cmc_coins:
         symbol_index = list(cmc_coins.keys()).index(coin.upper())
     if coin.lower() in cmc_coins.values():
         slug_index = list(cmc_coins.values()).index(coin.lower())
-    params = {"symbol": coin.upper(), "convert": fiat} if symbol_index <= slug_index else {"slug": coin.lower(), "convert": fiat}
+    return {"symbol": coin.upper(), "convert": fiat} if symbol_index <= slug_index else {"slug": coin.lower(), "convert": fiat}
+
+
+@cache.memoize(ttl=60)
+async def get_cmc_price(coin: str, fiat: str = "USD") -> str:
+    """Get the price of a crypto asset from CoinMarketCap.
+
+    If the market cap of the coin is less than 10M, we skip it.
+    """
+    params = await cmc_supported(coin, fiat)
+    if not params:
+        return ""
     url = "https://pro-api.coinmarketcap.com/v1/cryptocurrency/quotes/latest"
     response = await hx_req(url, params=params, headers=HEADERS, merge_headers=False, proxy=PROXY.CRYPTO, check_keys=["data"], check_kv={"status.error_code": 0})
     data = next(iter(response["data"].values()), {})
src/price/entrypoint.py
@@ -7,21 +7,21 @@ from pyrogram.types import Message
 
 from config import ENABLE, PREFIX, cache
 from messages.parser import parse_msg
+from messages.progress import modify_progress
 from messages.sender import send2tg
 from messages.utils import equal_prefix, startswith_prefix
-from price.binance import get_binance_price, get_binance_symbols
-from price.coinmarketcap import cmc_convert_price, get_cmc_coins, get_cmc_price
-from price.okx import get_okx_price, get_okx_symbols
-from price.tradingview import get_tradingview_price, get_tradingview_symbols
+from price.binance import binance_supported, get_binance_price
+from price.coinmarketcap import cmc_convert_price, cmc_supported, get_cmc_price
+from price.okx import get_okx_price, okx_supported
+from price.tradingview import get_tradingview_price, tradingview_supported
 
 HELP = f"""
 💵**查询价格**
 使用说明: `{PREFIX.PRICE}` + Symbol + [@Interval]
 其中symbol(大小写不限)支持如下类别:
 1. 加密货币, 如 `BTC`
-2. 股票, 如 `AAPL` (A股, 港股, 美股)
-3. 指数, 如 `SPX`
-4. 汇率, 如 `USD CNY` (中间有空格)
+2. 股票&指数, 如 `AAPL` / `SPX` (A股, 港股, 美股)
+3. 汇率, 如 `USD CNY` (中间有空格)
 
 K线Interval (可选):
 - 加密货币(默认30m)
@@ -31,7 +31,6 @@ K线Interval (可选):
 
 说明:
 - 加密货币支持币种代码(BTC), 币种名称(bitcoin), 或交易对(BTCUSDC).
-- 对于符号冲突的Symbol, 可使用完整代码表示。例如灰度基金推出的Grayscale Ethereum Mini Trust也使用`ETH`作为Symbol, 和加密货币`ETH`冲突. 查询时可指定完整代码, 如 `{PREFIX.PRICE} AMEX:ETH`
 - 此外在Symbol前添加数字可以计算对应数量的价值。当前仅支持对加密货币和法币汇率进行计算。
 
 示例:
@@ -65,15 +64,15 @@ K线Interval (可选):
 async def get_asset_price(client: Client, message: Message, **kwargs):
     """Get asset price."""
     if not ENABLE.PRICE:
-        return None
+        return
     info = parse_msg(message)
     # send docs if message == "/price"
     if equal_prefix(info["text"], prefix=[PREFIX.PRICE]):
         await send2tg(client, message, texts=HELP, **kwargs)
-        return None
+        return
 
     if not startswith_prefix(info["text"], prefix=[PREFIX.PRICE]):
-        return None
+        return
     text = info["text"].removeprefix(PREFIX.PRICE).strip()
     # these patterns should use CoinMarketCap API
     # some coin has "$" in symbol, so we need to match it
@@ -94,7 +93,8 @@ async def get_asset_price(client: Client, message: Message, **kwargs):
         base = matched.group(1)
         quote = matched.group(2)
     if amount > 0 and base and quote and (msg := await cmc_convert_price(amount, base, quote)):
-        return await send2tg(client, message, texts=msg, **kwargs)
+        await send2tg(client, message, texts=msg, **kwargs)
+        return
 
     # match interval: "BTC @1m" or "000001 @15m"
     if matched := re.search(r"^([$\dA-Za-z]+)\s+@(\d+[A-Za-z])$", text, re.IGNORECASE):
@@ -104,24 +104,68 @@ async def get_asset_price(client: Client, message: Message, **kwargs):
         symbol = text
         interval = None
 
+    categories = await match_symbol_category(symbol)
+    if not categories:
+        await send2tg(client, message, texts=f"不支持此Symbol: {symbol}\n{HELP}", **kwargs)
+        return
+    if warnings := categories.get("warnings"):
+        warn_msg = await send2tg(client, message, texts=warnings, **kwargs)
+        warn_msg = warn_msg[0]
+    else:
+        warn_msg = None
+    # Tradingview
+    if categories.get("tradingview") and await get_tradingview_price(client, message, symbol, interval, **kwargs):
+        await modify_progress(warn_msg, del_status=True, del_delay=5)
+        return
+
     # Binance & OKX will return klines chart
     if (res := await get_binance_price(symbol, interval)) or (res := await get_okx_price(symbol, interval)):
-        return await send2tg(client, message, **res, **kwargs)
-
+        await send2tg(client, message, **res, **kwargs)
+        return
     # other crypto assets supported by CoinMarketCap
     if res := await get_cmc_price(text):
-        return await send2tg(client, message, texts=res, **kwargs)
-
-    # other assets to tradingview
-    if await get_tradingview_price(client, message, symbol, interval, **kwargs):
-        return None
+        await send2tg(client, message, texts=res, **kwargs)
+    await modify_progress(warn_msg, del_status=True, del_delay=5)
 
-    return await send2tg(client, message, texts=f"不支持此Symbol: {text}\n{HELP}", **kwargs)
 
-
-async def prefetch_price_symbols():
-    if ENABLE.PRICE:
-        await get_binance_symbols()
-        await get_okx_symbols()
-        await get_tradingview_symbols()
-        await get_cmc_coins()
+@cache.memoize(ttl=3600)
+async def match_symbol_category(symbol: str = "") -> dict[str, str]:
+    if not ENABLE.PRICE:
+        return {}
+    category = {}
+    if cmc := await cmc_supported(symbol):  # {"symbol": "BTC"} or {"slug": "bitcoin"}
+        category["cmc"] = cmc.get("symbol", cmc.get("slug"))
+        category["crypto"] = cmc.get("symbol", cmc.get("slug"))
+    if okx := await okx_supported(symbol):  # symbol, instId
+        category["okx"] = okx[0]
+        category["crypto"] = okx[0]
+    if binance := await binance_supported(symbol):  # symbol, market
+        category["binance"] = binance[0]
+        category["crypto"] = binance[0]
+    if tradingview := await tradingview_supported(symbol):  # ["SSE:000001", "SZSE:000001"]
+        category["tradingview"] = ", ".join(tradingview)
+    # skip some crypto ETF (e.g. Grayscale Bitcoin Mini Trust use symbol "AMEX:BTC")
+    if category.get("crypto") and category.get("tradingview"):
+        exchange, coin = tradingview[0].split(":")
+        if exchange == "AMEX" and category["crypto"].startswith(coin):  # hit crypto ETF
+            tradingview = []
+            del category["tradingview"]
+
+    # if tradingview has multiles symbols
+    if len(tradingview) > 1:
+        msg = f"⚠️**{symbol.upper()}**代码重复:\n"
+        msg += f"股票&指数: **{category['tradingview']}**\n"
+        msg += f"本次查询: **{tradingview[0]}**\n"
+        msg += f"查询其他请使用完整代码:\n`{PREFIX.PRICE} {'/ '.join(tradingview[1:])}`"
+        category["warnings"] = msg
+
+    # if symbol matches tradingview & crypto categories
+    if category.get("crypto") and category.get("tradingview"):
+        msg = f"⚠️**{symbol.upper()}**代码重复:\n"
+        msg += f"股票&指数: **{category['tradingview']}**\n"
+        msg += f"加密货币: **{category['crypto']}**\n"
+        msg += f"默认查询股票&指数: **{tradingview[0]}**\n"
+        msg += f"查询其他请使用完整代码:\n`{PREFIX.PRICE} {category['crypto']}`"
+        category["warnings"] = msg
+
+    return category
src/price/okx.py
@@ -24,6 +24,22 @@ async def get_okx_symbols() -> dict[str, str]:
     return res
 
 
+async def okx_supported(coin: str) -> tuple[str, str]:
+    """Check if the coin is supported by OKX.
+
+    If supported, return the supported symbol format and the instId.
+
+    e.g. "BTC" -> ("BTCUSDT", "BTC-USDT")
+    """
+    symbols = await get_okx_symbols()
+    coin = coin.upper().replace("-", "")
+    symbol = coin
+    suffixes = ["USDT", "USDC", "USDTSWAP", "USDCSWAP", "USDSWAP"]
+    while symbol not in symbols and suffixes:
+        symbol = f"{coin}{suffixes.pop(0)}".upper()
+    return (symbol, symbols[symbol]) if symbol in symbols else ("", "")
+
+
 @cache.memoize(ttl=60)
 async def get_okx_price(coin: str, interval: str | None = None) -> dict:
     """Get the price of a crypto asset from OKX."""
@@ -37,15 +53,9 @@ async def get_okx_price(coin: str, interval: str | None = None) -> dict:
     if interval not in ["1m", "3m", "5m", "15m", "30m", "1H", "2H", "4H", "6H", "12H", "1D", "2D", "3D", "1W", "1M", "3M"]:
         interval = "30m"
 
-    symbols = await get_okx_symbols()
-    coin = coin.upper().replace("-", "")
-    symbol = coin
-    suffixes = ["USDT", "USDC", "USDTSWAP", "USDCSWAP", "USDSWAP"]
-    while symbol not in symbols and suffixes:
-        symbol = f"{coin}{suffixes.pop(0)}".upper()
-    if symbol not in symbols:
+    symbol, inst_id = await okx_supported(coin)
+    if not symbol:
         return {}
-    inst_id = symbols[symbol]
     url = f"{API.OKX}/api/v5/market/candles?instId={inst_id}&bar={interval}&limit=49"
     response = await hx_req(url, proxy=PROXY.CRYPTO, check_kv={"code": "0"}, silent=True)
     if response.get("hx_error"):
src/price/tradingview.py
@@ -1,5 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+from collections import defaultdict
+
 from loguru import logger
 from pyrogram.client import Client
 from pyrogram.types import Message
@@ -10,27 +12,41 @@ from networking import hx_req
 
 
 @cache.memoize(ttl=43200)  # 12 hours
-async def get_tradingview_symbols() -> dict[str, str]:
+async def get_tradingview_symbols() -> dict[str, list[str]]:
     """Get all symbols from TradingView.
 
     Returns: {
-        "AAPL": "NASDAQ:AAPL",  # (simple symbol)
-        "NASDAQ:AAPL": "NASDAQ:AAPL",  # (full symbol)
+        "AAPL": ["NASDAQ:AAPL"],  # (simple symbol)
+        "NASDAQ:AAPL": ["NASDAQ:AAPL"],  # (full symbol)
+        "000001": ["SSE:000001", "SZSE:000001"],  # (multile tickers with same symbol)
         }
     """
     logger.info("Fetching TradingView symbols...")
     full = {}
-    for region in ["hongkong", "china", "america", "cfd"]:  # always put cfd at the last
+    for region in ["cfd", "america", "china", "hongkong"]:  # priority: cfd > america > china > hongkong
         url = f"https://scanner.tradingview.com/{region}/scan"
         response = await hx_req(url, proxy=PROXY.CRYPTO, check_keys=["data"], silent=True)
         if response.get("hx_error"):
             continue
         data = response["data"]
-        full |= {coin["s"]: coin["s"] for coin in data}
-    simple = {k.split(":")[-1]: v for k, v in full.items()}
+        full |= {coin["s"]: [coin["s"]] for coin in data}
+    simple = defaultdict(list)
+    for k, v in full.items():
+        if k.startswith("CRYPTOCAP"):
+            continue
+        simple[k.split(":")[-1]].extend(v)
     return simple | full
 
 
+async def tradingview_supported(symbol: str) -> list[str]:
+    """Check if the coin is supported by TradingView.
+
+    If supported, return the list of full symbol format.
+    """
+    symbols = await get_tradingview_symbols()
+    return symbols.get(symbol.upper(), [])
+
+
 async def get_tradingview_price(client: Client, message: Message, symbol: str, interval: str | None = None, **kwargs) -> bool:
     """Get the price of a crypto asset from TradingView.
 
@@ -45,10 +61,9 @@ async def get_tradingview_price(client: Client, message: Message, symbol: str, i
         interval = interval.lower()
     if interval not in ["1m", "3m", "5m", "15m", "30m", "45m", "1h", "2h", "3h", "4h", "1D", "1W", "1M", "3M", "6M", "1Y"]:
         interval = "15m"
-    symbol = symbol.upper()
-    symbols = await get_tradingview_symbols()
-    if symbol not in symbols:
+
+    symbols = await tradingview_supported(symbol)  # list of supported full symbols
+    if not symbols:
         return False
-    ticker = symbols[symbol]
-    await send_to_chartimg_bridge(client, message, ticker, interval, **kwargs)
+    await send_to_chartimg_bridge(client, message, symbols[0], interval, **kwargs)
     return True
src/main.py
@@ -24,7 +24,7 @@ from bridge.social import forward_social_media_results
 from config import DAILY_MESSAGES, DEVICE_NAME, ENABLE, PROXY, TID, TOKEN, TZ, cache
 from handler import handle_social_media, handle_utilities
 from messages.parser import parse_msg
-from price.entrypoint import prefetch_price_symbols
+from price.entrypoint import match_symbol_category
 from utils import cleanup_old_files, nowdt, to_int
 
 # ruff: noqa: RUF001
@@ -125,7 +125,7 @@ async def main():
 async def scheduling(client: Client):
     cache.evict()  # delete expired cache
     cleanup_old_files()
-    await prefetch_price_symbols()
+    await match_symbol_category()  # to cache all supported symbols
 
     # custom crontab jobs
     now = nowdt(TZ)