mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-08-04 13:47:37 +02:00
download models at runtime
This commit is contained in:
parent
a154a13f98
commit
2304457557
@ -294,18 +294,8 @@ install_and_hold() { \
|
||||
}; \
|
||||
install_and_hold memx-accl'
|
||||
|
||||
# Set the working directory for model files
|
||||
WORKDIR /memryx_models
|
||||
|
||||
# Download and extract MemryX model files
|
||||
RUN wget -O yolox.zip https://developer.memryx.com/example_files/1p2_frigate/yolox.zip && unzip yolox.zip && \
|
||||
wget -O ssdlite.zip https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip && unzip ssdlite.zip && \
|
||||
wget -O yolo_nas.zip https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip && unzip yolo_nas.zip && \
|
||||
wget -O yolov9.zip https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip && unzip yolov9.zip && \
|
||||
wget -O yolov8.zip https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip && unzip yolov8.zip
|
||||
|
||||
# Set permissions for the models directory
|
||||
RUN chmod -R 755 /memryx_models
|
||||
# Copy the 81-class COCO label map
|
||||
RUN wget -O coco_81class_labelmap.txt https://developer.memryx.com/example_files/1p2_frigate/labelmap.txt
|
||||
|
||||
###### End MemryX setup
|
||||
|
||||
|
@ -39,8 +39,6 @@ class ModelTypeEnum(str, Enum):
|
||||
rfdetr = "rfdetr"
|
||||
ssd = "ssd"
|
||||
yolox = "yolox"
|
||||
yolov9 = "yolov9"
|
||||
yolov8 = "yolov8"
|
||||
yolonas = "yolonas"
|
||||
yologeneric = "yolo-generic"
|
||||
|
||||
|
@ -25,7 +25,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DETECTOR_KEY = "memryx"
|
||||
|
||||
|
||||
# Configuration class for model settings
|
||||
class ModelConfig(BaseModel):
|
||||
path: str = Field(default=None, title="Model Path") # Path to the DFP file
|
||||
@ -38,12 +37,11 @@ class MemryXDetectorConfig(BaseDetectorConfig):
|
||||
|
||||
|
||||
class MemryXDetector(DetectionApi):
|
||||
type_key = DETECTOR_KEY # Set the type key
|
||||
type_key = DETECTOR_KEY # Set the type key
|
||||
supported_models = [
|
||||
ModelTypeEnum.ssd,
|
||||
ModelTypeEnum.yolonas,
|
||||
ModelTypeEnum.yolov9,
|
||||
ModelTypeEnum.yolov8,
|
||||
ModelTypeEnum.yologeneric,
|
||||
ModelTypeEnum.yolox,
|
||||
]
|
||||
|
||||
@ -63,48 +61,49 @@ class MemryXDetector(DetectionApi):
|
||||
self.memx_model_width = detector_config.model.width
|
||||
self.memx_model_type = detector_config.model.model_type
|
||||
|
||||
# If it's yologeneric, treat it as yolov9
|
||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||
self.memx_model_type = ModelTypeEnum.yolov9
|
||||
|
||||
self.cache_dir = "/memryx_models"
|
||||
|
||||
if self.memx_model_type == ModelTypeEnum.yolov9:
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/yolov9.zip"
|
||||
)
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.yolov8:
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/yolov8.zip"
|
||||
)
|
||||
|
||||
if self.memx_model_type in [ModelTypeEnum.yolov8, ModelTypeEnum.yolov9]:
|
||||
# Shared constants for both yolov8 and yolov9 post-processing
|
||||
self.const_A = np.load(
|
||||
"/memryx_models/yolov9/_model_22_Constant_9_output_0.npy"
|
||||
)
|
||||
self.const_B = np.load(
|
||||
"/memryx_models/yolov9/_model_22_Constant_10_output_0.npy"
|
||||
)
|
||||
self.const_C = np.load(
|
||||
"/memryx_models/yolov9/_model_22_Constant_12_output_0.npy"
|
||||
self.expected_dfp_model = (
|
||||
"YOLO_v9_small_640_640_3_onnx.dfp"
|
||||
)
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/yolo_nas.zip"
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/yolonas.zip"
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"yolo_nas_s.dfp"
|
||||
)
|
||||
self.expected_post_model = (
|
||||
"yolo_nas_s_post.onnx"
|
||||
)
|
||||
self.expected_post_model = "yolo_nas/yolo_nas_s_post.onnx"
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.yolox:
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/yolox.zip"
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"YOLOX_640_640_3_onnx.dfp"
|
||||
)
|
||||
self.set_strides_grids()
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.ssd:
|
||||
self.model_url = (
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/ssdlite.zip"
|
||||
"https://developer.memryx.com/example_files/1p2_frigate/ssd.zip"
|
||||
)
|
||||
self.expected_dfp_model = (
|
||||
"SSDlite_MobileNet_v2_320_320_3_onnx.dfp"
|
||||
)
|
||||
self.expected_post_model = (
|
||||
"ssdlite/SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
|
||||
"SSDlite_MobileNet_v2_320_320_3_onnx_post.onnx"
|
||||
)
|
||||
|
||||
self.check_and_prepare_model()
|
||||
@ -136,53 +135,76 @@ class MemryXDetector(DetectionApi):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize MemryX model: {e}")
|
||||
raise
|
||||
|
||||
def load_yolo_constants(self):
|
||||
base = f"{self.cache_dir}/{self.memx_model_type.value}"
|
||||
# constants for yolov9 post-processing
|
||||
self.const_A = np.load(
|
||||
f"{base}/_model_22_Constant_9_output_0.npy"
|
||||
)
|
||||
self.const_B = np.load(
|
||||
f"{base}/_model_22_Constant_10_output_0.npy"
|
||||
)
|
||||
self.const_C = np.load(
|
||||
f"{base}/_model_22_Constant_12_output_0.npy"
|
||||
)
|
||||
|
||||
def check_and_prepare_model(self):
|
||||
"""Check if both models exist; if not, download and extract them."""
|
||||
"""Check if models exist; if not, download and extract them."""
|
||||
if not os.path.exists(self.cache_dir):
|
||||
os.makedirs(self.cache_dir)
|
||||
|
||||
if not self.expected_post_model:
|
||||
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
||||
dfp_exists = os.path.exists(self.memx_model_path) if self.memx_model_path else False
|
||||
post_exists = os.path.exists(self.expected_post_model) if self.expected_post_model else True # ok if no post model
|
||||
|
||||
else:
|
||||
post_model_file_path = os.path.join(
|
||||
self.cache_dir, self.expected_post_model
|
||||
if dfp_exists and post_exists:
|
||||
logger.info(
|
||||
f"Using cached models."
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Model files not found. Downloading from {self.model_url}..."
|
||||
)
|
||||
zip_path = os.path.join(self.cache_dir, f"{self.memx_model_type.value}.zip")
|
||||
|
||||
try:
|
||||
# Download the ZIP file
|
||||
urllib.request.urlretrieve(self.model_url, zip_path)
|
||||
logger.info(
|
||||
f"Model ZIP downloaded to {zip_path}. Extracting..."
|
||||
)
|
||||
|
||||
# Check if both post model file exist
|
||||
if os.path.isfile(post_model_file_path):
|
||||
self.memx_post_model = post_model_file_path
|
||||
logger.info(
|
||||
f"Post-processing model found at {post_model_file_path}, skipping download."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Model files not found. Downloading from {self.model_url}..."
|
||||
)
|
||||
zip_path = os.path.join(self.cache_dir, "memryx_model.zip")
|
||||
# Extract ZIP file
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(self.cache_dir)
|
||||
logger.info(
|
||||
f"Model extracted to {self.cache_dir}."
|
||||
)
|
||||
|
||||
# Download the ZIP file
|
||||
urllib.request.urlretrieve(self.model_url, zip_path)
|
||||
logger.info(f"Model ZIP downloaded to {zip_path}. Extracting...")
|
||||
# Determine the subfolder to search in
|
||||
model_subdir = os.path.join(self.cache_dir, self.memx_model_type.value)
|
||||
|
||||
# Assign extracted files to correct paths
|
||||
for file in os.listdir(model_subdir):
|
||||
file_path = os.path.join(model_subdir, file)
|
||||
if file == self.expected_dfp_model:
|
||||
self.memx_model_path = file_path
|
||||
elif file == self.expected_post_model:
|
||||
self.memx_post_model = file_path
|
||||
|
||||
# Extract ZIP file
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(self.cache_dir)
|
||||
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
||||
logger.info(f"Assigned Post-processing Model Path: {self.memx_post_model}")
|
||||
|
||||
logger.info(f"Model extracted to {self.cache_dir}.")
|
||||
if self.memx_model_type == ModelTypeEnum.yolov9:
|
||||
self.load_yolo_constants()
|
||||
|
||||
# Assign extracted files to correct paths
|
||||
for file in os.listdir(self.cache_dir):
|
||||
if file == self.expected_post_model:
|
||||
self.memx_post_model = os.path.join(self.cache_dir, file)
|
||||
|
||||
logger.info(f"Assigned Model Path: {self.memx_model_path}")
|
||||
logger.info(
|
||||
f"Assigned Post-processing Model Path: {self.memx_post_model}"
|
||||
)
|
||||
|
||||
# Cleanup: Remove the ZIP file after extraction
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to prepare model: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup: Remove the ZIP file after extraction
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
logger.info("Cleaned up ZIP file after extraction.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user