main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4from collections import defaultdict
  5
  6from glom import Coalesce, glom
  7from httpx import AsyncHTTPTransport
  8from loguru import logger
  9
 10from config import TEXT_LENGTH
 11from emby.account import all_accounts, emby_login, get_account
 12from emby.api import build_params, default_headers
 13from emby.constant import EMBY_PROXY, ITEM_TYPE
 14from messages.utils import count_without_entities
 15from networking import hx_req
 16from publish import publish_telegraph
 17
 18
 19async def emby_search(query: str, *, allow_nsfw: bool = False) -> tuple[str, str, str]:
 20    """Search for movies or series from emby.
 21
 22    Returns:
 23        tuple[str, str]: summary_texts, full_texts, telegraph_url of full_texts
 24
 25    Examples:
 26    Summary:
 27        @emby_1_bot
 28        🎬(2007)Movie-1
 29        🎬(2020)Movie-2
 30
 31        @emby_2_bot
 32        🎬(2007)Movie-1
 33        🎬(2020)Movie-2
 34    """
 35    accounts = await all_accounts()
 36    tasks = [search_single(query, acc_name, allow_nsfw=allow_nsfw) for acc_name in accounts]
 37    results = await asyncio.gather(*tasks)
 38    max_length = TEXT_LENGTH - 50
 39    cur_len = sum([len(name) for name in accounts])
 40    max_num_item = max(len(item) for item in results)
 41    summary = defaultdict(list)  # {"bot-1": ["item-1", "item-2"], "bot-2": ["item-3", "item-4"]}
 42    full = dict(zip(accounts, results, strict=True))
 43    trimmed = False
 44    for item_idx in range(max_num_item):
 45        for acc_name, items in zip(accounts, results, strict=True):
 46            if item := glom(items, f"{item_idx}", default=""):
 47                # check if we can add this item to summary
 48                this_len = await count_without_entities(item)
 49                if cur_len + this_len >= max_length:
 50                    trimmed = True
 51                    break
 52                summary[acc_name].append(item)
 53    summary_texts = ""
 54    full_texts = ""
 55    failed_accounts = []
 56    for acc_name, acc_info in accounts.items():
 57        if items := summary.get(acc_name):
 58            summary_texts += f"🤖[{acc_name}](t.me/{acc_info['bot']})\n"
 59            summary_texts += "".join(items)
 60        else:
 61            failed_accounts.append(f"❌[{acc_name}](t.me/{acc_info['bot']})\n")
 62        if items := full.get(acc_name):
 63            full_texts += f"🤖[{acc_name}](t.me/{acc_info['bot']})\n"
 64            full_texts += "".join(items)
 65    for failed in failed_accounts:
 66        summary_texts += failed
 67    if trimmed:
 68        html = "\n".join([f"<p>{s}</p>" for s in full_texts.split("\n")])
 69        telegraph_url = await publish_telegraph(title=query, html=html)
 70    else:
 71        telegraph_url = ""
 72    return summary_texts.strip(), full_texts.strip(), telegraph_url
 73
 74
 75async def search_single(query: str, account_name: str, *, allow_nsfw: bool = False, refresh: bool = False, retry: int = 0) -> list[str]:
 76    """Search from a single emby server.
 77
 78    https://dev.emby.media/reference/RestAPI/ItemsService/getUsersByUseridItems.html
 79    """
 80    if retry > 2:
 81        return []
 82    account = await get_account(account_name)
 83    if not account:
 84        return []
 85    if not allow_nsfw and account.get("nsfw"):
 86        return []
 87    credentials = await emby_login(account_name, refresh=refresh)
 88    if not credentials:
 89        return []
 90    params = {
 91        "UserId": credentials["User"]["Id"],
 92        "SearchTerm": query,
 93        "Recursive": "true",
 94        "SortBy": "SortName",
 95        "SortOrder": "Ascending",
 96        "GroupProgramsBySeries": "true",
 97        "IncludeItemTypes": "Movie,Series",
 98        "Fields": "PrimaryImageAspectRatio,ProductionYear,Status,EndDate,CommunityRating,RecursiveItemCount,ProviderIds,MediaSources,AlternateMediaSources,PremiereDate",
 99        "EnableImageTypes": "Primary,Backdrop,Thumb,Logo",
100        "ImageTypeLimit": "1",
101        "Limit": "50",
102        "StartIndex": "0",
103    }
104    logger.trace(f"Searching `{query}` on Emby Server: {account_name}")
105    resp = await hx_req(
106        f"{account['server']}/emby/items",
107        headers=default_headers(credentials),
108        params=build_params(credentials, params),
109        transport=AsyncHTTPTransport(),
110        proxy=EMBY_PROXY,
111        verify=False,
112        timeout=2,
113        check_keys=["Items"],
114        silent=True,
115        max_retry=1,
116    )
117    if error := resp.get("hx_error"):
118        logger.error(error)
119        if "401 Unauthorized" in error:
120            return await search_single(query, account_name, allow_nsfw=allow_nsfw, refresh=True, retry=retry + 1)
121        return []
122    if not resp["Items"]:
123        return []
124    logger.success(f"Found {len(resp['Items'])} results from {account_name}")
125    retrevaled = []
126    for item in resp["Items"]:
127        itype = ITEM_TYPE.get(item["Type"], "📺")
128        date = str(glom(item, Coalesce("PremiereDate", "ProductionYear"), default="0"))[:4]
129        text = f"{itype}({date}){item['Name']}\n" if date != "0" else f"{itype}{item['Name']}\n"
130        if text not in retrevaled:
131            retrevaled.append(text)
132    return retrevaled