Commit b9d475d
Changed files (3)
src/asr/cloudflare.py
@@ -17,12 +17,7 @@ from networking import hx_req
from utils import seconds_to_time, strings_list, zhcn
-async def cloudflare_asr(
- path: str | Path,
- duration: float,
- model: str | None = "",
- prompt: str | None = "",
-) -> dict:
+async def cloudflare_asr(path: str | Path, duration: float, model: str | None = "") -> dict:
"""Cloudflare ASR.
https://developers.cloudflare.com/workers-ai/models/whisper-large-v3-turbo/
@@ -40,17 +35,11 @@ async def cloudflare_asr(
audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
# max allowed file size is 25MB
if duration < ASR.CLOUDFLARE_CHUNK_SECONDS:
- return await cloudflare_single_file(audio_path, model=model, prompt=prompt)
- return await cloudflare_file_chunks(audio_path, duration, model=model, prompt=prompt)
+ return await cloudflare_single_file(audio_path, model=model)
+ return await cloudflare_file_chunks(audio_path, duration, model=model)
-async def cloudflare_single_file(
- path_or_bytes: Path | bytes,
- model: str | None = "",
- prompt: str | None = "",
- *,
- offset_seconds: int = 0,
-) -> dict:
+async def cloudflare_single_file(path_or_bytes: Path | bytes, model: str | None = "", *, offset_seconds: int = 0) -> dict:
"""Transcribe a single audio chunk with Groq API.
Returns:
@@ -73,8 +62,6 @@ async def cloudflare_single_file(
headers = {"Authorization": f"Bearer {cf_token}"}
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
payload = {"audio": audio_base64, "task": "transcribe", "vad_filter": False}
- if prompt:
- payload["initial_prompt"] = prompt
resp = await hx_req(
url,
"POST",
@@ -107,7 +94,6 @@ async def cloudflare_file_chunks(
path: Path,
duration: float,
model: str | None = "",
- prompt: str | None = "",
chunk_seconds: float = 600,
overlap_seconds: float = ASR.CLOUDFLARE_OVERLAP_SECONDS,
) -> dict:
@@ -152,7 +138,7 @@ async def cloudflare_file_chunks(
bytes_list = await asyncio.gather(*tasks) # convert chunks to bytes
tasks = []
for audio_bytes, offset_seconds in zip(bytes_list, offset_list, strict=True):
- task = cloudflare_single_file(audio_bytes, model, prompt, offset_seconds=offset_seconds)
+ task = cloudflare_single_file(audio_bytes, model, offset_seconds=offset_seconds)
tasks.append(run_with_semaphore(task))
results = await asyncio.gather(*tasks)
results = [r for r in results if r.get("segments")]
src/asr/groq.py
@@ -15,7 +15,7 @@ from networking import hx_req
from utils import guess_mime, seconds_to_time, strings_list, zhcn
-async def groq_asr(path: str | Path, model: str = "", prompt: str = "", temperature: float = 0, language: str = "") -> dict:
+async def groq_asr(path: str | Path, model: str = "", temperature: float = 0, language: str = "") -> dict:
"""Groq ASR.
https://console.groq.com/docs/api-reference#audio-transcription
@@ -30,15 +30,14 @@ async def groq_asr(path: str | Path, model: str = "", prompt: str = "", temperat
audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
# max allowed file size is 25MB
if audio_path.stat().st_size < ASR.GROQ_MAX_BYTES:
- return await groq_single_file(audio_path, model=model, prompt=prompt, temperature=temperature, language=language)
- return await groq_file_chunks(audio_path, model=model, temperature=temperature, prompt=prompt, language=language)
+ return await groq_single_file(audio_path, model=model, temperature=temperature, language=language)
+ return await groq_file_chunks(audio_path, model=model, temperature=temperature, language=language)
async def groq_single_file(
path_or_bytes: Path | bytes,
model: str = "",
temperature: float = 0,
- prompt: str = "",
language: str = "",
start_seconds: float = 0,
) -> dict:
@@ -60,8 +59,6 @@ async def groq_single_file(
else:
file_name = "chunk.wav"
mime = "audio/wav"
- if prompt:
- data["prompt"] = prompt
if language:
data["language"] = language
audio_bytes = await get_file_bytes(path_or_bytes)
@@ -253,7 +250,6 @@ async def groq_file_chunks(
overlap_seconds: float = ASR.GROQ_OVERLAP_SECONDS,
model: str = "",
temperature: float = 0,
- prompt: str = "",
language: str = "",
) -> dict:
"""Transcribe audio in chunks with overlap.
@@ -300,7 +296,6 @@ async def groq_file_chunks(
start_seconds=offset,
model=model,
temperature=temperature,
- prompt=prompt,
language=language,
)
for audio_bytes, offset in zip(bytes_list, offset_list, strict=True)
src/asr/voice_recognition.py
@@ -216,9 +216,9 @@ async def asr_file(
elif engine == "gemini":
res = await gemini_asr(path=path, prompt=prompt, delete_gemini_file=delete_gemini_file)
elif engine == "cloudflare":
- res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"), prompt=prompt)
+ res = await cloudflare_asr(path, duration, model=kwargs.get("cf_asr_model"))
elif engine == "groq":
- res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""), prompt=prompt)
+ res = await groq_asr(path=path, model=kwargs.get("groq_asr_model", ""))
else:
return {"error": "ASR method not supported"}
if res.get("texts"):