From 4ff81d58777b3c4a2e42081ed04bea8a96ded7c9 Mon Sep 17 00:00:00 2001 From: Simonas Kazlauskas Date: Fri, 20 Jun 2025 18:11:48 +0300 Subject: [PATCH] yolo nas: do not invalidate model when input shape is different (#18799) Model can be adjusted ahead of time to NHWC to avoid transpose on CPU, for example. All the model information is already presented in the configuration, and the stringent checks implemented in openvino are not present on other backends anyway. OpenVINO will properly report issues with mismatched layouts anyhow. --- frigate/detectors/plugins/openvino.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/frigate/detectors/plugins/openvino.py b/frigate/detectors/plugins/openvino.py index 70c2d4725..066b6d311 100644 --- a/frigate/detectors/plugins/openvino.py +++ b/frigate/detectors/plugins/openvino.py @@ -59,7 +59,6 @@ class OvDetector(DetectionApi): ) self.model_invalid = True - # Ensure the SSD model has the right input and output shapes if self.ov_model_type == ModelTypeEnum.ssd: model_inputs = self.interpreter.inputs model_outputs = self.interpreter.outputs @@ -75,12 +74,6 @@ class OvDetector(DetectionApi): ) self.model_invalid = True - if model_inputs[0].get_shape() != ov.Shape([1, self.w, self.h, 3]): - logger.error( - f"SSD model input doesn't match. Found {model_inputs[0].get_shape()}." - ) - self.model_invalid = True - output_shape = model_outputs[0].get_shape() if output_shape[0] != 1 or output_shape[1] != 1 or output_shape[3] != 7: logger.error(f"SSD model output doesn't match. Found {output_shape}.") @@ -100,13 +93,6 @@ class OvDetector(DetectionApi): f"YoloNAS models must be exported in flat format and only have 1 output. Found {len(model_outputs)}." ) self.model_invalid = True - - if model_inputs[0].get_shape() != ov.Shape([1, 3, self.w, self.h]): - logger.error( - f"YoloNAS model input doesn't match. Found {model_inputs[0].get_shape()}, but expected {[1, 3, self.w, self.h]}." - ) - self.model_invalid = True - output_shape = model_outputs[0].partial_shape if output_shape[-1] != 7: logger.error(