mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-09-23 17:52:05 +02:00
Add locking
This commit is contained in:
parent
4ab8de91a9
commit
7a02a448cb
@ -3,6 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
@ -330,8 +331,27 @@ class ZmqIpcRunner(BaseModelRunner):
|
||||
self._socket.setsockopt(zmq.LINGER, linger_ms)
|
||||
self._socket.connect(self._endpoint)
|
||||
self._model_ready = False
|
||||
self._io_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def is_complex_model(model_type: str) -> bool:
|
||||
# Import here to avoid circular imports
|
||||
from frigate.detectors.detector_config import ModelTypeEnum
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
|
||||
return model_type in [
|
||||
ModelTypeEnum.yolonas.value,
|
||||
EnrichmentModelTypeEnum.paddleocr.value,
|
||||
EnrichmentModelTypeEnum.jina_v1.value,
|
||||
EnrichmentModelTypeEnum.jina_v2.value,
|
||||
]
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
if "vision" in self.model_name:
|
||||
return ["pixel_values"]
|
||||
elif "arcface" in self.model_name:
|
||||
return ["data"]
|
||||
else:
|
||||
return ["input"]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
@ -373,6 +393,7 @@ class ZmqIpcRunner(BaseModelRunner):
|
||||
header = self._build_header(tensor)
|
||||
payload = memoryview(tensor.tobytes(order="C"))
|
||||
try:
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart([header, payload])
|
||||
frames = self._socket.recv_multipart()
|
||||
except zmq.Again as e:
|
||||
@ -392,6 +413,7 @@ class ZmqIpcRunner(BaseModelRunner):
|
||||
"""
|
||||
# Check model availability
|
||||
req = {"model_request": True, "model_name": self.model_name}
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
|
||||
|
||||
# Temporarily extend timeout for model ops
|
||||
@ -425,7 +447,10 @@ class ZmqIpcRunner(BaseModelRunner):
|
||||
return False
|
||||
|
||||
header = {"model_data": True, "model_name": self.model_name}
|
||||
self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes])
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart(
|
||||
[json.dumps(header).encode("utf-8"), model_bytes]
|
||||
)
|
||||
|
||||
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
try:
|
||||
@ -575,6 +600,10 @@ def get_optimized_runner(
|
||||
if rknn_path:
|
||||
return RKNNModelRunner(rknn_path)
|
||||
|
||||
if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type):
|
||||
logger.info(f"Using ZMQ detector model {model_path}")
|
||||
return ZmqIpcRunner(model_path, model_type, **kwargs)
|
||||
|
||||
providers, options = get_ort_providers(device == "CPU", device, **kwargs)
|
||||
|
||||
if providers[0] == "CPUExecutionProvider":
|
||||
|
Loading…
Reference in New Issue
Block a user