Commit 3e55f26

benny-dou <60535774+benny-dou@users.noreply.github.com>
2025-01-24 14:46:18
fix(networking): prefer http `get` to `stream` method to download_file
1 parent 0c33f8e
Changed files (1)
src/networking.py
@@ -177,6 +177,7 @@ async def download_file(
     skip_exist: bool = False,
     workers_proxy: bool = False,
     headers: dict | None = None,
+    stream: bool = False,
     **kwargs,
 ) -> str:
     """Download a file from the given link and save it to the specified path.
@@ -188,6 +189,7 @@ async def download_file(
         skip_exist (bool, optional): Skip downloading if the file already exists. Defaults to False.
         workers_proxy (bool, optional): Use workers proxy. Defaults to False.
         headers (dict, optional): The headers to use for the request. Defaults to Telegram UA.
+        stream (bool, optional): Stream the download. Defaults to False.
 
     Returns:
         str: Download file path.
@@ -209,18 +211,36 @@ async def download_file(
         headers = {"user-agent": UA.TELEGRAM}
     path.parent.mkdir(parents=True, exist_ok=True)
     logger.trace(f"Downloading {link} to {path}")
-    hx = AsyncClient(transport=AsyncHTTPTransport(retries=3), proxy=PROXY.DOWNLOAD, timeout=60, follow_redirects=True, event_hooks={"request": [log_req], "response": [log_resp]})
-    async with semaphore, hx.stream("GET", link, headers=headers, timeout=60) as response:
-        total = int(response.headers.get("Content-Length", 0))
-        async with await anyio.open_file(path, "wb") as f:
-            num_bytes_downloaded = response.num_bytes_downloaded
-            async for chunk in response.aiter_bytes():
-                await f.write(chunk)
-                msg = f"⏬下载中: {readable_size(num_bytes_downloaded)} / {readable_size(total)}\n💾{path.name}"
-                msg += f" ({num_bytes_downloaded / total:.2%})" if total and total > 0 else ""
-                await modify_progress(text=msg, **kwargs)
-                num_bytes_downloaded = response.num_bytes_downloaded
-
+    hx = AsyncClient(
+        headers=headers,
+        transport=AsyncHTTPTransport(retries=3),
+        proxy=PROXY.DOWNLOAD,
+        timeout=10,
+        follow_redirects=True,
+        event_hooks={"request": [log_req], "response": [log_resp]},
+    )
+    try:
+        if stream:  # can monitor progress, but the retry mechanism does not work
+            async with semaphore, hx.stream("GET", link) as response:
+                total = int(response.headers.get("Content-Length", 0))
+                async with await anyio.open_file(path, "wb") as f:
+                    num_bytes_downloaded = response.num_bytes_downloaded
+                    async for chunk in response.aiter_bytes():
+                        await f.write(chunk)
+                        msg = f"⏬下载中: {readable_size(num_bytes_downloaded)} / {readable_size(total)}\n💾{path.name}"
+                        msg += f" ({num_bytes_downloaded / total:.2%})" if total and total > 0 else ""
+                        await modify_progress(text=msg, **kwargs)
+                        num_bytes_downloaded = response.num_bytes_downloaded
+        else:
+            async with semaphore, hx:
+                response = await hx.get(link)
+                response.raise_for_status()
+                path.write_bytes(response.content)  # Save the file to disk
+    except (RequestError, HTTPStatusError) as e:
+        error = f"Failed to download: {e}"
+        logger.error(error)
+        await modify_progress(text=error, **kwargs)
+        return ""
     if path.is_file():
         logger.info(f"Downloaded file saved to {path}")
         await modify_progress(text=f"🎉下载成功\n{path.name}", **kwargs)