diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py old mode 100644 new mode 100755 index 49b05dd05..c44a17434 --- a/frigate/util/downloader.py +++ b/frigate/util/downloader.py @@ -3,7 +3,8 @@ import os import threading import time from pathlib import Path -from typing import Callable, List +from typing import Callable, List, Tuple +from urllib.parse import urlparse import requests @@ -13,6 +14,53 @@ from frigate.types import ModelStatusTypesEnum logger = logging.getLogger(__name__) +# Mirror mapping configuration +MIRROR_MAPPING = { + "raw.githubusercontent.com": "raw.gh.mirror.frigate-cn.video", + "github.com": "gh.mirror.frigate-cn.video", + "huggingface.co": "hf.mirror.frigate-cn.video", +} + +# Global flag to force using official sources only +FORCE_OFFICIAL_SOURCE = False + + +def set_force_official_source(force: bool = True) -> None: + """ + Set the global flag to force using only official sources. + """ + global FORCE_OFFICIAL_SOURCE + FORCE_OFFICIAL_SOURCE = force + if force: + logger.info("Forced to use official sources only") + + +def get_best_mirror(url: str, latency_threshold: int = 20) -> Tuple[str, bool]: + """ + Determine whether to use mirror based on environment variable USE_MIRROR_SOURCE. + Returns a tuple of (url_to_use, is_mirror). + """ + if FORCE_OFFICIAL_SOURCE: + return url, False + + parsed_url = urlparse(url) + host = parsed_url.netloc + + # Check if this URL has a mirror + mirror_host = MIRROR_MAPPING.get(host) + if not mirror_host: + return url, False + + # Check environment variable to determine if mirror should be used + use_mirror = os.environ.get("USE_MIRROR_SOURCE", "").lower() == "true" + + if use_mirror: + mirror_url = url.replace(host, mirror_host) + logger.info(f"Using mirror: {mirror_url}") + return mirror_url, True + + return url, False + class FileLock: def __init__(self, path): @@ -106,17 +154,24 @@ class ModelDownloader: self.download_complete.set() @staticmethod - def download_from_url(url: str, save_path: str, silent: bool = False) -> Path: + def download_from_url( + url: str, save_path: str, silent: bool = False, use_mirror: bool = True + ) -> Path: temporary_filename = Path(save_path).with_name( os.path.basename(save_path) + ".part" ) temporary_filename.parent.mkdir(parents=True, exist_ok=True) + # Check if we should use a mirror + download_url = url + if use_mirror: + download_url, is_mirror = get_best_mirror(url) + if not silent: - logger.info(f"Downloading model file from: {url}") + logger.info(f"Downloading model file from: {download_url}") try: - with requests.get(url, stream=True, allow_redirects=True) as r: + with requests.get(download_url, stream=True, allow_redirects=True) as r: r.raise_for_status() with open(temporary_filename, "wb") as f: for chunk in r.iter_content(chunk_size=8192): @@ -125,10 +180,16 @@ class ModelDownloader: temporary_filename.rename(save_path) except Exception as e: logger.error(f"Error downloading model: {str(e)}") + # If mirror failed, try official source as fallback + if download_url != url: + logger.info(f"Mirror download failed, trying official source: {url}") + return ModelDownloader.download_from_url( + url, save_path, silent, use_mirror=False + ) raise if not silent: - logger.info(f"Downloading complete: {url}") + logger.info(f"Downloading complete: {download_url}") return Path(save_path)