main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import base64
  5from pathlib import Path
  6from typing import Literal
  7
  8import anyio
  9from glom import glom
 10from loguru import logger
 11from pyrogram.client import Client
 12from pyrogram.types import Message
 13
 14from ai.texts.contexts import base64_media
 15from ai.utils import EMOJI_VIDEO_BOT, clean_cmd_prefix, prettify, replace_placeholder
 16from config import DOWNLOAD_DIR, PROXY
 17from messages.progress import modify_progress
 18from messages.sender import send2tg
 19from messages.utils import delete_message
 20from networking import download_file, hx_req
 21from utils import rand_string
 22
 23
 24async def http_post_video_generation(
 25    client: Client,
 26    message: Message,
 27    *,
 28    base_url: str = "",
 29    model_name: str = "",
 30    api_paths: dict | None = None,
 31    headers: dict | None = None,
 32    body: dict | None = None,
 33    extra_params: dict | None = None,
 34    proxy: str | None = PROXY.AI_POST,
 35    generation_type: Literal["text_to_video", "first_frame", "first_last_frame"] = "text_to_video",
 36    **kwargs,
 37) -> bool:
 38    """Get HTTP Post Video Generation."""
 39    status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_VIDEO_BOT}**{model_name}**:\n正在生成视频...", quote=True)
 40    try:
 41        prompt, first_frame, last_frame = await get_video_contexts(client, message, generation_type=generation_type)
 42        if not prompt:
 43            await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
 44            return False
 45        params = {}
 46        api_paths = api_paths or {}
 47        if headers:
 48            params |= {"headers": headers}
 49        if proxy:
 50            params |= {"proxy": proxy}
 51        params |= {"url": f"{base_url}{api_paths['video_gen']}", "method": "POST"}
 52        if body:
 53            params |= {"json_data": replace_placeholder(body, pairs={"%PROMPT%": prompt, "%FIRST_FRAME%": first_frame, "%LAST_FRAME%": last_frame})}
 54        if extra_params:
 55            params |= extra_params
 56        logger.debug(f"hx_req(**{prettify(params)})")
 57        resp = await hx_req(**params)
 58        if error := resp.get("hx_error"):
 59            await modify_progress(status_msg, text=f"❌**{model_name}**:\n{error}", force_update=True, **kwargs)
 60            return False
 61
 62        video_url: str = ""
 63        metadata = ""
 64        if task_id := resp.get("id"):  # Seedance
 65            video_url, metadata = await waiting_seedance_task(task_id, params | {"base_url": base_url, "api_paths": api_paths})
 66        if video_url:
 67            video_path, video_url = await download_generated_video(video_url, proxy=proxy)
 68            caption = f"{EMOJI_VIDEO_BOT}**{model_name}**\n{metadata}\n"
 69            if video_url:
 70                caption += f"[下载视频]({video_url})"
 71            await modify_progress(status_msg, text=caption, force_update=True, **kwargs)
 72            await send2tg(client, message, texts=caption, media=[{"video": video_path}], **kwargs)
 73            await delete_message(status_msg)
 74            return True
 75    except Exception as e:
 76        logger.error(f"HTTP Post Image Generation error: {e}")
 77    return False
 78
 79
 80def extract_metadata(response: dict) -> str:
 81    """Extract some useful metadata from response.
 82
 83    These information will be sent to Telegram caption.
 84    """
 85    if glom(response, "content.video_url", default=""):  # Seedance
 86        keep_keys = ["model", "duration", "resolution", "ratio", "framespersecond", "seed", "usage"]
 87        return prettify({k: response[k] for k in keep_keys})
 88    return ""
 89
 90
 91async def get_video_contexts(client: Client, message: Message, generation_type: Literal["text_to_video", "first_frame", "first_last_frame"]) -> tuple[str, str, str]:
 92    """Get video generation contexts.
 93
 94    Returns:
 95        tuple: prompt, first_frame, last_frame
 96    """
 97    messages = [message]
 98    while message.reply_to_message:
 99        message = message.reply_to_message
