main
1#!/venv/bin/python
2# -*- coding: utf-8 -*-
3import shutil
4from pathlib import Path
5from urllib.parse import urlparse
6
7from glom import Coalesce, glom
8
9from asr.utils import audio_duration
10from asr.voice_recognition import asr_file
11from config import PODCAST
12from messages.utils import remove_img_tag
13from networking import match_social_media_link
14from podcast.utils import get_pubdate
15from preview.bilibili import get_bilibili_vinfo
16from preview.youtube import get_youtube_vinfo
17from subtitles.base import fetch_subtitle
18from utils import convert_md, rand_string, readable_time, remove_consecutive_newlines, strings_list
19
20
21async def get_transcripts(
22 audio_path: str | Path,
23 feed_title: str,
24 feed_url: str,
25 entry: dict,
26) -> str:
27 """Get podcast transcripts.
28
29 If the link of this entry has embedded subtitles (YouTube, Bilibili links), use it directly.
30 Otherwise, generate the transcript via ASR.
31 """
32 desc = convert_md(html=glom(entry, Coalesce("content.0.value", "summary"), default=""))
33 desc, _ = remove_img_tag(desc)
34 desc = remove_consecutive_newlines(desc, newline_level=2)
35 reference = f"本次转录稿为播客栏目《{feed_title}》的一期节目。\n该期节目标题: [{entry['title']}]({entry['link']})\n播出日期: {get_pubdate(entry):%Y-%m-%d}\n节目简介: {desc}"
36 if urlparse(entry["link"]).netloc in ["www.youtube.com", "www.bilibili.com"]: # get subtitle from API first
37 res = await fetch_subtitle(entry["link"], reference=reference)
38 if res.get("subtitles"):
39 return res["subtitles"]
40
41 # generate transcript via ASR
42 # The audio file will be deleted after ASR is done.
43 # So we need to copy the file to another path before generating the transcript.
44 duration = await get_duration(audio_path, entry)
45 tmp_path = backup_audio(audio_path)
46 prompt = f"请转录播客栏目《{feed_title}》的一期节目的音频。\n该期节目标题: {entry['title']}\n节目时长: {readable_time(duration)}\n节目简介: {desc}"
47 engine = get_asr_engine(feed_title, feed_url)
48 asr_res = await asr_file(tmp_path, prompt=prompt, engine=engine, corrector_reference=reference, silent=True)
49 Path(tmp_path).unlink(missing_ok=True)
50 return asr_res.get("texts", "")
51
52
53def get_asr_engine(feed_title: str, feed_url: str) -> str:
54 if feed_title in strings_list(PODCAST.ASR_FORCE_GEMINI_TITLES):
55 return "gemini"
56 if feed_title in strings_list(PODCAST.ASR_FORCE_GROQ_TITLES):
57 return "groq"
58 if feed_title in strings_list(PODCAST.ASR_FORCE_CLOUDFLARE_TITLES):
59 return "cloudflare"
60 if feed_title in strings_list(PODCAST.ASR_FORCE_WHISPER_TITLES):
61 return "whisper"
62 if feed_title in strings_list(PODCAST.ASR_FORCE_UNCENSORED_TITLES):
63 return "uncensored"
64
65 if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_GEMINI_DOMAINS):
66 return "gemini"
67 if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_GROQ_DOMAINS):
68 return "groq"
69 if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_CLOUDFLARE_DOMAINS):
70 return "cloudflare"
71 if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_WHISPER_DOMAINS):
72 return "whisper"
73 if urlparse(feed_url.strip()).netloc in strings_list(PODCAST.ASR_FORCE_UNCENSORED_DOMAINS):
74 return "uncensored"
75 return PODCAST.ASR_ENGINE
76
77
78async def get_duration(path: str | Path, entry: dict) -> float:
79 """Get duration of audio file."""
80 # get duration from video info
81 vinfo = {}
82 matched = await match_social_media_link(entry["link"])
83 if matched["platform"] == "youtube":
84 vinfo = await get_youtube_vinfo(matched["vid"])
85 elif matched["platform"] == "bilibili":
86 vinfo = await get_bilibili_vinfo(matched["bvid"])
87 if vinfo.get("duration"):
88 return vinfo["duration"]
89
90 # get duration from audio file
91 return audio_duration(path)
92
93
94def backup_audio(path: str | Path) -> str:
95 tmp_path = Path(path).with_stem(rand_string(12))
96 shutil.copy(path, tmp_path)
97 return tmp_path.as_posix()