main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import re
  4from collections import defaultdict
  5from datetime import UTC, datetime, timedelta
  6from decimal import Decimal
  7from io import BytesIO
  8from pathlib import Path
  9from zoneinfo import ZoneInfo
 10
 11import pandas as pd
 12from glom import glom
 13from loguru import logger
 14
 15from config import DANMU, DOWNLOAD_DIR, TZ, cache
 16from database.r2 import get_cf_r2, set_cf_r2
 17from messages.progress import modify_progress
 18from others.emoji import CURRENCY
 19from utils import nowdt, number, number_to_emoji
 20
 21
 22async def query_r2(dates: list[str], user: str, keyword: str, caption: str, super_chats: defaultdict, qtype: str, **kwargs) -> dict:
 23    """从R2获取记录.
 24
 25    日期从新到旧, 数据从旧到新
 26    Returns:
 27        {"texts": str, "count": int}
 28    """
 29    if not dates:
 30        return {}
 31
 32    total_count = 0
 33    queried_dates = []
 34    texts = ""
 35    for date in sorted(dates, reverse=True):  # 日期从新到旧
 36        df = await query_r2_for_date(date, qtype)
 37        queried_dates.append(date.upper())
 38        if len(df) == 0:
 39            continue
 40        parsed = await parse_dataframe(df, user, keyword, super_chats, qtype)
 41        count = parsed.get("count", 0)
 42        if count == 0:
 43            continue
 44        total_count += count
 45        texts += parsed.get("texts", "")
 46        await modify_progress(text=caption + f"\n🔍查询时间: {''.join(queried_dates)}\n⏳匹配{qtype}数: {total_count}", force_update=True, **kwargs)
 47        del parsed
 48    return {"texts": texts.strip(), "count": total_count}
 49
 50
 51async def parse_dataframe(df: pd.DataFrame, user: str, keyword: str, super_chats: defaultdict, qtype: str) -> dict:
 52    """解析从R2获取的记录."""
 53    texts = ""
 54    count = 0
 55    if keyword:
 56        df = df[df["content"].str.contains(keyword)]
 57    if user and qtype == "弹幕":
 58        uids = await get_uids_by_name(name=user)
 59        df = df[df["uid"].isin(uids)]
 60    df["livedate"] = df["ts"].apply(ts_to_liveday, args=(qtype,))
 61    df = df.sort_values(by=["livedate", "ts"], ascending=[False, True])
 62    processed_day = set()
 63    for _, row in df.iterrows():
 64        day, title, url = await live_date_info(row["ts"], qtype)
 65        day_str = f"\n{title}\n" if day not in processed_day else ""
 66        processed_day.add(day)
 67        if qtype == "发言":
 68            texts += f"\n{day_str}{ts_with_url(row['ts'], url)}: {row['content'].strip()}"
 69        else:
 70            sc_amount = ""
 71            if super_chat := row["superchat"]:
 72                currency, amount = super_chat.split(" ")
 73                super_chats[currency] += Decimal(amount)
 74                sc_amount = f" ({CURRENCY[currency]}{currency} {number(amount)})" if currency in CURRENCY else ""
 75            username = "" if user else f"|{row['name']}"  # 当指定过滤user时, 隐藏用户名
 76            texts += f"\n{day_str}{ts_time(row['ts'])}{username}{sc_amount}: {row['content'].strip()}"
 77        count += 1
 78    return {"texts": texts.rstrip(), "count": count}
 79
 80
 81async def query_r2_for_date(date: str, qtype: str) -> pd.DataFrame:
 82    """首先尝试从本地磁盘获取, 如果不存在, 则从R2获取."""
 83    r2_key = DANMU.R2_PREFIX.rstrip("/") + f"/{qtype}/{date[:4]}"
 84    path = Path(f"{DOWNLOAD_DIR}/{qtype}/{date[:4]}.parquet")
 85    now = datetime.now(UTC).timestamp()
 86    # always use local file if it is less than 1 hour old
 87    if path.is_file() and now - path.stat().st_mtime < 3600:
 88        logger.trace(f"Load {qtype} from local file: {path.name}")
 89        df = pd.read_parquet(path)
 90
 91    # get from r2 for this year
 92    elif date[:4] == nowdt(TZ).strftime("%Y"):
 93        logger.debug(f"Query {qtype} from R2: {r2_key}")
 94        df = await get_r2_dataframe(r2_key, path)
 95
 96    # use local file if it exists
 97    elif path.is_file():
 98        logger.trace(f"Load {qtype} from local file: {path.name}")
 99        df = pd.read_parquet(path)
