Move ZMQ detector to onnx runner

This commit is contained in:
Nicolas Mowen 2025-09-22 12:55:36 -06:00
parent fbcf64d7bd
commit 4ab8de91a9
2 changed files with 172 additions and 198 deletions

View File

@ -1,5 +1,6 @@
"""Base runner implementation for ONNX models."""
import json
import logging
import os
from abc import ABC, abstractmethod
@ -7,7 +8,9 @@ from typing import Any
import numpy as np
import onnxruntime as ort
import zmq
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
from frigate.util.model import get_ort_providers
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
@ -301,6 +304,163 @@ class OpenVINOModelRunner(BaseModelRunner):
return outputs
class ZmqIpcRunner(BaseModelRunner):
"""Runner that forwards inference over ZMQ REQ/ROUTER to backend workers.
This allows reusing the same interface as local runners while delegating
inference to the external ZMQ workers.
"""
def __init__(
self,
model_path: str,
model_type: str,
request_timeout_ms: int = 200,
linger_ms: int = 0,
endpoint: str = REQ_ROUTER_ENDPOINT,
):
self.model_type = model_type
self.model_path = model_path
self.model_name = os.path.basename(model_path)
self._endpoint = endpoint
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.setsockopt(zmq.RCVTIMEO, request_timeout_ms)
self._socket.setsockopt(zmq.SNDTIMEO, request_timeout_ms)
self._socket.setsockopt(zmq.LINGER, linger_ms)
self._socket.connect(self._endpoint)
self._model_ready = False
def get_input_names(self) -> list[str]:
return ["input"]
def get_input_width(self) -> int:
# Not known/required for ZMQ forwarding
return -1
def _build_header(self, tensor_input: np.ndarray) -> bytes:
header: dict[str, object] = {
"shape": list(tensor_input.shape),
"dtype": str(tensor_input.dtype.name),
"model_type": self.model_type,
"model_name": self.model_name,
}
return json.dumps(header).encode("utf-8")
def _decode_response(self, frames: list[bytes]) -> np.ndarray:
if len(frames) == 1:
buf = frames[0]
if len(buf) != 20 * 6 * 4:
raise ValueError(f"Unexpected payload size: {len(buf)}")
return np.frombuffer(buf, dtype=np.float32).reshape((20, 6))
if len(frames) >= 2:
header = json.loads(frames[0].decode("utf-8"))
shape = tuple(header.get("shape", []))
dtype = np.dtype(header.get("dtype", "float32"))
return np.frombuffer(frames[1], dtype=dtype).reshape(shape)
raise ValueError("Empty or malformed reply from ZMQ detector")
def run(self, input: dict[str, np.ndarray]) -> np.ndarray | None:
if not self._model_ready:
if not self.ensure_model_ready(self.model_path):
raise TimeoutError("ZMQ detector model is not ready after transfer")
self._model_ready = True
input_name = next(iter(input))
tensor = input[input_name]
header = self._build_header(tensor)
payload = memoryview(tensor.tobytes(order="C"))
try:
self._socket.send_multipart([header, payload])
frames = self._socket.recv_multipart()
except zmq.Again as e:
raise TimeoutError("ZMQ detector request timed out") from e
except zmq.ZMQError as e:
raise RuntimeError(f"ZMQ error: {e}") from e
return self._decode_response(frames)
def ensure_model_ready(self, model_path: str) -> bool:
"""Ensure the remote has the model and it is loaded.
1) Send model_request with model_name
2) If not available, send model_data with the file contents
3) Wait for loaded confirmation
Returns True on success.
"""
# Check model availability
req = {"model_request": True, "model_name": self.model_name}
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
# Temporarily extend timeout for model ops
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
try:
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv or 0)))
resp_frames = self._socket.recv_multipart()
except zmq.Again:
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
return False
finally:
try:
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
except Exception:
pass
try:
if len(resp_frames) != 1:
return False
resp = json.loads(resp_frames[0].decode("utf-8"))
if resp.get("model_available") and resp.get("model_loaded"):
logger.info(f"ZMQ detector model {self.model_name} is ready")
return True
except Exception:
return False
try:
with open(model_path, "rb") as f:
model_bytes = f.read()
except Exception:
return False
header = {"model_data": True, "model_name": self.model_name}
self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes])
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
try:
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv2 or 0)))
resp2 = self._socket.recv_multipart()
except zmq.Again:
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
return False
finally:
try:
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
except Exception:
pass
try:
if len(resp2) != 1:
return False
j = json.loads(resp2[0].decode("utf-8"))
return bool(j.get("model_saved") and j.get("model_loaded"))
except Exception:
return False
def __del__(self) -> None:
try:
if self._socket is not None:
self._socket.close()
except Exception:
pass
try:
if self._context is not None:
self._context.term()
except Exception:
pass
class RKNNModelRunner(BaseModelRunner):
"""Run RKNN models for embeddings."""

View File

