mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Fix ROCm inference (#13988)
This commit is contained in:
parent
4a1da3ebc5
commit
a5595189ed
@ -4,6 +4,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
@ -12,6 +13,7 @@ from frigate.detectors.detection_api import DetectionApi
|
|||||||
from frigate.detectors.detector_config import (
|
from frigate.detectors.detector_config import (
|
||||||
BaseDetectorConfig,
|
BaseDetectorConfig,
|
||||||
ModelTypeEnum,
|
ModelTypeEnum,
|
||||||
|
PixelFormatEnum,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -123,6 +125,17 @@ class ROCmDetector(DetectionApi):
|
|||||||
|
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input):
|
||||||
model_input_name = self.model.get_parameter_names()[0]
|
model_input_name = self.model.get_parameter_names()[0]
|
||||||
|
model_input_name = self.model.get_inputs()[0].name
|
||||||
|
model_input_shape = self.model.get_inputs()[0].shape
|
||||||
|
|
||||||
|
tensor_input = cv2.dnn.blobFromImage(
|
||||||
|
tensor_input[0],
|
||||||
|
1.0,
|
||||||
|
(model_input_shape[3], model_input_shape[2]),
|
||||||
|
None,
|
||||||
|
swapRB=self.rocm_model_px == PixelFormatEnum.bgr,
|
||||||
|
).astype(np.uint8)
|
||||||
|
|
||||||
detector_result = self.model.run({model_input_name: tensor_input})[0]
|
detector_result = self.model.run({model_input_name: tensor_input})[0]
|
||||||
addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
addr = ctypes.cast(detector_result.data_ptr(), ctypes.POINTER(ctypes.c_float))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user