100
101    # get from r2 for other dates
102    else:
103        logger.debug(f"Save {qtype} to {path.name}")
104        df = await get_r2_dataframe(r2_key, path)
105
106    # filter specific date
107    # use 30-hour system for danmu
108    offset = timedelta(hours=6) if qtype == "弹幕" else timedelta(hours=0)
109    if len(date) == 7:  # YYYY-MM
110        start = datetime.strptime(date, "%Y-%m").replace(day=1, tzinfo=ZoneInfo(TZ)) + offset
111        end = datetime.strptime(date, "%Y-%m").replace(day=31, hour=23, minute=59, second=59, microsecond=999999, tzinfo=ZoneInfo(TZ)) + offset
112        start_ts = int(start.timestamp())
113        end_ts = int(end.timestamp())
114        df = df[(df["ts"] >= start_ts) & (df["ts"] <= end_ts)]
115    elif len(date) == 10:  # YYYY-MM-DD
116        start = datetime.strptime(date, "%Y-%m-%d").replace(tzinfo=ZoneInfo(TZ)) + offset
117        end = datetime.strptime(date, "%Y-%m-%d").replace(hour=23, minute=59, second=59, microsecond=999999, tzinfo=ZoneInfo(TZ)) + offset
118        start_ts = int(start.timestamp())
119        end_ts = int(end.timestamp())
120        df = df[(df["ts"] >= start_ts) & (df["ts"] <= end_ts)]
121    return df.reset_index(drop=True)
122
123
124@cache.memoize(ttl=120)
125async def get_r2_dataframe(r2_key: str, path: Path | None = None) -> pd.DataFrame:
126    # always use local file if it is less than 1 hour old
127    if isinstance(path, Path) and path.is_file() and datetime.now(UTC).timestamp() - path.stat().st_mtime < 3600:
128        logger.trace(f"Load {r2_key} from local file: {path.name}")
129        return pd.read_parquet(path)
130
131    parquet = await get_cf_r2(r2_key, rformat="bytes", silent=True)
132    if isinstance(parquet, bytes):
133        df = pd.read_parquet(BytesIO(parquet)).drop_duplicates()
134        if path is not None:
135            path.parent.mkdir(parents=True, exist_ok=True)
136            df.to_parquet(path, index=False, compression="brotli")
137        return df
138    return pd.DataFrame()
139
140
141async def save_dataframe_to_r2(r2_key: str, df: pd.DataFrame):
142    buffer = BytesIO()
143    df.to_parquet(buffer, index=False, compression="brotli")
144    parquet_bytes = buffer.getvalue()
145    buffer.close()
146    await set_cf_r2(r2_key, parquet_bytes, mime_type="application/x-parquet", silent=True)
147
148
149def ts_to_liveday(ts: int, qtype: str) -> str:
150    """将时间戳转换为直播日期.
151
152    弹幕时间戳是真实时间, 而发言时间是相对开播时间
153    """
154    dt = datetime.fromtimestamp(ts, tz=ZoneInfo(TZ))
155    if qtype == "发言":
156        return dt.strftime("%Y-%m-%d")
157
158    if 0 <= dt.hour < 6:  #  过了凌晨也算前一天
159        return (dt - timedelta(days=1)).strftime("%Y-%m-%d")
160    return dt.strftime("%Y-%m-%d")
161
162
163def ts_time(ts: int) -> str:
164    """将时间戳转换为时间 (格式: HH:MM:SS)."""
165    dt = datetime.fromtimestamp(ts, tz=ZoneInfo(TZ))
166    return dt.strftime("%H:%M:%S")
167
168
169def ts_with_url(ts: int, url: str) -> str:
170    """将时间戳添加到直播链接."""
171    dt = datetime.fromtimestamp(ts, tz=ZoneInfo(TZ))
172    start = dt.replace(hour=0, minute=0, second=0, microsecond=0)
173    seconds = int((dt - start).total_seconds())
174    return f"[{ts_time(ts)}]({url}&t={seconds})"
175
176
177@cache.memoize(ttl=300)
178async def r2_liveinfo() -> list[dict]:
179    """获取直播信息."""
180    return await get_cf_r2(DANMU.R2_PREFIX.rstrip("/") + "/liveinfo", silent=True)  # ty:ignore[invalid-return-type]
181
182
183@cache.memoize(ttl=300)
184async def r2_userinfo() -> pd.DataFrame:
185    """获取用户信息."""
186    return await get_r2_dataframe(DANMU.R2_PREFIX.rstrip("/") + "/userinfo")
187
188
189async def live_date_info(ts: int, qtype: str) -> tuple[str, str, str]:
190    """将时间戳转换为直播日期信息.
191
192    Returns:
193        tuple[str, str]: (直播日期, 直播链接)
194        Eg: ("2023-12-12", "[直播标题](https://...)")
195    """
196    day = ts_to_liveday(ts, qtype)  # YYYY-MM-DD
197
198    def beautify(title: str, url: str) -> str:
199        title = re.sub(r"[((]202\d{5}[))]", "", title).strip()  # delete date and bracket
200        return f"**[【{day}{title}]({url})**"
201
202    live_info = await r2_liveinfo()
203    title = glom(live_info, f"{day}.title", default="")
204    url = glom(live_info, f"{day}.url", default="")
205    return day, beautify(title, url), url
206
207
208@cache.memoize(ttl=60)
209async def get_uids_by_name(name: str, queried_names: set[str] | None = None) -> set[str]:
210    """Get uids by name."""
211    if name == "无名氏":
212        return set()
213    if re.match(r"^UC[\w-]{22}$", name):
214        return {name}
215    logger.debug(f"Querying name: {name}, queried_names: {queried_names}")
216    if queried_names is None:
217        queried_names = set()
218    queried_names.add(name.lower())
219    userinfo = await r2_userinfo()
220    uids = userinfo[userinfo["name"].str.lower() == name.lower()]["uid"].to_list()
221    logger.info(f"Found uids of {name}: {uids}")
222    # 递归查询
223    matched_uids = uids
224    for uid in map(str, uids):
225        logger.debug(f"Querying names of {uid}")
226        names = userinfo[userinfo["uid"].str.lower() == uid.lower()]["name"].to_list()
227        logger.info(f"Found names of {uid}: {names}")
228        for n in names:
229            if n.lower() in queried_names:
230                continue
231            logger.debug(f"Querying uid for name: {n}")
232            matched_uids.extend(await get_uids_by_name(n, queried_names))
233    return set(matched_uids)
234
235
236async def get_username_history(name: str) -> str:
237    """Get username history by name."""
238    texts = ""
239    history = pd.DataFrame()
240    uids = await get_uids_by_name(name)
241    userinfo = await r2_userinfo()
242    for uid in uids:
243        df = userinfo[userinfo["uid"].str.lower() == uid.lower()]
244        history = pd.concat([history, df], axis=0)
245    history = history.sort_values(by="first")
246
247    format_time = lambda ts: datetime.fromtimestamp(ts, tz=ZoneInfo(TZ)).strftime("%Y-%m-%d")
248
249    # 遍历每个组
250    for idx, (uid, group) in enumerate(history.groupby("uid", sort=False)):
251        if re.match(r"^UC[\w-]{22}$", str(uid)):
252            texts += f"\n{number_to_emoji(idx + 1)}**UID: [{uid}](https://www.youtube.com/channel/{uid})**\n"
253        else:
254            texts += f"\n{number_to_emoji(idx + 1)}**UID: {uid}**\n"
255        for _, row in group.iterrows():
256            texts += f"**{row['name']}**: {format_time(row['first'])}➡️{format_time(row['last'])}\n"
257    return texts