diff --git a/frigate/detectors/detector_config.py b/frigate/detectors/detector_config.py index c40ef65bf..45875e2e6 100644 --- a/frigate/detectors/detector_config.py +++ b/frigate/detectors/detector_config.py @@ -27,6 +27,11 @@ class InputTensorEnum(str, Enum): nhwc = "nhwc" +class InputDTypeEnum(str, Enum): + float = "float" + int = "int" + + class ModelTypeEnum(str, Enum): ssd = "ssd" yolox = "yolox" @@ -53,6 +58,9 @@ class ModelConfig(BaseModel): input_pixel_format: PixelFormatEnum = Field( default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format" ) + input_dtype: InputDTypeEnum = Field( + default=InputDTypeEnum.int, title="Model Input D Type" + ) model_type: ModelTypeEnum = Field( default=ModelTypeEnum.ssd, title="Object Detection Model Type" ) diff --git a/frigate/detectors/plugins/onnx.py b/frigate/detectors/plugins/onnx.py index 3e58df72a..7004f28fa 100644 --- a/frigate/detectors/plugins/onnx.py +++ b/frigate/detectors/plugins/onnx.py @@ -54,7 +54,7 @@ class ONNXDetector(DetectionApi): logger.info(f"ONNX: {path} loaded") - def detect_raw(self, tensor_input): + def detect_raw(self, tensor_input: np.ndarray): model_input_name = self.model.get_inputs()[0].name tensor_output = self.model.run(None, {model_input_name: tensor_input}) diff --git a/frigate/object_detection.py b/frigate/object_detection.py index eaa3b4e04..0af32034e 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -12,7 +12,11 @@ from setproctitle import setproctitle import frigate.util as util from frigate.detectors import create_detector -from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum +from frigate.detectors.detector_config import ( + BaseDetectorConfig, + InputDTypeEnum, + InputTensorEnum, +) from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.image import SharedMemoryFrameManager @@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector): self.input_transform = tensor_transform( detector_config.model.input_tensor ) + + self.dtype = detector_config.model.input_dtype else: self.input_transform = None + self.dtype = InputDTypeEnum.int self.detect_api = create_detector(detector_config) - def detect(self, tensor_input, threshold=0.4): + def detect(self, tensor_input: np.ndarray, threshold=0.4): detections = [] raw_detections = self.detect_raw(tensor_input) @@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector): self.fps.update() return detections - def detect_raw(self, tensor_input): + def detect_raw(self, tensor_input: np.ndarray): if self.input_transform: tensor_input = np.transpose(tensor_input, self.input_transform) + + if self.dtype == InputDTypeEnum.float: + tensor_input = tensor_input.astype(np.float32) + return self.detect_api.detect_raw(tensor_input=tensor_input) diff --git a/frigate/util/model.py b/frigate/util/model.py index 7aefe8b42..2aa06d0b2 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -13,7 +13,7 @@ except ImportError: def get_ort_providers( - force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False + force_cpu: bool = False, device: str = "AUTO", requires_fp16: bool = False ) -> tuple[list[str], list[dict[str, any]]]: if force_cpu: return ( @@ -38,7 +38,25 @@ def get_ort_providers( ) elif provider == "TensorrtExecutionProvider": # TensorrtExecutionProvider uses too much memory without options to control it - pass + # so it is not enabled by default + if device == "Tensorrt": + os.makedirs( + "/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True + ) + providers.append(provider) + options.append( + { + "arena_extend_strategy": "kSameAsRequested", + "trt_fp16_enable": requires_fp16 + and os.environ.get("USE_FP_16", "True") != "False", + "trt_timing_cache_enable": True, + "trt_engine_cache_enable": True, + "trt_timing_cache_path": "/config/model_cache/tensorrt/ort", + "trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines", + } + ) + else: + continue elif provider == "OpenVINOExecutionProvider": os.makedirs("/config/model_cache/openvino/ort", exist_ok=True) providers.append(provider) @@ -46,7 +64,7 @@ def get_ort_providers( { "arena_extend_strategy": "kSameAsRequested", "cache_dir": "/config/model_cache/openvino/ort", - "device_type": openvino_device, + "device_type": device, } ) elif provider == "CPUExecutionProvider":