Add locking

This commit is contained in:
Nicolas Mowen 2025-09-22 13:59:30 -06:00
parent 4ab8de91a9
commit 7a02a448cb

View File

@ -3,6 +3,7 @@
import json import json
import logging import logging
import os import os
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
@ -330,9 +331,28 @@ class ZmqIpcRunner(BaseModelRunner):
self._socket.setsockopt(zmq.LINGER, linger_ms) self._socket.setsockopt(zmq.LINGER, linger_ms)
self._socket.connect(self._endpoint) self._socket.connect(self._endpoint)
self._model_ready = False self._model_ready = False
self._io_lock = threading.Lock()
@staticmethod
def is_complex_model(model_type: str) -> bool:
# Import here to avoid circular imports
from frigate.detectors.detector_config import ModelTypeEnum
from frigate.embeddings.types import EnrichmentModelTypeEnum
return model_type in [
ModelTypeEnum.yolonas.value,
EnrichmentModelTypeEnum.paddleocr.value,
EnrichmentModelTypeEnum.jina_v1.value,
EnrichmentModelTypeEnum.jina_v2.value,
]
def get_input_names(self) -> list[str]: def get_input_names(self) -> list[str]:
return ["input"] if "vision" in self.model_name:
return ["pixel_values"]
elif "arcface" in self.model_name:
return ["data"]
else:
return ["input"]
def get_input_width(self) -> int: def get_input_width(self) -> int:
# Not known/required for ZMQ forwarding # Not known/required for ZMQ forwarding
@ -373,8 +393,9 @@ class ZmqIpcRunner(BaseModelRunner):
header = self._build_header(tensor) header = self._build_header(tensor)
payload = memoryview(tensor.tobytes(order="C")) payload = memoryview(tensor.tobytes(order="C"))
try: try:
self._socket.send_multipart([header, payload]) with self._io_lock:
frames = self._socket.recv_multipart() self._socket.send_multipart([header, payload])
frames = self._socket.recv_multipart()
except zmq.Again as e: except zmq.Again as e:
raise TimeoutError("ZMQ detector request timed out") from e raise TimeoutError("ZMQ detector request timed out") from e
except zmq.ZMQError as e: except zmq.ZMQError as e:
@ -392,7 +413,8 @@ class ZmqIpcRunner(BaseModelRunner):
""" """
# Check model availability # Check model availability
req = {"model_request": True, "model_name": self.model_name} req = {"model_request": True, "model_name": self.model_name}
self._socket.send_multipart([json.dumps(req).encode("utf-8")]) with self._io_lock:
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
# Temporarily extend timeout for model ops # Temporarily extend timeout for model ops
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO) original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
@ -425,7 +447,10 @@ class ZmqIpcRunner(BaseModelRunner):
return False return False
header = {"model_data": True, "model_name": self.model_name} header = {"model_data": True, "model_name": self.model_name}
self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes]) with self._io_lock:
self._socket.send_multipart(
[json.dumps(header).encode("utf-8"), model_bytes]
)
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO) original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
try: try:
@ -575,6 +600,10 @@ def get_optimized_runner(
if rknn_path: if rknn_path:
return RKNNModelRunner(rknn_path) return RKNNModelRunner(rknn_path)
if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type):
logger.info(f"Using ZMQ detector model {model_path}")
return ZmqIpcRunner(model_path, model_type, **kwargs)
providers, options = get_ort_providers(device == "CPU", device, **kwargs) providers, options = get_ort_providers(device == "CPU", device, **kwargs)
if providers[0] == "CPUExecutionProvider": if providers[0] == "CPUExecutionProvider":