Fix ROCm inference (#13988)

This commit is contained in:
Nicolas Mowen 2024-09-26 11:16:26 -06:00 committed by GitHub
parent 4a1da3ebc5
commit a5595189ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import os
import subprocess import subprocess
import sys import sys
import cv2
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field
from typing_extensions import Literal from typing_extensions import Literal
@ -12,6 +13,7 @@ from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import ( from frigate.detectors.detector_config import (
BaseDetectorConfig, BaseDetectorConfig,
ModelTypeEnum, ModelTypeEnum,
PixelFormatEnum,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -123,6 +125,17 @@ class ROCmDetector(DetectionApi):
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
model_input_name = self.model.get_parameter_names()[0] model_input_name = self.model.get_parameter_names()[0]
model_input_name = self.model.get_inputs()[0].name
model_input_shape = self.model.get_inputs()[0].shape
tensor_input = cv2.dnn.blobFromImage(
tensor_input[0],
1.0,
(model_input_shape[3], model_input_shape[2]),
None,
swapRB=self.rocm_model_px == PixelFormatEnum.bgr,
).astype(np.uint8)
detector_result = self.model.run({model_input_name: tensor_input})[0] detector_result = self.model.run({model_input_name: tensor_input})[0]
addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float)) addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float))