Commit 1908bea
Changed files (10)
src/asr/ali_asr.py
@@ -0,0 +1,182 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import asyncio
+import io
+import random
+import re
+from pathlib import Path
+
+import anyio
+from dashscope.audio.asr import Recognition, RecognitionCallback
+from glom import flatten, glom
+from httpx import AsyncHTTPTransport
+from loguru import logger
+
+from config import ASR, DB, FILE_SERVER
+from database import delete_alist, upload_alist
+from multimedia import convert_to_audio
+from networking import hx_req
+
+
+async def ali_asr(path: str | Path) -> str:
+ """Create Aliyun ASR Task.
+
+ 录音文件识别请求
+
+ Paraformer:
+ https://help.aliyun.com/zh/model-studio/paraformer-recorded-speech-recognition-restful-api
+
+ SenseVoice:
+ https://help.aliyun.com/zh/model-studio/developer-reference/sensevoice-recorded-speech-recognition-restful-applicant
+ """
+ api_keys = [x.strip() for x in ASR.ALI_API_KEY.split(",") if x.strip()]
+ if not api_keys:
+ return "请配置阿里云语音识别的API Key"
+ models = [x.strip() for x in ASR.ALI_MODEL.split(",") if x.strip()]
+ model = random.choice(models)
+ logger.debug(f"阿里云ASR {path} via model: {model}")
+ api_key = random.choice(api_keys)
+ if model.startswith("paraformer-realtime-"):
+ return ali_realtime_asr(model, path, api_key)
+
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable"}
+ path = Path(path).expanduser().resolve()
+ if ASR.ALI_FS_ENGINE.lower() == "local":
+ url = FILE_SERVER.removesuffix("/") + "/" + path.name
+ else:
+ url = await upload_alist(path)
+
+ payload = {"model": model, "input": {"file_urls": [url]}}
+ res = await hx_req(
+ "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
+ method="POST",
+ headers=headers,
+ post_json=payload,
+ timeout=600,
+ check_keys=["output.task_id"],
+ )
+ logger.success(f"ASR任务提交成功, TaskID: {res['output']['task_id']}")
+ return await query_ali_asr(task_id=res["output"]["task_id"], api_key=api_key)
+
+
+async def query_ali_asr(task_id: str, api_key: str) -> str:
+ """Query Ali ASR Task.
+
+ 录音文件识别结果查询
+ Paraformer:
+ https://help.aliyun.com/zh/model-studio/paraformer-recorded-speech-recognition-restful-api
+
+ SenseVoice:
+ https://help.aliyun.com/zh/model-studio/developer-reference/sensevoice-recorded-speech-recognition-restful-applicant
+ """
+ payload = {"task_id": task_id}
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "X-DashScope-Async": "enable"}
+ result = await hx_req(
+ f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}",
+ method="POST",
+ headers=headers,
+ post_json=payload,
+ check_keys=["output.task_status"],
+ )
+ status = glom(result, "output.task_status")
+ query_times = 0
+ while status in ["RUNNING", "PENDING"] and query_times < 600: # max 10 minutes
+ await asyncio.sleep(1)
+ query_times += 1
+ logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
+ result = await query_ali_asr(task_id, api_key)
+ if isinstance(result, str):
+ return result
+ status = glom(result, "output.task_status")
+ await clean_alist(glom(result, "output.results.0.file_url", default=""))
+ if status == "SUCCEEDED":
+ transcription_url = glom(result, "output.results.0.transcription_url")
+ trans_res = await hx_req(transcription_url, transport=AsyncHTTPTransport(), check_keys=["transcripts.0.sentences.0.text"]) # DO NOT use AsyncCurlTransport
+ sentence_start_ms = glom(trans_res, "transcripts.0.sentences.*.begin_time")
+ sentences = glom(trans_res, "transcripts.0.sentences.*.text")
+ return generate_ali_transcription(sentence_start_ms, sentences)
+ return "❌" + glom(result, "output.message", default="语音识别失败")
+
+
+def ali_realtime_asr(model: str, path: str | Path, api_key: str) -> str:
+ # convert audio file
+ sample_rate = 8000 if "8k" in model else 16000
+ path = convert_to_audio(path, ext="opus", codec="libopus", ac=1, ar=sample_rate)
+ ext = "opus"
+ recognition = Recognition(model=model, format=ext, sample_rate=sample_rate, callback=RecognitionCallback(), api_key=api_key)
+ result = recognition.call(Path(path).as_posix())
+ if result.status_code != 200:
+ return f"❌语音识别失败: {result.message}"
+ Path(path).unlink(missing_ok=True)
+ data = result.get_sentence()
+ start_times = flatten(glom(data, "*.words.*.begin_time"))
+ texts = flatten(glom(data, "*.words.*.text"))
+ punctuations = flatten(glom(data, "*.words.*.punctuation"))
+ sentences = [f"{text}{punc}" for text, punc in zip(texts, punctuations, strict=True)]
+ return generate_ali_transcription(start_times, sentences)
+
+
+def generate_ali_transcription(sentence_start_ms: list[int], sentences: list[str]) -> str:
+ def clean_tags(text: str) -> str:
+ """Clean sensevoice tags.
+
+ Remove <|sense-1|>, <|sense-2|>, ..., etc.
+ """
+ if not text:
+ return text
+ return re.sub(r"<\|.*?\|>", "", text)
+
+ res = ""
+ indexs = list(range(len(sentences)))
+ for idx, start_ms, sentence in zip(indexs, sentence_start_ms, sentences, strict=True):
+ text = clean_tags(sentence)
+ if not text:
+ continue
+ if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
+ start_seconds = float(start_ms) // 1000
+ minutes = int(start_seconds // 60)
+ seconds = int(start_seconds % 60)
+ res += f"\n[{minutes:02d}:{seconds:02d}] {text}"
+ else:
+ res += text
+ return res.strip()
+
+
+async def clean_alist(url: str):
+ """Clean alist file after ASR is finished."""
+ if not url:
+ return
+ prefix = DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + "/"
+ if url.startswith(prefix):
+ fname = url.removeprefix(prefix)
+ await delete_alist(fname)
+
+
+async def upload_ali_oss(path: str | Path, api_key: str, model_name: str):
+ """Get OSS url of Aliyun.
+
+ https://help.aliyun.com/zh/model-studio/get-temporary-file-url
+ """
+ url = "https://dashscope.aliyuncs.com/api/v1/uploads"
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
+ params = {"action": "getPolicy", "model": model_name}
+
+ response = await hx_req(url, headers=headers, params=params, check_keys=["data.upload_host"])
+ policy_data = response["data"]
+ path = Path(path)
+ key = f"{policy_data['upload_dir']}/{path.name}"
+ async with await anyio.open_file(path, "rb") as f:
+ content = await f.read()
+ files = {
+ "OSSAccessKeyId": (None, policy_data["oss_access_key_id"]),
+ "Signature": (None, policy_data["signature"]),
+ "policy": (None, policy_data["policy"]),
+ "x-oss-object-acl": (None, policy_data["x_oss_object_acl"]),
+ "x-oss-forbid-overwrite": (None, policy_data["x_oss_forbid_overwrite"]),
+ "key": (None, key),
+ "success_action_status": (None, "200"),
+ "file": (path.name, io.BytesIO(content)),
+ }
+ response = await hx_req(policy_data["upload_host"], method="POST", files=files, rformat="text")
+ # return f"oss://{key}"
+ return f"{policy_data['upload_host']}/{key}"
src/asr/tecent_asr.py
@@ -68,7 +68,7 @@ def generate_tencent_cloud_headers(
}
-async def single_sentence_asr(path: str | Path, engine: str, voice_format: str) -> str:
+async def tencent_single_asr(path: str | Path, engine: str, voice_format: str) -> str:
"""Tencent Single Sentence ASR.
一句话识别
@@ -91,7 +91,7 @@ async def single_sentence_asr(path: str | Path, engine: str, voice_format: str)
return generate_tencent_transcription(sentence_start_ms=[0], words=[res["Response"]["WordList"]])
-async def flash_asr(path: str | Path, engine: str, voice_format: str) -> str:
+async def tencent_flash_asr(path: str | Path, engine: str, voice_format: str) -> str:
"""Tencent Flash ASR.
录音文件识别极速版
@@ -130,7 +130,7 @@ async def flash_asr(path: str | Path, engine: str, voice_format: str) -> str:
return generate_tencent_transcription(sentence_start_ms, words)
-async def create_async_asr(url: str, engine: str) -> dict:
+async def tencent_create_asr(url: str, engine: str) -> dict:
"""Create Tencent ASR Task.
录音文件识别请求
@@ -149,7 +149,7 @@ async def create_async_asr(url: str, engine: str) -> dict:
)
-async def query_async_asr(task_id: int) -> str:
+async def tencent_query_asr(task_id: int) -> str:
"""Query Tencent ASR Task.
录音文件识别结果查询
@@ -167,10 +167,12 @@ async def query_async_asr(task_id: int) -> str:
check_keys=["Response.Data.StatusStr"],
)
status = glom(result, "Response.Data.StatusStr")
- while status in ["waiting", "doing"]:
+ query_times = 0
+ while status in ["waiting", "doing"] and query_times < 600: # max 10 minutes
await asyncio.sleep(1)
+ query_times += 1
logger.trace(f"Status:[{status}], Wating TaskID: {task_id}")
- result = await query_async_asr(task_id)
+ result = await tencent_query_asr(task_id)
if isinstance(result, str):
return result
status = glom(result, "Response.Data.StatusStr")
@@ -178,25 +180,21 @@ async def query_async_asr(task_id: int) -> str:
sentence_start_ms = glom(result, "Response.Data.ResultDetail.*.StartMs")
words = glom(result, "Response.Data.ResultDetail.*.Words")
return generate_tencent_transcription(sentence_start_ms, words)
- return glom(result, "Response.Data.ErrorMsg")
+ return "❌" + glom(result, "Response.Data.ErrorMsg", default="语音识别失败")
def generate_tencent_transcription(sentence_start_ms: list[int], words: list[list[dict]]) -> str:
res = ""
- show_timestamp = False
for start_offset, items in zip(sentence_start_ms, words, strict=True):
for idx, item in enumerate(items):
sentence = glom(item, Coalesce("Word", "word"), default="")
if not sentence:
continue
- if idx == 0 or res.endswith((".", "。")):
- show_timestamp = True
- if show_timestamp:
+ if idx == 0 or res.endswith((".", "。", "?", "?")): # noqa: RUF001
start_seconds = float(glom(item, Coalesce("StartTime", "OffsetStartMs", "start_time"), default=0) + float(start_offset)) // 1000
minutes = int(start_seconds // 60)
seconds = int(start_seconds % 60)
res += f"\n[{minutes:02d}:{seconds:02d}] {sentence}"
- show_timestamp = False
else:
res += sentence
return res.strip()
src/asr/utils.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
-
+import random
from config import ASR, FILE_SERVER
@@ -8,13 +8,22 @@ from config import ASR, FILE_SERVER
def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> tuple[str, list[str]]:
"""Get ASR method and supported file types."""
if duration < 60:
- asr_engine = ASR.SHORT_ENGINE
+ asr_engine = random.choice([x.strip() for x in ASR.SHORT_ENGINE.split(",") if x.strip()])
elif 60 <= duration <= 300:
- asr_engine = ASR.MIDDLE_ENGINE
+ asr_engine = random.choice([x.strip() for x in ASR.MIDDLE_ENGINE.split(",") if x.strip()])
else:
- asr_engine = ASR.LONG_ENGINE
- if force_engine:
- return get_tencent_asr_method(duration, file_size) if force_engine == "tencent" else get_gemini_asr_method(duration)
+ asr_engine = random.choice([x.strip() for x in ASR.LONG_ENGINE.split(",") if x.strip()])
+
+ # respect force_engine
+ if force_engine == "ali":
+ return get_ali_asr_method(file_size)
+ if force_engine == "tencent":
+ return get_tencent_asr_method(duration, file_size)
+ if force_engine == "gemini":
+ return get_gemini_asr_method(duration)
+
+ if asr_engine == "ali":
+ return get_ali_asr_method(file_size)
if asr_engine == "tencent":
return get_tencent_asr_method(duration, file_size)
if asr_engine.lower() == "gemini":
@@ -22,6 +31,19 @@ def get_asr_method(duration: float, file_size: int, force_engine: str = "") -> t
return f"ASR Engine: {asr_engine} is not support for duration: {duration}, filesize: {file_size}", []
+def get_ali_asr_method(file_size: int) -> tuple[str, list[str]]:
+ if not all([ASR.ALI_MODEL, ASR.ALI_API_KEY]):
+ return "请设置阿里云ASR相关环境变量", []
+
+ asr_method = ""
+ if FILE_SERVER and file_size < 2 * 1024 * 1024 * 1024: # 2GB
+ asr_method = "ali"
+ supported_ext = ["aac", "amr", "avi", "flac", "flv", "m4a", "mkv", "mov", "mp3", "mp4", "mpeg", "ogg-opus", "ogg", "opus", "wav", "webm", "wma", "wmv"]
+ else:
+ return "请联系管理员配置`FILE_SERVER`变量", []
+ return asr_method, supported_ext
+
+
def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[str]]:
if duration > ASR.TENCENT_MAX_DURATION:
return f"无法识别时长超过{ASR.TENCENT_MAX_DURATION}秒的音频, 当前音频时长: {duration}秒", []
@@ -30,13 +52,13 @@ def get_tencent_asr_method(duration: float, file_size: int) -> tuple[str, list[s
asr_method = ""
if duration < 60 and file_size < 3 * 1024 * 1024:
- asr_method = "single_sentence_asr" # 一句话识别
+ asr_method = "tencent_single_asr" # 一句话识别
supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
elif 60 <= duration <= 300 and file_size < 100 * 1024 * 1024:
- asr_method = "flash_asr" # 录音文件识别极速版
+ asr_method = "tencent_flash_asr" # 录音文件识别极速版
supported_ext = ["aac", "amr", "m4a", "mp3", "oga", "ogg-opus", "ogg", "opus", "pcm", "silk", "speex", "wav"]
elif FILE_SERVER:
- asr_method = "async_asr" # 录音文件识别 (异步请求)
+ asr_method = "tencent_async_asr" # 录音文件识别 (异步请求)
supported_ext = ["3gp", "aac", "amr", "flac", "flv", "m4a", "mp3", "mp4", "oga", "ogg-opus", "ogg", "opus", "wav", "wma"]
elif not FILE_SERVER:
return "音频过长, 需使用音频URL格式调用ASR\n请联系管理员配置`FILE_SERVER`变量", []
src/asr/voice_recognition.py
@@ -9,8 +9,9 @@ from loguru import logger
from pyrogram.client import Client
from pyrogram.types import Message
+from asr.ali_asr import ali_asr
from asr.gemini_asr import gemini_stream_asr
-from asr.tecent_asr import create_async_asr, flash_asr, query_async_asr, single_sentence_asr
+from asr.tecent_asr import tencent_create_asr, tencent_flash_asr, tencent_query_asr, tencent_single_asr
from asr.utils import get_asr_method
from config import CAPTION_LENGTH, FILE_SERVER, PREFIX, TEXT_LENGTH
from messages.parser import parse_msg
@@ -110,7 +111,7 @@ async def voice_to_text(
custom_code = custom_code.replace("fy", "zh_dialect")
if f"16k_{custom_code}" in ENGINE_MAP:
asr_language = f"16k_{custom_code}"
- elif custom_code in ["gemini", "tencent"]:
+ elif custom_code in ["gemini", "tencent", "ali"]:
force_engine = custom_code
msg = f"[ASR] 收到消息: {trigger_info['mtype']}, 开始识别..."
@@ -133,7 +134,6 @@ async def voice_to_text(
return
if texts := res.get("texts"):
final = blockquote(texts) if len(texts) > 300 else texts
- logger.success(f"{final!r}")
# send results
target_chat = kwargs["target_chat"] if kwargs.get("target_chat") else this_info["cid"]
reply_parameters = get_reply_to(trigger_info["mid"], kwargs.get("reply_msg_id", 0))
@@ -174,7 +174,7 @@ async def asr_file(
if duration == 0:
duration = info["duration"]
asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
- if asr_method not in ["single_sentence_asr", "flash_asr", "async_asr", "gemini"]:
+ if asr_method not in ["ali", "tencent_single_asr", "tencent_flash_asr", "tencent_async_asr", "gemini"]:
return {"error": asr_method}
voice_format = path.suffix.lstrip(".")
@@ -188,23 +188,25 @@ async def asr_file(
asr_method, supported_ext = get_asr_method(duration, file_size=Path(path).stat().st_size, force_engine=engine)
ogg_names = ["oga", "ogg-opus", "ogg", "opus"] # unify format name
- if asr_method in ["single_sentence_asr", "flash_asr", "async_asr"] and voice_format in ogg_names:
+ if asr_method.startswith("tencent") and voice_format in ogg_names:
voice_format = "ogg-opus"
path = path.rename(path.with_stem(rand_string())) # sanitize filename. (for Tencent Signature v3)
if asr_method == "gemini" and voice_format in ogg_names:
voice_format = "ogg"
- logger.debug(f"Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
+ logger.debug(f"[{asr_method}] Recognizing {voice_format} audio [{duration}s] by {language}: {path.as_posix()}")
try:
- if asr_method == "single_sentence_asr":
- texts = await single_sentence_asr(path, language, voice_format)
- elif asr_method == "flash_asr":
- texts = await flash_asr(path, language, voice_format)
- elif asr_method == "async_asr":
- resp = await create_async_asr(f"{FILE_SERVER}/{path.name}", language)
+ if asr_method == "tencent_single_asr":
+ texts = await tencent_single_asr(path, language, voice_format)
+ elif asr_method == "tencent_flash_asr":
+ texts = await tencent_flash_asr(path, language, voice_format)
+ elif asr_method == "tencent_async_asr":
+ resp = await tencent_create_asr(f"{FILE_SERVER}/{path.name}", language)
task_id = resp["Response"]["Data"]["TaskId"]
logger.success(f"ASR任务提交成功, TaskID: {task_id}")
- texts = await query_async_asr(task_id)
+ texts = await tencent_query_asr(task_id)
+ elif asr_method == "ali":
+ texts = await ali_asr(path)
elif asr_method == "gemini":
return await gemini_stream_asr(path=path, voice_format=voice_format, **kwargs)
res["texts"] = texts
src/config.py
@@ -240,11 +240,17 @@ class DB:
CF_R2_BUCKET_NAME = os.getenv("CF_R2_BUCKET_NAME", "bennybot")
CF_R2_ACCESS_KEY_ID = os.getenv("CF_R2_ACCESS_KEY_ID", "")
CF_R2_SECRET_ACCESS_KEY = os.getenv("CF_R2_SECRET_ACCESS_KEY", "")
+ ALIST_ENABLED = os.getenv("ALIST_ENABLED", "1").lower() in ["1", "y", "yes", "t", "true", "on"]
+ ALIST_USERNAME = os.getenv("ALIST_USERNAME", "guest")
+ ALIST_PASSWORD = os.getenv("ALIST_PASSWORD", "guest")
+ ALIST_SERVER = os.getenv("ALIST_SERVER", "")
+ ALIST_BASR_PATH = os.getenv("ALIST_BASR_PATH", "")
class ASR:
- SHORT_ENGINE = os.getenv("ASR_SHORT_ENGINE", "tencent") # duration < 60s
- MIDDLE_ENGINE = os.getenv("ASR_MIDDLE_ENGINE", "tencent") # 60s <= duration <= 300s
+ # support ali, tencent, gemini engines
+ SHORT_ENGINE = os.getenv("ASR_SHORT_ENGINE", "tencent,ali") # duration < 60s
+ MIDDLE_ENGINE = os.getenv("ASR_MIDDLE_ENGINE", "tencent,ali") # 60s <= duration <= 300s
LONG_ENGINE = os.getenv("ASR_LONG_ENGINE", "gemini") # duration > 300s
GEMINI_BASR_URL = os.getenv("ASR_GEMINI_BASR_URL", "https://generativelanguage.googleapis.com/")
GEMINI_API_KEY = os.getenv("ASR_GEMINI_API_KEY", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
@@ -258,6 +264,12 @@ class ASR:
TENCENT_PROXY = os.getenv("ASR_TENCENT_PROXY", None) # Banned oversea IP, need a back to China proxy
TENCENT_SECRET_ID = os.getenv("ASR_TENCENT_SECRET_ID", "")
TENCENT_SECRET_KEY = os.getenv("ASR_TENCENT_SECRET_KEY", "")
+ # WARN: some models do not allow oversea VPS. Can upload to an alist server in China.
+ ALI_MODEL = os.getenv("ASR_ALI_MODEL", "paraformer-realtime-v2,paraformer-realtime-v1") # comma separated keys for load balance. e.g. "model1,model2,model3"
+ ALI_API_KEY = os.getenv("ASR_ALI_API_KEY", "") # comma separated keys for load balance. e.g. "key1,key2,key3"
+ # If the bot is running on an oversea VPS, and Ali ASR model doesn't allow oversea fileserver.
+ # Change ASR_ALI_FS_ENGINE to alist (configurations in DB class)
+ ALI_FS_ENGINE = os.getenv("ASR_ALI_FS_ENGINE", "local") # local or alist.
class GEMINI: # Official Gemini
src/database.py
@@ -5,17 +5,24 @@
Note: Memory Cache is always enabled.
"""
+import base64
+import io
import json
import os
+import shutil
import warnings
+from pathlib import Path
from urllib.parse import quote_plus, unquote_plus
+import anyio
from aioboto3 import Session
from botocore.exceptions import ClientError
+from glom import glom
from httpx import AsyncClient, AsyncHTTPTransport
from loguru import logger
-from config import DB, cache
+from config import DB, DOWNLOAD_DIR, cache
+from networking import hx_req
from utils import bare_url, stringfy
# hot fix: https://developers.cloudflare.com/r2/examples/aws/boto3/
@@ -272,8 +279,87 @@ async def del_cf_r2(key: str):
return
+async def list_alist() -> list[dict]:
+ """List from Alist."""
+ if not DB.ALIST_ENABLED:
+ return [{}]
+ api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/list"
+ logger.info(f"List Alist from: {api}")
+ res = await hx_req(api, method="POST", post_json={"path": "/"}, check_kv={"code": 200})
+ return glom(res, "data.content", default=[]) or []
+
+
+async def download_alist(fname: str, save_path: str | Path | None = None) -> str:
+ """Download file from Alist."""
+ if not DB.ALIST_ENABLED:
+ return ""
+ save_path = Path(DOWNLOAD_DIR) / fname if save_path is None else Path(save_path)
+ ext = Path(fname).suffix
+ # Headers DO NOT support Unicode characters
+ if any(ord(c) > 127 for c in fname): # has Non-ASCII
+ b64str = base64.urlsafe_b64encode(fname.encode("utf-8")).decode("ascii").rstrip("=")
+ fname = b64str + ext
+ url = DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + f"/{fname.lstrip('/')}"
+ res = await hx_req(url, rformat="content", check_keys=["content"])
+ logger.info(f"Download {fname} to {save_path.name} from: {url}")
+ async with await anyio.open_file(save_path, "wb") as f:
+ await f.write(res["content"])
+ return save_path.as_posix()
+
+
+async def upload_alist(path: str | Path) -> str:
+ """Upload to Alist."""
+ if not DB.ALIST_ENABLED:
+ return ""
+ path = Path(path).expanduser().resolve()
+ if not path.is_file():
+ return ""
+ api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/form"
+ # Headers DO NOT support Unicode characters
+ if any(ord(c) > 127 for c in path.name): # has Non-ASCII
+ new_name = base64.urlsafe_b64encode(path.name.encode("utf-8")).decode("ascii").rstrip("=")
+ new_path = path.with_stem(new_name)
+ shutil.copy(path, new_path)
+ path = new_path
+ headers = {"File-Path": f"/{path.name}"}
+ logger.info(f"Upload {path.name} to: {api}")
+ async with await anyio.open_file(path, "rb") as f:
+ content = await f.read()
+ await hx_req(api, method="PUT", headers=headers, files={"file": (path.name, io.BytesIO(content))}, check_kv={"code": 200})
+ if "new_path" in locals():
+ new_path.unlink(missing_ok=True)
+ return DB.ALIST_SERVER.removesuffix("/") + "/d/" + DB.ALIST_BASR_PATH.strip("/") + f"/{path.name}"
+
+
+async def delete_alist(fname: str, *, ensure_ascii: bool = True) -> None:
+ """Delete from Alist."""
+ if not DB.ALIST_ENABLED:
+ return
+ # Get JWT Token
+ payload = {"username": DB.ALIST_USERNAME, "password": DB.ALIST_PASSWORD}
+ auth_api = DB.ALIST_SERVER.removesuffix("/") + "/api/auth/login"
+ res = await hx_req(auth_api, method="POST", post_json=payload, check_keys=["data.token"])
+ token = res["data"]["token"]
+
+ # Delete
+ api = DB.ALIST_SERVER.removesuffix("/") + "/api/fs/remove"
+ headers = {"Content-Type": "application/json", "Authorization": token}
+ if ensure_ascii and any(ord(c) > 127 for c in fname): # has Non-ASCII
+ b64str = base64.urlsafe_b64encode(fname.encode("utf-8")).decode("ascii").rstrip("=")
+ payload = {"names": [b64str + Path(fname).suffix]}
+ else:
+ payload = {"names": [fname]}
+ logger.info(f"Delete {fname} from: {api}")
+ res = await hx_req(api, method="POST", headers=headers, post_json=payload, check_kv={"code": 200})
+
+
if __name__ == "__main__":
import asyncio
- asyncio.run(set_cf_r2("test2", metadata={"finished": "1"}))
- asyncio.run(set_cf_r2("test2", data={"finished": "1"}))
+ asyncio.run(list_alist())
+ asyncio.run(upload_alist("测试.mp3"))
+ asyncio.run(download_alist("测试.mp3"))
+ asyncio.run(delete_alist("测试.mp3"))
+ # asyncio.run(download_alist("test.py"))
+ # asyncio.run(set_cf_r2("test2", metadata={"finished": "1"}))
+ # asyncio.run(set_cf_r2("test2", data={"finished": "1"}))
src/multimedia.py
@@ -209,7 +209,7 @@ def convert_to_h264(
return path
-def convert_to_audio(path: str | Path | None, ext: str = "m4a", *, codec: str = "aac", delete: bool = True) -> Path:
+def convert_to_audio(path: str | Path | None, ext: str = "m4a", *, codec: str = "aac", delete: bool = True, **kwargs) -> Path:
if path is None or not Path(path).expanduser().resolve().is_file():
return Path("")
path = Path(path).expanduser().resolve()
@@ -221,19 +221,19 @@ def convert_to_audio(path: str | Path | None, ext: str = "m4a", *, codec: str =
try:
if info["audio_codec"] == codec:
logger.debug(f"Audio stream is already {codec}, without re-encoding: {path.name} -> {tmp_path.name}")
- ffmpeg = FFmpeg().option("y").input(path).output(tmp_path, vn=None, acodec="copy")
+ ffmpeg = FFmpeg().option("y").input(path).output(tmp_path, vn=None, acodec="copy", **kwargs)
ffmpeg.execute()
else:
logger.warning(f"Re-encoding audio: {path.name} -> {tmp_path.name}")
- ffmpeg = FFmpeg().option("y").input(path).output(tmp_path, vn=None, acodec=codec)
+ ffmpeg = FFmpeg().option("y").input(path).output(tmp_path, vn=None, acodec=codec, **kwargs)
@ffmpeg.on("progress")
def on_progress(progress: Progress):
- logger.debug(progress)
+ logger.trace(progress)
@ffmpeg.on("completed")
def on_completed():
- logger.debug("completed")
+ logger.success(f"Converted audio: {path} to {final_path}, {codec=}")
ffmpeg.execute()
if delete:
src/networking.py
@@ -11,7 +11,8 @@ from urllib.parse import parse_qs, quote_plus, urlparse
import anyio
import httpx
-from httpx import AsyncClient, HTTPStatusError, Request, RequestError, Response
+from httpx import AsyncClient, AsyncHTTPTransport, HTTPStatusError, Request, RequestError, Response
+from httpx._types import RequestFiles
from httpx_curl_cffi import AsyncCurlTransport, CurlOpt
from loguru import logger
@@ -35,11 +36,13 @@ async def hx_req(
url,
method: str = "GET",
*,
+ transport: AsyncCurlTransport | AsyncHTTPTransport | None = None,
headers: dict | None = None,
cookies: dict | None = None,
params: dict | None = None,
post_json: dict | None = None,
post_content: httpx._types.RequestContent | None = None,
+ files: RequestFiles | None = None,
proxy: str | None = None,
follow_redirects: bool = True,
check_keys: list[str] | None = None,
@@ -49,7 +52,7 @@ async def hx_req(
max_retry: int = 2,
silent: bool = False,
mobile: bool = False,
- rformat: str = "json", # "json", "text"
+ rformat: str = "json", # "json", "text", "content"
last_error: str = "",
) -> dict[str, Any]:
"""Request the given URL with the given method and return the response as a dictionary.
@@ -79,14 +82,15 @@ async def hx_req(
if retry > max_retry:
logger.error(f"[{method}] Failed after {retry} retries: {url}")
return {"hx_error": last_error}
- transport = AsyncCurlTransport(proxy=proxy, impersonate="safari_ios" if mobile else "chrome", default_headers=True, curl_options={CurlOpt.FRESH_CONNECT: True})
+ if transport is None:
+ transport = AsyncCurlTransport(proxy=proxy, impersonate="safari_ios" if mobile else "chrome", default_headers=True, curl_options={CurlOpt.FRESH_CONNECT: True})
if silent:
client = AsyncClient(http2=True, proxy=proxy, transport=transport, follow_redirects=follow_redirects, timeout=timeout)
else:
client = AsyncClient(http2=True, proxy=proxy, transport=transport, follow_redirects=follow_redirects, timeout=timeout, event_hooks={"request": [log_req], "response": [log_resp]})
- if method not in ["GET", "POST"]:
+ if method not in ["GET", "POST", "PUT"]:
error = f"Invalid method: {method}"
logger.error(error)
return {"hx_error": error}
@@ -94,9 +98,13 @@ async def hx_req(
async with client:
if method == "GET":
response = await client.get(url, cookies=cookies, headers=headers, params=params)
+ elif method == "POST":
+ response = await client.post(url, cookies=cookies, headers=headers, json=post_json, files=files, content=post_content, params=params)
else:
- response = await client.post(url, cookies=cookies, headers=headers, json=post_json, content=post_content, params=params)
+ response = await client.put(url, cookies=cookies, headers=headers, files=files, params=params)
response.raise_for_status()
+ if rformat == "content":
+ return {"content": response.content}
data = response.text
check_data(data, check_keys=check_keys, check_kv=check_kv)
res = json.loads(data) if rformat == "json" else {rformat: data}
pyproject.toml
@@ -5,6 +5,7 @@ dependencies = [
"beautifulsoup4>=4.12.3",
"bilibili-api-python>=17.1.4",
"cacheout>=0.16.0",
+ "dashscope>=1.23.2",
"feedparser>=6.0.11",
"glom>=24.11.0",
"google-genai>=1.12.1",
uv.lock
@@ -223,6 +223,7 @@ dependencies = [
{ name = "beautifulsoup4" },
{ name = "bilibili-api-python" },
{ name = "cacheout" },
+ { name = "dashscope" },
{ name = "feedparser" },
{ name = "glom" },
{ name = "google-genai" },
@@ -262,6 +263,7 @@ requires-dist = [
{ name = "beautifulsoup4", specifier = ">=4.12.3" },
{ name = "bilibili-api-python", specifier = ">=17.1.4" },
{ name = "cacheout", specifier = ">=0.16.0" },
+ { name = "dashscope", specifier = ">=1.23.2" },
{ name = "feedparser", specifier = ">=6.0.11" },
{ name = "glom", specifier = ">=24.11.0" },
{ name = "google-genai", specifier = ">=1.12.1" },
@@ -630,6 +632,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/26/80/dbc3a5119afe233ddc9dfc14d5b175636443e5d6349f4cdce3eaa9e1523a/curl_cffi-0.10.0-cp39-abi3-win_amd64.whl", hash = "sha256:59389773a1556e087120e91eac1e33f84f1599d853e1bc168b153e4cdf360002", size = 1374008 },
]
+[[package]]
+name = "dashscope"
+version = "1.23.2"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "aiohttp" },
+ { name = "requests" },
+ { name = "websocket-client" },
+]
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/00/c8/afd737ff28f63c3ce846985f0865f29e27590148003f4df1edeb0b5761d6/dashscope-1.23.2-py3-none-any.whl", hash = "sha256:d2d17561ca58fcdeef6eef157efbd7d68b655551895be24d3c80e743eff21aa1", size = 1277881 },
+]
+
[[package]]
name = "decorator"
version = "5.2.1"
@@ -2394,6 +2409,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 },
]
+[[package]]
+name = "websocket-client"
+version = "1.8.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e6/30/fba0d96b4b5fbf5948ed3f4681f7da2f9f64512e1d303f94b4cc174c24a5/websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da", size = 54648 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526", size = 58826 },
+]
+
[[package]]
name = "websockets"
version = "15.0.1"