From a98269893363c4fb7743333a9901083be8f7ddc8 Mon Sep 17 00:00:00 2001 From: Indrek Mandre Date: Sun, 4 Feb 2024 22:56:57 +0200 Subject: [PATCH] frigate/detectors: renamed yolov8_preprocess->preprocess, pass input tensor element type --- frigate/detectors/plugins/onnx.py | 2 +- frigate/detectors/plugins/openvino.py | 5 ++++- frigate/detectors/plugins/rocm.py | 2 +- frigate/detectors/yolo_utils.py | 17 +++++++++++++---- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/frigate/detectors/plugins/onnx.py b/frigate/detectors/plugins/onnx.py index c98665d4d..428b68078 100644 --- a/frigate/detectors/plugins/onnx.py +++ b/frigate/detectors/plugins/onnx.py @@ -51,7 +51,7 @@ class ONNXDetector(DetectionApi): model_input_name = self.model.get_inputs()[0].name model_input_shape = self.model.get_inputs()[0].shape - tensor_input = yolo_utils.yolov8_preprocess(tensor_input, model_input_shape) + tensor_input = yolo_utils.preprocess(tensor_input, model_input_shape, np.float32) tensor_output = self.model.run(None, {model_input_name: tensor_input})[0] diff --git a/frigate/detectors/plugins/openvino.py b/frigate/detectors/plugins/openvino.py index 5cb1ea39c..5d748f1e7 100644 --- a/frigate/detectors/plugins/openvino.py +++ b/frigate/detectors/plugins/openvino.py @@ -8,6 +8,8 @@ from typing_extensions import Literal from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import BaseDetectorConfig, ModelTypeEnum +import frigate.detectors.yolo_utils as yolo_utils + logger = logging.getLogger(__name__) DETECTOR_KEY = "openvino" @@ -33,7 +35,7 @@ class OvDetector(DetectionApi): model=self.ov_model, device_name=detector_config.device ) - logger.info(f"Model Input Shape: {self.interpreter.input(0).shape}") + logger.info(f"Model Input Shape: {self.interpreter.input(0).shape} {self.interpreter.input(0).element_type.to_dtype()}") self.output_indexes = 0 while True: @@ -80,6 +82,7 @@ class OvDetector(DetectionApi): ] def detect_raw(self, tensor_input): + tensor_input = yolo_utils.preprocess(tensor_input, self.interpreter.inputs[0].shape, self.interpreter.inputs[0].element_type.to_dtype()) infer_request = self.interpreter.create_infer_request() infer_request.infer([tensor_input]) diff --git a/frigate/detectors/plugins/rocm.py b/frigate/detectors/plugins/rocm.py index 6f2c3c1a8..d5d0ba585 100644 --- a/frigate/detectors/plugins/rocm.py +++ b/frigate/detectors/plugins/rocm.py @@ -102,7 +102,7 @@ class ROCmDetector(DetectionApi): model_input_name = self.model.get_parameter_names()[0]; model_input_shape = tuple(self.model.get_parameter_shapes()[model_input_name].lens()); - tensor_input = yolo_utils.yolov8_preprocess(tensor_input, model_input_shape) + tensor_input = yolo_utils.preprocess(tensor_input, model_input_shape, np.float32) detector_result = self.model.run({model_input_name: tensor_input})[0] diff --git a/frigate/detectors/yolo_utils.py b/frigate/detectors/yolo_utils.py index b53f85a9d..02442ab65 100644 --- a/frigate/detectors/yolo_utils.py +++ b/frigate/detectors/yolo_utils.py @@ -5,11 +5,20 @@ import cv2 logger = logging.getLogger(__name__) -def yolov8_preprocess(tensor_input, model_input_shape): +def preprocess(tensor_input, model_input_shape, model_input_element_type): + model_input_shape = tuple(model_input_shape) + assert tensor_input.dtype == np.uint8, f'tensor_input.dtype: {tensor_input.dtype}' + if len(tensor_input.shape) == 3: + tensor_input = tensor_input[np.newaxis, :] + if model_input_element_type == np.uint8: + # nothing to do for uint8 model input + assert model_input_shape == tensor_input.shape, f'model_input_shape: {model_input_shape}, tensor_input.shape: {tensor_input.shape}' + return tensor_input + assert model_input_element_type == np.float32, f'model_input_element_type: {model_input_element_type}' # tensor_input must be nhwc - assert tensor_input.shape[3] == 3 - if tuple(tensor_input.shape[1:3]) != tuple(model_input_shape[2:4]): - logger.warn(f"yolov8_preprocess: tensor_input.shape {tensor_input.shape} and model_input_shape {model_input_shape} do not match!") + assert tensor_input.shape[3] == 3, f'tensor_input.shape: {tensor_input.shape}' + if tensor_input.shape[1:3] != model_input_shape[2:4]: + logger.warn(f"preprocess: tensor_input.shape {tensor_input.shape} and model_input_shape {model_input_shape} do not match!") # cv2.dnn.blobFromImage is faster than numpying it return cv2.dnn.blobFromImage(tensor_input[0], 1.0 / 255, (model_input_shape[3], model_input_shape[2]), None, swapRB=False)