Add locking

This commit is contained in:
Nicolas Mowen 2025-09-22 13:59:30 -06:00
parent 4ab8de91a9
commit 7a02a448cb

View File

@ -3,6 +3,7 @@
import json import json
import logging import logging
import os import os
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
@ -330,8 +331,27 @@ class ZmqIpcRunner(BaseModelRunner):
self._socket.setsockopt(zmq.LINGER, linger_ms) self._socket.setsockopt(zmq.LINGER, linger_ms)
self._socket.connect(self._endpoint) self._socket.connect(self._endpoint)
self._model_ready = False self._model_ready = False
self._io_lock = threading.Lock()
@staticmethod
def is_complex_model(model_type: str) -> bool:
# Import here to avoid circular imports
from frigate.detectors.detector_config import ModelTypeEnum
from frigate.embeddings.types import EnrichmentModelTypeEnum
return model_type in [
ModelTypeEnum.yolonas.value,
EnrichmentModelTypeEnum.paddleocr.value,
EnrichmentModelTypeEnum.jina_v1.value,
EnrichmentModelTypeEnum.jina_v2.value,
]
def get_input_names(self) -> list[str]: def get_input_names(self) -> list[str]:
if "vision" in self.model_name:
return ["pixel_values"]
elif "arcface" in self.model_name:
return ["data"]
else:
return ["input"] return ["input"]
def get_input_width(self) -> int: def get_input_width(self) -> int:
@ -373,6 +393,7 @@ class ZmqIpcRunner(BaseModelRunner):
header = self._build_header(tensor) header = self._build_header(tensor)
payload = memoryview(tensor.tobytes(order="C")) payload = memoryview(tensor.tobytes(order="C"))
try: try:
with self._io_lock:
self._socket.send_multipart([header, payload]) self._socket.send_multipart([header, payload])
frames = self._socket.recv_multipart() frames = self._socket.recv_multipart()
except zmq.Again as e: except zmq.Again as e:
@ -392,6 +413,7 @@ class ZmqIpcRunner(BaseModelRunner):
""" """
# Check model availability # Check model availability
req = {"model_request": True, "model_name": self.model_name} req = {"model_request": True, "model_name": self.model_name}
with self._io_lock:
self._socket.send_multipart([json.dumps(req).encode("utf-8")]) self._socket.send_multipart([json.dumps(req).encode("utf-8")])
# Temporarily extend timeout for model ops # Temporarily extend timeout for model ops
@ -425,7 +447,10 @@ class ZmqIpcRunner(BaseModelRunner):
return False return False
header = {"model_data": True, "model_name": self.model_name} header = {"model_data": True, "model_name": self.model_name}
self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes]) with self._io_lock:
self._socket.send_multipart(
[json.dumps(header).encode("utf-8"), model_bytes]
)
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO) original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
try: try:
@ -575,6 +600,10 @@ def get_optimized_runner(
if rknn_path: if rknn_path:
return RKNNModelRunner(rknn_path) return RKNNModelRunner(rknn_path)
if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type):
logger.info(f"Using ZMQ detector model {model_path}")
return ZmqIpcRunner(model_path, model_type, **kwargs)
providers, options = get_ort_providers(device == "CPU", device, **kwargs) providers, options = get_ort_providers(device == "CPU", device, **kwargs)
if providers[0] == "CPUExecutionProvider": if providers[0] == "CPUExecutionProvider":