Add locking

This commit is contained in:
Nicolas Mowen 2025-09-22 13:59:30 -06:00
parent 4ab8de91a9
commit 7a02a448cb

View File

@ -3,6 +3,7 @@
import json
import logging
import os
import threading
from abc import ABC, abstractmethod
from typing import Any
@ -330,8 +331,27 @@ class ZmqIpcRunner(BaseModelRunner):
self._socket.setsockopt(zmq.LINGER, linger_ms)
self._socket.connect(self._endpoint)
self._model_ready = False
self._io_lock = threading.Lock()
@staticmethod
def is_complex_model(model_type: str) -> bool:
# Import here to avoid circular imports
from frigate.detectors.detector_config import ModelTypeEnum
from frigate.embeddings.types import EnrichmentModelTypeEnum
return model_type in [
ModelTypeEnum.yolonas.value,
EnrichmentModelTypeEnum.paddleocr.value,
EnrichmentModelTypeEnum.jina_v1.value,
EnrichmentModelTypeEnum.jina_v2.value,
]
def get_input_names(self) -> list[str]:
if "vision" in self.model_name:
return ["pixel_values"]
elif "arcface" in self.model_name:
return ["data"]
else:
return ["input"]
def get_input_width(self) -> int:
@ -373,6 +393,7 @@ class ZmqIpcRunner(BaseModelRunner):
header = self._build_header(tensor)
payload = memoryview(tensor.tobytes(order="C"))
try:
with self._io_lock:
self._socket.send_multipart([header, payload])
frames = self._socket.recv_multipart()
except zmq.Again as e:
@ -392,6 +413,7 @@ class ZmqIpcRunner(BaseModelRunner):
"""
# Check model availability
req = {"model_request": True, "model_name": self.model_name}
with self._io_lock:
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
# Temporarily extend timeout for model ops
@ -425,7 +447,10 @@ class ZmqIpcRunner(BaseModelRunner):
return False
header = {"model_data": True, "model_name": self.model_name}
self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes])
with self._io_lock:
self._socket.send_multipart(
[json.dumps(header).encode("utf-8"), model_bytes]
)
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
try:
@ -575,6 +600,10 @@ def get_optimized_runner(
if rknn_path:
return RKNNModelRunner(rknn_path)
if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type):
logger.info(f"Using ZMQ detector model {model_path}")
return ZmqIpcRunner(model_path, model_type, **kwargs)
providers, options = get_ort_providers(device == "CPU", device, **kwargs)
if providers[0] == "CPUExecutionProvider":