main
  1#!/usr/bin/env python
  2# -*- coding: utf-8 -*-
  3import asyncio
  4import io
  5import re
  6from decimal import Decimal
  7from pathlib import Path
  8
  9from glom import glom
 10from loguru import logger
 11
 12from asr.utils import audio_chunk_to_bytes, convert_single_channel, downsampe_audio, get_file_bytes, load_audio
 13from config import ASR, PROXY
 14from networking import hx_req
 15from utils import guess_mime, seconds_to_time, strings_list, zhcn
 16
 17
 18async def groq_asr(path: str | Path, model: str = "", temperature: float = 0, language: str = "") -> dict:
 19    """Groq ASR.
 20
 21    https://console.groq.com/docs/api-reference#audio-transcription
 22
 23    Returns:
 24        {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
 25    """
 26    path = Path(path).expanduser().resolve()
 27    if not path.is_file():
 28        return {"texts": "", "error": "File not found."}
 29    audio_path = path if path.suffix.lower() == ".wav" else await downsampe_audio(path, ext="wav", codec="pcm_s16le")
 30    audio_path = await convert_single_channel(audio_path, ext="wav", codec="pcm_s16le")
 31    # max allowed file size is 25MB
 32    if audio_path.stat().st_size < ASR.GROQ_MAX_BYTES:
 33        return await groq_single_file(audio_path, model=model, temperature=temperature, language=language)
 34    return await groq_file_chunks(audio_path, model=model, temperature=temperature, language=language)
 35
 36
 37async def groq_single_file(
 38    path_or_bytes: Path | bytes,
 39    model: str = "",
 40    temperature: float = 0,
 41    language: str = "",
 42    start_seconds: float = 0,
 43) -> dict:
 44    """Transcribe a single audio chunk with Groq API.
 45
 46    Returns:
 47        {"texts": str, "raw_texts": str, "segments": list[dict], "error": str}
 48    """
 49    if not model:
 50        model = strings_list(ASR.GROQ_MODELS, shuffle=True)[0]
 51    data = {
 52        "model": model,
 53        "temperature": str(temperature),  # must be string
 54        "response_format": "verbose_json",
 55    }
 56    if isinstance(path_or_bytes, Path | str):
 57        file_name = Path(path_or_bytes).name
 58        mime = guess_mime(Path(path_or_bytes))
 59    else:
 60        file_name = "chunk.wav"
 61        mime = "audio/wav"
 62    if language:
 63        data["language"] = language
 64    audio_bytes = await get_file_bytes(path_or_bytes)
 65    resp = await hx_req(
 66        "https://api.groq.com/openai/v1/audio/transcriptions",
 67        method="POST",
 68        headers={"Authorization": f"Bearer {strings_list(ASR.GROQ_KEYS, shuffle=True)[0]}"},
 69        files={"file": (file_name, io.BytesIO(audio_bytes), mime)},
 70        data=data,
 71        timeout=600,
 72        proxy=PROXY.GROQ,
 73        check_kv={"task": "transcribe"},
 74        check_keys=["segments"],
 75    )
 76    start = Decimal(start_seconds).quantize(Decimal(".01"))
 77    resp["segments"] = [
 78        {
 79            "start": start + Decimal(str(seg["start"])),
 80            "end": start + Decimal(str(seg["end"])),
 81            "text": zhcn(seg["text"]),
 82        }
 83        for seg in resp.get("segments", [])
 84    ]
 85    resp["raw_texts"] = " ".join(str(x["text"]) for x in resp["segments"])
 86    resp["texts"] = "\n".join(f"[{seconds_to_time(float(x['start']))}] {str(x['text']).lstrip()}" for x in resp["segments"])  # with timestamp
 87    if resp.get("hx_error"):
 88        resp["error"] = resp.pop("hx_error")
 89    return resp
 90
 91
 92def find_longest_common_sequence(text_seq: list[str], *, match_by_words: bool = False) -> str:
 93    """Find the optimal alignment between sequences with longest common sequence and sliding window matching.
 94
 95    Note: `match_by_words` works great for English sequences, but not so much for Chinese.
 96
 97    Args:
 98        text_seq: List of text sequences to align and merge
 99        match_by_words: Whether to match by words (True) or characters (False)
100
101    Returns:
102        str: Merged sequence with optimal alignment
103
104    Raises:
105        RuntimeError: If there's a mismatch in sequence lengths during comparison
106    """
107    if not text_seq:
108        return ""
109
110    # Convert input based on matching 1strategy
111    sequences = [[word for word in re.split(r"(\s+\w+)", seq) if word] for seq in text_seq] if match_by_words else [list(seq) for seq in text_seq]
112
113    left_sequence = sequences[0]
114    left_length = len(left_sequence)
115    total_sequence = []
116    for right_sequence in sequences[1:]:
117        max_matching = 0.0
118        right_length = len(right_sequence)
119        max_indices = (left_length, left_length, 0, 0)  # left_start, left_stop, right_start, right_stop
120
121        # Try different alignments
122        for i in range(1, left_length + right_length + 1):
123            # Add epsilon to favor longer matches
124            eps = float(i) / 10000.0
125
126            left_start = max(0, left_length - i)
127            left_stop = min(left_length, left_length + right_length - i)
128            left = left_sequence[left_start:left_stop]
129
130            right_start = max(0, i - left_length)
131            right_stop = min(right_length, i)
132            right = right_sequence[right_start:right_stop]
133
134            if len(left) != len(right):
135                msg = "Mismatched subsequences detected during transcript merging."
136                raise RuntimeError(msg)
137
138            matches = sum(a == b for a, b in zip(left, right, strict=False))
139
140            # Normalize matches by position and add epsilon
141            matching = matches / float(i) + eps
142
143            # Require at least 2 matches
144            if matches > 1 and matching > max_matching:
145                max_matching = matching
146                max_indices = (left_start, left_stop, right_start, right_stop)
147
148        # Use the best alignment found
149        left_start, left_stop, right_start, right_stop = max_indices
150        # Take left half from left sequence and right half from right sequence
151        left_mid = (left_stop + left_start) // 2
152        right_mid = (right_stop + right_start) // 2
153
154        total_sequence.extend(left_sequence[:left_mid])
155        left_sequence = right_sequence[right_mid:]
156        left_length = len(left_sequence)
157
158    # Add remaining sequence
159    total_sequence.extend(left_sequence)
160
161    # Join back into text
162    return "".join(total_sequence)
163
164
165def merge_transcripts(results: list[dict]) -> dict:
166    """Merge transcription chunks and handle overlaps.
167
168    Args:
169        results: List of segments
170
171    Returns:
172        {"texts": str, "raw_texts": str, "segments": list[dict]}
173    """
174    processed_segments = []
175    processed_chunks = []
176    for i, chunk in enumerate(results):
177        segments = chunk.get("segments") or []
178        # If not last chunk, find next chunk start time
179        if i < len(results) - 1:
180            next_start = glom(results[i + 1], "segments.0.start", default=float("inf"))  # This is in milliseconds
181
182            # Split segments into current and overlap based on next chunk's start time
183            current_segments = []
184            overlap_segments = []
185
186            for segment in segments:
187                # Convert segment end time to ms and compare with next chunk start time
188                if segment["end"] > next_start:
189                    overlap_segments.append(segment)
190                else:
191                    current_segments.append(segment)
192
193            # Merge overlap segments if any exist
194            if overlap_segments:
195                merged_overlap = overlap_segments[0].copy()
196                merged_overlap |= {
197                    "text": " ".join(s.get("text", "") for s in overlap_segments),
198                    "end": overlap_segments[-1].get("end", 0),
199                }
200
201                current_segments.append(merged_overlap)
202
203            processed_chunks.append(current_segments)
204        else:  # last chunk
205            processed_chunks.append(segments)
206    # Merge boundaries between chunks
207    for i in range(len(processed_chunks) - 1):
208        # Skip if either chunk has no segments
209        if not processed_chunks[i] or not processed_chunks[i + 1]:
210            continue
211
212        # Add all segments except last from current chunk
213        if len(processed_chunks[i]) > 1:
214            processed_segments.extend(processed_chunks[i][:-1])
215
216        # Merge boundary segments
217        last_segment = processed_chunks[i][-1]
218        first_segment = processed_chunks[i + 1][0]
219
220        merged_text = find_longest_common_sequence([last_segment.get("text", ""), first_segment.get("text", "")])
221        merged_segment = last_segment.copy()
222
223        merged_segment |= {"text": merged_text, "end": first_segment.get("end", 0)}
224        processed_segments.append(merged_segment)
225
226    # Add all segments from last chunk
227    if processed_chunks and processed_chunks[-1]:
228        processed_segments.extend(processed_chunks[-1])
229
230    # Filter out duplicated segments
231    final_segments = []
232    for idx, seg in enumerate(processed_segments):
233        if idx == 0:  # the first segment is always added
234            final_segments.append(seg)
235            continue
236        if seg["start"] >= processed_segments[idx - 1]["end"]:  # ensure no overlap
237            final_segments.append(seg)
238    # Create final transcription
239    raw_texts = ""
240    texts = ""  # with timestamp
241    for x in final_segments:
242        raw_texts += x["text"].lstrip() + " "
243        texts += f"[{seconds_to_time(x['start'])}] {x['text'].lstrip()}\n"
244    return {"texts": texts.strip(), "raw_texts": raw_texts.strip(), "segments": final_segments}
245
246
247async def groq_file_chunks(
248    path: str | Path,
249    chunk_seconds: float = ASR.GROQ_CHUNK_SECONDS,
250    overlap_seconds: float = ASR.GROQ_OVERLAP_SECONDS,
251    model: str = "",
252    temperature: float = 0,
253    language: str = "",
254) -> dict:
255    """Transcribe audio in chunks with overlap.
256
257    Args:
258        path: Path to audio file
259        chunk_seconds: Length of each chunk in seconds
260        overlap_seconds: Overlap between chunks in seconds
261
262    Returns:
263        dict: {"texts": str, "raw_texts": str, "segments": list[dict]}
264
265    Raises:
266        RuntimeError: If audio file fails to load
267    """
268    path = Path(path).expanduser().resolve()
269    if not path.is_file():
270        return {"texts": "", "error": "File not found."}
271    if path.suffix.lower() != ".wav":
272        path = await downsampe_audio(path, ext="wav", codec="pcm_s16le")
273    audio, duration, sr = load_audio(path)
274    if sr == 0:
275        return {"error": "Failed to load audio."}
276    transcription = {}
277    try:
278        # Calculate # of chunks
279        total_chunks = int(duration // (chunk_seconds - overlap_seconds)) + 1
280        tasks = []
281        offset_list = []
282        # Loop through each chunk, extract current chunk from audio
283        for i in range(total_chunks):
284            start = int(i * (chunk_seconds - overlap_seconds) * sr)
285            end = int(min(start + chunk_seconds * sr, duration * sr))
286            logger.trace(f"Processing chunk {i + 1}/{total_chunks}, Time range: {start / sr:.0f}s - {end / sr:.0f}s")
287            chunk = audio[start:end]
288            if chunk.shape[0] == 0:  # empty chunk
289                continue
290            tasks.append(audio_chunk_to_bytes(chunk, sr))
291            offset_list.append(int(start / sr))
292        bytes_list = await asyncio.gather(*tasks)  # convert chunks to bytes
293
294        # Process each chunk in parallel (DO NOT do this due to OOM for large audio files)
295        # tasks = [
296        #     groq_single_file(
297        #         audio_bytes,
298        #         start_seconds=offset,
299        #         model=model,
300        #         temperature=temperature,
301        #         language=language,
302        #     )
303        #     for audio_bytes, offset in zip(bytes_list, offset_list, strict=True)
304        # ]
305        # results = await asyncio.gather(*tasks)
306        # results = [r for r in results if r.get("segments")]
307        results = []
308        for audio_bytes, offset in zip(bytes_list, offset_list, strict=True):
309            res = await groq_single_file(
310                audio_bytes,
311                start_seconds=offset,
312                model=model,
313                temperature=temperature,
314                language=language,
315            )
316            if res.get("segments"):
317                results.append(res)
318
319        transcription = merge_transcripts(sorted(results, key=lambda x: x["segments"][0]["start"]))
320    except Exception as e:
321        logger.error(e)
322        return {"error": str(e)}
323    return transcription