Commit faf0b67
Changed files (4)
src/asr/cloudflare.py
@@ -2,21 +2,21 @@
# -*- coding: utf-8 -*-
import asyncio
import base64
+import io
from collections.abc import Coroutine
from decimal import Decimal
from pathlib import Path
from typing import Any
-import anyio
import soundfile as sf
from glom import glom
from loguru import logger
from asr.groq import merge_transcripts
-from asr.utils import convert_single_channel, downsampe_audio
-from config import ASR, DOWNLOAD_DIR
+from asr.utils import convert_single_channel, downsampe_audio, get_file_bytes
+from config import ASR
from networking import hx_req
-from utils import rand_string, seconds_to_time, strings_list, zhcn
+from utils import seconds_to_time, strings_list, zhcn
async def cloudflare_asr(
@@ -48,12 +48,11 @@ async def cloudflare_asr(
async def cloudflare_single_file(
- path: Path,
+ path_or_bytes: Path | bytes,
model: str | None = "",
prompt: str | None = "",
*,
offset_seconds: int = 0,
- delete_local_file: bool = False,
) -> dict:
"""Transcribe a single audio chunk with Groq API.
@@ -65,14 +64,17 @@ async def cloudflare_single_file(
return {"error": "未配置Cloudflare相关API"}
if not model:
model = ASR.CLOUDFLARE_MODEL
+ audio_bytes = await get_file_bytes(path_or_bytes)
+ if not audio_bytes:
+ return {"error": f"Audio is empty: {path_or_bytes}"}
+ resp = {}
+
for key in strings_list(ASR.CLOUDFLARE_KEYS, shuffle=True):
cf_id, cf_token = key.split(":", 1)
try:
url = f"https://api.cloudflare.com/client/v4/accounts/{cf_id}/ai/run/{model}"
headers = {"Authorization": f"Bearer {cf_token}"}
- async with await anyio.open_file(path, "rb") as f:
- audio_bytes = await f.read()
- audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
payload = {"audio": audio_base64, "task": "transcribe", "vad_filter": True}
if prompt:
payload["initial_prompt"] = prompt
@@ -98,8 +100,6 @@ async def cloudflare_single_file(
resp["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in resp["segments"]) # with timestamp
if resp.get("hx_error"):
resp["error"] = resp.pop("hx_error")
- if delete_local_file:
- path.unlink(missing_ok=True)
except Exception as e:
logger.error(e)
return resp
@@ -141,9 +141,7 @@ async def cloudflare_file_chunks(
try:
# Calculate # of chunks
- total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
- total_chunks = int(total_chunks)
- chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
+ total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
tasks = []
# Loop through each chunk, extract current chunk from audio, transcribe
for i in range(total_chunks):
@@ -153,14 +151,16 @@ async def cloudflare_file_chunks(
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- chunk_path = chunk_paths[i]
- await asyncio.to_thread(sf.write, chunk_path.as_posix(), chunk, sr, format="ogg", subtype="OPUS")
- task = cloudflare_single_file(chunk_path, model, prompt, offset_seconds=int(start / sr), delete_local_file=False)
+ # Write audio chunk to buffer
+ buffer = io.BytesIO()
+ await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
+ buffer.seek(0) # move cursor to beginning
+ chunk_bytes = buffer.getvalue() # get chunk bytes
+ task = cloudflare_single_file(chunk_bytes, model, prompt, offset_seconds=int(start / sr))
tasks.append(run_with_semaphore(task))
results = await asyncio.gather(*tasks)
results = [r for r in results if r.get("segments")]
transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
- [p.unlink(missing_ok=True) for p in chunk_paths]
except Exception as e:
logger.error(e)
return {"error": str(e)}
src/asr/groq.py
@@ -3,16 +3,14 @@
import asyncio
import io
import re
-import tempfile
from decimal import Decimal
from pathlib import Path
-import anyio
import soundfile as sf
from glom import glom
from loguru import logger
-from asr.utils import convert_single_channel, downsampe_audio
+from asr.utils import convert_single_channel, downsampe_audio, get_file_bytes
from config import ASR
from networking import hx_req
from utils import seconds_to_time, strings_list, zhcn
@@ -39,28 +37,20 @@ async def groq_asr(path: str | Path, model: str = "", prompt: str = "", temperat
async def groq_single_file(
- path: Path | str,
+ path_or_bytes: Path | bytes,
model: str = "",
temperature: float = 0,
prompt: str = "",
language: str = "",
start_seconds: float = 0,
- *,
- delete_local_file: bool = False,
) -> dict:
"""Transcribe a single audio chunk with Groq API.
Returns:
{"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
"""
- path = Path(path).expanduser().resolve()
- if not path.is_file():
- return {"texts": "", "error": "File not found."}
if not model:
model = strings_list(ASR.GROQ_MODELS, shuffle=True)[0]
- async with await anyio.open_file(path, "rb") as f:
- content = await f.read()
- content_bytes = io.BytesIO(content)
data = {
"model": model,
"temperature": str(temperature), # must be string
@@ -70,11 +60,12 @@ async def groq_single_file(
data["prompt"] = prompt
if language:
data["language"] = language
+ audio_bytes = await get_file_bytes(path_or_bytes)
resp = await hx_req(
"https://api.groq.com/openai/v1/audio/transcriptions",
method="POST",
headers={"Authorization": f"Bearer {strings_list(ASR.GROQ_KEYS, shuffle=True)[0]}"},
- files={"file": ("chunk.ogg", content_bytes, "audio/ogg")},
+ files={"file": ("chunk.ogg", io.BytesIO(audio_bytes), "audio/ogg")},
data=data,
timeout=600,
proxy=ASR.GROQ_PROXY,
@@ -94,8 +85,6 @@ async def groq_single_file(
resp["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in resp["segments"]) # with timestamp
if resp.get("hx_error"):
resp["error"] = resp.pop("hx_error")
- if delete_local_file:
- path.unlink(missing_ok=True)
return resp
@@ -297,8 +286,7 @@ async def groq_file_chunks(
transcription = {}
try:
# Calculate # of chunks
- total_chunks = (duration_seconds // (chunk_seconds - overlap_seconds)) + 1
- total_chunks = int(total_chunks)
+ total_chunks = int(duration_seconds // (chunk_seconds - overlap_seconds)) + 1
tasks = []
# Loop through each chunk, extract current chunk from audio, transcribe
for i in range(total_chunks):
@@ -308,20 +296,21 @@ async def groq_file_chunks(
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as chunk_file:
- chunk_path = chunk_file.name
- await asyncio.to_thread(sf.write, chunk_path, chunk, sr, format="ogg", subtype="OPUS")
- tasks.append(
- groq_single_file(
- chunk_file.name,
- start_seconds=start / sr,
- model=model,
- temperature=temperature,
- prompt=prompt,
- language=language,
- delete_local_file=True,
- )
+ # Write audio chunk to buffer
+ buffer = io.BytesIO()
+ await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
+ buffer.seek(0) # move cursor to beginning
+ chunk_bytes = buffer.getvalue() # get chunk bytes
+ tasks.append(
+ groq_single_file(
+ chunk_bytes,
+ start_seconds=start / sr,
+ model=model,
+ temperature=temperature,
+ prompt=prompt,
+ language=language,
)
+ )
results = await asyncio.gather(*tasks)
transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
except Exception as e:
src/asr/tecent.py
@@ -4,6 +4,7 @@ import asyncio
import base64
import hashlib
import hmac
+import io
from collections.abc import Coroutine
from decimal import Decimal
from pathlib import Path
@@ -15,12 +16,12 @@ from glom import Coalesce, flatten, glom
from loguru import logger
from asr.groq import merge_transcripts
-from asr.utils import audio_duration, convert_single_channel, downsampe_audio, is_english_word
-from config import ASR, DOWNLOAD_DIR, FILE_SERVER
+from asr.utils import audio_duration, convert_single_channel, downsampe_audio, get_file_bytes, is_english_word
+from config import ASR, FILE_SERVER
from database.alist import delete_alist, upload_alist
from database.uguu import upload_uguu
from networking import hx_req
-from utils import nowdt, rand_string, seconds_to_time
+from utils import nowdt, seconds_to_time
def sign(key, msg):
@@ -101,7 +102,7 @@ async def tencent_asr(path: str | Path, language: str, duration: float) -> dict:
return await tencent_file_chunks(audio_path, language=language, duration=duration)
-async def tencent_single_asr(path: Path, language: str, *, offset_seconds: int = 0, delete_local_file: bool = False) -> dict:
+async def tencent_single_asr(path_or_bytes: Path | bytes, language: str, *, offset_seconds: int = 0) -> dict:
"""Tencent Single Sentence ASR.
一句话识别 (每月免费额度: 5000次)
@@ -118,16 +119,19 @@ async def tencent_single_asr(path: Path, language: str, *, offset_seconds: int =
}
"""
final = {"texts": "", "raw_texts": "", "segments": []}
- file_bytes = path.stat().st_size
- # max 3 MB
- audio_path = path if file_bytes < 3 * 1024 * 1024 else await downsampe_audio(path, ext="opus", codec="libopus")
- voice_format = Path(audio_path).suffix.lower().lstrip(".")
- if voice_format in ["ogg", "opus", "oga"]: # tencnet only supports ogg-opus
+ if isinstance(path_or_bytes, Path):
+ # max 3 MB
+ file_size = path_or_bytes.stat().st_size
+ audio_path = path_or_bytes if file_size < 3 * 1024 * 1024 else await downsampe_audio(path_or_bytes, ext="opus", codec="libopus")
+ voice_format = Path(audio_path).suffix.lower().lstrip(".")
+ if voice_format in ["ogg", "opus", "oga"]: # tencnet only supports ogg-opus
+ voice_format = "ogg-opus"
+ audio_bytes = await get_file_bytes(audio_path)
+ elif isinstance(path_or_bytes, bytes):
voice_format = "ogg-opus"
- async with await anyio.open_file(audio_path, "rb") as f:
- content = await f.read()
- data = base64.b64encode(content).decode("utf-8")
- payload = f'{{"EngSerViceType":"{language}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{data}"}}'
+ audio_bytes = path_or_bytes
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
+ payload = f'{{"EngSerViceType":"{language}","SourceType":1,"WordInfo":2,"VoiceFormat":"{voice_format}","Data":"{audio_base64}"}}'
headers = generate_tencent_cloud_headers(action="SentenceRecognition", payload=payload)
res = await hx_req(
"https://asr.tencentcloudapi.com",
@@ -170,9 +174,6 @@ async def tencent_single_asr(path: Path, language: str, *, offset_seconds: int =
segments.append({"start": start, "end": end, "text": text})
final["texts"] = "\n".join(f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}" for x in segments) # with timestamp
final["segments"] = segments
- if delete_local_file:
- path.unlink(missing_ok=True)
- audio_path.unlink(missing_ok=True)
return final
@@ -210,9 +211,7 @@ async def tencent_file_chunks(
try:
# Calculate # of chunks
- total_chunks = (duration // (chunk_seconds - overlap_seconds)) + 1
- total_chunks = int(total_chunks)
- chunk_paths = [Path(DOWNLOAD_DIR) / f"{rand_string()}.opus" for _ in range(total_chunks)]
+ total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
tasks = []
# Loop through each chunk, extract current chunk from audio, transcribe
for i in range(total_chunks):
@@ -222,14 +221,16 @@ async def tencent_file_chunks(
chunk = audio[start:end]
if chunk.shape[0] == 0: # empty chunk
continue
- chunk_path = chunk_paths[i]
- await asyncio.to_thread(sf.write, chunk_path.as_posix(), chunk, sr, format="ogg", subtype="OPUS")
- task = tencent_single_asr(chunk_path, language=language, offset_seconds=int(start / sr), delete_local_file=False)
+ # Write audio chunk to buffer
+ buffer = io.BytesIO()
+ await asyncio.to_thread(sf.write, buffer, chunk, sr, format="ogg", subtype="OPUS")
+ buffer.seek(0) # move cursor to beginning
+ chunk_bytes = buffer.getvalue() # get chunk bytes
+ task = tencent_single_asr(chunk_bytes, language=language, offset_seconds=int(start / sr))
tasks.append(run_with_semaphore(task))
results = await asyncio.gather(*tasks)
results = [r for r in results if r.get("segments")]
transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
- [p.unlink(missing_ok=True) for p in chunk_paths]
except Exception as e:
logger.error(e)
return {"error": str(e)}
src/asr/utils.py
@@ -6,8 +6,10 @@ import random
import re
from pathlib import Path
+import anyio
import soundfile as sf
from ffmpeg import FFmpeg
+from loguru import logger
from soundfile import LibsndfileError
from config import ASR, GEMINI
@@ -112,3 +114,16 @@ def audio_duration(path: str | Path) -> float:
return max(map(float, durations))
return 0.0
+
+
+async def get_file_bytes(path_or_bytes: str | Path | bytes) -> bytes:
+ file_bytes = b""
+ if isinstance(path_or_bytes, bytes):
+ return b""
+ if isinstance(path_or_bytes, (str, Path)):
+ if not Path(path_or_bytes).is_file():
+ logger.error(f"{path_or_bytes} is not exist.")
+ return b""
+ async with await anyio.open_file(path_or_bytes, "rb") as f:
+ file_bytes = await f.read()
+ return file_bytes