Add support for yolonas via ONNX and allow TensorRT execution provider to work correctly (#13776)

* Add support for yolonas in onnx

* Add correct deps

* Set ld library path

* Refactor cudnn to only be used in amd64

* Add onnx to docs and add explainer at the top

* Undo change

* Update comment

* Remove uneccesary

* Remove line change
This commit is contained in:
Nicolas Mowen 2024-09-16 15:17:31 -06:00 committed by GitHub
parent 9bcb928715
commit 2f69f5afe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 127 additions and 12 deletions

View File

@ -12,12 +12,27 @@ ARG TARGETARCH
COPY docker/tensorrt/requirements-amd64.txt /requirements-tensorrt.txt
RUN mkdir -p /trt-wheels && pip3 wheel --wheel-dir=/trt-wheels -r /requirements-tensorrt.txt
# Build CuDNN
FROM tensorrt-base AS cudnn-deps
ARG COMPUTE_LEVEL
RUN apt-get update \
&& apt-get install -y git build-essential
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.1-1_all.deb \
&& dpkg -i cuda-keyring_1.1-1_all.deb \
&& apt-get update \
&& apt-get -y install cuda-toolkit \
&& rm -rf /var/lib/apt/lists/*
FROM tensorrt-base AS frigate-tensorrt
ENV TRT_VER=8.5.3
RUN --mount=type=bind,from=trt-wheels,source=/trt-wheels,target=/deps/trt-wheels \
pip3 install -U /deps/trt-wheels/*.whl && \
ldconfig
ENV LD_LIBRARY_PATH=/usr/local/lib/python3.9/dist-packages/tensorrt:/usr/local/cuda/lib64:/usr/local/lib/python3.9/dist-packages/nvidia/cufft/lib
WORKDIR /opt/frigate/
COPY --from=rootfs / /
@ -26,6 +41,7 @@ FROM devcontainer AS devcontainer-trt
COPY --from=trt-deps /usr/local/lib/libyolo_layer.so /usr/local/lib/libyolo_layer.so
COPY --from=trt-deps /usr/local/src/tensorrt_demos /usr/local/src/tensorrt_demos
COPY --from=cudnn-deps /usr/local/cuda-12.6 /usr/local/cuda
COPY docker/tensorrt/detector/rootfs/ /
COPY --from=trt-deps /usr/local/lib/libyolo_layer.so /usr/local/lib/libyolo_layer.so
RUN --mount=type=bind,from=trt-wheels,source=/trt-wheels,target=/deps/trt-wheels \

View File

@ -7,7 +7,8 @@ cython == 0.29.*; platform_machine == 'x86_64'
nvidia-cuda-runtime-cu12 == 12.1.*; platform_machine == 'x86_64'
nvidia-cuda-runtime-cu11 == 11.8.*; platform_machine == 'x86_64'
nvidia-cublas-cu11 == 11.11.3.6; platform_machine == 'x86_64'
nvidia-cudnn-cu11 == 8.6.0.*; platform_machine == 'x86_64'
nvidia-cudnn-cu11 == 8.5.0.*; platform_machine == 'x86_64'
nvidia-cufft-cu11==10.*; platform_machine == 'x86_64'
onnx==1.14.0; platform_machine == 'x86_64'
onnxruntime-gpu==1.18.0; platform_machine == 'x86_64'
onnxruntime-gpu==1.17.*; platform_machine == 'x86_64'
protobuf==3.20.3; platform_machine == 'x86_64'

View File

@ -3,6 +3,24 @@ id: object_detectors
title: Object Detectors
---
# Supported Hardware
Frigate supports multiple different detectors that work on different types of hardware:
**Most Hardware**
- [Coral EdgeTPU](#edge-tpu-detector): The Google Coral EdgeTPU is available in USB and m.2 format allowing for a wide range of compatibility with devices.
**Intel**
- [OpenVino](#openvino-detector): OpenVino can run on Intel Arc GPUs, Intel integrated GPUs, and Intel CPUs to provide efficient object detection.
- [ONNX](#onnx): OpenVINO will automatically be detected and used as a detector in the default Frigate image when a supported ONNX model is configured.
**Nvidia**
- [TensortRT](#nvidia-tensorrt-detector): TensorRT can run on Nvidia GPUs, using one of many default models.
- [ONNX](#onnx): TensorRT will automatically be detected and used as a detector in the `-tensorrt` Frigate image when a supported ONNX is configured.
**Rockchip**
- [RKNN](#rockchip-platform): RKNN models can run on Rockchip devices with included NPUs.
# Officially Supported Detectors
Frigate provides the following builtin detector types: `cpu`, `edgetpu`, `openvino`, `tensorrt`, `rknn`, and `hailo8l`. By default, Frigate will use a single CPU detector. Other detectors may require additional configuration as described below. When using multiple detectors they will run in dedicated processes, but pull from a common queue of detection requests from across all cameras.
@ -278,6 +296,44 @@ model:
height: 320
```
## ONNX
ONNX is an open format for building machine learning models, these models can run on a wide variety of hardware. Frigate supports running ONNX models on CPU, OpenVINO, and TensorRT.
### Supported Models
There is no default model provided, the following formats are supported:
#### YOLO-NAS
[YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md) models are supported, but not included by default. You can build and download a compatible model with pre-trained weights using [this notebook](https://github.com/frigate/blob/dev/notebooks/YOLO_NAS_Pretrained_Export.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/blakeblackshear/frigate/blob/dev/notebooks/YOLO_NAS_Pretrained_Export.ipynb).
:::warning
The pre-trained YOLO-NAS weights from DeciAI are subject to their license and can't be used commercially. For more information, see: https://docs.deci.ai/super-gradients/latest/LICENSE.YOLONAS.html
:::
The input image size in this notebook is set to 320x320. This results in lower CPU usage and faster inference times without impacting performance in most cases due to the way Frigate crops video frames to areas of interest before running detection. The notebook and config can be updated to 640x640 if desired.
After placing the downloaded onnx model in your config folder, you can use the following configuration:
```yaml
detectors:
onnx:
type: onnx
model:
model_type: yolonas
width: 320 # <--- should match whatever was set in notebook
height: 320 # <--- should match whatever was set in notebook
input_pixel_format: bgr
path: /config/yolo_nas_s.onnx
labelmap_path: /labelmap/coco-80.txt
```
Note that the labelmap uses a subset of the complete COCO label set that has only 80 objects.
## Deepstack / CodeProject.AI Server Detector
The Deepstack / CodeProject.AI Server detector for Frigate allows you to integrate Deepstack and CodeProject.AI object detection capabilities into Frigate. CodeProject.AI and DeepStack are open-source AI platforms that can be run on various devices such as the Raspberry Pi, Nvidia Jetson, and other compatible hardware. It is important to note that the integration is performed over the network, so the inference times may not be as fast as native Frigate detectors, but it still provides an efficient and reliable solution for object detection and tracking.

View File

@ -1,11 +1,15 @@
import logging
import cv2
import numpy as np
from typing_extensions import Literal
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
from frigate.detectors.util import preprocess
from frigate.detectors.detector_config import (
BaseDetectorConfig,
ModelTypeEnum,
PixelFormatEnum,
)
logger = logging.getLogger(__name__)
@ -21,7 +25,7 @@ class ONNXDetector(DetectionApi):
def __init__(self, detector_config: ONNXDetectorConfig):
try:
import onnxruntime
import onnxruntime as ort
logger.info("ONNX: loaded onnxruntime module")
except ModuleNotFoundError:
@ -32,16 +36,54 @@ class ONNXDetector(DetectionApi):
path = detector_config.model.path
logger.info(f"ONNX: loading {detector_config.model.path}")
self.model = onnxruntime.InferenceSession(path)
self.model = ort.InferenceSession(path, providers=ort.get_available_providers())
self.h = detector_config.model.height
self.w = detector_config.model.width
self.onnx_model_type = detector_config.model.model_type
self.onnx_model_px = detector_config.model.input_pixel_format
path = detector_config.model.path
logger.info(f"ONNX: {path} loaded")
def detect_raw(self, tensor_input):
model_input_name = self.model.get_inputs()[0].name
model_input_shape = self.model.get_inputs()[0].shape
tensor_input = preprocess(tensor_input, model_input_shape, np.float32)
# ruff: noqa: F841
tensor_output = self.model.run(None, {model_input_name: tensor_input})[0]
raise Exception(
"No models are currently supported via onnx. See the docs for more info."
)
# adjust input shape
if self.onnx_model_type == ModelTypeEnum.yolonas:
tensor_input = cv2.dnn.blobFromImage(
tensor_input[0],
1.0,
(model_input_shape[3], model_input_shape[2]),
None,
swapRB=self.onnx_model_px == PixelFormatEnum.bgr,
).astype(np.uint8)
tensor_output = self.model.run(None, {model_input_name: tensor_input})
if self.onnx_model_type == ModelTypeEnum.yolonas:
predictions = tensor_output[0]
detections = np.zeros((20, 6), np.float32)
for i, prediction in enumerate(predictions):
if i == 20:
break
(_, x_min, y_min, x_max, y_max, confidence, class_id) = prediction
# when running in GPU mode, empty predictions in the output have class_id of -1
if class_id < 0:
break
detections[i] = [
class_id,
confidence,
y_min / self.h,
x_min / self.w,
y_max / self.h,
x_max / self.w,
]
return detections
else:
raise Exception(
f"{self.onnx_model_type} is currently not supported for rocm. See the docs for more info on supported models."
)