mirror of
				https://github.com/blakeblackshear/frigate.git
				synced 2025-10-27 10:52:11 +01:00 
			
		
		
		
	* Add support for TensorRT v10 (multiple api calls have changed) * Remove unnecessary size check in TensorRT v10 block * Refactor to reduce code duplication * Fix wrong function name in new _get_binding_dtype function and only return input check (not assertion) in new _binding_is_input function * Add space around TRT_VERSION variable assignment (=) to respect linting * More linting fix * Update frigate/detectors/plugins/tensorrt.py Co-authored-by: Nicolas Mowen <nickmowen213@gmail.com> * More linting --------- Co-authored-by: Nicolas Mowen <nickmowen213@gmail.com>
		
			
				
	
	
		
			341 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			341 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import ctypes
 | 
						|
import logging
 | 
						|
 | 
						|
import numpy as np
 | 
						|
 | 
						|
try:
 | 
						|
    import tensorrt as trt
 | 
						|
    from cuda import cuda
 | 
						|
 | 
						|
    TRT_VERSION = int(trt.__version__[0 : trt.__version__.find(".")])
 | 
						|
 | 
						|
    TRT_SUPPORT = True
 | 
						|
except ModuleNotFoundError:
 | 
						|
    TRT_SUPPORT = False
 | 
						|
 | 
						|
from pydantic import Field
 | 
						|
from typing_extensions import Literal
 | 
						|
 | 
						|
from frigate.detectors.detection_api import DetectionApi
 | 
						|
from frigate.detectors.detector_config import BaseDetectorConfig
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
DETECTOR_KEY = "tensorrt"
 | 
						|
 | 
						|
if TRT_SUPPORT:
 | 
						|
 | 
						|
    class TrtLogger(trt.ILogger):
 | 
						|
        def __init__(self):
 | 
						|
            trt.ILogger.__init__(self)
 | 
						|
 | 
						|
        def log(self, severity, msg):
 | 
						|
            logger.log(self.getSeverity(severity), msg)
 | 
						|
 | 
						|
        def getSeverity(self, sev: trt.ILogger.Severity) -> int:
 | 
						|
            if sev == trt.ILogger.VERBOSE:
 | 
						|
                return logging.DEBUG
 | 
						|
            elif sev == trt.ILogger.INFO:
 | 
						|
                return logging.INFO
 | 
						|
            elif sev == trt.ILogger.WARNING:
 | 
						|
                return logging.WARNING
 | 
						|
            elif sev == trt.ILogger.ERROR:
 | 
						|
                return logging.ERROR
 | 
						|
            elif sev == trt.ILogger.INTERNAL_ERROR:
 | 
						|
                return logging.CRITICAL
 | 
						|
            else:
 | 
						|
                return logging.DEBUG
 | 
						|
 | 
						|
 | 
						|
class TensorRTDetectorConfig(BaseDetectorConfig):
 | 
						|
    type: Literal[DETECTOR_KEY]
 | 
						|
    device: int = Field(default=0, title="GPU Device Index")
 | 
						|
 | 
						|
 | 
						|
class HostDeviceMem(object):
 | 
						|
    """Simple helper data class that's a little nicer to use than a 2-tuple."""
 | 
						|
 | 
						|
    def __init__(self, host_mem, device_mem, nbytes, size):
 | 
						|
        self.host = host_mem
 | 
						|
        err, self.host_dev = cuda.cuMemHostGetDevicePointer(self.host, 0)
 | 
						|
        self.device = device_mem
 | 
						|
        self.nbytes = nbytes
 | 
						|
        self.size = size
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return self.__str__()
 | 
						|
 | 
						|
    def __del__(self):
 | 
						|
        cuda.cuMemFreeHost(self.host)
 | 
						|
        cuda.cuMemFree(self.device)
 | 
						|
 | 
						|
 | 
						|
