Commit 39a1095
Changed files (10)
src
messages
src/ai/images/models.py
@@ -130,7 +130,7 @@ async def get_image_model_configs(message: Message) -> list[dict]:
}
Suppose this message is:
- Message(text="/gen A cute cat") -> use `default` as model_alias
+ Message(text="/gen A cute cat") -> use `AI.IMG_GENERATION_DEFAULT_MODEL` as model_alias
Message(text="/gen @seedream hello") -> use `seedream` as model_alias
Returns: list of model config
@@ -156,7 +156,7 @@ async def get_image_model_configs(message: Message) -> list[dict]:
]
"""
texts = str(message.content).strip()
- if matched := re.match(rf"^{PREFIX.AI_IMG_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts): # match /ai @custom_model_id
+ if matched := re.match(rf"^{PREFIX.AI_IMG_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts): # match /gen @custom_model_id
model_alias = matched.group(1).strip()
return await get_config_by_model_alias(model_alias)
return await get_config_by_model_alias(AI.IMG_GENERATION_DEFAULT_MODEL)
src/ai/images/openai_img.py
@@ -32,7 +32,7 @@ async def openai_image_generation(
proxy: str | None = PROXY.OPENAI,
support_reference_images: bool = False,
**kwargs,
-) -> dict:
+) -> bool:
"""Get OpenAI Image Generation."""
status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
try:
@@ -44,7 +44,7 @@ async def openai_image_generation(
prompt, reference_images = await get_openai_image_contexts(client, message, support_reference_images=support_reference_images)
if not prompt:
await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
- return {}
+ return False
params = {}
if literal_eval(generate_config):
params |= literal_eval(generate_config)
@@ -54,7 +54,7 @@ async def openai_image_generation(
logger.debug(f"openai.images.generate(**{prettify(params)})")
except Exception as e:
logger.error(f"OpenAI client setup error: {e}")
- return {}
+ return False
resp = {}
for api_key in strings_list(api_keys, shuffle=True):
try:
@@ -71,11 +71,11 @@ async def openai_image_generation(
await modify_progress(status_msg, text=prettify(resp), force_update=True, **kwargs)
await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
await delete_message(status_msg)
- return {}
+ return True
except Exception as e:
logger.error(f"OpenAI Image Generation error: {e}\n\n{prettify(resp)}")
await modify_progress(status_msg, text=f"❌{e}\n\n{prettify(resp)}", force_update=True, **kwargs)
- return {}
+ return False
async def get_openai_image_contexts(client: Client, message: Message, *, support_reference_images: bool = False) -> tuple[str, list[str]]:
src/ai/images/post.py
@@ -34,14 +34,14 @@ async def http_post_image_generation(
proxy: str | None = PROXY.AI_POST,
support_reference_images: bool = False,
**kwargs,
-) -> dict:
+) -> bool:
"""Get HTTP Post Image Generation."""
status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_IMG_BOT}**{model_name}**:\n正在生成图像...", quote=True)
try:
prompt, reference_images = await get_image_contexts(client, message, support_reference_images=support_reference_images)
if not prompt:
await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
- return {}
+ return False
params = {}
api_paths = api_paths or {}
if headers:
@@ -58,7 +58,7 @@ async def http_post_image_generation(
resp = await hx_req(**params)
if error := resp.get("hx_error"):
await modify_progress(status_msg, text=f"❌**{model_name}**:\n{error}", force_update=True, **kwargs)
- return {}
+ return False
image_urls: list[str] = []
metadata = ""
@@ -75,10 +75,10 @@ async def http_post_image_generation(
await modify_progress(status_msg, text=caption, force_update=True, **kwargs)
await send2tg(client, message, texts=caption, media=[{"photo": img["path"]} for img in images], **kwargs)
await delete_message(status_msg)
- return {}
+ return True
except Exception as e:
logger.error(f"HTTP Post Image Generation error: {e}")
- return {}
+ return False
def extract_metadata(response: dict) -> str:
@@ -96,7 +96,7 @@ def extract_metadata(response: dict) -> str:
async def get_image_contexts(client: Client, message: Message, *, support_reference_images: bool = False) -> tuple[str, list[str]]:
- """Generate OpenAI image generation contexts.
+ """Get image generation contexts.
Returns:
tuple: prompt, list_of_images
@@ -135,7 +135,7 @@ async def download_generated_images(image_urls: list[str], proxy: str | None) ->
results = []
for url in image_urls:
if url.startswith("http"):
- img_path = await download_file(url, proxy=proxy)
+ img_path = await download_file(url, impersonate=None, proxy=proxy)
if Path(img_path).is_file():
results.append({"path": img_path, "url": url})
else: # base64 json
@@ -183,44 +183,25 @@ async def waiting_modelscope_task(task_id: str, params: dict) -> tuple[list[str]
Returns:
tuple: list of images, metadata
"""
- # get real base_url
+ headers = {k.lower(): v for k, v in params.get("headers", {}).items()} | {"x-modelscope-task-type": "image_generation"}
base_url = params["base_url"]
- headers = {k.lower(): v for k, v in params.get("headers", {}).items()}
if base_url.startswith("https://gateway.helicone.ai"):
- base_url = headers.get("helicone-target-url", "")
-
- api_key = headers.get("authorization", "").replace("Bearer ", "")
+ helicone_target_url = headers.get("helicone-target-url", "").rstrip("/")
+ base_url = base_url.replace("https://gateway.helicone.ai", helicone_target_url)
+ headers.pop("helicone-target-url", None)
+ headers.pop("helicone-auth", None)
task_url = base_url + glom(params, "api_paths.task_check", default="")
url = replace_placeholder(task_url, {"%TASK_ID%": task_id})
- resp = await hx_req(
- url,
- headers={
- "authorization": f"Bearer {api_key}",
- "content-type": "application/json",
- "x-modelscope-task-type": "image_generation",
- },
- check_keys=["task_status"],
- proxy=params.get("proxy"),
- )
+ resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["task_status"])
while True:
- if "hx_error" in resp or not glom(resp, "task_status", default="") or resp["task_status"].upper() in {"FAILED", "CANCELLED", "UNKNOWN"}:
+ if "hx_error" in resp or resp["task_status"].upper() in {"FAILED", "CANCELLED", "UNKNOWN"}:
logger.error(f"Image Generation Task {task_id} error: {resp}")
return [], ""
if resp["task_status"] == "SUCCEED":
return glom(resp, "output_images", default=[]), extract_metadata(resp)
await asyncio.sleep(5)
- resp = await hx_req(
- url,
- headers={
- "authorization": f"Bearer {api_key}",
- "content-type": "application/json",
- "x-modelscope-task-type": "image_generation",
- },
- check_keys=["task_status"],
- proxy=params.get("proxy"),
- )
- return [], extract_metadata(resp)
+ resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["task_status"])
def replace_placeholder(data: dict, pairs: dict[str, str]) -> dict:
src/ai/videos/models.py
@@ -0,0 +1,161 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import re
+
+from loguru import logger
+from pyrogram.types import Message
+
+from config import AI, PREFIX
+from database.kv import get_cf_kv
+
+
+async def get_video_model_configs(message: Message) -> list[dict]:
+ r"""Get model configs based on the message.
+
+ Model config is retrieved from CF-KV with key: {AI.IMG_MODEL_CONFIG_KEY}
+
+ A sample config:
+ {
+ "docs": "📽**AI视频**: `/gvid` + 提示词\n默认使用**Seedance-1.5-Pro**模型\n\n🔄使用以下命令强制切换模型:\n/sdt2v: Seedance 文生视频\n/sdf2v: Seedance 首帧生视频\n/sdfl2v: Seedance 首尾帧生视频",
+ "seedance": [
+ {
+ "model_name": "Seedance-1.5-Pro",
+ "api_type": "post",
+ "generation_type": "text_to_video",
+ "base_url": "https://ark.cn-beijing.volces.com",
+ "api_paths": {
+ "video_gen": "/api/v3/contents/generations/tasks",
+ "task_check": "/api/v3/contents/generations/tasks/%TASK_ID%"
+ },
+ "headers": {
+ "Authorization": "Bearer ARK_API_KEY",
+ "Content-Type": "application/json"
+ },
+ "body": {
+ "model": "doubao-seedance-1-5-pro-251215",
+ "content": [
+ {
+ "type": "text",
+ "text": "%PROMPT%"
+ }
+ ],
+ "generate_audio": true,
+ "resolution": "1080p",
+ "ratio": "adaptive",
+ "duration": -1,
+ "watermark": false
+ }
+ }
+ ],
+ "seedance-first-frame": [
+ {
+ "model_name": "Seedance-1.5-Pro",
+ "api_type": "post",
+ "generation_type": "first_frame",
+ "base_url": "https://ark.cn-beijing.volces.com",
+ "api_paths": {
+ "video_gen": "/api/v3/contents/generations/tasks",
+ "task_check": "/api/v3/contents/generations/tasks/%TASK_ID%"
+ },
+ "headers": {
+ "Authorization": "Bearer ARK_API_KEY",
+ "Content-Type": "application/json"
+ },
+ "body": {
+ "model": "doubao-seedance-1-5-pro-251215",
+ "content": [
+ {
+ "type": "text",
+ "text": "%PROMPT%"
+ },
+ {
+ "type": "image_url",
+ "role": "first_frame",
+ "image_url": {
+ "url": "%FIRST_FRAME%"
+ }
+ }
+ ],
+ "generate_audio": true,
+ "resolution": "720p",
+ "ratio": "adaptive",
+ "duration": -1,
+ "watermark": false
+ }
+ }
+ ],
+ "seedance-first-last-frame": [
+ {
+ "model_name": "Seedance-1.5-Pro",
+ "api_type": "post",
+ "generation_type": "first_last_frame",
+ "base_url": "https://ark.cn-beijing.volces.com",
+ "api_paths": {
+ "video_gen": "/api/v3/contents/generations/tasks",
+ "task_check": "/api/v3/contents/generations/tasks/%TASK_ID%"
+ },
+ "headers": {
+ "Authorization": "Bearer ARK_API_KEY",
+ "Content-Type": "application/json"
+ },
+ "body": {
+ "model": "doubao-seedance-1-5-pro-251215",
+ "content": [
+ {
+ "type": "text",
+ "text": "%PROMPT%"
+ },
+ {
+ "type": "image_url",
+ "role": "first_frame",
+ "image_url": {
+ "url": "%FIRST_FRAME%"
+ }
+ },
+ {
+ "type": "image_url",
+ "role": "last_frame",
+ "image_url": {
+ "url": "%LAST_FRAME%"
+ }
+ }
+ ],
+ "generate_audio": true,
+ "resolution": "720p",
+ "duration": -1,
+ "watermark": false
+ }
+ }
+ ]
+ }
+
+ Suppose this message is:
+ Message(text="/gvid prompt") -> use `AI.VIDEO_GENERATION_DEFAULT_MODEL` as model_alias
+ Message(text="/gvid @seedance-first-frame prompt") -> use `seedance-first-frame` as model_alias
+
+ Returns:
+ list of model config
+ """
+ texts = str(message.content).strip()
+ if matched := re.match(rf"^{PREFIX.AI_VIDEO_GENERATION}\s+@([a-zA-Z0-9_\-\.]+)(\s+)?", texts): # match /gvid @custom_model_id
+ model_alias = matched.group(1).strip()
+ return await get_config_by_model_alias(model_alias)
+ return await get_config_by_model_alias(AI.VIDEO_GENERATION_DEFAULT_MODEL)
+
+
+async def get_config_by_model_alias(model_alias: str) -> list[dict]:
+ """Get model config by model_alias.
+
+ Returns:
+ model_config
+ """
+ kv = await get_cf_kv(AI.VIDEO_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
+
+ custom_config = kv.get(model_alias, [])
+ if not custom_config:
+ logger.warning(f"Model `{model_alias}` is not configured in KV, using default config")
+ default_config = kv.get(AI.VIDEO_GENERATION_DEFAULT_MODEL, [])
+ if not default_config:
+ logger.warning(f"CF-KV key `{AI.VIDEO_MODEL_CONFIG_KEY}` does not has default `{AI.VIDEO_GENERATION_DEFAULT_MODEL}` field")
+ return default_config
+ return custom_config
src/ai/videos/post.py
@@ -0,0 +1,212 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import asyncio
+import base64
+import json
+from pathlib import Path
+from typing import Literal
+
+import anyio
+from glom import glom
+from loguru import logger
+from pyrogram.client import Client
+from pyrogram.types import Message
+
+from ai.texts.contexts import base64_media
+from ai.utils import EMOJI_VIDEO_BOT, clean_cmd_prefix, prettify
+from config import DOWNLOAD_DIR, PROXY
+from messages.progress import modify_progress
+from messages.sender import send2tg
+from messages.utils import delete_message
+from networking import download_file, hx_req
+from utils import rand_string
+
+
+async def http_post_video_generation(
+ client: Client,
+ message: Message,
+ *,
+ base_url: str = "",
+ model_name: str = "",
+ api_paths: dict | None = None,
+ headers: dict | None = None,
+ body: dict | None = None,
+ extra_params: dict | None = None,
+ proxy: str | None = PROXY.AI_POST,
+ generation_type: Literal["text_to_video", "first_frame", "first_last_frame"] = "text_to_video",
+ **kwargs,
+) -> bool:
+ """Get HTTP Post Video Generation."""
+ status_msg = kwargs.get("progress") or await message.reply(f"{EMOJI_VIDEO_BOT}**{model_name}**:\n正在生成视频...", quote=True)
+ try:
+ prompt, first_frame, last_frame = await get_video_contexts(client, message, generation_type=generation_type)
+ if not prompt:
+ await modify_progress(status_msg, text=f"❌**{model_name}**:\n请提供提示词", force_update=True, **kwargs)
+ return False
+ params = {}
+ api_paths = api_paths or {}
+ if headers:
+ params |= {"headers": headers}
+ if proxy:
+ params |= {"proxy": proxy}
+ params |= {"url": f"{base_url}{api_paths['video_gen']}", "method": "POST"}
+ if body:
+ params |= {"json_data": replace_placeholder(body, pairs={"%PROMPT%": prompt, "%FIRST_FRAME%": first_frame, "%LAST_FRAME%": last_frame})}
+ if extra_params:
+ params |= extra_params
+ logger.debug(f"hx_req(**{prettify(params)})")
+ resp = await hx_req(**params)
+ if error := resp.get("hx_error"):
+ await modify_progress(status_msg, text=f"❌**{model_name}**:\n{error}", force_update=True, **kwargs)
+ return False
+
+ video_url: str = ""
+ metadata = ""
+ if task_id := resp.get("id"): # Seedance
+ video_url, metadata = await waiting_seedance_task(task_id, params | {"base_url": base_url, "api_paths": api_paths})
+ if video_url:
+ video_path, video_url = await download_generated_video(video_url, proxy=proxy)
+ caption = f"{EMOJI_VIDEO_BOT}**{model_name}**\n{metadata}\n"
+ if video_url:
+ caption += f"[下载视频]({video_url})"
+ await modify_progress(status_msg, text=caption, force_update=True, **kwargs)
+ await send2tg(client, message, texts=caption, media=[{"video": video_path}], **kwargs)
+ await delete_message(status_msg)
+ return True
+ except Exception as e:
+ logger.error(f"HTTP Post Image Generation error: {e}")
+ return False
+
+
+def extract_metadata(response: dict) -> str:
+ """Extract some useful metadata from response.
+
+ These information will be sent to Telegram caption.
+ """
+ if glom(response, "content.video_url", default=""): # Seedance
+ keep_keys = ["model", "duration", "resolution", "ratio", "framespersecond", "seed", "usage"]
+ return prettify({k: response[k] for k in keep_keys})
+ return ""
+
+
+async def get_video_contexts(client: Client, message: Message, generation_type: Literal["text_to_video", "first_frame", "first_last_frame"]) -> tuple[str, str, str]:
+ """Get video generation contexts.
+
+ Returns:
+ tuple: prompt, first_frame, last_frame
+ """
+ messages = [message]
+ while message.reply_to_message:
+ message = message.reply_to_message
+ messages.append(message)
+ messages.reverse() # old to new
+ image_messages = [] # image message from old to new
+ prompt = ""
+ for m in messages:
+ group_messages = await client.get_media_group(m.chat.id, m.id) if m.media_group_id else [m]
+ for msg in group_messages:
+ prompt = clean_cmd_prefix(msg.content) or prompt
+ if not msg.photo:
+ continue
+ image_messages.append(msg)
+
+ if generation_type == "text_to_video":
+ return prompt, "", ""
+
+ if generation_type == "first_frame" and len(image_messages) >= 1:
+ first_frame = await base64_media(client, image_messages[-1])
+ first_frame = f"data:image/{first_frame['ext']};base64,{first_frame['base64']}"
+ return prompt, first_frame, ""
+
+ if generation_type == "first_last_frame" and len(image_messages) >= 2:
+ first_frame = await base64_media(client, image_messages[-2])
+ first_frame = f"data:image/{first_frame['ext']};base64,{first_frame['base64']}"
+ last_frame = await base64_media(client, image_messages[-1])
+ last_frame = f"data:image/{last_frame['ext']};base64,{last_frame['base64']}"
+ return prompt, first_frame, last_frame
+
+ return prompt, "", ""
+
+
+async def download_generated_video(url: str, proxy: str | None) -> tuple[str, str]:
+ """Download generated video.
+
+ Return:
+ video_path, url
+ """
+ if url.startswith("http"):
+ video_path = await download_file(url, impersonate=None, proxy=proxy)
+ if Path(video_path).is_file():
+ return video_path, url
+ else: # base64 json
+ video_bytes = base64.b64decode(url)
+ save_path = Path(DOWNLOAD_DIR) / f"{rand_string(10)}.mp4"
+ async with await anyio.open_file(save_path, "wb") as f:
+ await f.write(video_bytes)
+ return save_path.as_posix(), ""
+ return "", ""
+
+
+async def waiting_seedance_task(task_id: str, params: dict) -> tuple[str, str]:
+ """Waiting for async task to be SUCCEED.
+
+ Task Submmited Response:
+ {'id': 'cgt-20260118154732-5dvwx'}
+
+ Task Check Response:
+ {
+ "content": { "video_url": "https://..." },
+ "created_at": 1768722453,
+ "draft": false,
+ "duration": 4,
+ "execution_expires_after": 172800,
+ "framespersecond": 24,
+ "generate_audio": true,
+ "id": "cgt-20260118154732-5dvwx",
+ "model": "doubao-seedance-1-5-pro-251215",
+ "ratio": "16:9",
+ "resolution": "480p",
+ "seed": 24000,
+ "service_tier": "default",
+ "status": "succeeded",
+ "updated_at": 1768722486,
+ "usage": {"completion_tokens": 40594, "total_tokens": 40594}
+ }
+
+ Returns:
+ tuple: video_url, metadata
+ """
+ # get real base_url
+ base_url = params["base_url"]
+ headers = {k.lower(): v for k, v in params.get("headers", {}).items()}
+ if base_url.startswith("https://gateway.helicone.ai"):
+ helicone_target_url = headers.get("helicone-target-url", "").rstrip("/")
+ base_url = base_url.replace("https://gateway.helicone.ai", helicone_target_url)
+ headers.pop("helicone-target-url", None)
+ headers.pop("helicone-auth", None)
+ task_url = base_url + glom(params, "api_paths.task_check", default="")
+ url = replace_placeholder(task_url, {"%TASK_ID%": task_id})
+ resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["status"])
+ while True:
+ if "hx_error" in resp or resp.get("status", "").upper() in {"CANCELLED", "FAILED", "EXPIRED"}:
+ logger.error(f"Video Generation Task {task_id} error: {resp}")
+ return "", ""
+ if resp["status"].upper() == "SUCCEEDED":
+ return glom(resp, "content.video_url", default=""), extract_metadata(resp)
+
+ await asyncio.sleep(5)
+ resp = await hx_req(url, headers=headers, proxy=params.get("proxy"), check_keys=["status"])
+
+
+def replace_placeholder(data: dict, pairs: dict[str, str]) -> dict:
+ """Replace placeholder in data.
+
+ Args:
+ data: dict with placeholder
+ Returns:
+ dict with replaced placeholder
+ """
+ data_str = json.dumps(data, ensure_ascii=False)
+ for key, value in pairs.items():
+ data_str = data_str.replace(key, value)
+ return json.loads(data_str)
src/ai/main.py
@@ -13,7 +13,9 @@ from ai.texts.models import get_config_by_model_id, get_text_model_config
from ai.texts.openai_chat import openai_chat_completions
from ai.texts.openai_response import openai_responses_api
from ai.texts.tool_call import get_tool_call_results
-from ai.utils import img_generation_docs
+from ai.utils import img_generation_docs, video_generation_docs
+from ai.videos.models import get_video_model_configs
+from ai.videos.post import http_post_video_generation
from config import AI, PREFIX
from messages.sender import send2tg
from messages.utils import startswith_prefix
@@ -46,9 +48,9 @@ async def ai_text_generation(client: Client, message: Message, *, silent: bool =
return {}
-async def ai_image_generation(client: Client, message: Message, **kwargs) -> dict:
+async def ai_image_generation(client: Client, message: Message, **kwargs) -> None:
if not startswith_prefix(message.content, PREFIX.AI_IMG_GENERATION):
- return {}
+ return
texts = str(message.content).strip()
this_msg = message
prompt = texts.removeprefix(PREFIX.AI_IMG_GENERATION).strip()
@@ -56,14 +58,36 @@ async def ai_image_generation(client: Client, message: Message, **kwargs) -> dic
if not prompt:
if not message.reply_to_message:
await send2tg(client, message, texts=await img_generation_docs(), **kwargs)
- return {}
+ return
message = this_msg.reply_to_message
model_configs = await get_image_model_configs(this_msg)
if not model_configs:
- return {}
+ return
for model_config in model_configs:
- if model_config["api_type"] == "openai":
- return await openai_image_generation(client, message, **model_config)
- if model_config["api_type"] == "post":
- return await http_post_image_generation(client, message, **model_config)
- return {}
+ match model_config["api_type"]:
+ case "openai":
+ if await openai_image_generation(client, message, **model_config):
+ return
+ case "post":
+ if await http_post_image_generation(client, message, **model_config):
+ return
+
+
+async def ai_video_generation(client: Client, message: Message, **kwargs) -> None:
+ if not startswith_prefix(message.content, PREFIX.AI_VIDEO_GENERATION):
+ return
+ texts = str(message.content).strip()
+ this_msg = message
+ prompt = texts.removeprefix(PREFIX.AI_VIDEO_GENERATION).strip()
+ prompt = re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", prompt, flags=re.DOTALL).strip()
+ if not prompt:
+ if not message.reply_to_message:
+ await send2tg(client, message, texts=await video_generation_docs(), **kwargs)
+ return
+ message = this_msg.reply_to_message
+ model_configs = await get_video_model_configs(this_msg)
+ if not model_configs:
+ return
+ for model_config in model_configs:
+ if await http_post_video_generation(client, message, **model_config):
+ return
src/ai/utils.py
@@ -17,6 +17,7 @@ from utils import nowdt, remove_consecutive_newlines, remove_dash, remove_pound,
EMOJI_TEXT_BOT = "🤖"
EMOJI_IMG_BOT = "🌠"
+EMOJI_VIDEO_BOT = "📽"
EMOJI_REASONING_BEGIN = "🤔" # use emoji to separate model reasoning and content
EMOJI_REASONING_END = "💡"
BOT_TIPS = "(回复以继续)" # noqa: RUF001
@@ -24,12 +25,17 @@ BOT_TIPS = "(回复以继续)" # noqa: RUF001
async def text_generation_docs() -> str:
kv = await get_cf_kv(AI.TEXT_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
- return kv.get("docs", f"🤖**AI对话**: `{PREFIX.AI_TEXT_GENERATION}` + 提示词")
+ return kv.get("docs", f"{EMOJI_TEXT_BOT}**AI对话**: `{PREFIX.AI_TEXT_GENERATION}` + 提示词")
async def img_generation_docs() -> str:
kv = await get_cf_kv(AI.IMG_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
- return kv.get("docs", f"🌠AI生图: `{PREFIX.AI_IMG_GENERATION}` + 提示词")
+ return kv.get("docs", f"{EMOJI_IMG_BOT}**AI生图**: `{PREFIX.AI_IMG_GENERATION}` + 提示词")
+
+
+async def video_generation_docs() -> str:
+ kv = await get_cf_kv(AI.VIDEO_MODEL_CONFIG_KEY, cache_ttl=600, silent=True)
+ return kv.get("docs", f"{EMOJI_VIDEO_BOT}**AI视频**: `{PREFIX.AI_VIDEO_GENERATION}` + 提示词")
def literal_eval(string: str | dict) -> dict:
@@ -59,7 +65,7 @@ def prettify(data: dict) -> str:
def clean_cmd_prefix(text: str) -> str:
- for prefix in [PREFIX.AI_TEXT_GENERATION, PREFIX.AI_IMG_GENERATION]:
+ for prefix in [PREFIX.AI_TEXT_GENERATION, PREFIX.AI_IMG_GENERATION, PREFIX.AI_VIDEO_GENERATION]:
text = text.removeprefix(prefix).lstrip()
return re.sub(r"^@([a-zA-Z0-9_\-\.]+)(\s+)?", "", text, flags=re.DOTALL).strip()
src/messages/main.py
@@ -7,7 +7,7 @@ from pyrogram.client import Client
from pyrogram.types import Message
from ai.chat_summary import ai_chat_summary
-from ai.main import ai_image_generation, ai_text_generation
+from ai.main import ai_image_generation, ai_text_generation, ai_video_generation
from asr.voice_recognition import voice_to_text
from bridge.ocr import send_to_ocr_bridge
from config import FAVORITE, PREFIX, PROXY
@@ -101,6 +101,7 @@ async def process_message(
if ai:
await ai_text_generation(client, message, **kwargs) # /ai
await ai_image_generation(client, message, **kwargs) # /gen
+ await ai_video_generation(client, message, **kwargs) # /gvid
if asr:
await voice_to_text(client, message, **kwargs) # /asr
if audio_extract:
src/config.py
@@ -84,6 +84,7 @@ class PREFIX:
AI_SUMMARY = os.getenv("PREFIX_AI_SUMMARY", "/summary").lower()
AI_TEXT_GENERATION = os.getenv("PREFIX_AI_TEXT_GENERATION", "/ai").lower()
AI_IMG_GENERATION = os.getenv("PREFIX_AI_IMG_GENERATION", "/gen").lower()
+ AI_VIDEO_GENERATION = os.getenv("PREFIX_AI_VIDEO_GENERATION", "/gvid").lower()
ASR = os.getenv("PREFIX_ASR", "/asr").lower()
AUDIO = os.getenv("PREFIX_AUDIO", "/audio").lower()
CONVERT = os.getenv("PREFIX_CONVERT", "/convert").lower() # convert image file to photo
@@ -401,3 +402,7 @@ class AI:
# Image Generation
IMG_MODEL_CONFIG_KEY = os.getenv("AI_IMG_MODEL_CONFIG_KEY", "AI-IMG") # model configuration key in CF-KV
IMG_GENERATION_DEFAULT_MODEL = os.getenv("AI_IMG_GENERATION_DEFAULT_MODEL", "seedream")
+
+ # Video Generation
+ VIDEO_MODEL_CONFIG_KEY = os.getenv("AI_VIDEO_MODEL_CONFIG_KEY", "AI-VIDEO") # model configuration key in CF-KV
+ VIDEO_GENERATION_DEFAULT_MODEL = os.getenv("AI_VIDEO_GENERATION_DEFAULT_MODEL", "seedance")
src/networking.py
@@ -10,6 +10,7 @@ from typing import Any, Literal
from urllib.parse import parse_qs, urlparse
import anyio
+from curl_cffi.requests.impersonate import BrowserTypeLiteral
from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Request, RequestError, Response
from httpx._types import RequestContent, RequestData, RequestFiles # type: ignore
from httpx_curl_cffi import AsyncCurlTransport, CurlOpt
@@ -146,6 +147,7 @@ async def download_file(
skip_exist: bool = False,
proxy: str | None = None,
headers: dict | None = None,
+ impersonate: BrowserTypeLiteral | None = "safari_ios",
stream: bool = False,
**kwargs,
) -> str:
@@ -179,7 +181,7 @@ async def download_file(
logger.trace(f"Downloading {link} to {path} with proxy={proxy}")
hx = AsyncClient(
headers=headers,
- transport=AsyncCurlTransport(proxy=proxy, impersonate="safari_ios", default_headers=True, curl_options={CurlOpt.FRESH_CONNECT: True}),
+ transport=AsyncCurlTransport(proxy=proxy, impersonate=impersonate, default_headers=True, curl_options={CurlOpt.FRESH_CONNECT: True}) if isinstance(impersonate, str) else None,
proxy=proxy,
timeout=REQUEST_TIMEOUT,
follow_redirects=True,