mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-09-23 17:52:05 +02:00
Move ZMQ detector to onnx runner
This commit is contained in:
parent
fbcf64d7bd
commit
4ab8de91a9
@ -1,5 +1,6 @@
|
|||||||
"""Base runner implementation for ONNX models."""
|
"""Base runner implementation for ONNX models."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -7,7 +8,9 @@ from typing import Any
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
||||||
from frigate.util.model import get_ort_providers
|
from frigate.util.model import get_ort_providers
|
||||||
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
||||||
|
|
||||||
@ -301,6 +304,163 @@ class OpenVINOModelRunner(BaseModelRunner):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqIpcRunner(BaseModelRunner):
|
||||||
|
"""Runner that forwards inference over ZMQ REQ/ROUTER to backend workers.
|
||||||
|
|
||||||
|
This allows reusing the same interface as local runners while delegating
|
||||||
|
inference to the external ZMQ workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
model_type: str,
|
||||||
|
request_timeout_ms: int = 200,
|
||||||
|
linger_ms: int = 0,
|
||||||
|
endpoint: str = REQ_ROUTER_ENDPOINT,
|
||||||
|
):
|
||||||
|
self.model_type = model_type
|
||||||
|
self.model_path = model_path
|
||||||
|
self.model_name = os.path.basename(model_path)
|
||||||
|
self._endpoint = endpoint
|
||||||
|
self._context = zmq.Context()
|
||||||
|
self._socket = self._context.socket(zmq.REQ)
|
||||||
|
self._socket.setsockopt(zmq.RCVTIMEO, request_timeout_ms)
|
||||||
|
self._socket.setsockopt(zmq.SNDTIMEO, request_timeout_ms)
|
||||||
|
self._socket.setsockopt(zmq.LINGER, linger_ms)
|
||||||
|
self._socket.connect(self._endpoint)
|
||||||
|
self._model_ready = False
|
||||||
|
|
||||||
|
def get_input_names(self) -> list[str]:
|
||||||
|
return ["input"]
|
||||||
|
|
||||||
|
def get_input_width(self) -> int:
|
||||||
|
# Not known/required for ZMQ forwarding
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
||||||
|
header: dict[str, object] = {
|
||||||
|
"shape": list(tensor_input.shape),
|
||||||
|
"dtype": str(tensor_input.dtype.name),
|
||||||
|
"model_type": self.model_type,
|
||||||
|
"model_name": self.model_name,
|
||||||
|
}
|
||||||
|
return json.dumps(header).encode("utf-8")
|
||||||
|
|
||||||
|
def _decode_response(self, frames: list[bytes]) -> np.ndarray:
|
||||||
|
if len(frames) == 1:
|
||||||
|
buf = frames[0]
|
||||||
|
if len(buf) != 20 * 6 * 4:
|
||||||
|
raise ValueError(f"Unexpected payload size: {len(buf)}")
|
||||||
|
return np.frombuffer(buf, dtype=np.float32).reshape((20, 6))
|
||||||
|
|
||||||
|
if len(frames) >= 2:
|
||||||
|
header = json.loads(frames[0].decode("utf-8"))
|
||||||
|
shape = tuple(header.get("shape", []))
|
||||||
|
dtype = np.dtype(header.get("dtype", "float32"))
|
||||||
|
return np.frombuffer(frames[1], dtype=dtype).reshape(shape)
|
||||||
|
|
||||||
|
raise ValueError("Empty or malformed reply from ZMQ detector")
|
||||||
|
|
||||||
|
def run(self, input: dict[str, np.ndarray]) -> np.ndarray | None:
|
||||||
|
if not self._model_ready:
|
||||||
|
if not self.ensure_model_ready(self.model_path):
|
||||||
|
raise TimeoutError("ZMQ detector model is not ready after transfer")
|
||||||
|
self._model_ready = True
|
||||||
|
|
||||||
|
input_name = next(iter(input))
|
||||||
|
tensor = input[input_name]
|
||||||
|
header = self._build_header(tensor)
|
||||||
|
payload = memoryview(tensor.tobytes(order="C"))
|
||||||
|
try:
|
||||||
|
self._socket.send_multipart([header, payload])
|
||||||
|
frames = self._socket.recv_multipart()
|
||||||
|
except zmq.Again as e:
|
||||||
|
raise TimeoutError("ZMQ detector request timed out") from e
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
raise RuntimeError(f"ZMQ error: {e}") from e
|
||||||
|
|
||||||
|
return self._decode_response(frames)
|
||||||
|
|
||||||
|
def ensure_model_ready(self, model_path: str) -> bool:
|
||||||
|
"""Ensure the remote has the model and it is loaded.
|
||||||
|
|
||||||
|
1) Send model_request with model_name
|
||||||
|
2) If not available, send model_data with the file contents
|
||||||
|
3) Wait for loaded confirmation
|
||||||
|
Returns True on success.
|
||||||
|
"""
|
||||||
|
# Check model availability
|
||||||
|
req = {"model_request": True, "model_name": self.model_name}
|
||||||
|
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
|
||||||
|
|
||||||
|
# Temporarily extend timeout for model ops
|
||||||
|
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||||
|
try:
|
||||||
|
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv or 0)))
|
||||||
|
resp_frames = self._socket.recv_multipart()
|
||||||
|
except zmq.Again:
|
||||||
|
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if len(resp_frames) != 1:
|
||||||
|
return False
|
||||||
|
resp = json.loads(resp_frames[0].decode("utf-8"))
|
||||||
|
if resp.get("model_available") and resp.get("model_loaded"):
|
||||||
|
logger.info(f"ZMQ detector model {self.model_name} is ready")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(model_path, "rb") as f:
|
||||||
|
model_bytes = f.read()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
header = {"model_data": True, "model_name": self.model_name}
|
||||||
|
self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes])
|
||||||
|
|
||||||
|
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||||
|
try:
|
||||||
|
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv2 or 0)))
|
||||||
|
resp2 = self._socket.recv_multipart()
|
||||||
|
except zmq.Again:
|
||||||
|
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if len(resp2) != 1:
|
||||||
|
return False
|
||||||
|
j = json.loads(resp2[0].decode("utf-8"))
|
||||||
|
return bool(j.get("model_saved") and j.get("model_loaded"))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
try:
|
||||||
|
if self._socket is not None:
|
||||||
|
self._socket.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if self._context is not None:
|
||||||
|
self._context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RKNNModelRunner(BaseModelRunner):
|
class RKNNModelRunner(BaseModelRunner):
|
||||||
"""Run RKNN models for embeddings."""
|
"""Run RKNN models for embeddings."""
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from typing_extensions import Literal
|
|||||||
|
|
||||||
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from frigate.detectors.detection_api import DetectionApi
|
||||||
|
from frigate.detectors.detection_runners import ZmqIpcRunner
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -50,9 +51,7 @@ class ZmqIpcDetector(DetectionApi):
|
|||||||
On any error or timeout, this detector returns a zero array of shape (20, 6).
|
On any error or timeout, this detector returns a zero array of shape (20, 6).
|
||||||
|
|
||||||
Model Management:
|
Model Management:
|
||||||
- On initialization, sends model request to check if model is available
|
- Model transfer/availability is handled by the runner automatically
|
||||||
- If model not available, sends model data via ZMQ
|
|
||||||
- Only starts inference after model is ready
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type_key = DETECTOR_KEY
|
type_key = DETECTOR_KEY
|
||||||
@ -67,15 +66,17 @@ class ZmqIpcDetector(DetectionApi):
|
|||||||
self._socket = None
|
self._socket = None
|
||||||
self._create_socket()
|
self._create_socket()
|
||||||
|
|
||||||
# Model management
|
|
||||||
self._model_ready = False
|
|
||||||
self._model_name = self._get_model_name()
|
self._model_name = self._get_model_name()
|
||||||
|
|
||||||
# Initialize model if needed
|
|
||||||
self._initialize_model()
|
|
||||||
|
|
||||||
# Preallocate zero result for error paths
|
# Preallocate zero result for error paths
|
||||||
self._zero_result = np.zeros((20, 6), np.float32)
|
self._zero_result = np.zeros((20, 6), np.float32)
|
||||||
|
self._runner = ZmqIpcRunner(
|
||||||
|
model_path=self.detector_config.model.path,
|
||||||
|
model_type=str(self.detector_config.model.model_type.value),
|
||||||
|
request_timeout_ms=self._request_timeout_ms,
|
||||||
|
linger_ms=self._linger_ms,
|
||||||
|
endpoint=self._endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
def _create_socket(self) -> None:
|
def _create_socket(self) -> None:
|
||||||
if self._socket is not None:
|
if self._socket is not None:
|
||||||
@ -97,162 +98,6 @@ class ZmqIpcDetector(DetectionApi):
|
|||||||
model_path = self.detector_config.model.path
|
model_path = self.detector_config.model.path
|
||||||
return os.path.basename(model_path)
|
return os.path.basename(model_path)
|
||||||
|
|
||||||
def _initialize_model(self) -> None:
|
|
||||||
"""Initialize the model by checking availability and transferring if needed."""
|
|
||||||
try:
|
|
||||||
logger.info(f"Initializing model: {self._model_name}")
|
|
||||||
|
|
||||||
# Check if model is available and transfer if needed
|
|
||||||
if self._check_and_transfer_model():
|
|
||||||
logger.info(f"Model {self._model_name} is ready")
|
|
||||||
self._model_ready = True
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to initialize model {self._model_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize model: {e}")
|
|
||||||
|
|
||||||
def _check_and_transfer_model(self) -> bool:
|
|
||||||
"""Check if model is available and transfer if needed in one atomic operation."""
|
|
||||||
try:
|
|
||||||
# Send model availability request
|
|
||||||
header = {"model_request": True, "model_name": self._model_name}
|
|
||||||
header_bytes = json.dumps(header).encode("utf-8")
|
|
||||||
|
|
||||||
self._socket.send_multipart([header_bytes])
|
|
||||||
|
|
||||||
# Temporarily increase timeout for model operations
|
|
||||||
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
|
|
||||||
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response_frames = self._socket.recv_multipart()
|
|
||||||
finally:
|
|
||||||
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
|
|
||||||
|
|
||||||
if len(response_frames) == 1:
|
|
||||||
try:
|
|
||||||
response = json.loads(response_frames[0].decode("utf-8"))
|
|
||||||
model_available = response.get("model_available", False)
|
|
||||||
model_loaded = response.get("model_loaded", False)
|
|
||||||
|
|
||||||
if model_available and model_loaded:
|
|
||||||
return True
|
|
||||||
elif model_available and not model_loaded:
|
|
||||||
logger.error("Model exists but failed to load")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return self._send_model_data()
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(
|
|
||||||
"Received non-JSON response for model availability check"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Received unexpected response format for model availability check"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to check and transfer model: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _check_model_availability(self) -> bool:
|
|
||||||
"""Check if the model is available on the detector."""
|
|
||||||
try:
|
|
||||||
# Send model availability request
|
|
||||||
header = {"model_request": True, "model_name": self._model_name}
|
|
||||||
header_bytes = json.dumps(header).encode("utf-8")
|
|
||||||
|
|
||||||
self._socket.send_multipart([header_bytes])
|
|
||||||
|
|
||||||
# Receive response
|
|
||||||
response_frames = self._socket.recv_multipart()
|
|
||||||
|
|
||||||
# Check if this is a JSON response (model management)
|
|
||||||
if len(response_frames) == 1:
|
|
||||||
try:
|
|
||||||
response = json.loads(response_frames[0].decode("utf-8"))
|
|
||||||
model_available = response.get("model_available", False)
|
|
||||||
model_loaded = response.get("model_loaded", False)
|
|
||||||
logger.debug(
|
|
||||||
f"Model availability check: available={model_available}, loaded={model_loaded}"
|
|
||||||
)
|
|
||||||
return model_available and model_loaded
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(
|
|
||||||
"Received non-JSON response for model availability check"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Received unexpected response format for model availability check"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to check model availability: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _send_model_data(self) -> bool:
|
|
||||||
"""Send model data to the detector."""
|
|
||||||
try:
|
|
||||||
model_path = self.detector_config.model.path
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
logger.error(f"Model file not found: {model_path}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.info(f"Transferring model to detector: {self._model_name}")
|
|
||||||
with open(model_path, "rb") as f:
|
|
||||||
model_data = f.read()
|
|
||||||
|
|
||||||
header = {"model_data": True, "model_name": self._model_name}
|
|
||||||
header_bytes = json.dumps(header).encode("utf-8")
|
|
||||||
|
|
||||||
self._socket.send_multipart([header_bytes, model_data])
|
|
||||||
|
|
||||||
# Temporarily increase timeout for model loading (can take several seconds)
|
|
||||||
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
|
|
||||||
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Receive response
|
|
||||||
response_frames = self._socket.recv_multipart()
|
|
||||||
finally:
|
|
||||||
# Restore original timeout
|
|
||||||
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
|
|
||||||
|
|
||||||
# Check if this is a JSON response (model management)
|
|
||||||
if len(response_frames) == 1:
|
|
||||||
try:
|
|
||||||
response = json.loads(response_frames[0].decode("utf-8"))
|
|
||||||
model_saved = response.get("model_saved", False)
|
|
||||||
model_loaded = response.get("model_loaded", False)
|
|
||||||
if model_saved and model_loaded:
|
|
||||||
logger.info(
|
|
||||||
f"Model {self._model_name} transferred and loaded successfully"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Model transfer failed: saved={model_saved}, loaded={model_loaded}"
|
|
||||||
)
|
|
||||||
return model_saved and model_loaded
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning("Received non-JSON response for model data transfer")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Received unexpected response format for model data transfer"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to send model data: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
||||||
header: dict[str, Any] = {
|
header: dict[str, Any] = {
|
||||||
"shape": list(tensor_input.shape),
|
"shape": list(tensor_input.shape),
|
||||||
@ -287,42 +132,11 @@ class ZmqIpcDetector(DetectionApi):
|
|||||||
return self._zero_result
|
return self._zero_result
|
||||||
|
|
||||||
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
|
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
|
||||||
if not self._model_ready:
|
|
||||||
logger.warning("Model not ready, returning zero detections")
|
|
||||||
return self._zero_result
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header_bytes = self._build_header(tensor_input)
|
result = self._runner.run({"input": tensor_input})
|
||||||
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
|
return result if isinstance(result, np.ndarray) else self._zero_result
|
||||||
|
|
||||||
# Send request
|
|
||||||
self._socket.send_multipart([header_bytes, payload_bytes])
|
|
||||||
|
|
||||||
# Receive reply
|
|
||||||
reply_frames = self._socket.recv_multipart()
|
|
||||||
detections = self._decode_response(reply_frames)
|
|
||||||
|
|
||||||
# Ensure output shape and dtype are exactly as expected
|
|
||||||
return detections
|
|
||||||
except zmq.Again:
|
|
||||||
# Timeout
|
|
||||||
logger.debug("ZMQ detector request timed out; resetting socket")
|
|
||||||
try:
|
|
||||||
self._create_socket()
|
|
||||||
self._initialize_model()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return self._zero_result
|
|
||||||
except zmq.ZMQError as exc:
|
|
||||||
logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket")
|
|
||||||
try:
|
|
||||||
self._create_socket()
|
|
||||||
self._initialize_model()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return self._zero_result
|
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.error(f"ZMQ detector unexpected error: {exc}")
|
logger.error(f"ZMQ IPC runner error: {exc}")
|
||||||
return self._zero_result
|
return self._zero_result
|
||||||
|
|
||||||
def __del__(self) -> None: # pragma: no cover - best-effort cleanup
|
def __del__(self) -> None: # pragma: no cover - best-effort cleanup
|
||||||
|
Loading…
Reference in New Issue
Block a user