mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-12-19 19:06:16 +01:00
Fix ROCm inference (#13988)
This commit is contained in:
parent
4a1da3ebc5
commit
a5595189ed
@ -4,6 +4,7 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
@ -12,6 +13,7 @@ from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
ModelTypeEnum,
|
||||
PixelFormatEnum,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -123,6 +125,17 @@ class ROCmDetector(DetectionApi):
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
model_input_name = self.model.get_parameter_names()[0]
|
||||
model_input_name = self.model.get_inputs()[0].name
|
||||
model_input_shape = self.model.get_inputs()[0].shape
|
||||
|
||||
tensor_input = cv2.dnn.blobFromImage(
|
||||
tensor_input[0],
|
||||
1.0,
|
||||
(model_input_shape[3], model_input_shape[2]),
|
||||
None,
|
||||
swapRB=self.rocm_model_px == PixelFormatEnum.bgr,
|
||||
).astype(np.uint8)
|
||||
|
||||
detector_result = self.model.run({model_input_name: tensor_input})[0]
|
||||
addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user