This commit is contained in:
GuoQing Liu 2025-07-29 20:40:03 +02:00 committed by GitHub
commit b86b89638e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,11 @@
import logging
import os
import socket
import threading
import time
from pathlib import Path
from typing import Callable, List
from typing import Callable, List, Optional, Tuple
from urllib.parse import urlparse
import requests
@ -13,6 +15,86 @@ from frigate.types import ModelStatusTypesEnum
logger = logging.getLogger(__name__)
# Mirror mapping configuration
MIRROR_MAPPING = {
"raw.githubusercontent.com": "raw.gh.mirror.frigate-cn.video",
"github.com": "gh.mirror.frigate-cn.video",
"huggingface.co": "hf.mirror.frigate-cn.video",
}
# Global flag to force using official sources only
FORCE_OFFICIAL_SOURCE = False
def measure_latency(host: str, port: int = 80, timeout: float = 2.0) -> Optional[float]:
"""
Measure the latency to a host using a TCP connection.
Returns the latency in milliseconds or None if the connection failed.
"""
try:
start_time = time.time()
with socket.create_connection((host, port), timeout=timeout):
end_time = time.time()
return (end_time - start_time) * 1000 # Convert to milliseconds
except (socket.timeout, socket.error):
return None
def set_force_official_source(force: bool = True) -> None:
"""
Set the global flag to force using only official sources.
"""
global FORCE_OFFICIAL_SOURCE
FORCE_OFFICIAL_SOURCE = force
if force:
logger.info("Forced to use official sources only")
def get_best_mirror(url: str, latency_threshold: int = 20) -> Tuple[str, bool]:
"""
Determine the best URL to use based on latency measurements.
Returns a tuple of (url_to_use, is_mirror).
"""
if FORCE_OFFICIAL_SOURCE:
return url, False
parsed_url = urlparse(url)
host = parsed_url.netloc
# Check if this URL has a mirror
mirror_host = MIRROR_MAPPING.get(host)
if not mirror_host:
return url, False
# Measure latency to both hosts
official_latency = measure_latency(host)
mirror_latency = measure_latency(mirror_host)
# Log latency information
if official_latency is not None and mirror_latency is not None:
logger.info(
f"Latency - Official: {official_latency:.2f}ms, Mirror: {mirror_latency:.2f}ms"
)
# Determine which URL to use
use_mirror = False
if official_latency is None:
# Official site unreachable, try mirror
use_mirror = mirror_latency is not None
elif mirror_latency is None:
# Mirror unreachable, use official
use_mirror = False
else:
# Both reachable, compare latency
use_mirror = mirror_latency < (official_latency - latency_threshold)
if use_mirror:
mirror_url = url.replace(host, mirror_host)
logger.info(f"Using mirror: {mirror_url}")
return mirror_url, True
return url, False
class FileLock:
def __init__(self, path):
@ -106,17 +188,24 @@ class ModelDownloader:
self.download_complete.set()
@staticmethod
def download_from_url(url: str, save_path: str, silent: bool = False) -> Path:
def download_from_url(
url: str, save_path: str, silent: bool = False, use_mirror: bool = True
) -> Path:
temporary_filename = Path(save_path).with_name(
os.path.basename(save_path) + ".part"
)
temporary_filename.parent.mkdir(parents=True, exist_ok=True)
# Check if we should use a mirror
download_url = url
if use_mirror:
download_url, is_mirror = get_best_mirror(url)
if not silent:
logger.info(f"Downloading model file from: {url}")
logger.info(f"Downloading model file from: {download_url}")
try:
with requests.get(url, stream=True, allow_redirects=True) as r:
with requests.get(download_url, stream=True, allow_redirects=True) as r:
r.raise_for_status()
with open(temporary_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
@ -125,10 +214,16 @@ class ModelDownloader:
temporary_filename.rename(save_path)
except Exception as e:
logger.error(f"Error downloading model: {str(e)}")
# If mirror failed, try official source as fallback
if download_url != url:
logger.info(f"Mirror download failed, trying official source: {url}")
return ModelDownloader.download_from_url(
url, save_path, silent, use_mirror=False
)
raise
if not silent:
logger.info(f"Downloading complete: {url}")
logger.info(f"Downloading complete: {download_url}")
return Path(save_path)