Nvidia TensorRT detector (#4718)

* Initial WIP dockerfile and scripts to add tensorrt support

* Add tensorRT detector

* WIP attempt to install TensorRT 8.5

* Updates to detector for cuda python library

* TensorRT Cuda library rework WIP

Does not run

* Fixes from rebase to detector factory

* Fix parsing output memory pointer

* Handle TensorRT logs with the python logger

* Use non-async interface and convert input data to float32. Detection runs without error.

* Make TensorRT a separate build from the base Frigate image.

* Add script and documentation for generating TRT Models

* Add support for TensorRT devcontainer

* Add labelmap to trt model script and docs.  Cleanup of old scripts.

* Update detect to normalize input tensor using model input type

* Add config for selecting GPU. Fix Async inference. Update documentation.

* Update some CUDA libraries to clean up version warning

* Add CI stage to build TensorRT tag

* Add note in docs for image tag and model support
This commit is contained in:
Nate Meyer 2022-12-30 11:53:17 -05:00 committed by GitHub
parent e3ec292528
commit 3f05f74ecb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 515 additions and 16 deletions

View File

@ -36,7 +36,19 @@ jobs:
context: .
push: true
platforms: linux/amd64,linux/arm64,linux/arm/v7
target: frigate
tags: |
ghcr.io/blakeblackshear/frigate:${{ github.ref_name }}-${{ env.SHORT_SHA }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Build and push TensorRT
uses: docker/build-push-action@v3
with:
context: .
push: true
platforms: linux/amd64
target: frigate-tensorrt
tags: |
ghcr.io/blakeblackshear/frigate:${{ github.ref_name }}-${{ env.SHORT_SHA }}-tensorrt
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@ -71,6 +71,15 @@ WORKDIR /rootfs/usr/local/go2rtc/bin
RUN wget -qO go2rtc "https://github.com/AlexxIT/go2rtc/releases/download/v0.1-rc.5/go2rtc_linux_${TARGETARCH}" \
&& chmod +x go2rtc
####
#
# OpenVino Support
#
# 1. Download and convert a model from Intel's Public Open Model Zoo
# 2. Build libUSB without udev to handle NCS2 enumeration
#
####
# Download and Convert OpenVino model
FROM base_amd64 AS ov-converter
ARG DEBIAN_FRONTEND
@ -115,8 +124,6 @@ RUN /bin/mkdir -p '/usr/local/lib' && \
/usr/bin/install -c -m 644 libusb-1.0.pc '/usr/local/lib/pkgconfig' && \
ldconfig
FROM wget AS models
# Get model and labels
@ -160,7 +167,8 @@ RUN apt-get -qq update \
libtbb2 libtbb-dev libdc1394-22-dev libopenexr-dev \
libgstreamer-plugins-base1.0-dev libgstreamer1.0-dev \
# scipy dependencies
gcc gfortran libopenblas-dev liblapack-dev
gcc gfortran libopenblas-dev liblapack-dev && \
rm -rf /var/lib/apt/lists/*
RUN wget -q https://bootstrap.pypa.io/get-pip.py -O get-pip.py \
&& python3 get-pip.py "pip"
@ -176,6 +184,10 @@ RUN pip3 install -r requirements.txt
COPY requirements-wheels.txt /requirements-wheels.txt
RUN pip3 wheel --wheel-dir=/wheels -r requirements-wheels.txt
# Add TensorRT wheels to another folder
COPY requirements-tensorrt.txt /requirements-tensorrt.txt
RUN mkdir -p /trt-wheels && pip3 wheel --wheel-dir=/trt-wheels -r requirements-tensorrt.txt
# Collect deps in a single layer
FROM scratch AS deps-rootfs
@ -283,7 +295,18 @@ COPY migrations migrations/
COPY --from=web-build /work/dist/ web/
# Frigate final container
FROM deps
FROM deps AS frigate
WORKDIR /opt/frigate/
COPY --from=rootfs / /
# Frigate w/ TensorRT Support as separate image
FROM frigate AS frigate-tensorrt
RUN --mount=type=bind,from=wheels,source=/trt-wheels,target=/deps/trt-wheels \
pip3 install -U /deps/trt-wheels/*.whl
# Dev Container w/ TRT
FROM devcontainer AS devcontainer-trt
RUN --mount=type=bind,from=wheels,source=/trt-wheels,target=/deps/trt-wheels \
pip3 install -U /deps/trt-wheels/*.whl

View File

@ -10,22 +10,27 @@ version:
echo 'VERSION = "$(VERSION)-$(COMMIT_HASH)"' > frigate/version.py
local: version
docker buildx build --tag frigate:latest --load .
docker buildx build --target=frigate --tag frigate:latest --load .
local-trt: version
docker buildx build --target=frigate-tensorrt --tag frigate:latest-tensorrt --load .
amd64:
docker buildx build --platform linux/amd64 --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH) .
docker buildx build --platform linux/amd64 --target=frigate --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH) .
docker buildx build --platform linux/amd64 --target=frigate-tensorrt --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH)-tensorrt .
arm64:
docker buildx build --platform linux/arm64 --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH) .
docker buildx build --platform linux/arm64 --target=frigate --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH) .
armv7:
docker buildx build --platform linux/arm/v7 --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH) .
docker buildx build --platform linux/arm/v7 --target=frigate --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH) .
build: version amd64 arm64 armv7
docker buildx build --platform linux/arm/v7,linux/arm64/v8,linux/amd64 --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH) .
docker buildx build --platform linux/arm/v7,linux/arm64/v8,linux/amd64 --target=frigate --tag $(IMAGE_REPO):$(VERSION)-$(COMMIT_HASH) .
push: build
docker buildx build --push --platform linux/arm/v7,linux/arm64/v8,linux/amd64 --tag $(IMAGE_REPO):${GITHUB_REF_NAME}-$(COMMIT_HASH) .
docker buildx build --push --platform linux/arm/v7,linux/arm64/v8,linux/amd64 --target=frigate --tag $(IMAGE_REPO):${GITHUB_REF_NAME}-$(COMMIT_HASH) .
docker buildx build --push --platform linux/amd64 --target=frigate-tensorrt --tag $(IMAGE_REPO):${GITHUB_REF_NAME}-$(COMMIT_HASH)-tensorrt .
run: local
docker run --rm --publish=5000:5000 --volume=${PWD}/config/config.yml:/config/config.yml frigate:latest

View File

@ -11,7 +11,15 @@ services:
shm_size: "256mb"
build:
context: .
# Use target devcontainer-trt for TensorRT dev
target: devcontainer
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
devices:
- /dev/bus/usb:/dev/bus/usb
# - /dev/dri:/dev/dri # for intel hwaccel, needs to be updated for your hardware
@ -21,6 +29,8 @@ services:
- /etc/localtime:/etc/localtime:ro
- ./config/config.yml:/config/config.yml:ro
- ./debug:/media/frigate
# Create the trt-models folder using the documented method of generating TRT models
# - ./debug/trt-models:/trt-models
- /dev/bus/usb:/dev/bus/usb
mqtt:
container_name: mqtt

37
docker/tensorrt_models.sh Executable file
View File

@ -0,0 +1,37 @@
#!/bin/bash
set -euxo pipefail
CUDA_HOME=/usr/local/cuda
LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64
OUTPUT_FOLDER=/tensorrt_models
echo "Generating the following TRT Models: ${YOLO_MODELS:="yolov4-tiny-288,yolov4-tiny-416,yolov7-tiny-416"}"
# Create output folder
mkdir -p ${OUTPUT_FOLDER}
# Install packages
pip install --upgrade pip && pip install onnx==1.9.0 protobuf==3.20.3
# Clone tensorrt_demos repo
git clone --depth 1 https://github.com/yeahme49/tensorrt_demos.git /tensorrt_demos
# Build libyolo
cd /tensorrt_demos/plugins && make all
cp libyolo_layer.so ${OUTPUT_FOLDER}/libyolo_layer.so
# Download yolo weights
cd /tensorrt_demos/yolo && ./download_yolo.sh
# Build trt engine
cd /tensorrt_demos/yolo
for model in ${YOLO_MODELS//,/ }
do
python3 yolo_to_onnx.py -m ${model}
python3 onnx_to_tensorrt.py -m ${model}
cp /tensorrt_demos/yolo/${model}.trt ${OUTPUT_FOLDER}/${model}.trt;
done
# Download Labelmap
wget -q https://github.com/openvinotoolkit/open_model_zoo/raw/master/data/dataset_classes/coco_91cl.txt -O ${OUTPUT_FOLDER}/coco_91cl.txt

View File

@ -3,11 +3,10 @@ id: detectors
title: Detectors
---
Frigate provides the following builtin detector types: `cpu`, `edgetpu`, and `openvino`. By default, Frigate will use a single CPU detector. Other detectors may require additional configuration as described below. When using multiple detectors they will run in dedicated processes, but pull from a common queue of detection requests from across all cameras.
**Note**: There is not yet support for Nvidia GPUs to perform object detection with tensorflow. It can be used for ffmpeg decoding, but not object detection.
Frigate provides the following builtin detector types: `cpu`, `edgetpu`, `openvino`, and `tensorrt`. By default, Frigate will use a single CPU detector. Other detectors may require additional configuration as described below. When using multiple detectors they will run in dedicated processes, but pull from a common queue of detection requests from across all cameras.
## CPU Detector (not recommended)
The CPU detector type runs a TensorFlow Lite model utilizing the CPU without hardware acceleration. It is recommended to use a hardware accelerated detector type instead for better performance. To configure a CPU based detector, set the `"type"` attribute to `"cpu"`.
The number of threads used by the interpreter can be specified using the `"num_threads"` attribute, and defaults to `3.`
@ -60,6 +59,7 @@ detectors:
```
### Native Coral (Dev Board)
_warning: may have [compatibility issues](https://github.com/blakeblackshear/frigate/issues/1706) after `v0.9.x`_
```yaml
@ -99,7 +99,7 @@ The OpenVINO detector type runs an OpenVINO IR model on Intel CPU, GPU and VPU h
The OpenVINO device to be used is specified using the `"device"` attribute according to the naming conventions in the [Device Documentation](https://docs.openvino.ai/latest/openvino_docs_OV_UG_Working_with_devices.html). Other supported devices could be `AUTO`, `CPU`, `GPU`, `MYRIAD`, etc. If not specified, the default OpenVINO device will be selected by the `AUTO` plugin.
OpenVINO is supported on 6th Gen Intel platforms (Skylake) and newer. A supported Intel platform is required to use the `GPU` device with OpenVINO. The `MYRIAD` device may be run on any platform, including Arm devices. For detailed system requirements, see [OpenVINO System Requirements](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/system-requirements.html)
OpenVINO is supported on 6th Gen Intel platforms (Skylake) and newer. A supported Intel platform is required to use the `GPU` device with OpenVINO. The `MYRIAD` device may be run on any platform, including Arm devices. For detailed system requirements, see [OpenVINO System Requirements](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/system-requirements.html)
An OpenVINO model is provided in the container at `/openvino-model/ssdlite_mobilenet_v2.xml` and is used by this detector type by default. The model comes from Intel's Open Model Zoo [SSDLite MobileNet V2](https://github.com/openvinotoolkit/open_model_zoo/tree/master/models/public/ssdlite_mobilenet_v2) and is converted to an FP16 precision IR model. Use the model configuration shown below when using the OpenVINO detector.
@ -121,7 +121,7 @@ model:
### Intel NCS2 VPU and Myriad X Setup
Intel produces a neural net inference accelleration chip called Myriad X. This chip was sold in their Neural Compute Stick 2 (NCS2) which has been discontinued. If intending to use the MYRIAD device for accelleration, additional setup is required to pass through the USB device. The host needs a udev rule installed to handle the NCS2 device.
Intel produces a neural net inference accelleration chip called Myriad X. This chip was sold in their Neural Compute Stick 2 (NCS2) which has been discontinued. If intending to use the MYRIAD device for accelleration, additional setup is required to pass through the USB device. The host needs a udev rule installed to handle the NCS2 device.
```bash
sudo usermod -a -G users "$(whoami)"
@ -139,11 +139,96 @@ Additionally, the Frigate docker container needs to run with the following confi
```bash
--device-cgroup-rule='c 189:\* rmw' -v /dev/bus/usb:/dev/bus/usb
```
or in your compose file:
```yml
device_cgroup_rules:
- 'c 189:* rmw'
- "c 189:* rmw"
volumes:
- /dev/bus/usb:/dev/bus/usb
```
## NVidia TensorRT Detector
NVidia GPUs may be used for object detection using the TensorRT libraries. Due to the size of the additional libraries, this detector is only provided in images with the `-tensorrt` tag suffix. This detector is designed to work with Yolo models for object detection.
### Minimum Hardware Support
The TensorRT detector uses the 11.x series of CUDA libraries which have minor version compatibility. The minimum driver version on the host system must be `>=450.80.02`. Also the GPU must support a Compute Capability of `5.0` or greater. This generally correlates to a Maxwell-era GPU or newer, check the NVIDIA GPU Compute Capability table linked below.
> **TODO:** NVidia claims support on compute 3.5 and 3.7, but marks it as deprecated. This would have some, but not all, Kepler GPUs as possibly working. This needs testing before making any claims of support.
There are improved capabilities in newer GPU architectures that TensorRT can benefit from, such as INT8 operations and Tensor cores. The features compatible with your hardware will be optimized when the model is converted to a trt file. Currently the script provided for generating the model provides a switch to enable/disable FP16 operations. If you wish to use newer features such as INT8 optimization, more work is required.
#### Compatibility References:
[NVIDIA TensorRT Support Matrix](https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-841/support-matrix/index.html)
[NVIDIA CUDA Compatibility](https://docs.nvidia.com/deploy/cuda-compatibility/index.html)
[NVIDIA GPU Compute Capability](https://developer.nvidia.com/cuda-gpus)
### Generate Models
The models used for TensorRT must be preprocessed on the same hardware platform that they will run on. This means that each user must run additional setup to generate these model files for the TensorRT library. A script is provided that will build several common models.
To generate the model files, create a new folder to save the models, download the script, and launch a docker container that will run the script.
```bash
mkdir trt-models
wget https://raw.githubusercontent.com/blakeblackshear/frigate/nvidia-detector/docker/tensorrt_models.sh
chmod +x tensorrt_models.sh
docker run --gpus=all --rm -it -v `pwd`/trt-models:/tensorrt_models -v `pwd`/tensorrt_models.sh:/tensorrt_models.sh nvcr.io/nvidia/tensorrt:22.07-py3 /tensorrt_models.sh
```
The `trt-models` folder can then be mapped into your frigate container as `trt-models` and the models referenced from the config.
If your GPU does not support FP16 operations, you can pass the environment variable `-e USE_FP16=False` to the `docker run` command to disable it.
Specific models can be selected by passing an environment variable to the `docker run` command. Use the form `-e YOLO_MODELS=yolov4-416,yolov4-tiny-416` to select one or more model names. The models available are shown below.
```
yolov3-288
yolov3-416
yolov3-608
yolov3-spp-288
yolov3-spp-416
yolov3-spp-608
yolov3-tiny-288
yolov3-tiny-416
yolov4-288
yolov4-416
yolov4-608
yolov4-csp-256
yolov4-csp-512
yolov4-p5-448
yolov4-p5-896
yolov4-tiny-288
yolov4-tiny-416
yolov4x-mish-320
yolov4x-mish-640
yolov7-tiny-288
yolov7-tiny-416
```
### Configuration Parameters
The TensorRT detector can be selected by specifying `tensorrt` as the model type. The GPU will need to be passed through to the docker container using the same methods described in the [Hardware Acceleration](hardware_acceleration.md#nvidia-gpu) section. If you pass through multiple GPUs, you can select which GPU is used for a detector with the `device` configuration parameter. The `device` parameter is an integer value of the GPU index, as shown by `nvidia-smi` within the container.
The TensorRT detector uses `.trt` model files that are located in `/trt-models/` by default. These model file path and dimensions used will depend on which model you have generated.
```yaml
detectors:
tensorrt:
type: tensorrt
device: 0 #This is the default, select the first GPU
model:
path: /trt-models/yolov7-tiny-416.trt
labelmap_path: /trt-models/coco_91cl.txt
input_tensor: nchw
input_pixel_format: rgb
width: 416
height: 416
```

View File

@ -0,0 +1,318 @@
import logging
import ctypes
import numpy as np
try:
import tensorrt as trt
from cuda import cuda
TRT_SUPPORT = True
except ModuleNotFoundError as e:
TRT_SUPPORT = False
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
from typing import Literal
from pydantic import Field
logger = logging.getLogger(__name__)
DETECTOR_KEY = "tensorrt"
if TRT_SUPPORT:
class TrtLogger(trt.ILogger):
def __init__(self):
trt.ILogger.__init__(self)
def log(self, severity, msg):
logger.log(self.getSeverity(severity), msg)
def getSeverity(self, sev: trt.ILogger.Severity) -> int:
if sev == trt.ILogger.VERBOSE:
return logging.DEBUG
elif sev == trt.ILogger.INFO:
return logging.INFO
elif sev == trt.ILogger.WARNING:
return logging.WARNING
elif sev == trt.ILogger.ERROR:
return logging.ERROR
elif sev == trt.ILogger.INTERNAL_ERROR:
return logging.CRITICAL
else:
return logging.DEBUG
class TensorRTDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
device: int = Field(default=0, title="GPU Device Index")
class HostDeviceMem(object):
"""Simple helper data class that's a little nicer to use than a 2-tuple."""
def __init__(self, host_mem, device_mem, nbytes, size):
self.host = host_mem
err, self.host_dev = cuda.cuMemHostGetDevicePointer(self.host, 0)
self.device = device_mem
self.nbytes = nbytes
self.size = size
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
def __del__(self):
cuda.cuMemFreeHost(self.host)
cuda.cuMemFree(self.device)
class TensorRtDetector(DetectionApi):
type_key = DETECTOR_KEY
def _load_engine(self, model_path):
try:
ctypes.cdll.LoadLibrary(
"/usr/local/lib/python3.9/dist-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0"
)
ctypes.cdll.LoadLibrary(
"/usr/local/lib/python3.9/dist-packages/tensorrt/libnvinfer.so.8"
)
trt.init_libnvinfer_plugins(self.trt_logger, "")
ctypes.cdll.LoadLibrary("/trt-models/libyolo_layer.so")
except OSError as e:
logger.error(
"ERROR: failed to load libraries. %s",
e,
)
with open(model_path, "rb") as f, trt.Runtime(self.trt_logger) as runtime:
return runtime.deserialize_cuda_engine(f.read())
def _get_input_shape(self):
"""Get input shape of the TensorRT YOLO engine."""
binding = self.engine[0]
assert self.engine.binding_is_input(binding)
binding_dims = self.engine.get_binding_shape(binding)
if len(binding_dims) == 4:
return (
tuple(binding_dims[2:]),
trt.nptype(self.engine.get_binding_dtype(binding)),
)
elif len(binding_dims) == 3:
return (
tuple(binding_dims[1:]),
trt.nptype(self.engine.get_binding_dtype(binding)),
)
else:
raise ValueError(
"bad dims of binding %s: %s" % (binding, str(binding_dims))
)
def _allocate_buffers(self):
"""Allocates all host/device in/out buffers required for an engine."""
inputs = []
outputs = []
bindings = []
output_idx = 0
for binding in self.engine:
binding_dims = self.engine.get_binding_shape(binding)
if len(binding_dims) == 4:
# explicit batch case (TensorRT 7+)
size = trt.volume(binding_dims)
elif len(binding_dims) == 3:
# implicit batch case (TensorRT 6 or older)
size = trt.volume(binding_dims) * self.engine.max_batch_size
else:
raise ValueError(
"bad dims of binding %s: %s" % (binding, str(binding_dims))
)
nbytes = size * self.engine.get_binding_dtype(binding).itemsize
# Allocate host and device buffers
err, host_mem = cuda.cuMemHostAlloc(
nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP
)
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}"
logger.debug(
f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self.engine.get_binding_dtype(binding)})"
)
err, device_mem = cuda.cuMemAlloc(nbytes)
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}"
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if self.engine.binding_is_input(binding):
logger.debug(f"Input has Shape {binding_dims}")
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
else:
# each grid has 3 anchors, each anchor generates a detection
# output of 7 float32 values
assert size % 7 == 0, f"output size was {size}"
logger.debug(f"Output has Shape {binding_dims}")
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
output_idx += 1
assert len(inputs) == 1, f"inputs len was {len(inputs)}"
assert len(outputs) == 1, f"output len was {len(outputs)}"
return inputs, outputs, bindings
def _do_inference(self):
"""do_inference (for TensorRT 7.0+)
This function is generalized for multiple inputs/outputs for full
dimension networks.
Inputs and outputs are expected to be lists of HostDeviceMem objects.
"""
# Push CUDA Context
cuda.cuCtxPushCurrent(self.cu_ctx)
# Transfer input data to the GPU.
[
cuda.cuMemcpyHtoDAsync(inp.device, inp.host, inp.nbytes, self.stream)
for inp in self.inputs
]
# Run inference.
if not self.context.execute_async_v2(
bindings=self.bindings, stream_handle=self.stream
):
logger.warn(f"Execute returned false")
# Transfer predictions back from the GPU.
[
cuda.cuMemcpyDtoHAsync(out.host, out.device, out.nbytes, self.stream)
for out in self.outputs
]
# Synchronize the stream
cuda.cuStreamSynchronize(self.stream)
# Pop CUDA Context
cuda.cuCtxPopCurrent()
# Return only the host outputs.
return [
np.array(
(ctypes.c_float * out.size).from_address(out.host), dtype=np.float32
)
for out in self.outputs
]
def __init__(self, detector_config: TensorRTDetectorConfig):
assert (
TRT_SUPPORT
), f"TensorRT libraries not found, {DETECTOR_KEY} detector not present"
(cuda_err,) = cuda.cuInit(0)
assert (
cuda_err == cuda.CUresult.CUDA_SUCCESS
), f"Failed to initialize cuda {cuda_err}"
err, dev_count = cuda.cuDeviceGetCount()
logger.debug(f"Num Available Devices: {dev_count}")
assert (
detector_config.device < dev_count
), f"Invalid TensorRT Device Config. Device {detector_config.device} Invalid."
err, self.cu_ctx = cuda.cuCtxCreate(
cuda.CUctx_flags.CU_CTX_MAP_HOST, detector_config.device
)
self.conf_th = 0.4 ##TODO: model config parameter
self.nms_threshold = 0.4
err, self.stream = cuda.cuStreamCreate(0)
self.trt_logger = TrtLogger()
self.engine = self._load_engine(detector_config.model.path)
self.input_shape = self._get_input_shape()
try:
self.context = self.engine.create_execution_context()
(
self.inputs,
self.outputs,
self.bindings,
) = self._allocate_buffers()
except Exception as e:
logger.error(e)
raise RuntimeError("fail to allocate CUDA resources") from e
logger.debug("TensorRT loaded. Input shape is %s", self.input_shape)
logger.debug("TensorRT version is %s", trt.__version__[0])
def __del__(self):
"""Free CUDA memories."""
if self.outputs is not None:
del self.outputs
if self.inputs is not None:
del self.inputs
if self.stream is not None:
cuda.cuStreamDestroy(self.stream)
del self.stream
del self.engine
del self.context
del self.trt_logger
cuda.cuCtxDestroy(self.cu_ctx)
def _postprocess_yolo(self, trt_outputs, conf_th):
"""Postprocess TensorRT outputs.
# Args
trt_outputs: a list of 2 or 3 tensors, where each tensor
contains a multiple of 7 float32 numbers in
the order of [x, y, w, h, box_confidence, class_id, class_prob]
conf_th: confidence threshold
# Returns
boxes, scores, classes
"""
# filter low-conf detections and concatenate results of all yolo layers
detections = []
for o in trt_outputs:
dets = o.reshape((-1, 7))
dets = dets[dets[:, 4] * dets[:, 6] >= conf_th]
detections.append(dets)
detections = np.concatenate(detections, axis=0)
return detections
def detect_raw(self, tensor_input):
# Input tensor has the shape of the [height, width, 3]
# Output tensor of float32 of shape [20, 6] where:
# O - class id
# 1 - score
# 2..5 - a value between 0 and 1 of the box: [top, left, bottom, right]
# normalize
if self.input_shape[-1] != trt.int8:
tensor_input = tensor_input.astype(self.input_shape[-1])
tensor_input /= 255.0
self.inputs[0].host = np.ascontiguousarray(
tensor_input.astype(self.input_shape[-1])
)
trt_outputs = self._do_inference()
raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th)
if len(raw_detections) == 0:
return np.zeros((20, 6), np.float32)
# raw_detections: Nx7 numpy arrays of
# [[x, y, w, h, box_confidence, class_id, class_prob],
# Calculate score as box_confidence x class_prob
raw_detections[:, 4] = raw_detections[:, 4] * raw_detections[:, 6]
# Reorder elements by the score, best on top, remove class_prob
ordered = raw_detections[raw_detections[:, 4].argsort()[::-1]][:, 0:6]
# transform width to right with clamp to 0..1
ordered[:, 2] = np.clip(ordered[:, 2] + ordered[:, 0], 0, 1)
# transform height to bottom with clamp to 0..1
ordered[:, 3] = np.clip(ordered[:, 3] + ordered[:, 1], 0, 1)
# put result into the correct order and limit to top 20
detections = ordered[:, [5, 4, 1, 0, 3, 2]][:20]
# pad to 20x6 shape
append_cnt = 20 - len(detections)
if append_cnt > 0:
detections = np.append(
detections, np.zeros((append_cnt, 6), np.float32), axis=0
)
return detections

View File

@ -0,0 +1,8 @@
# NVidia TensorRT Support (amd64 only)
nvidia-pyindex; platform_machine == 'x86_64'
nvidia-tensorrt == 8.4.1.5; platform_machine == 'x86_64'
cuda-python == 11.7; platform_machine == 'x86_64'
cython == 0.29.*; platform_machine == 'x86_64'
nvidia-cuda-runtime-cu11 == 11.7.*; platform_machine == 'x86_64'
nvidia-cublas-cu11 == 11.11.*; platform_machine == 'x86_64'
nvidia-cudnn-cu11 == 8.7.*; platform_machine == 'x86_64'

View File

@ -1 +1,2 @@
scikit-build == 0.14.1
nvidia-pyindex