From 26178444f38c289f2b9a36062126d9f31d7f9484 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Tue, 16 Sep 2025 16:06:51 -0600 Subject: [PATCH] Fixes (#20102) * Catch bird classification resize error * Improve openvino width detection * Use auto by default * Set type --- frigate/data_processing/real_time/bird.py | 6 ++++- .../real_time/custom_classification.py | 16 +++++++++--- frigate/detectors/detection_runners.py | 25 +++++++++++++++---- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/frigate/data_processing/real_time/bird.py b/frigate/data_processing/real_time/bird.py index ed2496b90..848860230 100644 --- a/frigate/data_processing/real_time/bird.py +++ b/frigate/data_processing/real_time/bird.py @@ -128,7 +128,11 @@ class BirdRealTimeProcessor(RealTimeProcessorApi): ] if input.shape != (224, 224): - input = cv2.resize(input, (224, 224)) + try: + input = cv2.resize(input, (224, 224)) + except Exception: + logger.warning("Failed to resize image for bird classification") + return input = np.expand_dims(input, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index b62b29882..daa9fee96 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -133,8 +133,12 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): x:x2, ] - if frame.shape != (224, 224): - frame = cv2.resize(frame, (224, 224)) + if input.shape != (224, 224): + try: + input = cv2.resize(input, (224, 224)) + except Exception: + logger.warning("Failed to resize image for state classification") + return input = np.expand_dims(frame, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) @@ -254,8 +258,12 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): x:x2, ] - if crop.shape != (224, 224): - crop = cv2.resize(crop, (224, 224)) + if input.shape != (224, 224): + try: + input = cv2.resize(input, (224, 224)) + except Exception: + logger.warning("Failed to resize image for object classification") + return input = np.expand_dims(crop, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) diff --git a/frigate/detectors/detection_runners.py b/frigate/detectors/detection_runners.py index 145fe79b7..f6928c481 100644 --- a/frigate/detectors/detection_runners.py +++ b/frigate/detectors/detection_runners.py @@ -195,9 +195,24 @@ class OpenVINOModelRunner(BaseModelRunner): def get_input_width(self) -> int: """Get the input width of the model.""" - input_shape = self.compiled_model.inputs[0].get_shape() - # Assuming NCHW format, width is the last dimension - return int(input_shape[-1]) + input_info = self.compiled_model.inputs + first_input = input_info[0] + + try: + partial_shape = first_input.get_partial_shape() + # width dimension + if len(partial_shape) >= 4 and partial_shape[3].is_static: + return partial_shape[3].get_length() + + # If width is dynamic or we can't determine it + return -1 + except Exception: + try: + # gemini says some ov versions might still allow this + input_shape = first_input.shape + return input_shape[3] if len(input_shape) >= 4 else -1 + except Exception: + return -1 def run(self, inputs: dict[str, Any]) -> list[np.ndarray]: """Run inference with the model. @@ -354,7 +369,7 @@ class RKNNModelRunner(BaseModelRunner): def get_optimized_runner( - model_path: str, device: str, complex_model: bool = True, **kwargs + model_path: str, device: str | None, complex_model: bool = True, **kwargs ) -> BaseModelRunner: """Get an optimized runner for the hardware.""" if is_rknn_compatible(model_path): @@ -364,7 +379,7 @@ def get_optimized_runner( return RKNNModelRunner(rknn_path) if device != "CPU" and is_openvino_gpu_npu_available(): - return OpenVINOModelRunner(model_path, device, **kwargs) + return OpenVINOModelRunner(model_path, device or "AUTO", **kwargs) providers, options = get_ort_providers(device == "CPU", device, **kwargs) ortSession = ort.InferenceSession(