class TensorRtDetector(DetectionApi):
 | 
						|
    type_key = DETECTOR_KEY
 | 
						|
 | 
						|
    def _load_engine(self, model_path):
 | 
						|
        try:
 | 
						|
            trt.init_libnvinfer_plugins(self.trt_logger, "")
 | 
						|
 | 
						|
            ctypes.cdll.LoadLibrary("/usr/local/lib/libyolo_layer.so")
 | 
						|
        except OSError as e:
 | 
						|
            logger.error(
 | 
						|
                "ERROR: failed to load libraries. %s",
 | 
						|
                e,
 | 
						|
            )
 | 
						|
 | 
						|
        with open(model_path, "rb") as f, trt.Runtime(self.trt_logger) as runtime:
 | 
						|
            return runtime.deserialize_cuda_engine(f.read())
 | 
						|
 | 
						|
    def _binding_is_input(self, binding):
 | 
						|
        if TRT_VERSION < 10:
 | 
						|
            return self.engine.binding_is_input(binding)
 | 
						|
        else:
 | 
						|
            return binding == "input"
 | 
						|
 | 
						|
    def _get_binding_dims(self, binding):
 | 
						|
        if TRT_VERSION < 10:
 | 
						|
            return self.engine.get_binding_shape(binding)
 | 
						|
        else:
 | 
						|
            return self.engine.get_tensor_shape(binding)
 | 
						|
 | 
						|
    def _get_binding_dtype(self, binding):
 | 
						|
        if TRT_VERSION < 10:
 | 
						|
            return self.engine.get_binding_dtype(binding)
 | 
						|
        else:
 | 
						|
            return self.engine.get_tensor_dtype(binding)
 | 
						|
 | 
						|
    def _execute(self):
 | 
						|
        if TRT_VERSION < 10:
 | 
						|
            return self.context.execute_async_v2(
 | 
						|
                bindings=self.bindings, stream_handle=self.stream
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            return self.context.execute_v2(self.bindings)
 | 
						|
 | 
						|
    def _get_input_shape(self):
 | 
						|
        """Get input shape of the TensorRT YOLO engine."""
 | 
						|
        binding = self.engine[0]
 | 
						|
        assert self._binding_is_input(binding)
 | 
						|
        binding_dims = self._get_binding_dims(binding)
 | 
						|
        if len(binding_dims) == 4:
 | 
						|
            return (
 | 
						|
                tuple(binding_dims[2:]),
 | 
						|
                trt.nptype(self._get_binding_dtype(binding)),
 | 
						|
            )
 | 
						|
        elif len(binding_dims) == 3:
 | 
						|
            return (
 | 
						|
                tuple(binding_dims[1:]),
 | 
						|
                trt.nptype(self._get_binding_dtype(binding)),
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise ValueError(
 | 
						|
                "bad dims of binding %s: %s" % (binding, str(binding_dims))
 | 
						|
            )
 | 
						|
 | 
						|
    def _allocate_buffers(self):
 | 
						|
        """Allocates all host/device in/out buffers required for an engine."""
 | 
						|
        inputs = []
 | 
						|
        outputs = []
 | 
						|
        bindings = []
 | 
						|
        output_idx = 0
 | 
						|
        for binding in self.engine:
 | 
						|
            binding_dims = self._get_binding_dims(binding)
 | 
						|
            if len(binding_dims) == 4:
 | 
						|
                # explicit batch case (TensorRT 7+)
 | 
						|
                size = trt.volume(binding_dims)
 | 
						|
            elif len(binding_dims) == 3:
 | 
						|
                # implicit batch case (TensorRT 6 or older)
 | 
						|
                size = trt.volume(binding_dims) * self.engine.max_batch_size
 | 
						|
            else:
 | 
						|
                raise ValueError(
 | 
						|
                    "bad dims of binding %s: %s" % (binding, str(binding_dims))
 | 
						|
                )
 | 
						|
            nbytes = size * self._get_binding_dtype(binding).itemsize
 | 
						|
            # Allocate host and device buffers
 | 
						|
            err, host_mem = cuda.cuMemHostAlloc(
 | 
						|
                nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP
 | 
						|
            )
 | 
						|
            assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}"
 | 
						|
            logger.debug(
 | 
						|
                f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self._get_binding_dtype(binding)})"
 | 
						|
            )
 | 
						|
            err, device_mem = cuda.cuMemAlloc(nbytes)
 | 
						|
            assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}"
 | 
						|
            # Append the device buffer to device bindings.
 | 
						|
            bindings.append(int(device_mem))
 | 
						|
            # Append to the appropriate list.
 | 
						|
            if self._binding_is_input(binding):
 | 
						|
                logger.debug(f"Input has Shape {binding_dims}")
 | 
						|
                inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
 | 
						|
            else:
 | 
						|
                # each grid has 3 anchors, each anchor generates a detection
 | 
						|
                # output of 7 float32 values
 | 
						|
                assert size % 7 == 0, f"output size was {size}"
 | 
						|
                logger.debug(f"Output has Shape {binding_dims}")
 | 
						|
                outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
 | 
						|
                output_idx += 1
 | 
						|
        assert len(inputs) == 1, f"inputs len was {len(inputs)}"
 | 
						|
        assert len(outputs) == 1, f"output len was {len(outputs)}"
 | 
						|
        return inputs, outputs, bindings
 | 
						|
 | 
						|
    def _do_inference(self):
 | 
						|
        """do_inference (for TensorRT 7.0+)
 | 
						|
        This function is generalized for multiple inputs/outputs for full
 | 
						|
        dimension networks.
 | 
						|
        Inputs and outputs are expected to be lists of HostDeviceMem objects.
 | 
						|
        """
 | 
						|
        # Push CUDA Context
 | 
						|
        cuda.cuCtxPushCurrent(self.cu_ctx)
 | 
						|
 | 
						|
        # Transfer input data to the GPU.
 | 
						|
        [
 | 
						|
            cuda.cuMemcpyHtoDAsync(inp.device, inp.host, inp.nbytes, self.stream)
 | 
						|
            for inp in self.inputs
 | 
						|
        ]
 | 
						|
 | 
						|
        # Run inference.
 | 
						|
        if not self._execute():
 | 
						|
            logger.warn("Execute returned false")
 | 
						|
 | 
						|
        # Transfer predictions back from the GPU.
 | 
						|
        [
 | 
						|
            cuda.cuMemcpyDtoHAsync(out.host, out.device, out.nbytes, self.stream)
 | 
						|
            for out in self.outputs
 | 
						|
        ]
 | 
						|
 | 
						|
        # Synchronize the stream
 | 
						|
        cuda.cuStreamSynchronize(self.stream)
 | 
						|
 | 
						|
        # Pop CUDA Context
 | 
						|
        cuda.cuCtxPopCurrent()
 | 
						|
 | 
						|
        # Return only the host outputs.
 | 
						|
        return [
 | 
						|
            np.array(
 | 
						|
                (ctypes.c_float * out.size).from_address(out.host), dtype=np.float32
 | 
						|
            )
 | 
						|
            for out in self.outputs
 | 
						|
        ]
 | 
						|
 | 
						|
    def __init__(self, detector_config: TensorRTDetectorConfig):
 | 
						|
        assert (
 | 
						|
            TRT_SUPPORT
 | 
						|
        ), f"TensorRT libraries not found, {DETECTOR_KEY} detector not present"
 | 
						|
 | 
						|
        (cuda_err,) = cuda.cuInit(0)
 | 
						|
        assert (
 | 
						|
            cuda_err == cuda.CUresult.CUDA_SUCCESS
 | 
						|
        ), f"Failed to initialize cuda {cuda_err}"
 | 
						|
        err, dev_count = cuda.cuDeviceGetCount()
 | 
						|
        logger.debug(f"Num Available Devices: {dev_count}")
 | 
						|
        assert (
 | 
						|
            detector_config.device < dev_count
 | 
						|
        ), f"Invalid TensorRT Device Config. Device {detector_config.device} Invalid."
 | 
						|
        err, self.cu_ctx = cuda.cuCtxCreate(
 | 
						|
            cuda.CUctx_flags.CU_CTX_MAP_HOST, detector_config.device
 | 
						|
        )
 | 
						|
 | 
						|
        self.conf_th = 0.4  ##TODO: model config parameter
 | 
						|
        self.nms_threshold = 0.4
 | 
						|
        err, self.stream = cuda.cuStreamCreate(0)
 | 
						|
        self.trt_logger = TrtLogger()
 | 
						|
        self.engine = self._load_engine(detector_config.model.path)
 | 
						|
        self.input_shape = self._get_input_shape()
 | 
						|
 | 
						|
        try:
 | 
						|
            self.context = self.engine.create_execution_context()
 | 
						|
            (
 | 
						|
                self.inputs,
 | 
						|
                self.outputs,
 | 
						|
                self.bindings,
 | 
						|
            ) = self._allocate_buffers()
 | 
						|
        except Exception as e:
 | 
						|
            logger.error(e)
 | 
						|
            raise RuntimeError("fail to allocate CUDA resources") from e
 | 
						|
 | 
						|
        logger.debug("TensorRT loaded. Input shape is %s", self.input_shape)
 | 
						|
        logger.debug("TensorRT version is %s", trt.__version__[0])
 | 
						|
 | 
						|
    def __del__(self):
 | 
						|
        """Free CUDA memories."""
 | 
						|
        if self.outputs is not None:
 | 
						|
            del self.outputs
 | 
						|
        if self.inputs is not None:
 | 
						|
            del self.inputs
 | 
						|
        if self.stream is not None:
 | 
						|
            cuda.cuStreamDestroy(self.stream)
 | 
						|
            del self.stream
 | 
						|
        del self.engine
 | 
						|
        del self.context
 | 
						|
        del self.trt_logger
 | 
						|
        cuda.cuCtxDestroy(self.cu_ctx)
 | 
						|
 | 
						|
    def _postprocess_yolo(self, trt_outputs, conf_th):
 | 
						|
        """Postprocess TensorRT outputs.
 | 
						|
        # Args
 | 
						|
            trt_outputs: a list of 2 or 3 tensors, where each tensor
 | 
						|
                        contains a multiple of 7 float32 numbers in
 | 
						|
                        the order of [x, y, w, h, box_confidence, class_id, class_prob]
 | 
						|
            conf_th: confidence threshold
 | 
						|
        # Returns
 | 
						|
            boxes, scores, classes
 | 
						|
        """
 | 
						|
        # filter low-conf detections and concatenate results of all yolo layers
 | 
						|
        detections = []
 | 
						|
        for o in trt_outputs:
 | 
						|
            dets = o.reshape((-1, 7))
 | 
						|
            dets = dets[dets[:, 4] * dets[:, 6] >= conf_th]
 | 
						|
            detections.append(dets)
 | 
						|
        detections = np.concatenate(detections, axis=0)
 | 
						|
 | 
						|
        return detections
 | 
						|
 | 
						|
    def detect_raw(self, tensor_input):
 | 
						|
        # Input tensor has the shape of the [height, width, 3]
 | 
						|
        # Output tensor of float32 of shape [20, 6] where:
 | 
						|
        # O - class id
 | 
						|
        # 1 - score
 | 
						|
        # 2..5 - a value between 0 and 1 of the box: [top, left, bottom, right]
 | 
						|
 | 
						|
        # normalize
 | 
						|
        if self.input_shape[-1] != trt.int8:
 | 
						|
            tensor_input = tensor_input.astype(self.input_shape[-1])
 | 
						|
            tensor_input /= 255.0
 | 
						|
 | 
						|
        self.inputs[0].host = np.ascontiguousarray(
 | 
						|
            tensor_input.astype(self.input_shape[-1])
 | 
						|
        )
 | 
						|
        trt_outputs = self._do_inference()
 | 
						|
 | 
						|
        raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th)
 | 
						|
 | 
						|
        if len(raw_detections) == 0:
 | 
						|
            return np.zeros((20, 6), np.float32)
 | 
						|
 | 
						|
        # raw_detections: Nx7 numpy arrays of
 | 
						|
        #             [[x, y, w, h, box_confidence, class_id, class_prob],
 | 
						|
 | 
						|
        # Calculate score as box_confidence x class_prob
 | 
						|
        raw_detections[:, 4] = raw_detections[:, 4] * raw_detections[:, 6]
 | 
						|
        # Reorder elements by the score, best on top, remove class_prob
 | 
						|
        ordered = raw_detections[raw_detections[:, 4].argsort()[::-1]][:, 0:6]
 | 
						|
        # transform width to right with clamp to 0..1
 | 
						|
        ordered[:, 2] = np.clip(ordered[:, 2] + ordered[:, 0], 0, 1)
 | 
						|
        # transform height to bottom with clamp to 0..1
 | 
						|
        ordered[:, 3] = np.clip(ordered[:, 3] + ordered[:, 1], 0, 1)
 | 
						|
        # put result into the correct order and limit to top 20
 | 
						|
        detections = ordered[:, [5, 4, 1, 0, 3, 2]][:20]
 | 
						|
 | 
						|
        # pad to 20x6 shape
 | 
						|
        append_cnt = 20 - len(detections)
 | 
						|
        if append_cnt > 0:
 | 
						|
            detections = np.append(
 | 
						|
                detections, np.zeros((append_cnt, 6), np.float32), axis=0
 | 
						|
            )
 | 
						|
 | 
						|
        return detections
 |