From 592b64523153230698f99ce192f9499ed40fe1e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20B=C3=A9dard-Couture?= Date: Wed, 22 May 2024 08:57:52 -0400 Subject: [PATCH] Add support for TensorRT v10 (multiple api calls have changed) (#11166) * Add support for TensorRT v10 (multiple api calls have changed) * Remove unnecessary size check in TensorRT v10 block * Refactor to reduce code duplication * Fix wrong function name in new _get_binding_dtype function and only return input check (not assertion) in new _binding_is_input function * Add space around TRT_VERSION variable assignment (=) to respect linting * More linting fix * Update frigate/detectors/plugins/tensorrt.py Co-authored-by: Nicolas Mowen * More linting --------- Co-authored-by: Nicolas Mowen --- frigate/detectors/plugins/tensorrt.py | 48 +++++++++++++++++++++------ 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/frigate/detectors/plugins/tensorrt.py b/frigate/detectors/plugins/tensorrt.py index 2a57ec2d3..64b0849c7 100644 --- a/frigate/detectors/plugins/tensorrt.py +++ b/frigate/detectors/plugins/tensorrt.py @@ -7,6 +7,8 @@ try: import tensorrt as trt from cuda import cuda + TRT_VERSION = int(trt.__version__[0 : trt.__version__.find(".")]) + TRT_SUPPORT = True except ModuleNotFoundError: TRT_SUPPORT = False @@ -88,20 +90,46 @@ class TensorRtDetector(DetectionApi): with open(model_path, "rb") as f, trt.Runtime(self.trt_logger) as runtime: return runtime.deserialize_cuda_engine(f.read()) + def _binding_is_input(self, binding): + if TRT_VERSION < 10: + return self.engine.binding_is_input(binding) + else: + return binding == "input" + + def _get_binding_dims(self, binding): + if TRT_VERSION < 10: + return self.engine.get_binding_shape(binding) + else: + return self.engine.get_tensor_shape(binding) + + def _get_binding_dtype(self, binding): + if TRT_VERSION < 10: + return self.engine.get_binding_dtype(binding) + else: + return self.engine.get_tensor_dtype(binding) + + def _execute(self): + if TRT_VERSION < 10: + return self.context.execute_async_v2( + bindings=self.bindings, stream_handle=self.stream + ) + else: + return self.context.execute_v2(self.bindings) + def _get_input_shape(self): """Get input shape of the TensorRT YOLO engine.""" binding = self.engine[0] - assert self.engine.binding_is_input(binding) - binding_dims = self.engine.get_binding_shape(binding) + assert self._binding_is_input(binding) + binding_dims = self._get_binding_dims(binding) if len(binding_dims) == 4: return ( tuple(binding_dims[2:]), - trt.nptype(self.engine.get_binding_dtype(binding)), + trt.nptype(self._get_binding_dtype(binding)), ) elif len(binding_dims) == 3: return ( tuple(binding_dims[1:]), - trt.nptype(self.engine.get_binding_dtype(binding)), + trt.nptype(self._get_binding_dtype(binding)), ) else: raise ValueError( @@ -115,7 +143,7 @@ class TensorRtDetector(DetectionApi): bindings = [] output_idx = 0 for binding in self.engine: - binding_dims = self.engine.get_binding_shape(binding) + binding_dims = self._get_binding_dims(binding) if len(binding_dims) == 4: # explicit batch case (TensorRT 7+) size = trt.volume(binding_dims) @@ -126,21 +154,21 @@ class TensorRtDetector(DetectionApi): raise ValueError( "bad dims of binding %s: %s" % (binding, str(binding_dims)) ) - nbytes = size * self.engine.get_binding_dtype(binding).itemsize + nbytes = size * self._get_binding_dtype(binding).itemsize # Allocate host and device buffers err, host_mem = cuda.cuMemHostAlloc( nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP ) assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}" logger.debug( - f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self.engine.get_binding_dtype(binding)})" + f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self._get_binding_dtype(binding)})" ) err, device_mem = cuda.cuMemAlloc(nbytes) assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}" # Append the device buffer to device bindings. bindings.append(int(device_mem)) # Append to the appropriate list. - if self.engine.binding_is_input(binding): + if self._binding_is_input(binding): logger.debug(f"Input has Shape {binding_dims}") inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size)) else: @@ -170,9 +198,7 @@ class TensorRtDetector(DetectionApi): ] # Run inference. - if not self.context.execute_async_v2( - bindings=self.bindings, stream_handle=self.stream - ): + if not self._execute(): logger.warn("Execute returned false") # Transfer predictions back from the GPU.