diff --git a/frigate/detectors/detection_runners.py b/frigate/detectors/detection_runners.py index 9ad151398..9a9c1865b 100644 --- a/frigate/detectors/detection_runners.py +++ b/frigate/detectors/detection_runners.py @@ -3,6 +3,7 @@ import json import logging import os +import threading from abc import ABC, abstractmethod from typing import Any @@ -330,9 +331,28 @@ class ZmqIpcRunner(BaseModelRunner): self._socket.setsockopt(zmq.LINGER, linger_ms) self._socket.connect(self._endpoint) self._model_ready = False + self._io_lock = threading.Lock() + + @staticmethod + def is_complex_model(model_type: str) -> bool: + # Import here to avoid circular imports + from frigate.detectors.detector_config import ModelTypeEnum + from frigate.embeddings.types import EnrichmentModelTypeEnum + + return model_type in [ + ModelTypeEnum.yolonas.value, + EnrichmentModelTypeEnum.paddleocr.value, + EnrichmentModelTypeEnum.jina_v1.value, + EnrichmentModelTypeEnum.jina_v2.value, + ] def get_input_names(self) -> list[str]: - return ["input"] + if "vision" in self.model_name: + return ["pixel_values"] + elif "arcface" in self.model_name: + return ["data"] + else: + return ["input"] def get_input_width(self) -> int: # Not known/required for ZMQ forwarding @@ -373,8 +393,9 @@ class ZmqIpcRunner(BaseModelRunner): header = self._build_header(tensor) payload = memoryview(tensor.tobytes(order="C")) try: - self._socket.send_multipart([header, payload]) - frames = self._socket.recv_multipart() + with self._io_lock: + self._socket.send_multipart([header, payload]) + frames = self._socket.recv_multipart() except zmq.Again as e: raise TimeoutError("ZMQ detector request timed out") from e except zmq.ZMQError as e: @@ -392,7 +413,8 @@ class ZmqIpcRunner(BaseModelRunner): """ # Check model availability req = {"model_request": True, "model_name": self.model_name} - self._socket.send_multipart([json.dumps(req).encode("utf-8")]) + with self._io_lock: + self._socket.send_multipart([json.dumps(req).encode("utf-8")]) # Temporarily extend timeout for model ops original_rcv = self._socket.getsockopt(zmq.RCVTIMEO) @@ -425,7 +447,10 @@ class ZmqIpcRunner(BaseModelRunner): return False header = {"model_data": True, "model_name": self.model_name} - self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes]) + with self._io_lock: + self._socket.send_multipart( + [json.dumps(header).encode("utf-8"), model_bytes] + ) original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO) try: @@ -575,6 +600,10 @@ def get_optimized_runner( if rknn_path: return RKNNModelRunner(rknn_path) + if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type): + logger.info(f"Using ZMQ detector model {model_path}") + return ZmqIpcRunner(model_path, model_type, **kwargs) + providers, options = get_ort_providers(device == "CPU", device, **kwargs) if providers[0] == "CPUExecutionProvider":