From 41376f127300544646137283341ad382827d7e8f Mon Sep 17 00:00:00 2001 From: ZhaiSoul <842607283@qq.com> Date: Tue, 29 Jul 2025 09:08:40 +0000 Subject: [PATCH] feat: add mirror download source --- frigate/util/downloader.py | 105 +++++++++++++++++++++++++++++++++++-- 1 file changed, 100 insertions(+), 5 deletions(-) diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py index 49b05dd05..e8a2dfcc6 100644 --- a/frigate/util/downloader.py +++ b/frigate/util/downloader.py @@ -1,9 +1,11 @@ import logging import os +import socket import threading import time from pathlib import Path -from typing import Callable, List +from typing import Callable, List, Optional, Tuple +from urllib.parse import urlparse import requests @@ -13,6 +15,86 @@ from frigate.types import ModelStatusTypesEnum logger = logging.getLogger(__name__) +# Mirror mapping configuration +MIRROR_MAPPING = { + "raw.githubusercontent.com": "raw.gh.mirror.frigate-cn.video", + "github.com": "gh.mirror.frigate-cn.video", + "huggingface.co": "hf.mirror.frigate-cn.video", +} + +# Global flag to force using official sources only +FORCE_OFFICIAL_SOURCE = False + + +def measure_latency(host: str, port: int = 80, timeout: float = 2.0) -> Optional[float]: + """ + Measure the latency to a host using a TCP connection. + Returns the latency in milliseconds or None if the connection failed. + """ + try: + start_time = time.time() + with socket.create_connection((host, port), timeout=timeout): + end_time = time.time() + return (end_time - start_time) * 1000 # Convert to milliseconds + except (socket.timeout, socket.error): + return None + + +def set_force_official_source(force: bool = True) -> None: + """ + Set the global flag to force using only official sources. + """ + global FORCE_OFFICIAL_SOURCE + FORCE_OFFICIAL_SOURCE = force + if force: + logger.info("Forced to use official sources only") + + +def get_best_mirror(url: str, latency_threshold: int = 20) -> Tuple[str, bool]: + """ + Determine the best URL to use based on latency measurements. + Returns a tuple of (url_to_use, is_mirror). + """ + if FORCE_OFFICIAL_SOURCE: + return url, False + + parsed_url = urlparse(url) + host = parsed_url.netloc + + # Check if this URL has a mirror + mirror_host = MIRROR_MAPPING.get(host) + if not mirror_host: + return url, False + + # Measure latency to both hosts + official_latency = measure_latency(host) + mirror_latency = measure_latency(mirror_host) + + # Log latency information + if official_latency is not None and mirror_latency is not None: + logger.info( + f"Latency - Official: {official_latency:.2f}ms, Mirror: {mirror_latency:.2f}ms" + ) + + # Determine which URL to use + use_mirror = False + if official_latency is None: + # Official site unreachable, try mirror + use_mirror = mirror_latency is not None + elif mirror_latency is None: + # Mirror unreachable, use official + use_mirror = False + else: + # Both reachable, compare latency + use_mirror = mirror_latency < (official_latency - latency_threshold) + + if use_mirror: + mirror_url = url.replace(host, mirror_host) + logger.info(f"Using mirror: {mirror_url}") + return mirror_url, True + + return url, False + class FileLock: def __init__(self, path): @@ -106,17 +188,24 @@ class ModelDownloader: self.download_complete.set() @staticmethod - def download_from_url(url: str, save_path: str, silent: bool = False) -> Path: + def download_from_url( + url: str, save_path: str, silent: bool = False, use_mirror: bool = True + ) -> Path: temporary_filename = Path(save_path).with_name( os.path.basename(save_path) + ".part" ) temporary_filename.parent.mkdir(parents=True, exist_ok=True) + # Check if we should use a mirror + download_url = url + if use_mirror: + download_url, is_mirror = get_best_mirror(url) + if not silent: - logger.info(f"Downloading model file from: {url}") + logger.info(f"Downloading model file from: {download_url}") try: - with requests.get(url, stream=True, allow_redirects=True) as r: + with requests.get(download_url, stream=True, allow_redirects=True) as r: r.raise_for_status() with open(temporary_filename, "wb") as f: for chunk in r.iter_content(chunk_size=8192): @@ -125,10 +214,16 @@ class ModelDownloader: temporary_filename.rename(save_path) except Exception as e: logger.error(f"Error downloading model: {str(e)}") + # If mirror failed, try official source as fallback + if download_url != url: + logger.info(f"Mirror download failed, trying official source: {url}") + return ModelDownloader.download_from_url( + url, save_path, silent, use_mirror=False + ) raise if not silent: - logger.info(f"Downloading complete: {url}") + logger.info(f"Downloading complete: {download_url}") return Path(save_path)