This commit is contained in:
GuoQing Liu 2025-09-04 18:26:58 +08:00 committed by GitHub
commit 83efa2b45a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

71
frigate/util/downloader.py Normal file → Executable file
View File

@ -3,7 +3,8 @@ import os
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Callable, List from typing import Callable, List, Tuple
from urllib.parse import urlparse
import requests import requests
@ -13,6 +14,53 @@ from frigate.types import ModelStatusTypesEnum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Mirror mapping configuration
MIRROR_MAPPING = {
"raw.githubusercontent.com": "raw.gh.mirror.frigate-cn.video",
"github.com": "gh.mirror.frigate-cn.video",
"huggingface.co": "hf.mirror.frigate-cn.video",
}
# Global flag to force using official sources only
FORCE_OFFICIAL_SOURCE = False
def set_force_official_source(force: bool = True) -> None:
"""
Set the global flag to force using only official sources.
"""
global FORCE_OFFICIAL_SOURCE
FORCE_OFFICIAL_SOURCE = force
if force:
logger.info("Forced to use official sources only")
def get_best_mirror(url: str, latency_threshold: int = 20) -> Tuple[str, bool]:
"""
Determine whether to use mirror based on environment variable USE_MIRROR_SOURCE.
Returns a tuple of (url_to_use, is_mirror).
"""
if FORCE_OFFICIAL_SOURCE:
return url, False
parsed_url = urlparse(url)
host = parsed_url.netloc
# Check if this URL has a mirror
mirror_host = MIRROR_MAPPING.get(host)
if not mirror_host:
return url, False
# Check environment variable to determine if mirror should be used
use_mirror = os.environ.get("USE_MIRROR_SOURCE", "").lower() == "true"
if use_mirror:
mirror_url = url.replace(host, mirror_host)
logger.info(f"Using mirror: {mirror_url}")
return mirror_url, True
return url, False
class FileLock: class FileLock:
def __init__(self, path): def __init__(self, path):
@ -106,17 +154,24 @@ class ModelDownloader:
self.download_complete.set() self.download_complete.set()
@staticmethod @staticmethod
def download_from_url(url: str, save_path: str, silent: bool = False) -> Path: def download_from_url(
url: str, save_path: str, silent: bool = False, use_mirror: bool = True
) -> Path:
temporary_filename = Path(save_path).with_name( temporary_filename = Path(save_path).with_name(
os.path.basename(save_path) + ".part" os.path.basename(save_path) + ".part"
) )
temporary_filename.parent.mkdir(parents=True, exist_ok=True) temporary_filename.parent.mkdir(parents=True, exist_ok=True)
# Check if we should use a mirror
download_url = url
if use_mirror:
download_url, is_mirror = get_best_mirror(url)
if not silent: if not silent:
logger.info(f"Downloading model file from: {url}") logger.info(f"Downloading model file from: {download_url}")
try: try:
with requests.get(url, stream=True, allow_redirects=True) as r: with requests.get(download_url, stream=True, allow_redirects=True) as r:
r.raise_for_status() r.raise_for_status()
with open(temporary_filename, "wb") as f: with open(temporary_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192): for chunk in r.iter_content(chunk_size=8192):
@ -125,10 +180,16 @@ class ModelDownloader:
temporary_filename.rename(save_path) temporary_filename.rename(save_path)
except Exception as e: except Exception as e:
logger.error(f"Error downloading model: {str(e)}") logger.error(f"Error downloading model: {str(e)}")
# If mirror failed, try official source as fallback
if download_url != url:
logger.info(f"Mirror download failed, trying official source: {url}")
return ModelDownloader.download_from_url(
url, save_path, silent, use_mirror=False
)
raise raise
if not silent: if not silent:
logger.info(f"Downloading complete: {url}") logger.info(f"Downloading complete: {download_url}")
return Path(save_path) return Path(save_path)