mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-09-23 17:52:05 +02:00
Add locking
This commit is contained in:
parent
4ab8de91a9
commit
7a02a448cb
@ -3,6 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -330,9 +331,28 @@ class ZmqIpcRunner(BaseModelRunner):
|
|||||||
self._socket.setsockopt(zmq.LINGER, linger_ms)
|
self._socket.setsockopt(zmq.LINGER, linger_ms)
|
||||||
self._socket.connect(self._endpoint)
|
self._socket.connect(self._endpoint)
|
||||||
self._model_ready = False
|
self._model_ready = False
|
||||||
|
self._io_lock = threading.Lock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_complex_model(model_type: str) -> bool:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from frigate.detectors.detector_config import ModelTypeEnum
|
||||||
|
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||||
|
|
||||||
|
return model_type in [
|
||||||
|
ModelTypeEnum.yolonas.value,
|
||||||
|
EnrichmentModelTypeEnum.paddleocr.value,
|
||||||
|
EnrichmentModelTypeEnum.jina_v1.value,
|
||||||
|
EnrichmentModelTypeEnum.jina_v2.value,
|
||||||
|
]
|
||||||
|
|
||||||
def get_input_names(self) -> list[str]:
|
def get_input_names(self) -> list[str]:
|
||||||
return ["input"]
|
if "vision" in self.model_name:
|
||||||
|
return ["pixel_values"]
|
||||||
|
elif "arcface" in self.model_name:
|
||||||
|
return ["data"]
|
||||||
|
else:
|
||||||
|
return ["input"]
|
||||||
|
|
||||||
def get_input_width(self) -> int:
|
def get_input_width(self) -> int:
|
||||||
# Not known/required for ZMQ forwarding
|
# Not known/required for ZMQ forwarding
|
||||||
@ -373,8 +393,9 @@ class ZmqIpcRunner(BaseModelRunner):
|
|||||||
header = self._build_header(tensor)
|
header = self._build_header(tensor)
|
||||||
payload = memoryview(tensor.tobytes(order="C"))
|
payload = memoryview(tensor.tobytes(order="C"))
|
||||||
try:
|
try:
|
||||||
self._socket.send_multipart([header, payload])
|
with self._io_lock:
|
||||||
frames = self._socket.recv_multipart()
|
self._socket.send_multipart([header, payload])
|
||||||
|
frames = self._socket.recv_multipart()
|
||||||
except zmq.Again as e:
|
except zmq.Again as e:
|
||||||
raise TimeoutError("ZMQ detector request timed out") from e
|
raise TimeoutError("ZMQ detector request timed out") from e
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
@ -392,7 +413,8 @@ class ZmqIpcRunner(BaseModelRunner):
|
|||||||
"""
|
"""
|
||||||
# Check model availability
|
# Check model availability
|
||||||
req = {"model_request": True, "model_name": self.model_name}
|
req = {"model_request": True, "model_name": self.model_name}
|
||||||
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
|
with self._io_lock:
|
||||||
|
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
|
||||||
|
|
||||||
# Temporarily extend timeout for model ops
|
# Temporarily extend timeout for model ops
|
||||||
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
|
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||||
@ -425,7 +447,10 @@ class ZmqIpcRunner(BaseModelRunner):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
header = {"model_data": True, "model_name": self.model_name}
|
header = {"model_data": True, "model_name": self.model_name}
|
||||||
self._socket.send_multipart([json.dumps(header).encode("utf-8"), model_bytes])
|
with self._io_lock:
|
||||||
|
self._socket.send_multipart(
|
||||||
|
[json.dumps(header).encode("utf-8"), model_bytes]
|
||||||
|
)
|
||||||
|
|
||||||
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
|
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||||
try:
|
try:
|
||||||
@ -575,6 +600,10 @@ def get_optimized_runner(
|
|||||||
if rknn_path:
|
if rknn_path:
|
||||||
return RKNNModelRunner(rknn_path)
|
return RKNNModelRunner(rknn_path)
|
||||||
|
|
||||||
|
if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type):
|
||||||
|
logger.info(f"Using ZMQ detector model {model_path}")
|
||||||
|
return ZmqIpcRunner(model_path, model_type, **kwargs)
|
||||||
|
|
||||||
providers, options = get_ort_providers(device == "CPU", device, **kwargs)
|
providers, options = get_ort_providers(device == "CPU", device, **kwargs)
|
||||||
|
|
||||||
if providers[0] == "CPUExecutionProvider":
|
if providers[0] == "CPUExecutionProvider":
|
||||||
|
Loading…
Reference in New Issue
Block a user