diff --git a/frigate/detectors/plugins/rocm.py b/frigate/detectors/plugins/rocm.py index aa49e0544..5203934e9 100644 --- a/frigate/detectors/plugins/rocm.py +++ b/frigate/detectors/plugins/rocm.py @@ -4,6 +4,7 @@ import os import subprocess import sys +import cv2 import numpy as np from pydantic import Field from typing_extensions import Literal @@ -12,6 +13,7 @@ from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import ( BaseDetectorConfig, ModelTypeEnum, + PixelFormatEnum, ) logger = logging.getLogger(__name__) @@ -123,6 +125,17 @@ class ROCmDetector(DetectionApi): def detect_raw(self, tensor_input): model_input_name = self.model.get_parameter_names()[0] + model_input_name = self.model.get_inputs()[0].name + model_input_shape = self.model.get_inputs()[0].shape + + tensor_input = cv2.dnn.blobFromImage( + tensor_input[0], + 1.0, + (model_input_shape[3], model_input_shape[2]), + None, + swapRB=self.rocm_model_px == PixelFormatEnum.bgr, + ).astype(np.uint8) + detector_result = self.model.run({model_input_name: tensor_input})[0] addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float))