main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import base64
5from pathlib import Path
6from typing import Literal
7
8import anyio
9from glom import glom
10from loguru import logger
11from pyrogram.client import Client
12from pyrogram.types import Message
13
14from ai.texts.contexts import base64_media
15from ai.utils import EMOJI_VIDEO_BOT, clean_cmd_prefix, prettify, replace_placeholder
16from config import DOWNLOAD_DIR, PROXY
17from messages.progress import modify_progress
18from messages.sender import send2tg
19from messages.utils import delete_message
20from networking import download_file, hx_req
21from utils import rand_string
22
23
24async def http_post_video_generation(
25 client: Client,
26 message: Message,
27 *,
28 base_url: str = "",
29 model_name: str = "",
30 api_paths: dict | None = None,
31 headers: dict | None = None,
32 body: dict | None = None,
33 extra_params: dict | None = None,
34 proxy: str | None = PROXY.AI_POST,
35 generation_type: Literal["text_to_video", "first_frame", "first_last_frame"] = "text_to_video",
36 **kwargs,
37) -> bool:
38 """Get HTTP Post Video Generation."""
39 status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_VIDEO_BOT}**{model_name}**:\n正在生成视频...", quote=True)
40 try:
41 prompt, first_frame, last_frame = await get_video_contexts(client, message, generation_type=generation_type)
42 if not prompt:
43 await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
44 return False
45 params = {}
46 api_paths = api_paths or {}
47 if headers:
48 params |= {"headers": headers}
49 if proxy:
50 params |= {"proxy": proxy}
51 params |= {"url": f"{base_url}{api_paths['video_gen']}", "method": "POST"}
52 if body:
53 params |= {"json_data": replace_placeholder(body, pairs={"%PROMPT%": prompt, "%FIRST_FRAME%": first_frame, "%LAST_FRAME%": last_frame})}
54 if extra_params:
55 params |= extra_params
56 logger.debug(f"hx_req(**{prettify(params)})")
57 resp = await hx_req(**params)
58 if error := resp.get("hx_error"):
59 await modify_progress(status_msg, text=f"❌**{model_name}**:\n{error}", force_update=True, **kwargs)
60 return False
61
62 video_url: str = ""
63 metadata = ""
64 if task_id := resp.get("id"): # Seedance
65 video_url, metadata = await waiting_seedance_task(task_id, params | {"base_url": base_url, "api_paths": api_paths})
66 if video_url:
67 video_path, video_url = await download_generated_video(video_url, proxy=proxy)
68 caption = f"{EMOJI_VIDEO_BOT}**{model_name}**\n{metadata}\n"
69 if video_url:
70 caption += f"[下载视频]({video_url})"
71 await modify_progress(status_msg, text=caption, force_update=True, **kwargs)
72 await send2tg(client, message, texts=caption, media=[{"video": video_path}], **kwargs)
73 await delete_message(status_msg)
74 return True
75 except Exception as e:
76 logger.error(f"HTTP Post Image Generation error: {e}")
77 return False
78
79
80def extract_metadata(response: dict) -> str:
81 """Extract some useful metadata from response.
82
83 These information will be sent to Telegram caption.
84 """
85 if glom(response, "content.video_url", default=""): # Seedance
86 keep_keys = ["model", "duration", "resolution", "ratio", "framespersecond", "seed", "usage"]
87 return prettify({k: response[k] for k in keep_keys})
88 return ""
89
90
91async def get_video_contexts(client: Client, message: Message, generation_type: Literal["text_to_video", "first_frame", "first_last_frame"]) -> tuple[str, str, str]:
92 """Get video generation contexts.
93
94 Returns:
95 tuple: prompt, first_frame, last_frame
96 """
97 messages = [message]
98 while message.reply_to_message:
99 message = message.reply_to_message
100 messages.append(message)
101 messages.reverse() # old to new
102 image_messages = [] # image message from old to new
103 prompt = ""
104 for m in messages:
105 group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
106 for msg in group_messages:
107 prompt = clean_cmd_prefix(msg.content) or prompt
108 if not msg.photo:
109 continue
110 image_messages.append(msg)
111
112 if generation_type == "text_to_video":
113 return prompt, "", ""
114
115 if generation_type == "first_frame" and len(image_messages) >= 1:
116 first_frame = await base64_media(client, image_messages[-1])
117 first_frame = f"data:image/{first_frame['ext']};base64,{first_frame['base64']}"
118 return prompt, first_frame, ""
119
120 if generation_type == "first_last_frame" and len(image_messages) >= 2:
121 first_frame = await base64_media(client, image_messages[-2])
122 first_frame = f"data:image/{first_frame['ext']};base64,{first_frame['base64']}"
123 last_frame = await base64_media(client, image_messages[-1])
124 last_frame = f"data:image/{last_frame['ext']};base64,{last_frame['base64']}"
125 return prompt, first_frame, last_frame
126
127 return prompt, "", ""
128
129
130async def download_generated_video(url: str, proxy: str | None) -> tuple[str, str]:
131 """Download generated video.
132
133 Return:
134 video_path, url
135 """
136 if url.startswith("http"):
137 video_path = await download_file(url, impersonate=None, proxy=proxy)
138 if Path(video_path).is_file():
139 return video_path, url
140 else: # base64 json
141 video_bytes = base64.b64decode(url)
142 save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.mp4"
143 async with await anyio.open_file(save_path, "wb") as f:
144 await f.write(video_bytes)
145 return save_path.as_posix(), ""
146 return "", ""
147
148
149async def waiting_seedance_task(task_id: str, params: dict) -> tuple[str, str]:
150 """Waiting for async task to be SUCCEED.
151
152 Task Submmited Response:
153 {'id': 'cgt-20260118154732-5dvwx'}
154
155 Task Check Response:
156 {
157 "content": { "video_url": "https://..." },
158 "created_at": 1768722453,
159 "draft": false,
160 "duration": 4,
161 "execution_expires_after": 172800,
162 "framespersecond": 24,
163 "generate_audio": true,
164 "id": "cgt-20260118154732-5dvwx",
165 "model": "doubao-seedance-1-5-pro-251215",
166 "ratio": "16:9",
167 "resolution": "480p",
168 "seed": 24000,
169 "service_tier": "default",
170 "status": "succeeded",
171 "updated_at": 1768722486,
172 "usage": {"completion_tokens": 40594, "total_tokens": 40594}
173 }
174
175 Returns:
176 tuple: video_url, metadata
177 """
178 # get real base_url
179 base_url = params["base_url"]
180 headers = {k.lower(): v for k, v in params.get("headers", {}).items()}
181 if base_url.startswith("https://gateway.helicone.ai"):
182 helicone_target_url = headers.get("helicone-target-url", "").rstrip("/")
183 base_url = base_url.replace("https://gateway.helicone.ai", helicone_target_url)
184 headers.pop("helicone-target-url", None)
185 headers.pop("helicone-auth", None)
186 task_url = base_url + glom(params, "api_paths.task_check", default="")
187 url = replace_placeholder(task_url, {"%TASK_ID%": task_id})
188 resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["status"])
189 while True:
190 if "hx_error" in resp or resp.get("status", "").upper() in {"CANCELLED", "FAILED", "EXPIRED"}:
191 logger.error(f"Video Generation Task {task_id} error: {resp}")
192 return "", ""
193 if resp["status"].upper() == "SUCCEEDED":
194 return glom(resp, "content.video_url", default=""), extract_metadata(resp)
195
196 await asyncio.sleep(5)
197 resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["status"])