Rocm yolonas (#13816)

* Implement ROCm detectors

* Cleanup tensor input

* Fixup image creation

* Add support for yolonas in onnx

* Get build working with onnx

* Update docs and simplify config

* Remove unused imports
This commit is contained in:
Nicolas Mowen 2024-09-18 18:34:07 -06:00 committed by GitHub
parent efd1194307
commit 4515eb4637
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 194 additions and 154 deletions

View File

@ -179,57 +179,18 @@ jobs:
h8l.tags=${{ steps.setup.outputs.image-name }}-h8l
*.cache-from=type=registry,ref=${{ steps.setup.outputs.cache-name }}-h8l
*.cache-to=type=registry,ref=${{ steps.setup.outputs.cache-name }}-h8l,mode=max
#- name: AMD/ROCm general build
# env:
# AMDGPU: gfx
# HSA_OVERRIDE: 0
# uses: docker/bake-action@v3
# with:
# push: true
# targets: rocm
# files: docker/rocm/rocm.hcl
# set: |
# rocm.tags=${{ steps.setup.outputs.image-name }}-rocm
# *.cache-from=type=gha
#- name: AMD/ROCm gfx900
# env:
# AMDGPU: gfx900
# HSA_OVERRIDE: 1
# HSA_OVERRIDE_GFX_VERSION: 9.0.0
# uses: docker/bake-action@v3
# with:
# push: true
# targets: rocm
# files: docker/rocm/rocm.hcl
# set: |
# rocm.tags=${{ steps.setup.outputs.image-name }}-rocm-gfx900
# *.cache-from=type=gha
#- name: AMD/ROCm gfx1030
# env:
# AMDGPU: gfx1030
# HSA_OVERRIDE: 1
# HSA_OVERRIDE_GFX_VERSION: 10.3.0
# uses: docker/bake-action@v3
# with:
# push: true
# targets: rocm
# files: docker/rocm/rocm.hcl
# set: |
# rocm.tags=${{ steps.setup.outputs.image-name }}-rocm-gfx1030
# *.cache-from=type=gha
#- name: AMD/ROCm gfx1100
# env:
# AMDGPU: gfx1100
# HSA_OVERRIDE: 1
# HSA_OVERRIDE_GFX_VERSION: 11.0.0
# uses: docker/bake-action@v3
# with:
# push: true
# targets: rocm
# files: docker/rocm/rocm.hcl
# set: |
# rocm.tags=${{ steps.setup.outputs.image-name }}-rocm-gfx1100
# *.cache-from=type=gha
- name: AMD/ROCm general build
env:
AMDGPU: gfx
HSA_OVERRIDE: 0
uses: docker/bake-action@v3
with:
push: true
targets: rocm
files: docker/rocm/rocm.hcl
set: |
rocm.tags=${{ steps.setup.outputs.image-name }}-rocm
*.cache-from=type=gha
# The majority of users running arm64 are rpi users, so the rpi
# build should be the primary arm64 image
assemble_default_build:

View File

@ -23,11 +23,11 @@ COPY docker/rocm/rocm-pin-600 /etc/apt/preferences.d/
RUN apt-get update
RUN apt-get -y install --no-install-recommends migraphx
RUN apt-get -y install --no-install-recommends migraphx hipfft roctracer
RUN apt-get -y install --no-install-recommends migraphx-dev
RUN mkdir -p /opt/rocm-dist/opt/rocm-$ROCM/lib
RUN cd /opt/rocm-$ROCM/lib && cp -dpr libMIOpen*.so* libamd*.so* libhip*.so* libhsa*.so* libmigraphx*.so* librocm*.so* librocblas*.so* /opt/rocm-dist/opt/rocm-$ROCM/lib/
RUN cd /opt/rocm-$ROCM/lib && cp -dpr libMIOpen*.so* libamd*.so* libhip*.so* libhsa*.so* libmigraphx*.so* librocm*.so* librocblas*.so* libroctracer*.so* librocfft*.so* /opt/rocm-dist/opt/rocm-$ROCM/lib/
RUN cd /opt/rocm-dist/opt/ && ln -s rocm-$ROCM rocm
RUN mkdir -p /opt/rocm-dist/etc/ld.so.conf.d/
@ -69,7 +69,11 @@ RUN apt-get -y install libnuma1
WORKDIR /opt/frigate/
COPY --from=rootfs / /
COPY docker/rocm/rootfs/ /
COPY docker/rocm/requirements-wheels-rocm.txt /requirements.txt
RUN python3 -m pip install --upgrade pip \
&& pip3 uninstall -y onnxruntime-openvino \
&& pip3 install -r /requirements.txt
#######################################################################
FROM scratch AS rocm-dist
@ -101,6 +105,3 @@ ENV HSA_OVERRIDE_GFX_VERSION=$HSA_OVERRIDE_GFX_VERSION
#######################################################################
FROM rocm-prelim-hsa-override$HSA_OVERRIDE as rocm-deps
# Request yolov8 download at startup
ENV DOWNLOAD_YOLOV8=1

View File

@ -0,0 +1 @@
onnxruntime-rocm @ https://github.com/NickM-27/frigate-onnxruntime-rocm/releases/download/v1.0.0/onnxruntime_rocm-1.17.3-cp39-cp39-linux_x86_64.whl

View File

@ -1,20 +0,0 @@
#!/command/with-contenv bash
# shellcheck shell=bash
# Compile YoloV8 ONNX files into ROCm MIGraphX files
OVERRIDE=$(cd /opt/frigate && python3 -c 'import frigate.detectors.plugins.rocm as rocm; print(rocm.auto_override_gfx_version())')
if ! test -z "$OVERRIDE"; then
echo "Using HSA_OVERRIDE_GFX_VERSION=${OVERRIDE}"
export HSA_OVERRIDE_GFX_VERSION=$OVERRIDE
fi
for onnx in /config/model_cache/yolov8/*.onnx
do
mxr="${onnx%.onnx}.mxr"
if ! test -f $mxr; then
echo "processing $onnx into $mxr"
/opt/rocm/bin/migraphx-driver compile $onnx --optimize --gpu --enable-offload-copy --binary -o $mxr
fi
done

View File

@ -1 +0,0 @@
/etc/s6-overlay/s6-rc.d/compile-rocm-models/run

View File

@ -9,6 +9,11 @@ Frigate supports multiple different detectors that work on different types of ha
**Most Hardware**
- [Coral EdgeTPU](#edge-tpu-detector): The Google Coral EdgeTPU is available in USB and m.2 format allowing for a wide range of compatibility with devices.
- [Hailo](#hailo-8l): The Hailo8 AI Acceleration module is available in m.2 format with a HAT for RPi devices, offering a wide range of compatibility with devices.
**AMD**
- [ROCm](#amdrocm-gpu-detector): ROCm can run on AMD Discrete GPUs to provide efficient object detection.
- [ONNX](#onnx): ROCm will automatically be detected and used as a detector in the `-rocm` Frigate image when a supported ONNX model is configured.
**Intel**
- [OpenVino](#openvino-detector): OpenVino can run on Intel Arc GPUs, Intel integrated GPUs, and Intel CPUs to provide efficient object detection.
@ -16,7 +21,7 @@ Frigate supports multiple different detectors that work on different types of ha
**Nvidia**
- [TensortRT](#nvidia-tensorrt-detector): TensorRT can run on Nvidia GPUs, using one of many default models.
- [ONNX](#onnx): TensorRT will automatically be detected and used as a detector in the `-tensorrt` Frigate image when a supported ONNX is configured.
- [ONNX](#onnx): TensorRT will automatically be detected and used as a detector in the `-tensorrt` Frigate image when a supported ONNX model is configured.
**Rockchip**
- [RKNN](#rockchip-platform): RKNN models can run on Rockchip devices with included NPUs.
@ -312,6 +317,121 @@ model:
height: 320
```
## AMD/ROCm GPU detector
### Setup
The `rocm` detector supports running YOLO-NAS models on AMD GPUs. Use a frigate docker image with `-rocm` suffix, for example `ghcr.io/blakeblackshear/frigate:stable-rocm`.
### Docker settings for GPU access
ROCm needs access to the `/dev/kfd` and `/dev/dri` devices. When docker or frigate is not run under root then also `video` (and possibly `render` and `ssl/_ssl`) groups should be added.
When running docker directly the following flags should be added for device access:
```bash
$ docker run --device=/dev/kfd --device=/dev/dri \
...
```
When using docker compose:
```yaml
services:
frigate:
---
devices:
- /dev/dri
- /dev/kfd
```
For reference on recommended settings see [running ROCm/pytorch in Docker](https://rocm.docs.amd.com/projects/install-on-linux/en/develop/how-to/3rd-party/pytorch-install.html#using-docker-with-pytorch-pre-installed).
### Docker settings for overriding the GPU chipset
Your GPU might work just fine without any special configuration but in many cases they need manual settings. AMD/ROCm software stack comes with a limited set of GPU drivers and for newer or missing models you will have to override the chipset version to an older/generic version to get things working.
Also AMD/ROCm does not "officially" support integrated GPUs. It still does work with most of them just fine but requires special settings. One has to configure the `HSA_OVERRIDE_GFX_VERSION` environment variable. See the [ROCm bug report](https://github.com/ROCm/ROCm/issues/1743) for context and examples.
For the rocm frigate build there is some automatic detection:
- gfx90c -> 9.0.0
- gfx1031 -> 10.3.0
- gfx1103 -> 11.0.0
If you have something else you might need to override the `HSA_OVERRIDE_GFX_VERSION` at Docker launch. Suppose the version you want is `9.0.0`, then you should configure it from command line as:
```bash
$ docker run -e HSA_OVERRIDE_GFX_VERSION=9.0.0 \
...
```
When using docker compose:
```yaml
services:
frigate:
...
environment:
HSA_OVERRIDE_GFX_VERSION: "9.0.0"
```
Figuring out what version you need can be complicated as you can't tell the chipset name and driver from the AMD brand name.
- first make sure that rocm environment is running properly by running `/opt/rocm/bin/rocminfo` in the frigate container -- it should list both the CPU and the GPU with their properties
- find the chipset version you have (gfxNNN) from the output of the `rocminfo` (see below)
- use a search engine to query what `HSA_OVERRIDE_GFX_VERSION` you need for the given gfx name ("gfxNNN ROCm HSA_OVERRIDE_GFX_VERSION")
- override the `HSA_OVERRIDE_GFX_VERSION` with relevant value
- if things are not working check the frigate docker logs
#### Figuring out if AMD/ROCm is working and found your GPU
```bash
$ docker exec -it frigate /opt/rocm/bin/rocminfo
```
#### Figuring out your AMD GPU chipset version:
We unset the `HSA_OVERRIDE_GFX_VERSION` to prevent an existing override from messing up the result:
```bash
$ docker exec -it frigate /bin/bash -c '(unset HSA_OVERRIDE_GFX_VERSION && /opt/rocm/bin/rocminfo |grep gfx)'
```
### Supported Models
There is no default model provided, the following formats are supported:
#### YOLO-NAS
[YOLO-NAS](https://github.com/Deci-AI/super-gradients/blob/master/YOLONAS.md) models are supported, but not included by default. You can build and download a compatible model with pre-trained weights using [this notebook](https://github.com/frigate/blob/dev/notebooks/YOLO_NAS_Pretrained_Export.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/blakeblackshear/frigate/blob/dev/notebooks/YOLO_NAS_Pretrained_Export.ipynb).
:::warning
The pre-trained YOLO-NAS weights from DeciAI are subject to their license and can't be used commercially. For more information, see: https://docs.deci.ai/super-gradients/latest/LICENSE.YOLONAS.html
:::
The input image size in this notebook is set to 320x320. This results in lower CPU usage and faster inference times without impacting performance in most cases due to the way Frigate crops video frames to areas of interest before running detection. The notebook and config can be updated to 640x640 if desired.
After placing the downloaded onnx model in your config folder, you can use the following configuration:
```yaml
detectors:
onnx:
type: rocm
model:
model_type: yolonas
width: 320 # <--- should match whatever was set in notebook
height: 320 # <--- should match whatever was set in notebook
input_pixel_format: bgr
path: /config/yolo_nas_s.onnx
labelmap_path: /labelmap/coco-80.txt
```
Note that the labelmap uses a subset of the complete COCO label set that has only 80 objects.
## ONNX
ONNX is an open format for building machine learning models, Frigate supports running ONNX models on CPU, OpenVINO, and TensorRT. On startup Frigate will automatically try to use a GPU if one is available.
@ -475,7 +595,7 @@ $ cat /sys/kernel/debug/rknpu/load
## Hailo-8l
This detector is available if you are using the Raspberry Pi 5 with Hailo-8L AI Kit. This has not been tested using the Hailo-8L with other hardware.
This detector is available for use with Hailo-8 AI Acceleration Module.
### Configuration

View File

@ -87,6 +87,10 @@ Inference speeds will vary greatly depending on the GPU and the model used.
| Quadro P400 2GB | 20 - 25 ms |
| Quadro P2000 | ~ 12 ms |
#### AMD GPUs
With the [rocm](../configuration/object_detectors.md#amdrocm-gpu-detector) detector Frigate can take advantage of many AMD GPUs.
### Community Supported:
#### Nvidia Jetson

View File

@ -24,7 +24,6 @@ from typing_extensions import Literal
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
from frigate.detectors.util import preprocess # Assuming this function is available
# Set up logging
logger = logging.getLogger(__name__)
@ -146,17 +145,9 @@ class HailoDetector(DetectionApi):
f"[detect_raw] Converted tensor_input to numpy array: shape {tensor_input.shape}"
)
# Preprocess the tensor input using Frigate's preprocess function
processed_tensor = preprocess(
tensor_input, (1, self.h8l_model_height, self.h8l_model_width, 3), np.uint8
)
input_data = tensor_input
logger.debug(
f"[detect_raw] Tensor data and shape after preprocessing: {processed_tensor} {processed_tensor.shape}"
)
input_data = processed_tensor
logger.debug(
f"[detect_raw] Input data for inference shape: {processed_tensor.shape}, dtype: {processed_tensor.dtype}"
f"[detect_raw] Input data for inference shape: {tensor_input.shape}, dtype: {tensor_input.dtype}"
)
try:

View File

@ -1,7 +1,6 @@
import logging
import os
import cv2
import numpy as np
from typing_extensions import Literal
@ -9,7 +8,6 @@ from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import (
BaseDetectorConfig,
ModelTypeEnum,
PixelFormatEnum,
)
logger = logging.getLogger(__name__)
@ -73,24 +71,13 @@ class ONNXDetector(DetectionApi):
self.w = detector_config.model.width
self.onnx_model_type = detector_config.model.model_type
self.onnx_model_px = detector_config.model.input_pixel_format
self.onnx_model_shape = detector_config.model.input_tensor
path = detector_config.model.path
logger.info(f"ONNX: {path} loaded")
def detect_raw(self, tensor_input):
model_input_name = self.model.get_inputs()[0].name
model_input_shape = self.model.get_inputs()[0].shape
# adjust input shape
if self.onnx_model_type == ModelTypeEnum.yolonas:
tensor_input = cv2.dnn.blobFromImage(
tensor_input[0],
1.0,
(model_input_shape[3], model_input_shape[2]),
None,
swapRB=self.onnx_model_px == PixelFormatEnum.bgr,
).astype(np.uint8)
tensor_output = self.model.run(None, {model_input_name: tensor_input})
if self.onnx_model_type == ModelTypeEnum.yolonas:

View File

@ -9,8 +9,10 @@ from pydantic import Field
from typing_extensions import Literal
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
from frigate.detectors.util import preprocess
from frigate.detectors.detector_config import (
BaseDetectorConfig,
ModelTypeEnum,
)
logger = logging.getLogger(__name__)
@ -74,7 +76,16 @@ class ROCmDetector(DetectionApi):
logger.error("AMD/ROCm: module loading failed, missing ROCm environment?")
raise
if detector_config.conserve_cpu:
logger.info("AMD/ROCm: switching HIP to blocking mode to conserve CPU")
ctypes.CDLL("/opt/rocm/lib/libamdhip64.so").hipSetDeviceFlags(4)
self.h = detector_config.model.height
self.w = detector_config.model.width
self.rocm_model_type = detector_config.model.model_type
self.rocm_model_px = detector_config.model.input_pixel_format
path = detector_config.model.path
mxr_path = os.path.splitext(path)[0] + ".mxr"
if path.endswith(".mxr"):
logger.info(f"AMD/ROCm: loading parsed model from {mxr_path}")
@ -84,6 +95,7 @@ class ROCmDetector(DetectionApi):
self.model = migraphx.load(mxr_path)
else:
logger.info(f"AMD/ROCm: loading model from {path}")
if path.endswith(".onnx"):
self.model = migraphx.parse_onnx(path)
elif (
@ -95,30 +107,51 @@ class ROCmDetector(DetectionApi):
self.model = migraphx.parse_tf(path)
else:
raise Exception(f"AMD/ROCm: unknown model format {path}")
logger.info("AMD/ROCm: compiling the model")
self.model.compile(
migraphx.get_target("gpu"), offload_copy=True, fast_math=True
)
logger.info(f"AMD/ROCm: saving parsed model into {mxr_path}")
os.makedirs("/config/model_cache/rocm", exist_ok=True)
migraphx.save(self.model, mxr_path)
logger.info("AMD/ROCm: model loaded")
def detect_raw(self, tensor_input):
model_input_name = self.model.get_parameter_names()[0]
model_input_shape = tuple(
self.model.get_parameter_shapes()[model_input_name].lens()
)
tensor_input = preprocess(tensor_input, model_input_shape, np.float32)
detector_result = self.model.run({model_input_name: tensor_input})[0]
addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float))
# ruff: noqa: F841
tensor_output = np.ctypeslib.as_array(
addr, shape=detector_result.get_shape().lens()
)
if self.rocm_model_type == ModelTypeEnum.yolonas:
predictions = tensor_output
detections = np.zeros((20, 6), np.float32)
for i, prediction in enumerate(predictions):
if i == 20:
break
(_, x_min, y_min, x_max, y_max, confidence, class_id) = prediction
# when running in GPU mode, empty predictions in the output have class_id of -1
if class_id < 0:
break
detections[i] = [
class_id,
confidence,
y_min / self.h,
x_min / self.w,
y_max / self.h,
x_max / self.w,
]
return detections
else:
raise Exception(
"No models are currently supported for rocm. See the docs for more info."
f"{self.rocm_model_type} is currently not supported for rocm. See the docs for more info on supported models."
)

View File

@ -1,36 +0,0 @@
import logging
import cv2
import numpy as np
logger = logging.getLogger(__name__)
def preprocess(tensor_input, model_input_shape, model_input_element_type):
model_input_shape = tuple(model_input_shape)
assert tensor_input.dtype == np.uint8, f"tensor_input.dtype: {tensor_input.dtype}"
if len(tensor_input.shape) == 3:
tensor_input = tensor_input[np.newaxis, :]
if model_input_element_type == np.uint8:
# nothing to do for uint8 model input
assert (
model_input_shape == tensor_input.shape
), f"model_input_shape: {model_input_shape}, tensor_input.shape: {tensor_input.shape}"
return tensor_input
assert (
model_input_element_type == np.float32
), f"model_input_element_type: {model_input_element_type}"
# tensor_input must be nhwc
assert tensor_input.shape[3] == 3, f"tensor_input.shape: {tensor_input.shape}"
if tensor_input.shape[1:3] != model_input_shape[2:4]:
logger.warn(
f"preprocess: tensor_input.shape {tensor_input.shape} and model_input_shape {model_input_shape} do not match!"
)
# cv2.dnn.blobFromImage is faster than running it through numpy
return cv2.dnn.blobFromImage(
tensor_input[0],
1.0 / 255,
(model_input_shape[3], model_input_shape[2]),
None,
swapRB=False,
)