@ -10,6 +10,7 @@ from typing_extensions import Literal
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detection_runners import ZmqIpcRunner
from frigate.detectors.detector_config import BaseDetectorConfig
logger = logging.getLogger(__name__)
@ -50,9 +51,7 @@ class ZmqIpcDetector(DetectionApi):
On any error or timeout, this detector returns a zero array of shape (20, 6).
Model Management:
- On initialization, sends model request to check if model is available
- If model not available, sends model data via ZMQ
- Only starts inference after model is ready
- Model transfer/availability is handled by the runner automatically
"""
type_key = DETECTOR_KEY
@ -67,15 +66,17 @@ class ZmqIpcDetector(DetectionApi):
self._socket = None
self._create_socket()
# Model management
self._model_ready = False
self._model_name = self._get_model_name()
# Initialize model if needed
self._initialize_model()
# Preallocate zero result for error paths
self._zero_result = np.zeros((20, 6), np.float32)
self._runner = ZmqIpcRunner(
model_path=self.detector_config.model.path,
model_type=str(self.detector_config.model.model_type.value),
request_timeout_ms=self._request_timeout_ms,
linger_ms=self._linger_ms,
endpoint=self._endpoint,
)
def _create_socket(self) -> None:
if self._socket is not None:
@ -97,162 +98,6 @@ class ZmqIpcDetector(DetectionApi):
model_path = self.detector_config.model.path
return os.path.basename(model_path)
def _initialize_model(self) -> None:
"""Initialize the model by checking availability and transferring if needed."""
try:
logger.info(f"Initializing model: {self._model_name}")
# Check if model is available and transfer if needed
if self._check_and_transfer_model():
logger.info(f"Model {self._model_name} is ready")
self._model_ready = True
else:
logger.error(f"Failed to initialize model {self._model_name}")
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
def _check_and_transfer_model(self) -> bool:
"""Check if model is available and transfer if needed in one atomic operation."""
try:
# Send model availability request
header = {"model_request": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes])
# Temporarily increase timeout for model operations
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
try:
response_frames = self._socket.recv_multipart()
finally:
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_available = response.get("model_available", False)
model_loaded = response.get("model_loaded", False)
if model_available and model_loaded:
return True
elif model_available and not model_loaded:
logger.error("Model exists but failed to load")
return False
else:
return self._send_model_data()
except json.JSONDecodeError:
logger.warning(
"Received non-JSON response for model availability check"
)
return False
else:
logger.warning(
"Received unexpected response format for model availability check"
)
return False
except Exception as e:
logger.error(f"Failed to check and transfer model: {e}")
return False
def _check_model_availability(self) -> bool:
"""Check if the model is available on the detector."""
try:
# Send model availability request
header = {"model_request": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes])
# Receive response
response_frames = self._socket.recv_multipart()
# Check if this is a JSON response (model management)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_available = response.get("model_available", False)
model_loaded = response.get("model_loaded", False)
logger.debug(
f"Model availability check: available={model_available}, loaded={model_loaded}"
)
return model_available and model_loaded
except json.JSONDecodeError:
logger.warning(
"Received non-JSON response for model availability check"
)
return False
else:
logger.warning(
"Received unexpected response format for model availability check"
)
return False
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
def _send_model_data(self) -> bool:
"""Send model data to the detector."""
try:
model_path = self.detector_config.model.path
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return False
logger.info(f"Transferring model to detector: {self._model_name}")
with open(model_path, "rb") as f:
model_data = f.read()
header = {"model_data": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes, model_data])
# Temporarily increase timeout for model loading (can take several seconds)
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
try:
# Receive response
response_frames = self._socket.recv_multipart()
finally:
# Restore original timeout
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
# Check if this is a JSON response (model management)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_saved = response.get("model_saved", False)
model_loaded = response.get("model_loaded", False)
if model_saved and model_loaded:
logger.info(
f"Model {self._model_name} transferred and loaded successfully"
)
else:
logger.error(
f"Model transfer failed: saved={model_saved}, loaded={model_loaded}"
)
return model_saved and model_loaded
except json.JSONDecodeError:
logger.warning("Received non-JSON response for model data transfer")
return False
else:
logger.warning(
"Received unexpected response format for model data transfer"
)
return False
except Exception as e:
logger.error(f"Failed to send model data: {e}")
return False
def _build_header(self, tensor_input: np.ndarray) -> bytes:
header: dict[str, Any] = {
"shape": list(tensor_input.shape),
@ -287,42 +132,11 @@ class ZmqIpcDetector(DetectionApi):
return self._zero_result
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
if not self._model_ready:
logger.warning("Model not ready, returning zero detections")
return self._zero_result
try:
header_bytes = self._build_header(tensor_input)
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
# Send request
self._socket.send_multipart([header_bytes, payload_bytes])
# Receive reply
reply_frames = self._socket.recv_multipart()
detections = self._decode_response(reply_frames)
# Ensure output shape and dtype are exactly as expected
return detections
except zmq.Again:
# Timeout
logger.debug("ZMQ detector request timed out; resetting socket")
try:
self._create_socket()
self._initialize_model()
except Exception:
pass
return self._zero_result
except zmq.ZMQError as exc:
logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket")
try:
self._create_socket()
self._initialize_model()
except Exception:
pass
return self._zero_result
result = self._runner.run({"input": tensor_input})
return result if isinstance(result, np.ndarray) else self._zero_result
except Exception as exc: # noqa: BLE001
logger.error(f"ZMQ detector unexpected error: {exc}")
logger.error(f"ZMQ IPC runner error: {exc}")
return self._zero_result
def __del__(self) -> None: # pragma: no cover - best-effort cleanup