100        messages.append(message)
101    messages.reverse()  # old to new
102    image_messages = []  # image message from old to new
103    prompt = ""
104    for m in messages:
105        group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
106        for msg in group_messages:
107            prompt = clean_cmd_prefix(msg.content) or prompt
108            if not msg.photo:
109                continue
110            image_messages.append(msg)
111
112    if generation_type == "text_to_video":
113        return prompt, "", ""
114
115    if generation_type == "first_frame" and len(image_messages) >= 1:
116        first_frame = await base64_media(client, image_messages[-1])
117        first_frame = f"data:image/{first_frame['ext']};base64,{first_frame['base64']}"
118        return prompt, first_frame, ""
119
120    if generation_type == "first_last_frame" and len(image_messages) >= 2:
121        first_frame = await base64_media(client, image_messages[-2])
122        first_frame = f"data:image/{first_frame['ext']};base64,{first_frame['base64']}"
123        last_frame = await base64_media(client, image_messages[-1])
124        last_frame = f"data:image/{last_frame['ext']};base64,{last_frame['base64']}"
125        return prompt, first_frame, last_frame
126
127    return prompt, "", ""
128
129
130async def download_generated_video(url: str, proxy: str | None) -> tuple[str, str]:
131    """Download generated video.
132
133    Return:
134        video_path, url
135    """
136    if url.startswith("http"):
137        video_path = await download_file(url, impersonate=None, proxy=proxy)
138        if Path(video_path).is_file():
139            return video_path, url
140    else:  # base64 json
141        video_bytes = base64.b64decode(url)
142        save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.mp4"
143        async with await anyio.open_file(save_path, "wb") as f:
144            await f.write(video_bytes)
145        return save_path.as_posix(), ""
146    return "", ""
147
148
149async def waiting_seedance_task(task_id: str, params: dict) -> tuple[str, str]:
150    """Waiting for async task to be SUCCEED.
151
152    Task Submmited Response:
153    {'id': 'cgt-20260118154732-5dvwx'}
154
155    Task Check Response:
156    {
157    "content": { "video_url": "https://..." },
158    "created_at": 1768722453,
159    "draft": false,
160    "duration": 4,
161    "execution_expires_after": 172800,
162    "framespersecond": 24,
163    "generate_audio": true,
164    "id": "cgt-20260118154732-5dvwx",
165    "model": "doubao-seedance-1-5-pro-251215",
166    "ratio": "16:9",
167    "resolution": "480p",
168    "seed": 24000,
169    "service_tier": "default",
170    "status": "succeeded",
171    "updated_at": 1768722486,
172    "usage": {"completion_tokens": 40594, "total_tokens": 40594}
173    }
174
175    Returns:
176        tuple: video_url, metadata
177    """
178    # get real base_url
179    base_url = params["base_url"]
180    headers = {k.lower(): v for k, v in params.get("headers", {}).items()}
181    if base_url.startswith("https://gateway.helicone.ai"):
182        helicone_target_url = headers.get("helicone-target-url", "").rstrip("/")
183        base_url = base_url.replace("https://gateway.helicone.ai", helicone_target_url)
184        headers.pop("helicone-target-url", None)
185        headers.pop("helicone-auth", None)
186    task_url = base_url + glom(params, "api_paths.task_check", default="")
187    url = replace_placeholder(task_url, {"%TASK_ID%": task_id})
188    resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["status"])
189    while True:
190        if "hx_error" in resp or resp.get("status", "").upper() in {"CANCELLED", "FAILED", "EXPIRED"}:
191            logger.error(f"Video Generation Task {task_id} error: {resp}")
192            return "", ""
193        if resp["status"].upper() == "SUCCEEDED":
194            return glom(resp, "content.video_url", default=""), extract_metadata(resp)
195
196        await asyncio.sleep(5)
197        resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["status"])