Commit ecbfb32

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-02-14 18:29:28
feat: use `chart-img` API to get tradingview chart
1 parent 285cd5b
Changed files (4)
src/bridge/chartimg.py
@@ -15,7 +15,9 @@ CHART_BOT = "chartImgBot"
 
 @cache.memoize(ttl=10)
 async def send_to_chartimg_bridge(client: Client, message: Message, symbol: str, interval: str, target_chat: int | str | None = None, reply_msg_id: int = 0, **kwargs):  # noqa: ARG001
-    """See docs in `bridge/README.md` for details.
+    """This bridge is now deprecated. We use API from https://chart-img.com to get the chart directly.
+
+    See docs in `bridge/README.md` for details.
 
     Args:
         target_chat (int | str, optional): Send result to this telegram target chat. If not set, send to the trigger message's chat.
src/price/entrypoint.py
@@ -3,9 +3,9 @@
 import re
 
 from pyrogram.client import Client
-from pyrogram.types import Message
+from pyrogram.types import Message, ReplyParameters
 
-from config import ENABLE, PREFIX, cache
+from config import ENABLE, PREFIX, TZ, cache
 from messages.parser import parse_msg
 from messages.progress import modify_progress
 from messages.sender import send2tg
@@ -46,7 +46,7 @@ K线Interval (可选):
 - `{PREFIX.PRICE} BTC @4h`
 
 2. 查询股票价格:
-- 默认返回Interval为15m的K线图
+- 默认返回Interval为5m的K线图
 - `{PREFIX.PRICE}` AAPL 或 NASDAQ:AAPL
 - `{PREFIX.PRICE}` SPX 或 SP:SPX
 - `{PREFIX.PRICE}` 000001 或 SSE:000001
@@ -113,27 +113,36 @@ async def get_asset_price(client: Client, message: Message, **kwargs):
     if not categories:
         await send2tg(client, message, texts=f"不支持此Symbol: {symbol.upper()}\n{HELP}", **kwargs)
         return
+    msg = f"🔍查询价格: {symbol.upper()}"
+    if kwargs.get("show_progress"):
+        res = await send2tg(client, message, texts=msg, **kwargs)
+        kwargs["progress"] = res[0]
     if warnings := categories.get("warnings"):
-        warn_msg = await send2tg(client, message, texts=warnings, **kwargs)
-        warn_msg = warn_msg[0]
-    else:
-        warn_msg = None
+        await modify_progress(text=warnings, **kwargs)
     # Tradingview
-    if not crypto_only and categories.get("tradingview") and await get_tradingview_price(client, message, symbol, interval, **kwargs):
-        await modify_progress(warn_msg, del_status=True, del_delay=10)
+    if not crypto_only and categories.get("tradingview") and (data := await get_tradingview_price(symbol, interval, **kwargs)):
+        await client.send_photo(
+            chat_id=info["cid"],
+            photo=data["url"],
+            caption=f"[{data['symbol']}](https://www.tradingview.com/chart/?symbol={data['symbol']}) @{data['interval']} ({TZ})",
+            reply_parameters=ReplyParameters(message_id=info["mid"]),
+        )
+        await modify_progress(del_status=True, **kwargs)
         return
 
     # Belows are only for crypto market
     if stock_only:
+        await modify_progress(del_status=True, **kwargs)
         return
     # Binance & OKX will return klines chart
     if (res := await get_binance_price(symbol, interval)) or (res := await get_okx_price(symbol, interval)):
         await send2tg(client, message, **res, **kwargs)
+        await modify_progress(del_status=True, **kwargs)
         return
     # other crypto assets supported by CoinMarketCap
     if res := await get_cmc_price(text):
         await send2tg(client, message, texts=res, **kwargs)
-    await modify_progress(warn_msg, del_status=True, del_delay=10)
+    await modify_progress(del_status=True, **kwargs)
 
 
 @cache.memoize(ttl=3600)
