* Catch bird classification resize error

* Improve openvino width detection

* Use auto by default

* Set type
This commit is contained in:
Nicolas Mowen 2025-09-16 16:06:51 -06:00 committed by GitHub
parent 975c8485f9
commit 26178444f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 10 deletions

View File

@ -128,7 +128,11 @@ class BirdRealTimeProcessor(RealTimeProcessorApi):
] ]
if input.shape != (224, 224): if input.shape != (224, 224):
try:
input = cv2.resize(input, (224, 224)) input = cv2.resize(input, (224, 224))
except Exception:
logger.warning("Failed to resize image for bird classification")
return
input = np.expand_dims(input, axis=0) input = np.expand_dims(input, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)

View File

@ -133,8 +133,12 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
x:x2, x:x2,
] ]
if frame.shape != (224, 224): if input.shape != (224, 224):
frame = cv2.resize(frame, (224, 224)) try:
input = cv2.resize(input, (224, 224))
except Exception:
logger.warning("Failed to resize image for state classification")
return
input = np.expand_dims(frame, axis=0) input = np.expand_dims(frame, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
@ -254,8 +258,12 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
x:x2, x:x2,
] ]
if crop.shape != (224, 224): if input.shape != (224, 224):
crop = cv2.resize(crop, (224, 224)) try:
input = cv2.resize(input, (224, 224))
except Exception:
logger.warning("Failed to resize image for object classification")
return
input = np.expand_dims(crop, axis=0) input = np.expand_dims(crop, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)

View File

@ -195,9 +195,24 @@ class OpenVINOModelRunner(BaseModelRunner):
def get_input_width(self) -> int: def get_input_width(self) -> int:
"""Get the input width of the model.""" """Get the input width of the model."""
input_shape = self.compiled_model.inputs[0].get_shape() input_info = self.compiled_model.inputs
# Assuming NCHW format, width is the last dimension first_input = input_info[0]
return int(input_shape[-1])
try:
partial_shape = first_input.get_partial_shape()
# width dimension
if len(partial_shape) >= 4 and partial_shape[3].is_static:
return partial_shape[3].get_length()
# If width is dynamic or we can't determine it
return -1
except Exception:
try:
# gemini says some ov versions might still allow this
input_shape = first_input.shape
return input_shape[3] if len(input_shape) >= 4 else -1
except Exception:
return -1
def run(self, inputs: dict[str, Any]) -> list[np.ndarray]: def run(self, inputs: dict[str, Any]) -> list[np.ndarray]:
"""Run inference with the model. """Run inference with the model.
@ -354,7 +369,7 @@ class RKNNModelRunner(BaseModelRunner):
def get_optimized_runner( def get_optimized_runner(
model_path: str, device: str, complex_model: bool = True, **kwargs model_path: str, device: str | None, complex_model: bool = True, **kwargs
) -> BaseModelRunner: ) -> BaseModelRunner:
"""Get an optimized runner for the hardware.""" """Get an optimized runner for the hardware."""
if is_rknn_compatible(model_path): if is_rknn_compatible(model_path):
@ -364,7 +379,7 @@ def get_optimized_runner(
return RKNNModelRunner(rknn_path) return RKNNModelRunner(rknn_path)
if device != "CPU" and is_openvino_gpu_npu_available(): if device != "CPU" and is_openvino_gpu_npu_available():
return OpenVINOModelRunner(model_path, device, **kwargs) return OpenVINOModelRunner(model_path, device or "AUTO", **kwargs)
providers, options = get_ort_providers(device == "CPU", device, **kwargs) providers, options = get_ort_providers(device == "CPU", device, **kwargs)
ortSession = ort.InferenceSession( ortSession = ort.InferenceSession(