mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-09-23 17:52:05 +02:00
Move ZMQ detector to onnx runner
This commit is contained in:
parent
fbcf64d7bd
commit
4ab8de91a9
@ -1,5 +1,6 @@
|
||||
"""Base runner implementation for ONNX models."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
@ -7,7 +8,9 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import zmq
|
||||
|
||||
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
||||
from frigate.util.model import get_ort_providers
|
||||
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
||||
|
||||
@ -301,6 +304,163 @@ class OpenVINOModelRunner(BaseModelRunner):
|
||||
return outputs
|
||||
|
||||
|
||||
class ZmqIpcRunner(BaseModelRunner):
|
||||
"""Runner that forwards inference over ZMQ REQ/ROUTER to backend workers.
|
||||
|
||||
This allows reusing the same interface as local runners while delegating
|
||||
inference to the external ZMQ workers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
model_type: str,
|
||||
request_timeout_ms: int = 200,
|
||||
linger_ms: int = 0,
|
||||
endpoint: str = REQ_ROUTER_ENDPOINT,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.model_path = model_path
|
||||
self.model_name = os.path.basename(model_path)
|
||||
self._endpoint = endpoint
|
||||
self._context = zmq.Context()
|
||||
self._socket = self._context.socket(zmq.REQ)
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, request_timeout_ms)
|
||||
self._socket.setsockopt(zmq.SNDTIMEO, request_timeout_ms)
|
||||
self._socket.setsockopt(zmq.LINGER, linger_ms)
|
||||
self._socket.connect(self._endpoint)
|
||||
self._model_ready = False
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
return ["input"]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
# Not known/required for ZMQ forwarding
|
||||
return -1
|
||||
|
||||
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
||||
header: dict[str, object] = {
|
||||
"shape": list(tensor_input.shape),
|
||||
"dtype": str(tensor_input.dtype.name),
|
||||
"model_type": self.model_type,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return json.dumps(header).encode("utf-8")
|
||||
|
||||
def _decode_response(self, frames: list[bytes]) -> np.ndarray:
|
||||
if len(frames) == 1:
|
||||
buf = frames[0]
|
||||
if len(buf) != 20 * 6 * 4:
|
||||
raise ValueError(f"Unexpected payload size: {len(buf)}")
|
||||
return np.frombuffer(buf, dtype=np.float32).reshape((20, 6))
|
||||
|
||||
if len(frames) >= 2:
|
||||
header = json.loads(frames[0].decode("utf-8"))
|
||||
shape = tuple(header.get("shape", []))
|
||||
dtype = np.dtype(header.get("dtype", "float32"))
|
||||
return np.frombuffer(frames[1], dtype=dtype).reshape(shape)
|
||||
|
||||
raise ValueError("Empty or malformed reply from ZMQ detector")
|
||||
|
||||
def run(self, input: dict[str, np.ndarray]) -> np.ndarray | None:
|
||||
if not self._model_ready:
|
||||
if not self.ensure_model_ready(self.model_path):
|
||||
raise TimeoutError("ZMQ detector model is not ready after transfer")
|
||||
self._model_ready = True
|
||||
|
||||
input_name = next(iter(input))
|
||||
tensor = input[input_name]
|
||||
header = self._build_header(tensor)
|
||||
payload = memoryview(tensor.tobytes(order="C"))
|
||||
try:
|
||||
self._socket.send_multipart([header, payload])
|
||||
frames = self._socket.recv_multipart()
|
||||
except zmq.Again as e:
|
||||
raise TimeoutError("ZMQ detector request timed out") from e
|
||||
except zmq.ZMQError as e:
|
||||
raise RuntimeError(f"ZMQ error: {e}") from e
|
||||
|
||||
return self._decode_response(frames)
|
||||
|
||||
def ensure_model_ready(self, model_path: str) -> bool:
|
||||
"""Ensure the remote has the model and it is loaded.
|
||||
|
||||
1) Send model_request with model_name
|
||||
2) If not available, send model_data with the file contents
|
||||
3) Wait for loaded confirmation
|
||||
Returns True on success.
|
||||
"""
|
||||
# Check model availability
|
||||
req = {"model_request": True, "model_name": self.model_name}
|
||||
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
|
||||
|
||||
# Temporarily extend timeout for model ops
|
||||
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv or 0)))
|
||||
resp_frames = self._socket.recv_multipart()
|
||||
except zmq.Again:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if len(resp_frames) != 1:
|
||||
return False
|
||||
resp = json.loads(resp_frames[0].decode("utf-8"))
|
||||
if resp.get("model_available") and resp.get("model_loaded"):
|
||||
logger.info(f"ZMQ detector model {self.model_name} is ready")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(model_path, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
header = {"model_data": True, "model_name": self.model_name}
|
||||
self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes])
|
||||
|
||||
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv2 or 0)))
|
||||
resp2 = self._socket.recv_multipart()
|
||||
except zmq.Again:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if len(resp2) != 1:
|
||||
return False
|
||||
j = json.loads(resp2[0].decode("utf-8"))
|
||||
return bool(j.get("model_saved") and j.get("model_loaded"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context is not None:
|
||||
self._context.term()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class RKNNModelRunner(BaseModelRunner):
|
||||
"""Run RKNN models for embeddings."""
|
||||
|
||||
|
@ -10,6 +10,7 @@ from typing_extensions import Literal
|
||||
|
||||
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detection_runners import ZmqIpcRunner
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -50,9 +51,7 @@ class ZmqIpcDetector(DetectionApi):
|
||||
On any error or timeout, this detector returns a zero array of shape (20, 6).
|
||||
|
||||
Model Management:
|
||||
- On initialization, sends model request to check if model is available
|
||||
- If model not available, sends model data via ZMQ
|
||||
- Only starts inference after model is ready
|
||||
- Model transfer/availability is handled by the runner automatically
|
||||
"""
|
||||
|
||||
type_key = DETECTOR_KEY
|
||||
@ -67,15 +66,17 @@ class ZmqIpcDetector(DetectionApi):
|
||||
self._socket = None
|
||||
self._create_socket()
|
||||
|
||||
# Model management
|
||||
self._model_ready = False
|
||||
self._model_name = self._get_model_name()
|
||||
|
||||
# Initialize model if needed
|
||||
self._initialize_model()
|
||||
|
||||
# Preallocate zero result for error paths
|
||||
self._zero_result = np.zeros((20, 6), np.float32)
|
||||
self._runner = ZmqIpcRunner(
|
||||
model_path=self.detector_config.model.path,
|
||||
model_type=str(self.detector_config.model.model_type.value),
|
||||
request_timeout_ms=self._request_timeout_ms,
|
||||
linger_ms=self._linger_ms,
|
||||
endpoint=self._endpoint,
|
||||
)
|
||||
|
||||
def _create_socket(self) -> None:
|
||||
if self._socket is not None:
|
||||
@ -97,162 +98,6 @@ class ZmqIpcDetector(DetectionApi):
|
||||
model_path = self.detector_config.model.path
|
||||
return os.path.basename(model_path)
|
||||
|
||||
def _initialize_model(self) -> None:
|
||||
"""Initialize the model by checking availability and transferring if needed."""
|
||||
try:
|
||||
logger.info(f"Initializing model: {self._model_name}")
|
||||
|
||||
# Check if model is available and transfer if needed
|
||||
if self._check_and_transfer_model():
|
||||
logger.info(f"Model {self._model_name} is ready")
|
||||
self._model_ready = True
|
||||
else:
|
||||
logger.error(f"Failed to initialize model {self._model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {e}")
|
||||
|
||||
def _check_and_transfer_model(self) -> bool:
|
||||
"""Check if model is available and transfer if needed in one atomic operation."""
|
||||
try:
|
||||
# Send model availability request
|
||||
header = {"model_request": True, "model_name": self._model_name}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
|
||||
self._socket.send_multipart([header_bytes])
|
||||
|
||||
# Temporarily increase timeout for model operations
|
||||
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
|
||||
try:
|
||||
response_frames = self._socket.recv_multipart()
|
||||
finally:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
|
||||
|
||||
if len(response_frames) == 1:
|
||||
try:
|
||||
response = json.loads(response_frames[0].decode("utf-8"))
|
||||
model_available = response.get("model_available", False)
|
||||
model_loaded = response.get("model_loaded", False)
|
||||
|
||||
if model_available and model_loaded:
|
||||
return True
|
||||
elif model_available and not model_loaded:
|
||||
logger.error("Model exists but failed to load")
|
||||
return False
|
||||
else:
|
||||
return self._send_model_data()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Received non-JSON response for model availability check"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.warning(
|
||||
"Received unexpected response format for model availability check"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check and transfer model: {e}")
|
||||
return False
|
||||
|
||||
def _check_model_availability(self) -> bool:
|
||||
"""Check if the model is available on the detector."""
|
||||
try:
|
||||
# Send model availability request
|
||||
header = {"model_request": True, "model_name": self._model_name}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
|
||||
self._socket.send_multipart([header_bytes])
|
||||
|
||||
# Receive response
|
||||
response_frames = self._socket.recv_multipart()
|
||||
|
||||
# Check if this is a JSON response (model management)
|
||||
if len(response_frames) == 1:
|
||||
try:
|
||||
response = json.loads(response_frames[0].decode("utf-8"))
|
||||
model_available = response.get("model_available", False)
|
||||
model_loaded = response.get("model_loaded", False)
|
||||
logger.debug(
|
||||
f"Model availability check: available={model_available}, loaded={model_loaded}"
|
||||
)
|
||||
return model_available and model_loaded
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Received non-JSON response for model availability check"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.warning(
|
||||
"Received unexpected response format for model availability check"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability: {e}")
|
||||
return False
|
||||
|
||||
def _send_model_data(self) -> bool:
|
||||
"""Send model data to the detector."""
|
||||
try:
|
||||
model_path = self.detector_config.model.path
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return False
|
||||
|
||||
logger.info(f"Transferring model to detector: {self._model_name}")
|
||||
with open(model_path, "rb") as f:
|
||||
model_data = f.read()
|
||||
|
||||
header = {"model_data": True, "model_name": self._model_name}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
|
||||
self._socket.send_multipart([header_bytes, model_data])
|
||||
|
||||
# Temporarily increase timeout for model loading (can take several seconds)
|
||||
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
|
||||
try:
|
||||
# Receive response
|
||||
response_frames = self._socket.recv_multipart()
|
||||
finally:
|
||||
# Restore original timeout
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
|
||||
|
||||
# Check if this is a JSON response (model management)
|
||||
if len(response_frames) == 1:
|
||||
try:
|
||||
response = json.loads(response_frames[0].decode("utf-8"))
|
||||
model_saved = response.get("model_saved", False)
|
||||
model_loaded = response.get("model_loaded", False)
|
||||
if model_saved and model_loaded:
|
||||
logger.info(
|
||||
f"Model {self._model_name} transferred and loaded successfully"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Model transfer failed: saved={model_saved}, loaded={model_loaded}"
|
||||
)
|
||||
return model_saved and model_loaded
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Received non-JSON response for model data transfer")
|
||||
return False
|
||||
else:
|
||||
logger.warning(
|
||||
"Received unexpected response format for model data transfer"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send model data: {e}")
|
||||
return False
|
||||
|
||||
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
||||
header: dict[str, Any] = {
|
||||
"shape": list(tensor_input.shape),
|
||||
@ -287,42 +132,11 @@ class ZmqIpcDetector(DetectionApi):
|
||||
return self._zero_result
|
||||
|
||||
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
|
||||
if not self._model_ready:
|
||||
logger.warning("Model not ready, returning zero detections")
|
||||
return self._zero_result
|
||||
|
||||
try:
|
||||
header_bytes = self._build_header(tensor_input)
|
||||
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
|
||||
|
||||
# Send request
|
||||
self._socket.send_multipart([header_bytes, payload_bytes])
|
||||
|
||||
# Receive reply
|
||||
reply_frames = self._socket.recv_multipart()
|
||||
detections = self._decode_response(reply_frames)
|
||||
|
||||
# Ensure output shape and dtype are exactly as expected
|
||||
return detections
|
||||
except zmq.Again:
|
||||
# Timeout
|
||||
logger.debug("ZMQ detector request timed out; resetting socket")
|
||||
try:
|
||||
self._create_socket()
|
||||
self._initialize_model()
|
||||
except Exception:
|
||||
pass
|
||||
return self._zero_result
|
||||
except zmq.ZMQError as exc:
|
||||
logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket")
|
||||
try:
|
||||
self._create_socket()
|
||||
self._initialize_model()
|
||||
except Exception:
|
||||
pass
|
||||
return self._zero_result
|
||||
result = self._runner.run({"input": tensor_input})
|
||||
return result if isinstance(result, np.ndarray) else self._zero_result
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"ZMQ detector unexpected error: {exc}")
|
||||
logger.error(f"ZMQ IPC runner error: {exc}")
|
||||
return self._zero_result
|
||||
|
||||
def __del__(self) -> None: # pragma: no cover - best-effort cleanup
|
||||
|
Loading…
Reference in New Issue
Block a user