@@ -157,22 +166,23 @@ async def match_symbol_category(symbol: str = "", *, crypto_only: bool = False,
             category["crypto"] = binance_symbol
 
     # Stock market
-    tradingview = []
-    if not crypto_only and (tradingview := await tradingview_supported(symbol)):  # ["SSE:000001", "SZSE:000001"]
-        category["tradingview"] = ", ".join(tradingview)
+    tv_symbols = []
+    if not crypto_only and (tradingview := await tradingview_supported(symbol)):  # [("SSE:000001", "china"), ("SZSE:000001", "china")]
+        tv_symbols = [x[0] for x in tradingview]  # ["SSE:000001", "SZSE:000001"]
+        category["tradingview"] = ", ".join(tv_symbols)
     # skip some crypto ETF (e.g. Grayscale Bitcoin Mini Trust use symbol "AMEX:BTC")
     if category.get("crypto") and category.get("tradingview"):
-        exchange, coin = tradingview[0].split(":")
+        exchange, coin = tradingview[0][0].split(":")
         if exchange == "AMEX" and category["crypto"].startswith(coin):  # hit crypto ETF
             tradingview = []
             del category["tradingview"]
 
     # if tradingview has multiles symbols
-    if len(tradingview) > 1:
+    if len(tv_symbols) > 1:
         msg = f"⚠️**{symbol.upper()}**代码重复:\n"
         msg += f"股票: **{category['tradingview']}**\n"
-        msg += f"本次查询: **{tradingview[0]}**\n"
-        msg += f"查询其他请使用完整代码:\n`{PREFIX.PRICE} {'/ '.join(tradingview[1:])}`"
+        msg += f"本次查询: **{tv_symbols[0]}**\n"
+        msg += f"查询其他请使用完整代码:\n`{PREFIX.PRICE} {'/ '.join(tv_symbols[1:])}`"
         category["warnings"] = msg
 
     # if symbol matches tradingview & crypto categories
@@ -180,7 +190,7 @@ async def match_symbol_category(symbol: str = "", *, crypto_only: bool = False,
         msg = f"⚠️**{symbol.upper()}**代码重复:\n"
         msg += f"股票: **{category['tradingview']}**\n"
         msg += f"加密货币: **{category['crypto']}**\n"
-        msg += f"默认查询股票: **{tradingview[0]}**\n"
+        msg += f"默认查询股票: **{tv_symbols[0]}**\n"
         msg += f"查询其他`{PREFIX.CRYPTO}` 或 `{PREFIX.STOCK}`限定市场"
         category["warnings"] = msg
 
src/price/tradingview.py
@@ -1,24 +1,23 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import asyncio
 from collections import defaultdict
 
 from loguru import logger
-from pyrogram.client import Client
-from pyrogram.types import Message
 
-from bridge.chartimg import send_to_chartimg_bridge
-from config import PROXY, cache
+from config import PROXY, TOKEN, TZ, cache
+from messages.progress import modify_progress
 from networking import hx_req
 
 
 @cache.memoize(ttl=43200)  # 12 hours
-async def get_tradingview_symbols() -> dict[str, list[str]]:
+async def get_tradingview_symbols() -> dict[str, list[tuple]]:
     """Get all symbols from TradingView.
 
     Returns: {
-        "AAPL": ["NASDAQ:AAPL"],  # (simple symbol)
-        "NASDAQ:AAPL": ["NASDAQ:AAPL"],  # (full symbol)
-        "000001": ["SSE:000001", "SZSE:000001"],  # (multile tickers with same symbol)
+        "AAPL": [("NASDAQ:AAPL", "america")]  # (simple symbol)
+        "NASDAQ:AAPL": [("NASDAQ:AAPL", "america")],  # (full symbol)
+        "000001": [("SSE:000001", "china"), ("SZSE:000001", "china")]  # (multile tickers with same symbol)
         }
     """
     logger.info("Fetching TradingView symbols...")
@@ -28,8 +27,8 @@ async def get_tradingview_symbols() -> dict[str, list[str]]:
         response = await hx_req(url, proxy=PROXY.CRYPTO, check_keys=["data"], silent=True)
         if response.get("hx_error"):
             continue
-        data = response["data"]
-        full |= {coin["s"]: [coin["s"]] for coin in data}
+        response = response["data"]
+        full |= {coin["s"]: [(coin["s"], region)] for coin in response}
     simple = defaultdict(list)
     for k, v in full.items():
         if k.startswith("CRYPTOCAP"):
@@ -38,32 +37,57 @@ async def get_tradingview_symbols() -> dict[str, list[str]]:
     return simple | full
 
 
-async def tradingview_supported(symbol: str) -> list[str]:
+async def tradingview_supported(symbol: str) -> list[tuple[str, str]]:
     """Check if the coin is supported by TradingView.
 
     If supported, return the list of full symbol format.
+    e.g. [("SSE:000001", "china"), ("SZSE:000001", "china")]
     """
     symbols = await get_tradingview_symbols()
     return symbols.get(symbol.upper(), [])
 
 
-async def get_tradingview_price(client: Client, message: Message, symbol: str, interval: str | None = None, **kwargs) -> bool:
+async def get_tradingview_price(symbol: str, interval: str | None = None, **kwargs) -> dict:
     """Get the price of a crypto asset from TradingView.
 
-    This function is currently supported by third-party bot: https://t.me/chartImgOpnBot
+    Returns: {
+        "url": "remote url of the chart image",
+        "symbol": "NADAQ:AAPL",
+        "interval": "5m"
+    }
     """
     if interval is None:
-        interval = "15m"
+        interval = "5m"
     # TradingView interval unit: m, h, D, W, M, Y
     if interval.endswith("h"):
         interval = interval.upper()
     elif interval.endswith(("D", "W", "M", "Y")):
         interval = interval.lower()
     if interval not in ["1m", "3m", "5m", "15m", "30m", "45m", "1h", "2h", "3h", "4h", "1D", "1W", "1M", "3M", "6M", "1Y"]:
-        interval = "15m"
+        interval = "5m"
 
     symbols = await tradingview_supported(symbol)  # list of supported full symbols
     if not symbols:
-        return False
-    await send_to_chartimg_bridge(client, message, symbols[0], interval, **kwargs)
-    return True
+        return {}
+    if not TOKEN.CHART_IMG:
+        await modify_progress(text="❌CHART_IMG_API is not set. Get it from: https://chart-img.com", force_update=True, **kwargs)
+        await asyncio.sleep(5)
+        return {}
+    query_symbol, market = symbols[0]  # the first supported symbol
+    logger.info(f"Fetching TradingView chart for {query_symbol} @{interval} in {market.capitalize()}...")
+
+    params = {
+        "theme": "dark",
+        "interval": interval,
+        "session": "extended",
+        "symbol": query_symbol,
+        "timezone": TZ,
+        "studies": [{"name": "Volume", "forceOverlay": True}, {"name": "MA Cross", "override": {"PlotCrosses.visible": False}}],
+        "override": {"showStudyLastValue": False},
+    }
+    logger.trace(params)
+    resp = await hx_req("https://api.chart-img.com/v2/tradingview/advanced-chart/storage", "POST", max_retry=0, headers={"x-api-key": TOKEN.CHART_IMG}, post_json=params, check_keys=["url"])
+    if error := resp.get("hx_error"):
+        await modify_progress(text=f"❌Failed to fetch TradingView chart for {query_symbol} @{interval} in {market.capitalize()}\n{error}", force_update=True, **kwargs)
+        return {}
+    return {"url": resp["url"], "symbol": query_symbol, "interval": interval} if resp["url"] else {}
src/config.py
@@ -111,6 +111,7 @@ class TOKEN:
     CMC_API_KEY = os.getenv("CMC_API_KEY", "")
     GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY", "")
     GOOGLE_SEARCH_CX = os.getenv("GOOGLE_SEARCH_CX", "")
+    CHART_IMG = os.getenv("CHART_IMG_KEY", "")
 
 
 class PROXY:  # format: socks5://127.0.0.1:7890