main
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import asyncio
4import io
5import re
6from decimal import Decimal
7from pathlib import Path
8
9from glom import glom
10from loguru import logger
11
12from asr.utils import audio_chunk_to_bytes, convert_single_channel, downsampe_audio, get_file_bytes, load_audio
13from config import ASR, PROXY
14from networking import hx_req
15from utils import guess_mime, seconds_to_time, strings_list, zhcn
16
17
18async def groq_asr(path: str | Path, model: str = "", temperature: float = 0, language: str = "") -> dict:
19 """Groq ASR.
20
21 https://console.groq.com/docs/api-reference#audio-transcription
22
23 Returns:
24 {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
25 """
26 path = Path(path).expanduser().resolve()
27 if not path.is_file():
28 return {"texts": "", "error": "File not found."}
29 audio_path = path if path.suffix.lower() == ".wav" else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
30 audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
31 # max allowed file size is 25MB
32 if audio_path.stat().st_size < ASR.GROQ_MAX_BYTES:
33 return await groq_single_file(audio_path, model=model, temperature=temperature, language=language)
34 return await groq_file_chunks(audio_path, model=model, temperature=temperature, language=language)
35
36
37async def groq_single_file(
38 path_or_bytes: Path | bytes,
39 model: str = "",
40 temperature: float = 0,
41 language: str = "",
42 start_seconds: float = 0,
43) -> dict:
44 """Transcribe a single audio chunk with Groq API.
45
46 Returns:
47 {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
48 """
49 if not model:
50 model = strings_list(ASR.GROQ_MODELS, shuffle=True)[0]
51 data = {
52 "model": model,
53 "temperature": str(temperature), # must be string
54 "response_format": "verbose_json",
55 }
56 if isinstance(path_or_bytes, Path | str):
57 file_name = Path(path_or_bytes).name
58 mime = guess_mime(Path(path_or_bytes))
59 else:
60 file_name = "chunk.wav"
61 mime = "audio/wav"
62 if language:
63 data["language"] = language
64 audio_bytes = await get_file_bytes(path_or_bytes)
65 resp = await hx_req(
66 "https://api.groq.com/openai/v1/audio/transcriptions",
67 method="POST",
68 headers={"Authorization": f"Bearer {strings_list(ASR.GROQ_KEYS, shuffle=True)[0]}"},
69 files={"file": (file_name, io.BytesIO(audio_bytes), mime)},
70 data=data,
71 timeout=600,
72 proxy=PROXY.GROQ,
73 check_kv={"task": "transcribe"},
74 check_keys=["segments"],
75 )
76 start = Decimal(start_seconds).quantize(Decimal(".01"))
77 resp["segments"] = [
78 {
79 "start": start + Decimal(str(seg["start"])),
80 "end": start + Decimal(str(seg["end"])),
81 "text": zhcn(seg["text"]),
82 }
83 for seg in resp.get("segments", [])
84 ]
85 resp["raw_texts"] = " ".join(str(x["text"]) for x in resp["segments"])
86 resp["texts"] = "\n".join(f"[{seconds_to_time(float(x['start']))}] {str(x['text']).lstrip()}" for x in resp["segments"]) # with timestamp
87 if resp.get("hx_error"):
88 resp["error"] = resp.pop("hx_error")
89 return resp
90
91
92def find_longest_common_sequence(text_seq: list[str], *, match_by_words: bool = False) -> str:
93 """Find the optimal alignment between sequences with longest common sequence and sliding window matching.
94
95 Note: `match_by_words` works great for English sequences, but not so much for Chinese.
96
97 Args:
98 text_seq: List of text sequences to align and merge
99 match_by_words: Whether to match by words (True) or characters (False)
100
101 Returns:
102 str: Merged sequence with optimal alignment
103
104 Raises:
105 RuntimeError: If there's a mismatch in sequence lengths during comparison
106 """
107 if not text_seq:
108 return ""
109
110 # Convert input based on matching 1strategy
111 sequences = [[word for word in re.split(r"(\s+\w+)", seq) if word] for seq in text_seq] if match_by_words else [list(seq) for seq in text_seq]
112
113 left_sequence = sequences[0]
114 left_length = len(left_sequence)
115 total_sequence = []
116 for right_sequence in sequences[1:]:
117 max_matching = 0.0
118 right_length = len(right_sequence)
119 max_indices = (left_length, left_length, 0, 0) # left_start, left_stop, right_start, right_stop
120
121 # Try different alignments
122 for i in range(1, left_length + right_length + 1):
123 # Add epsilon to favor longer matches
124 eps = float(i) / 10000.0
125
126 left_start = max(0, left_length - i)
127 left_stop = min(left_length, left_length + right_length - i)
128 left = left_sequence[left_start:left_stop]
129
130 right_start = max(0, i - left_length)
131 right_stop = min(right_length, i)
132 right = right_sequence[right_start:right_stop]
133
134 if len(left) != len(right):
135 msg = "Mismatched subsequences detected during transcript merging."
136 raise RuntimeError(msg)
137
138 matches = sum(a == b for a, b in zip(left, right, strict=False))
139
140 # Normalize matches by position and add epsilon
141 matching = matches / float(i) + eps
142
143 # Require at least 2 matches
144 if matches > 1 and matching > max_matching:
145 max_matching = matching
146 max_indices = (left_start, left_stop, right_start, right_stop)
147
148 # Use the best alignment found
149 left_start, left_stop, right_start, right_stop = max_indices
150 # Take left half from left sequence and right half from right sequence
151 left_mid = (left_stop + left_start) // 2
152 right_mid = (right_stop + right_start) // 2
153
154 total_sequence.extend(left_sequence[:left_mid])
155 left_sequence = right_sequence[right_mid:]
156 left_length = len(left_sequence)
157
158 # Add remaining sequence
159 total_sequence.extend(left_sequence)
160
161 # Join back into text
162 return "".join(total_sequence)
163
164
165def merge_transcripts(results: list[dict]) -> dict:
166 """Merge transcription chunks and handle overlaps.
167
168 Args:
169 results: List of segments
170
171 Returns:
172 {"texts": str, "raw_texts": str, "segments": list[dict]}
173 """
174 processed_segments = []
175 processed_chunks = []
176 for i, chunk in enumerate(results):
177 segments = chunk.get("segments") or []
178 # If not last chunk, find next chunk start time
179 if i < len(results) - 1:
180 next_start = glom(results[i + 1], "segments.0.start", default=float("inf")) # This is in milliseconds
181
182 # Split segments into current and overlap based on next chunk's start time
183 current_segments = []
184 overlap_segments = []
185
186 for segment in segments:
187 # Convert segment end time to ms and compare with next chunk start time
188 if segment["end"] > next_start:
189 overlap_segments.append(segment)
190 else:
191 current_segments.append(segment)
192
193 # Merge overlap segments if any exist
194 if overlap_segments:
195 merged_overlap = overlap_segments[0].copy()
196 merged_overlap |= {
197 "text": " ".join(s.get("text", "") for s in overlap_segments),
198 "end": overlap_segments[-1].get("end", 0),
199 }
200
201 current_segments.append(merged_overlap)
202
203 processed_chunks.append(current_segments)
204 else: # last chunk
205 processed_chunks.append(segments)
206 # Merge boundaries between chunks
207 for i in range(len(processed_chunks) - 1):
208 # Skip if either chunk has no segments
209 if not processed_chunks[i] or not processed_chunks[i + 1]:
210 continue
211
212 # Add all segments except last from current chunk
213 if len(processed_chunks[i]) > 1:
214 processed_segments.extend(processed_chunks[i][:-1])
215
216 # Merge boundary segments
217 last_segment = processed_chunks[i][-1]
218 first_segment = processed_chunks[i + 1][0]
219
220 merged_text = find_longest_common_sequence([last_segment.get("text", ""), first_segment.get("text", "")])
221 merged_segment = last_segment.copy()
222
223 merged_segment |= {"text": merged_text, "end": first_segment.get("end", 0)}
224 processed_segments.append(merged_segment)
225
226 # Add all segments from last chunk
227 if processed_chunks and processed_chunks[-1]:
228 processed_segments.extend(processed_chunks[-1])
229
230 # Filter out duplicated segments
231 final_segments = []
232 for idx, seg in enumerate(processed_segments):
233 if idx == 0: # the first segment is always added
234 final_segments.append(seg)
235 continue
236 if seg["start"] >= processed_segments[idx - 1]["end"]: # ensure no overlap
237 final_segments.append(seg)
238 # Create final transcription
239 raw_texts = ""
240 texts = "" # with timestamp
241 for x in final_segments:
242 raw_texts += x["text"].lstrip() + " "
243 texts += f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}\n"
244 return {"texts": texts.strip(), "raw_texts": raw_texts.strip(), "segments": final_segments}
245
246
247async def groq_file_chunks(
248 path: str | Path,
249 chunk_seconds: float = ASR.GROQ_CHUNK_SECONDS,
250 overlap_seconds: float = ASR.GROQ_OVERLAP_SECONDS,
251 model: str = "",
252 temperature: float = 0,
253 language: str = "",
254) -> dict:
255 """Transcribe audio in chunks with overlap.
256
257 Args:
258 path: Path to audio file
259 chunk_seconds: Length of each chunk in seconds
260 overlap_seconds: Overlap between chunks in seconds
261
262 Returns:
263 dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
264
265 Raises:
266 RuntimeError: If audio file fails to load
267 """
268 path = Path(path).expanduser().resolve()
269 if not path.is_file():
270 return {"texts": "", "error": "File not found."}
271 if path.suffix.lower() != ".wav":
272 path = await downsampe_audio(path, ext="wav", codec="pcm_s16le")
273 audio, duration, sr = load_audio(path)
274 if sr == 0:
275 return {"error": "Failed to load audio."}
276 transcription = {}
277 try:
278 # Calculate # of chunks
279 total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
280 tasks = []
281 offset_list = []
282 # Loop through each chunk, extract current chunk from audio
283 for i in range(total_chunks):
284 start = int(i * (chunk_seconds - overlap_seconds) * sr)
285 end = int(min(start + chunk_seconds * sr, duration * sr))
286 logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
287 chunk = audio[start:end]
288 if chunk.shape[0] == 0: # empty chunk
289 continue
290 tasks.append(audio_chunk_to_bytes(chunk, sr))
291 offset_list.append(int(start / sr))
292 bytes_list = await asyncio.gather(*tasks) # convert chunks to bytes
293
294 # Process each chunk in parallel (DO NOT do this due to OOM for large audio files)
295 # tasks = [
296 # groq_single_file(
297 # audio_bytes,
298 # start_seconds=offset,
299 # model=model,
300 # temperature=temperature,
301 # language=language,
302 # )
303 # for audio_bytes, offset in zip(bytes_list, offset_list, strict=True)
304 # ]
305 # results = await asyncio.gather(*tasks)
306 # results = [r for r in results if r.get("segments")]
307 results = []
308 for audio_bytes, offset in zip(bytes_list, offset_list, strict=True):
309 res = await groq_single_file(
310 audio_bytes,
311 start_seconds=offset,
312 model=model,
313 temperature=temperature,
314 language=language,
315 )
316 if res.get("segments"):
317 results.append(res)
318
319 transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
320 except Exception as e:
321 logger.error(e)
322 return {"error": str(e)}
323 return transcription