From 318457113b3d18a988687d2ffd1bc740511a3ffe Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 22 Sep 2025 06:02:55 -0600 Subject: [PATCH] Add ability to transfer model via ZMQ Detector (#20161) * Add ability to transfer model via ZMQ * Cleanup --- frigate/detectors/plugins/zmq_ipc.py | 181 ++++++++++++++++++++++++++- 1 file changed, 180 insertions(+), 1 deletion(-) diff --git a/frigate/detectors/plugins/zmq_ipc.py b/frigate/detectors/plugins/zmq_ipc.py index aa7113c72..cd397aefa 100644 --- a/frigate/detectors/plugins/zmq_ipc.py +++ b/frigate/detectors/plugins/zmq_ipc.py @@ -1,5 +1,6 @@ import json import logging +import os from typing import Any, List import numpy as np @@ -46,6 +47,11 @@ class ZmqIpcDetector(DetectionApi): b) Single frame tensor_bytes of length 20*6*4 bytes (float32). On any error or timeout, this detector returns a zero array of shape (20, 6). + + Model Management: + - On initialization, sends model request to check if model is available + - If model not available, sends model data via ZMQ + - Only starts inference after model is ready """ type_key = DETECTOR_KEY @@ -60,6 +66,13 @@ class ZmqIpcDetector(DetectionApi): self._socket = None self._create_socket() + # Model management + self._model_ready = False + self._model_name = self._get_model_name() + + # Initialize model if needed + self._initialize_model() + # Preallocate zero result for error paths self._zero_result = np.zeros((20, 6), np.float32) @@ -78,6 +91,167 @@ class ZmqIpcDetector(DetectionApi): logger.debug(f"ZMQ detector connecting to {self._endpoint}") self._socket.connect(self._endpoint) + def _get_model_name(self) -> str: + """Get the model filename from the detector config.""" + model_path = self.detector_config.model.path + return os.path.basename(model_path) + + def _initialize_model(self) -> None: + """Initialize the model by checking availability and transferring if needed.""" + try: + logger.info(f"Initializing model: {self._model_name}") + + # Check if model is available and transfer if needed + if self._check_and_transfer_model(): + logger.info(f"Model {self._model_name} is ready") + self._model_ready = True + else: + logger.error(f"Failed to initialize model {self._model_name}") + + except Exception as e: + logger.error(f"Failed to initialize model: {e}") + + def _check_and_transfer_model(self) -> bool: + """Check if model is available and transfer if needed in one atomic operation.""" + try: + # Send model availability request + header = {"model_request": True, "model_name": self._model_name} + header_bytes = json.dumps(header).encode("utf-8") + + self._socket.send_multipart([header_bytes]) + + # Temporarily increase timeout for model operations + original_timeout = self._socket.getsockopt(zmq.RCVTIMEO) + self._socket.setsockopt(zmq.RCVTIMEO, 30000) + + try: + response_frames = self._socket.recv_multipart() + finally: + self._socket.setsockopt(zmq.RCVTIMEO, original_timeout) + + if len(response_frames) == 1: + try: + response = json.loads(response_frames[0].decode("utf-8")) + model_available = response.get("model_available", False) + model_loaded = response.get("model_loaded", False) + + if model_available and model_loaded: + return True + elif model_available and not model_loaded: + logger.error("Model exists but failed to load") + return False + else: + return self._send_model_data() + + except json.JSONDecodeError: + logger.warning( + "Received non-JSON response for model availability check" + ) + return False + else: + logger.warning( + "Received unexpected response format for model availability check" + ) + return False + + except Exception as e: + logger.error(f"Failed to check and transfer model: {e}") + return False + + def _check_model_availability(self) -> bool: + """Check if the model is available on the detector.""" + try: + # Send model availability request + header = {"model_request": True, "model_name": self._model_name} + header_bytes = json.dumps(header).encode("utf-8") + + self._socket.send_multipart([header_bytes]) + + # Receive response + response_frames = self._socket.recv_multipart() + + # Check if this is a JSON response (model management) + if len(response_frames) == 1: + try: + response = json.loads(response_frames[0].decode("utf-8")) + model_available = response.get("model_available", False) + model_loaded = response.get("model_loaded", False) + logger.debug( + f"Model availability check: available={model_available}, loaded={model_loaded}" + ) + return model_available and model_loaded + except json.JSONDecodeError: + logger.warning( + "Received non-JSON response for model availability check" + ) + return False + else: + logger.warning( + "Received unexpected response format for model availability check" + ) + return False + + except Exception as e: + logger.error(f"Failed to check model availability: {e}") + return False + + def _send_model_data(self) -> bool: + """Send model data to the detector.""" + try: + model_path = self.detector_config.model.path + + if not os.path.exists(model_path): + logger.error(f"Model file not found: {model_path}") + return False + + logger.info(f"Transferring model to detector: {self._model_name}") + with open(model_path, "rb") as f: + model_data = f.read() + + header = {"model_data": True, "model_name": self._model_name} + header_bytes = json.dumps(header).encode("utf-8") + + self._socket.send_multipart([header_bytes, model_data]) + + # Temporarily increase timeout for model loading (can take several seconds) + original_timeout = self._socket.getsockopt(zmq.RCVTIMEO) + self._socket.setsockopt(zmq.RCVTIMEO, 30000) + + try: + # Receive response + response_frames = self._socket.recv_multipart() + finally: + # Restore original timeout + self._socket.setsockopt(zmq.RCVTIMEO, original_timeout) + + # Check if this is a JSON response (model management) + if len(response_frames) == 1: + try: + response = json.loads(response_frames[0].decode("utf-8")) + model_saved = response.get("model_saved", False) + model_loaded = response.get("model_loaded", False) + if model_saved and model_loaded: + logger.info( + f"Model {self._model_name} transferred and loaded successfully" + ) + else: + logger.error( + f"Model transfer failed: saved={model_saved}, loaded={model_loaded}" + ) + return model_saved and model_loaded + except json.JSONDecodeError: + logger.warning("Received non-JSON response for model data transfer") + return False + else: + logger.warning( + "Received unexpected response format for model data transfer" + ) + return False + + except Exception as e: + logger.error(f"Failed to send model data: {e}") + return False + def _build_header(self, tensor_input: np.ndarray) -> bytes: header: dict[str, Any] = { "shape": list(tensor_input.shape), @@ -111,6 +285,10 @@ class ZmqIpcDetector(DetectionApi): return self._zero_result def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray: + if not self._model_ready: + logger.warning("Model not ready, returning zero detections") + return self._zero_result + try: header_bytes = self._build_header(tensor_input) payload_bytes = memoryview(tensor_input.tobytes(order="C")) @@ -123,13 +301,13 @@ class ZmqIpcDetector(DetectionApi): detections = self._decode_response(reply_frames) # Ensure output shape and dtype are exactly as expected - return detections except zmq.Again: # Timeout logger.debug("ZMQ detector request timed out; resetting socket") try: self._create_socket() + self._initialize_model() except Exception: pass return self._zero_result @@ -137,6 +315,7 @@ class ZmqIpcDetector(DetectionApi): logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket") try: self._create_socket() + self._initialize_model() except Exception: pass return self._zero_result