mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-09-14 17:52:10 +02:00
Optimize OpenVINO and ONNX Model Runners (#20063)
* Use re-usable inference request to reduce CPU usage * Share tensor * Don't count performance * Create openvino runner class * Break apart onnx runner * Add specific note about inability to use CUDA graphs for some models * Adjust rknn to use RKNNRunner * Use optimized runner * Add support for non-complex models for CudaExecutionProvider * Use core mask for rknn * Correctly handle cuda input * Cleanup * Sort imports
This commit is contained in:
parent
41ed013cc4
commit
81d7c47129
323
frigate/detectors/detection_runners.py
Normal file
323
frigate/detectors/detection_runners.py
Normal file
@ -0,0 +1,323 @@
|
||||
"""Base runner implementation for ONNX models."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from frigate.util.model import get_ort_providers
|
||||
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import OpenVINO only when needed to avoid circular dependencies
|
||||
try:
|
||||
import openvino as ov
|
||||
except ImportError:
|
||||
ov = None
|
||||
|
||||
|
||||
class BaseModelRunner(ABC):
|
||||
"""Abstract base class for model runners."""
|
||||
|
||||
def __init__(self, model_path: str, device: str, **kwargs):
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
|
||||
@abstractmethod
|
||||
def get_input_names(self) -> list[str]:
|
||||
"""Get input names for the model."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_input_width(self) -> int:
|
||||
"""Get the input width of the model."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input: dict[str, Any]) -> Any | None:
|
||||
"""Run inference with the model."""
|
||||
pass
|
||||
|
||||
|
||||
class ONNXModelRunner(BaseModelRunner):
|
||||
"""Run ONNX models using ONNX Runtime."""
|
||||
|
||||
def __init__(self, ort: ort.InferenceSession):
|
||||
self.ort = ort
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
return [input.name for input in self.ort.get_inputs()]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
"""Get the input width of the model."""
|
||||
return self.ort.get_inputs()[0].shape[3]
|
||||
|
||||
def run(self, input: dict[str, Any]) -> Any | None:
|
||||
return self.ort.run(None, input)
|
||||
|
||||
|
||||
class CudaGraphRunner(BaseModelRunner):
|
||||
"""Encapsulates CUDA Graph capture and replay using ONNX Runtime IOBinding.
|
||||
|
||||
This runner assumes a single tensor input and binds all model outputs.
|
||||
|
||||
NOTE: CUDA Graphs limit supported model operations, so they are not usable
|
||||
for more complex models like CLIP or PaddleOCR.
|
||||
"""
|
||||
|
||||
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
|
||||
self._session = session
|
||||
self._cuda_device_id = cuda_device_id
|
||||
self._captured = False
|
||||
self._io_binding: ort.IOBinding | None = None
|
||||
self._input_name: str | None = None
|
||||
self._output_names: list[str] | None = None
|
||||
self._input_ortvalue: ort.OrtValue | None = None
|
||||
self._output_ortvalues: ort.OrtValue | None = None
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
"""Get input names for the model."""
|
||||
return [input.name for input in self._session.get_inputs()]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
"""Get the input width of the model."""
|
||||
return self._session.get_inputs()[0].shape[3]
|
||||
|
||||
def run(self, input: dict[str, Any]):
|
||||
# Extract the single tensor input (assuming one input)
|
||||
input_name = list(input.keys())[0]
|
||||
tensor_input = input[input_name]
|
||||
tensor_input = np.ascontiguousarray(tensor_input)
|
||||
|
||||
if not self._captured:
|
||||
# Prepare IOBinding with CUDA buffers and let ORT allocate outputs on device
|
||||
self._io_binding = self._session.io_binding()
|
||||
self._input_name = input_name
|
||||
self._output_names = [o.name for o in self._session.get_outputs()]
|
||||
|
||||
self._input_ortvalue = ort.OrtValue.ortvalue_from_numpy(
|
||||
tensor_input, "cuda", self._cuda_device_id
|
||||
)
|
||||
self._io_binding.bind_ortvalue_input(self._input_name, self._input_ortvalue)
|
||||
|
||||
for name in self._output_names:
|
||||
# Bind outputs to CUDA and allow ORT to allocate appropriately
|
||||
self._io_binding.bind_output(name, "cuda", self._cuda_device_id)
|
||||
|
||||
# First IOBinding run to allocate, execute, and capture CUDA Graph
|
||||
ro = ort.RunOptions()
|
||||
self._session.run_with_iobinding(self._io_binding, ro)
|
||||
self._captured = True
|
||||
return self._io_binding.copy_outputs_to_cpu()
|
||||
|
||||
# Replay using updated input, copy results to CPU
|
||||
self._input_ortvalue.update_inplace(tensor_input)
|
||||
ro = ort.RunOptions()
|
||||
self._session.run_with_iobinding(self._io_binding, ro)
|
||||
return self._io_binding.copy_outputs_to_cpu()
|
||||
|
||||
|
||||
class OpenVINOModelRunner(BaseModelRunner):
|
||||
"""OpenVINO model runner that handles inference efficiently."""
|
||||
|
||||
def __init__(self, model_path: str, device: str, **kwargs):
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError(f"OpenVINO model file {model_path} not found.")
|
||||
|
||||
if ov is None:
|
||||
raise ImportError(
|
||||
"OpenVINO is not available. Please install openvino package."
|
||||
)
|
||||
|
||||
self.ov_core = ov.Core()
|
||||
|
||||
# Apply performance optimization
|
||||
self.ov_core.set_property(device, {"PERF_COUNT": "NO"})
|
||||
|
||||
# Compile model
|
||||
self.compiled_model = self.ov_core.compile_model(
|
||||
model=model_path, device_name=device
|
||||
)
|
||||
|
||||
# Create reusable inference request
|
||||
self.infer_request = self.compiled_model.create_infer_request()
|
||||
input_shape = self.compiled_model.inputs[0].get_shape()
|
||||
self.input_tensor = ov.Tensor(ov.Type.f32, input_shape)
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
"""Get input names for the model."""
|
||||
return [input.get_any_name() for input in self.compiled_model.inputs]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
"""Get the input width of the model."""
|
||||
input_shape = self.compiled_model.inputs[0].get_shape()
|
||||
# Assuming NCHW format, width is the last dimension
|
||||
return int(input_shape[-1])
|
||||
|
||||
def run(self, input_data: np.ndarray) -> list[np.ndarray]:
|
||||
"""Run inference with the model.
|
||||
|
||||
Args:
|
||||
input_data: Input tensor data
|
||||
|
||||
Returns:
|
||||
List of output tensors
|
||||
"""
|
||||
# Copy input data to pre-allocated tensor
|
||||
np.copyto(self.input_tensor.data, input_data)
|
||||
|
||||
# Run inference
|
||||
self.infer_request.infer(self.input_tensor)
|
||||
|
||||
# Get all output tensors
|
||||
outputs = []
|
||||
for i in range(len(self.compiled_model.outputs)):
|
||||
outputs.append(self.infer_request.get_output_tensor(i).data)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class RKNNModelRunner(BaseModelRunner):
|
||||
"""Run RKNN models for embeddings."""
|
||||
|
||||
def __init__(self, model_path: str, model_type: str = None, core_mask: int = 0):
|
||||
self.model_path = model_path
|
||||
self.model_type = model_type
|
||||
self.core_mask = core_mask
|
||||
self.rknn = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the RKNN model."""
|
||||
try:
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
self.rknn = RKNNLite(verbose=False)
|
||||
|
||||
if self.rknn.load_rknn(self.model_path) != 0:
|
||||
logger.error(f"Failed to load RKNN model: {self.model_path}")
|
||||
raise RuntimeError("Failed to load RKNN model")
|
||||
|
||||
if self.rknn.init_runtime(core_mask=self.core_mask) != 0:
|
||||
logger.error("Failed to initialize RKNN runtime")
|
||||
raise RuntimeError("Failed to initialize RKNN runtime")
|
||||
|
||||
logger.info(f"Successfully loaded RKNN model: {self.model_path}")
|
||||
|
||||
except ImportError:
|
||||
logger.error("RKNN Lite not available")
|
||||
raise ImportError("RKNN Lite not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading RKNN model: {e}")
|
||||
raise
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
"""Get input names for the model."""
|
||||
# For detection models, we typically use "input" as the default input name
|
||||
# For CLIP models, we need to determine the model type from the path
|
||||
model_name = os.path.basename(self.model_path).lower()
|
||||
|
||||
if "vision" in model_name:
|
||||
return ["pixel_values"]
|
||||
elif "arcface" in model_name:
|
||||
return ["data"]
|
||||
else:
|
||||
# Default fallback - try to infer from model type
|
||||
if self.model_type and "jina-clip" in self.model_type:
|
||||
if "vision" in self.model_type:
|
||||
return ["pixel_values"]
|
||||
|
||||
# Generic fallback
|
||||
return ["input"]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
"""Get the input width of the model."""
|
||||
# For CLIP vision models, this is typically 224
|
||||
model_name = os.path.basename(self.model_path).lower()
|
||||
if "vision" in model_name:
|
||||
return 224 # CLIP V1 uses 224x224
|
||||
elif "arcface" in model_name:
|
||||
return 112
|
||||
# For detection models, we can't easily determine this from the RKNN model
|
||||
# The calling code should provide this information
|
||||
return -1
|
||||
|
||||
def run(self, inputs: dict[str, Any]) -> Any:
|
||||
"""Run inference with the RKNN model."""
|
||||
if not self.rknn:
|
||||
raise RuntimeError("RKNN model not loaded")
|
||||
|
||||
try:
|
||||
input_names = self.get_input_names()
|
||||
rknn_inputs = []
|
||||
|
||||
for name in input_names:
|
||||
if name in inputs:
|
||||
if name == "pixel_values":
|
||||
# RKNN expects NHWC format, but ONNX typically provides NCHW
|
||||
# Transpose from [batch, channels, height, width] to [batch, height, width, channels]
|
||||
pixel_data = inputs[name]
|
||||
if len(pixel_data.shape) == 4 and pixel_data.shape[1] == 3:
|
||||
# Transpose from NCHW to NHWC
|
||||
pixel_data = np.transpose(pixel_data, (0, 2, 3, 1))
|
||||
rknn_inputs.append(pixel_data)
|
||||
else:
|
||||
rknn_inputs.append(inputs[name])
|
||||
|
||||
outputs = self.rknn.inference(inputs=rknn_inputs)
|
||||
return outputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during RKNN inference: {e}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when the runner is destroyed."""
|
||||
if self.rknn:
|
||||
try:
|
||||
self.rknn.release()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_optimized_runner(
|
||||
model_path: str, device: str, complex_model: bool = True, **kwargs
|
||||
) -> BaseModelRunner:
|
||||
"""Get an optimized runner for the hardware."""
|
||||
if is_rknn_compatible(model_path):
|
||||
rknn_path = auto_convert_model(model_path)
|
||||
|
||||
if rknn_path:
|
||||
return RKNNModelRunner(rknn_path)
|
||||
|
||||
providers, options = get_ort_providers(device == "CPU", device, **kwargs)
|
||||
|
||||
if device == "CPU":
|
||||
return ONNXModelRunner(
|
||||
ort.InferenceSession(
|
||||
model_path,
|
||||
providers=providers,
|
||||
provider_options=options,
|
||||
)
|
||||
)
|
||||
|
||||
if "OpenVINOExecutionProvider" in providers:
|
||||
return OpenVINOModelRunner(model_path, device, **kwargs)
|
||||
|
||||
ortSession = ort.InferenceSession(
|
||||
model_path,
|
||||
providers=providers,
|
||||
provider_options=options,
|
||||
)
|
||||
|
||||
if not complex_model and providers[0] == "CUDAExecutionProvider":
|
||||
return CudaGraphRunner(ortSession, options[0]["device_id"])
|
||||
|
||||
return ONNXModelRunner(ortSession)
|
@ -6,6 +6,7 @@ from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detection_runners import CudaGraphRunner
|
||||
from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
ModelTypeEnum,
|
||||
@ -23,53 +24,6 @@ logger = logging.getLogger(__name__)
|
||||
DETECTOR_KEY = "onnx"
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
"""Encapsulates CUDA Graph capture and replay using ONNX Runtime IOBinding.
|
||||
|
||||
This runner assumes a single tensor input and binds all model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
|
||||
self._session = session
|
||||
self._cuda_device_id = cuda_device_id
|
||||
self._captured = False
|
||||
self._io_binding: ort.IOBinding | None = None
|
||||
self._input_name: str | None = None
|
||||
self._output_names: list[str] | None = None
|
||||
self._input_ortvalue: ort.OrtValue | None = None
|
||||
self._output_ortvalues: ort.OrtValue | None = None
|
||||
|
||||
def run(self, input_name: str, tensor_input: np.ndarray):
|
||||
tensor_input = np.ascontiguousarray(tensor_input)
|
||||
|
||||
if not self._captured:
|
||||
# Prepare IOBinding with CUDA buffers and let ORT allocate outputs on device
|
||||
self._io_binding = self._session.io_binding()
|
||||
self._input_name = input_name
|
||||
self._output_names = [o.name for o in self._session.get_outputs()]
|
||||
|
||||
self._input_ortvalue = ort.OrtValue.ortvalue_from_numpy(
|
||||
tensor_input, "cuda", self._cuda_device_id
|
||||
)
|
||||
self._io_binding.bind_ortvalue_input(self._input_name, self._input_ortvalue)
|
||||
|
||||
for name in self._output_names:
|
||||
# Bind outputs to CUDA and allow ORT to allocate appropriately
|
||||
self._io_binding.bind_output(name, "cuda", self._cuda_device_id)
|
||||
|
||||
# First IOBinding run to allocate, execute, and capture CUDA Graph
|
||||
ro = ort.RunOptions()
|
||||
self._session.run_with_iobinding(self._io_binding, ro)
|
||||
self._captured = True
|
||||
return self._io_binding.copy_outputs_to_cpu()
|
||||
|
||||
# Replay using updated input, copy results to CPU
|
||||
self._input_ortvalue.update_inplace(tensor_input)
|
||||
ro = ort.RunOptions()
|
||||
self._session.run_with_iobinding(self._io_binding, ro)
|
||||
return self._io_binding.copy_outputs_to_cpu()
|
||||
|
||||
|
||||
class ONNXDetectorConfig(BaseDetectorConfig):
|
||||
type: Literal[DETECTOR_KEY]
|
||||
device: str = Field(default="AUTO", title="Device Type")
|
||||
@ -114,7 +68,6 @@ class ONNXDetector(DetectionApi):
|
||||
|
||||
try:
|
||||
if "CUDAExecutionProvider" in providers:
|
||||
cuda_idx = providers.index("CUDAExecutionProvider")
|
||||
self._cuda_device_id = options[cuda_idx].get("device_id", 0)
|
||||
|
||||
if options[cuda_idx].get("enable_cuda_graph"):
|
||||
@ -142,7 +95,7 @@ class ONNXDetector(DetectionApi):
|
||||
if self._cg_runner is not None:
|
||||
try:
|
||||
# Run using CUDA graphs if available
|
||||
tensor_output = self._cg_runner.run(model_input_name, tensor_input)
|
||||
tensor_output = self._cg_runner.run({model_input_name: tensor_input})
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA Graphs failed, falling back to regular run: {e}")
|
||||
self._cg_runner = None
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import openvino as ov
|
||||
@ -7,6 +6,7 @@ from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detection_runners import OpenVINOModelRunner
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
|
||||
from frigate.util.model import (
|
||||
post_process_dfine,
|
||||
@ -37,20 +37,21 @@ class OvDetector(DetectionApi):
|
||||
|
||||
def __init__(self, detector_config: OvDetectorConfig):
|
||||
super().__init__(detector_config)
|
||||
self.ov_core = ov.Core()
|
||||
self.ov_model_type = detector_config.model.model_type
|
||||
|
||||
self.h = detector_config.model.height
|
||||
self.w = detector_config.model.width
|
||||
|
||||
if not os.path.isfile(detector_config.model.path):
|
||||
logger.error(f"OpenVino model file {detector_config.model.path} not found.")
|
||||
raise FileNotFoundError
|
||||
|
||||
self.interpreter = self.ov_core.compile_model(
|
||||
model=detector_config.model.path, device_name=detector_config.device
|
||||
self.runner = OpenVINOModelRunner(
|
||||
model_path=detector_config.model.path, device=detector_config.device
|
||||
)
|
||||
|
||||
# For dfine models, also pre-allocate target sizes tensor
|
||||
if self.ov_model_type == ModelTypeEnum.dfine:
|
||||
self.target_sizes_tensor = ov.Tensor(
|
||||
np.array([[self.h, self.w]], dtype=np.int64)
|
||||
)
|
||||
|
||||
self.model_invalid = False
|
||||
|
||||
if self.ov_model_type not in self.supported_models:
|
||||
@ -60,8 +61,8 @@ class OvDetector(DetectionApi):
|
||||
self.model_invalid = True
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.ssd:
|
||||
model_inputs = self.interpreter.inputs
|
||||
model_outputs = self.interpreter.outputs
|
||||
model_inputs = self.runner.compiled_model.inputs
|
||||
model_outputs = self.runner.compiled_model.outputs
|
||||
|
||||
if len(model_inputs) != 1:
|
||||
logger.error(
|
||||
@ -80,8 +81,8 @@ class OvDetector(DetectionApi):
|
||||
self.model_invalid = True
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.yolonas:
|
||||
model_inputs = self.interpreter.inputs
|
||||
model_outputs = self.interpreter.outputs
|
||||
model_inputs = self.runner.compiled_model.inputs
|
||||
model_outputs = self.runner.compiled_model.outputs
|
||||
|
||||
if len(model_inputs) != 1:
|
||||
logger.error(
|
||||
@ -104,7 +105,9 @@ class OvDetector(DetectionApi):
|
||||
self.output_indexes = 0
|
||||
while True:
|
||||
try:
|
||||
tensor_shape = self.interpreter.output(self.output_indexes).shape
|
||||
tensor_shape = self.runner.compiled_model.output(
|
||||
self.output_indexes
|
||||
).shape
|
||||
logger.info(
|
||||
f"Model Output-{self.output_indexes} Shape: {tensor_shape}"
|
||||
)
|
||||
@ -129,39 +132,32 @@ class OvDetector(DetectionApi):
|
||||
]
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
infer_request = self.interpreter.create_infer_request()
|
||||
# TODO: see if we can use shared_memory=True
|
||||
input_tensor = ov.Tensor(array=tensor_input)
|
||||
if self.model_invalid:
|
||||
return np.zeros((20, 6), np.float32)
|
||||
|
||||
if self.ov_model_type == ModelTypeEnum.dfine:
|
||||
infer_request.set_tensor("images", input_tensor)
|
||||
target_sizes_tensor = ov.Tensor(
|
||||
np.array([[self.h, self.w]], dtype=np.int64)
|
||||
)
|
||||
infer_request.set_tensor("orig_target_sizes", target_sizes_tensor)
|
||||
infer_request.infer()
|
||||
# Use named inputs for dfine models
|
||||
inputs = {
|
||||
"images": tensor_input,
|
||||
"orig_target_sizes": np.array([[self.h, self.w]], dtype=np.int64),
|
||||
}
|
||||
outputs = self.runner.run_with_named_inputs(inputs)
|
||||
tensor_output = (
|
||||
infer_request.get_output_tensor(0).data,
|
||||
infer_request.get_output_tensor(1).data,
|
||||
infer_request.get_output_tensor(2).data,
|
||||
outputs["output0"],
|
||||
outputs["output1"],
|
||||
outputs["output2"],
|
||||
)
|
||||
return post_process_dfine(tensor_output, self.w, self.h)
|
||||
|
||||
infer_request.infer(input_tensor)
|
||||
# Run inference using the runner
|
||||
outputs = self.runner.run(tensor_input)
|
||||
|
||||
detections = np.zeros((20, 6), np.float32)
|
||||
|
||||
if self.model_invalid:
|
||||
return detections
|
||||
elif self.ov_model_type == ModelTypeEnum.rfdetr:
|
||||
return post_process_rfdetr(
|
||||
[
|
||||
infer_request.get_output_tensor(0).data,
|
||||
infer_request.get_output_tensor(1).data,
|
||||
]
|
||||
)
|
||||
if self.ov_model_type == ModelTypeEnum.rfdetr:
|
||||
return post_process_rfdetr(outputs)
|
||||
elif self.ov_model_type == ModelTypeEnum.ssd:
|
||||
results = infer_request.get_output_tensor(0).data[0][0]
|
||||
results = outputs[0][0][0]
|
||||
|
||||
for i, (_, class_id, score, xmin, ymin, xmax, ymax) in enumerate(results):
|
||||
if i == 20:
|
||||
@ -176,7 +172,7 @@ class OvDetector(DetectionApi):
|
||||
]
|
||||
return detections
|
||||
elif self.ov_model_type == ModelTypeEnum.yolonas:
|
||||
predictions = infer_request.get_output_tensor(0).data
|
||||
predictions = outputs[0]
|
||||
|
||||
for i, prediction in enumerate(predictions):
|
||||
if i == 20:
|
||||
@ -195,16 +191,10 @@ class OvDetector(DetectionApi):
|
||||
]
|
||||
return detections
|
||||
elif self.ov_model_type == ModelTypeEnum.yologeneric:
|
||||
out_tensor = []
|
||||
|
||||
for item in infer_request.output_tensors:
|
||||
out_tensor.append(item.data)
|
||||
|
||||
return post_process_yolo(out_tensor, self.w, self.h)
|
||||
return post_process_yolo(outputs, self.w, self.h)
|
||||
elif self.ov_model_type == ModelTypeEnum.yolox:
|
||||
out_tensor = infer_request.get_output_tensor()
|
||||
# [x, y, h, w, box_score, class_no_1, ..., class_no_80],
|
||||
results = out_tensor.data
|
||||
results = outputs[0]
|
||||
results[..., :2] = (results[..., :2] + self.grids) * self.expanded_strides
|
||||
results[..., 2:4] = np.exp(results[..., 2:4]) * self.expanded_strides
|
||||
image_pred = results[0, ...]
|
||||
|
@ -10,6 +10,7 @@ from pydantic import Field
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detection_runners import RKNNModelRunner
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum
|
||||
from frigate.util.model import post_process_yolo
|
||||
from frigate.util.rknn_converter import auto_convert_model
|
||||
@ -61,18 +62,18 @@ class Rknn(DetectionApi):
|
||||
"For more information, see: https://docs.deci.ai/super-gradients/latest/LICENSE.YOLONAS.html"
|
||||
)
|
||||
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
self.rknn = RKNNLite(verbose=False)
|
||||
if self.rknn.load_rknn(model_props["path"]) != 0:
|
||||
logger.error("Error initializing rknn model.")
|
||||
if self.rknn.init_runtime(core_mask=core_mask) != 0:
|
||||
logger.error(
|
||||
"Error initializing rknn runtime. Do you run docker in privileged mode?"
|
||||
)
|
||||
self.runner = RKNNModelRunner(
|
||||
model_path=model_props["path"],
|
||||
model_type=config.model.model_type.value
|
||||
if config.model.model_type
|
||||
else None,
|
||||
core_mask=core_mask,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.rknn.release()
|
||||
if hasattr(self, "runner") and self.runner:
|
||||
# The runner's __del__ method will handle cleanup
|
||||
pass
|
||||
|
||||
def get_soc(self):
|
||||
try:
|
||||
@ -305,9 +306,7 @@ class Rknn(DetectionApi):
|
||||
)
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
output = self.rknn.inference(
|
||||
[
|
||||
tensor_input,
|
||||
]
|
||||
)
|
||||
# Prepare input for the runner
|
||||
inputs = {"input": tensor_input}
|
||||
output = self.runner.run(inputs)
|
||||
return self.post_process(output)
|
||||
|
@ -6,12 +6,12 @@ import os
|
||||
import numpy as np
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.detectors.detection_runners import get_optimized_runner
|
||||
from frigate.log import redirect_output_to_logger
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
from ...config import FaceRecognitionConfig
|
||||
from .base_embedding import BaseEmbedding
|
||||
from .runner import ONNXModelRunner
|
||||
|
||||
try:
|
||||
from tflite_runtime.interpreter import Interpreter
|
||||
@ -148,9 +148,10 @@ class ArcfaceEmbedding(BaseEmbedding):
|
||||
if self.downloader:
|
||||
self.downloader.wait_for_download()
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
device=self.config.device or "GPU",
|
||||
complex_model=False,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
@ -12,11 +12,11 @@ from transformers.utils.logging import disable_progress_bar
|
||||
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
from .base_embedding import BaseEmbedding
|
||||
from .runner import ONNXModelRunner
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@ -125,7 +125,7 @@ class JinaV1TextEmbedding(BaseEmbedding):
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
)
|
||||
@ -170,7 +170,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
|
||||
self.device = device
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.feature_extractor = None
|
||||
self.runner: ONNXModelRunner | None = None
|
||||
self.runner: BaseModelRunner | None = None
|
||||
files_names = list(self.download_urls.keys())
|
||||
if not all(
|
||||
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||
@ -203,7 +203,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
|
||||
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
||||
)
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
)
|
||||
|
@ -11,11 +11,11 @@ from transformers.utils.logging import disable_progress_bar, set_verbosity_error
|
||||
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.detectors.detection_runners import get_optimized_runner
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
from .base_embedding import BaseEmbedding
|
||||
from .runner import ONNXModelRunner
|
||||
|
||||
# disables the progress bar and download logging for downloading tokenizers and image processors
|
||||
disable_progress_bar()
|
||||
@ -125,7 +125,7 @@ class JinaV2Embedding(BaseEmbedding):
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
)
|
||||
|
@ -7,11 +7,11 @@ import numpy as np
|
||||
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
from .base_embedding import BaseEmbedding
|
||||
from .runner import ONNXModelRunner
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@ -47,7 +47,7 @@ class PaddleOCRDetection(BaseEmbedding):
|
||||
self.model_size = model_size
|
||||
self.device = device
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.runner: ONNXModelRunner | None = None
|
||||
self.runner: BaseModelRunner | None = None
|
||||
files_names = list(self.download_urls.keys())
|
||||
if not all(
|
||||
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||
@ -76,10 +76,9 @@ class PaddleOCRDetection(BaseEmbedding):
|
||||
if self.downloader:
|
||||
self.downloader.wait_for_download()
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
self.model_size,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@ -107,7 +106,7 @@ class PaddleOCRClassification(BaseEmbedding):
|
||||
self.model_size = model_size
|
||||
self.device = device
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.runner: ONNXModelRunner | None = None
|
||||
self.runner: BaseModelRunner | None = None
|
||||
files_names = list(self.download_urls.keys())
|
||||
if not all(
|
||||
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||
@ -136,10 +135,9 @@ class PaddleOCRClassification(BaseEmbedding):
|
||||
if self.downloader:
|
||||
self.downloader.wait_for_download()
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
self.model_size,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@ -168,7 +166,7 @@ class PaddleOCRRecognition(BaseEmbedding):
|
||||
self.model_size = model_size
|
||||
self.device = device
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.runner: ONNXModelRunner | None = None
|
||||
self.runner: BaseModelRunner | None = None
|
||||
files_names = list(self.download_urls.keys())
|
||||
if not all(
|
||||
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||
@ -197,10 +195,9 @@ class PaddleOCRRecognition(BaseEmbedding):
|
||||
if self.downloader:
|
||||
self.downloader.wait_for_download()
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
self.model_size,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
@ -229,7 +226,7 @@ class LicensePlateDetector(BaseEmbedding):
|
||||
self.model_size = model_size
|
||||
self.device = device
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.runner: ONNXModelRunner | None = None
|
||||
self.runner: BaseModelRunner | None = None
|
||||
files_names = list(self.download_urls.keys())
|
||||
if not all(
|
||||
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
|
||||
@ -258,10 +255,10 @@ class LicensePlateDetector(BaseEmbedding):
|
||||
if self.downloader:
|
||||
self.downloader.wait_for_download()
|
||||
|
||||
self.runner = ONNXModelRunner(
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
self.model_size,
|
||||
complex_model=False,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
@ -1,243 +0,0 @@
|
||||
"""Convenience runner for onnx models."""
|
||||
|
||||
import logging
|
||||
import os.path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.util.model import get_ort_providers
|
||||
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
||||
|
||||
try:
|
||||
import openvino as ov
|
||||
except ImportError:
|
||||
# openvino is not included
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ONNXModelRunner:
|
||||
"""Run onnx models optimally based on available hardware."""
|
||||
|
||||
def __init__(self, model_path: str, device: str, requires_fp16: bool = False):
|
||||
self.model_path = model_path
|
||||
self.ort: ort.InferenceSession = None
|
||||
self.ov: ov.Core = None
|
||||
self.rknn = None
|
||||
self.type = "ort"
|
||||
|
||||
try:
|
||||
if device != "CPU" and is_rknn_compatible(model_path):
|
||||
# Try to auto-convert to RKNN format
|
||||
rknn_path = auto_convert_model(model_path)
|
||||
if rknn_path:
|
||||
try:
|
||||
self.rknn = RKNNModelRunner(rknn_path, device)
|
||||
self.type = "rknn"
|
||||
logger.info(f"Using RKNN model: {rknn_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to load RKNN model, falling back to ONNX: {e}"
|
||||
)
|
||||
self.rknn = None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fall back to standard ONNX providers
|
||||
providers, options = get_ort_providers(
|
||||
device == "CPU",
|
||||
device,
|
||||
requires_fp16,
|
||||
)
|
||||
self.interpreter = None
|
||||
|
||||
if "OpenVINOExecutionProvider" in providers:
|
||||
try:
|
||||
# use OpenVINO directly
|
||||
self.type = "ov"
|
||||
self.ov = ov.Core()
|
||||
self.ov.set_property(
|
||||
{ov.properties.cache_dir: os.path.join(MODEL_CACHE_DIR, "openvino")}
|
||||
)
|
||||
self.interpreter = self.ov.compile_model(
|
||||
model=model_path, device_name=device
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"OpenVINO failed to build model, using CPU instead: {e}"
|
||||
)
|
||||
self.interpreter = None
|
||||
|
||||
# Use ONNXRuntime
|
||||
if self.interpreter is None:
|
||||
self.type = "ort"
|
||||
self.ort = ort.InferenceSession(
|
||||
model_path,
|
||||
providers=providers,
|
||||
provider_options=options,
|
||||
)
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
if self.type == "rknn":
|
||||
return self.rknn.get_input_names()
|
||||
elif self.type == "ov":
|
||||
input_names = []
|
||||
|
||||
for input in self.interpreter.inputs:
|
||||
input_names.extend(input.names)
|
||||
|
||||
return input_names
|
||||
elif self.type == "ort":
|
||||
return [input.name for input in self.ort.get_inputs()]
|
||||
|
||||
def get_input_width(self):
|
||||
"""Get the input width of the model regardless of backend."""
|
||||
if self.type == "rknn":
|
||||
return self.rknn.get_input_width()
|
||||
elif self.type == "ort":
|
||||
return self.ort.get_inputs()[0].shape[3]
|
||||
elif self.type == "ov":
|
||||
input_info = self.interpreter.inputs
|
||||
first_input = input_info[0]
|
||||
|
||||
try:
|
||||
partial_shape = first_input.get_partial_shape()
|
||||
# width dimension
|
||||
if len(partial_shape) >= 4 and partial_shape[3].is_static:
|
||||
return partial_shape[3].get_length()
|
||||
|
||||
# If width is dynamic or we can't determine it
|
||||
return -1
|
||||
except Exception:
|
||||
try:
|
||||
# gemini says some ov versions might still allow this
|
||||
input_shape = first_input.shape
|
||||
return input_shape[3] if len(input_shape) >= 4 else -1
|
||||
except Exception:
|
||||
return -1
|
||||
return -1
|
||||
|
||||
def run(self, input: dict[str, Any]) -> Any | None:
|
||||
if self.type == "rknn":
|
||||
return self.rknn.run(input)
|
||||
elif self.type == "ov":
|
||||
infer_request = self.interpreter.create_infer_request()
|
||||
|
||||
try:
|
||||
# This ensures the model starts with a clean state for each sequence
|
||||
# Important for RNN models like PaddleOCR recognition
|
||||
infer_request.reset_state()
|
||||
except Exception:
|
||||
# this will raise an exception for models with AUTO set as the device
|
||||
pass
|
||||
|
||||
outputs = infer_request.infer(input)
|
||||
|
||||
return outputs
|
||||
elif self.type == "ort":
|
||||
return self.ort.run(None, input)
|
||||
|
||||
|
||||
class RKNNModelRunner:
|
||||
"""Run RKNN models for embeddings."""
|
||||
|
||||
def __init__(self, model_path: str, device: str = "AUTO", model_type: str = None):
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
self.model_type = model_type
|
||||
self.rknn = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the RKNN model."""
|
||||
try:
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
self.rknn = RKNNLite(verbose=False)
|
||||
|
||||
if self.rknn.load_rknn(self.model_path) != 0:
|
||||
logger.error(f"Failed to load RKNN model: {self.model_path}")
|
||||
raise RuntimeError("Failed to load RKNN model")
|
||||
|
||||
if self.rknn.init_runtime() != 0:
|
||||
logger.error("Failed to initialize RKNN runtime")
|
||||
raise RuntimeError("Failed to initialize RKNN runtime")
|
||||
|
||||
logger.info(f"Successfully loaded RKNN model: {self.model_path}")
|
||||
|
||||
except ImportError:
|
||||
logger.error("RKNN Lite not available")
|
||||
raise ImportError("RKNN Lite not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading RKNN model: {e}")
|
||||
raise
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
"""Get input names for the model."""
|
||||
# For CLIP models, we need to determine the model type from the path
|
||||
model_name = os.path.basename(self.model_path).lower()
|
||||
|
||||
if "vision" in model_name:
|
||||
return ["pixel_values"]
|
||||
elif "arcface" in model_name:
|
||||
return ["data"]
|
||||
else:
|
||||
# Default fallback - try to infer from model type
|
||||
if self.model_type and "jina-clip" in self.model_type:
|
||||
if "vision" in self.model_type:
|
||||
return ["pixel_values"]
|
||||
|
||||
# Generic fallback
|
||||
return ["input"]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
"""Get the input width of the model."""
|
||||
# For CLIP vision models, this is typically 224
|
||||
model_name = os.path.basename(self.model_path).lower()
|
||||
if "vision" in model_name:
|
||||
return 224 # CLIP V1 uses 224x224
|
||||
elif "arcface" in model_name:
|
||||
return 112
|
||||
return -1
|
||||
|
||||
def run(self, inputs: dict[str, Any]) -> Any:
|
||||
"""Run inference with the RKNN model."""
|
||||
if not self.rknn:
|
||||
raise RuntimeError("RKNN model not loaded")
|
||||
|
||||
try:
|
||||
input_names = self.get_input_names()
|
||||
rknn_inputs = []
|
||||
|
||||
for name in input_names:
|
||||
if name in inputs:
|
||||
if name == "pixel_values":
|
||||
# RKNN expects NHWC format, but ONNX typically provides NCHW
|
||||
# Transpose from [batch, channels, height, width] to [batch, height, width, channels]
|
||||
pixel_data = inputs[name]
|
||||
if len(pixel_data.shape) == 4 and pixel_data.shape[1] == 3:
|
||||
# Transpose from NCHW to NHWC
|
||||
pixel_data = np.transpose(pixel_data, (0, 2, 3, 1))
|
||||
rknn_inputs.append(pixel_data)
|
||||
else:
|
||||
rknn_inputs.append(inputs[name])
|
||||
|
||||
outputs = self.rknn.inference(inputs=rknn_inputs)
|
||||
return outputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during RKNN inference: {e}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when the runner is destroyed."""
|
||||
if self.rknn:
|
||||
try:
|
||||
self.rknn.release()
|
||||
except Exception:
|
||||
pass
|
Loading…
Reference in New Issue
Block a user