2024-01-25 19:14:22 +01:00
import logging
import sys
import os
import numpy as np
import ctypes
from pydantic import Field
from typing_extensions import Literal
import glob
2024-01-26 09:30:01 +01:00
import cv2
2024-01-25 19:14:22 +01:00
from frigate . detectors . detection_api import DetectionApi
from frigate . detectors . detector_config import BaseDetectorConfig
2024-01-27 21:35:32 +01:00
import frigate . detectors . yolo_utils as yolo_utils
2024-01-25 19:14:22 +01:00
logger = logging . getLogger ( __name__ )
DETECTOR_KEY = " rocm "
class ROCmDetectorConfig ( BaseDetectorConfig ) :
type : Literal [ DETECTOR_KEY ]
class ROCmDetector ( DetectionApi ) :
type_key = DETECTOR_KEY
def __init__ ( self , detector_config : ROCmDetectorConfig ) :
try :
2024-01-26 09:30:01 +01:00
sys . path . append ( " /opt/rocm/lib " )
2024-01-25 19:14:22 +01:00
import migraphx
logger . info ( f " AMD/ROCm: loaded migraphx module " )
2024-01-27 21:35:32 +01:00
except ModuleNotFoundError :
2024-01-25 19:14:22 +01:00
logger . error (
" AMD/ROCm: module loading failed, missing ROCm environment? "
)
raise
2024-01-26 09:30:01 +01:00
assert detector_config . model . model_type == ' yolov8 ' , " AMD/ROCm: detector_config.model.model_type: only yolov8 supported "
assert detector_config . model . input_tensor == ' nhwc ' , " AMD/ROCm: detector_config.model.input_tensor: only nhwc supported "
if detector_config . model . input_pixel_format != ' rgb ' :
logger . warn ( " AMD/ROCm: detector_config.model.input_pixel_format: should be ' rgb ' for yolov8, but ' {detector_config.model.input_pixel_format} ' specified! " )
2024-01-25 19:14:22 +01:00
assert detector_config . model . path is not None , " No model.path configured, please configure model.path and model.labelmap_path; some suggestions: " + ' , ' . join ( glob . glob ( " /*.onnx " ) ) + " and " + ' , ' . join ( glob . glob ( " /*_labels.txt " ) )
2024-01-26 09:30:01 +01:00
2024-01-25 19:14:22 +01:00
path = detector_config . model . path
mxr_path = " /config/model_cache/rocm/ " + os . path . basename ( os . path . splitext ( path ) [ 0 ] + ' .mxr ' )
2024-01-26 09:30:01 +01:00
if path . endswith ( ' .mxr ' ) :
logger . info ( f " AMD/ROCm: loading parsed model from { mxr_path } " )
self . model = migraphx . load ( mxr_path )
elif os . path . exists ( mxr_path ) :
2024-01-25 19:14:22 +01:00
logger . info ( f " AMD/ROCm: loading parsed model from { mxr_path } " )
self . model = migraphx . load ( mxr_path )
else :
logger . info ( f " AMD/ROCm: loading model from { path } " )
if path . endswith ( ' .onnx ' ) :
self . model = migraphx . parse_onnx ( path )
elif path . endswith ( ' .tf ' ) or path . endswith ( ' .tf2 ' ) or path . endswith ( ' .tflite ' ) :
2024-01-26 09:30:01 +01:00
# untested
2024-01-25 19:14:22 +01:00
self . model = migraphx . parse_tf ( path )
else :
2024-01-26 09:30:01 +01:00
raise Exception ( f " AMD/ROCm: unkown model format { path } " )
2024-01-25 19:14:22 +01:00
logger . info ( f " AMD/ROCm: compiling the model " )
self . model . compile ( migraphx . get_target ( ' gpu ' ) , offload_copy = True , fast_math = True )
logger . info ( f " AMD/ROCm: saving parsed model into { mxr_path } " )
2024-01-26 09:30:01 +01:00
os . makedirs ( " /config/model_cache/rocm " , exist_ok = True )
2024-01-25 19:14:22 +01:00
migraphx . save ( self . model , mxr_path )
logger . info ( f " AMD/ROCm: model loaded " )
def detect_raw ( self , tensor_input ) :
model_input_name = self . model . get_parameter_names ( ) [ 0 ] ;
model_input_shape = tuple ( self . model . get_parameter_shapes ( ) [ model_input_name ] . lens ( ) ) ;
2024-01-27 21:35:32 +01:00
tensor_input = yolo_utils . yolov8_preprocess ( tensor_input , model_input_shape )
2024-01-25 19:14:22 +01:00
detector_result = self . model . run ( { model_input_name : tensor_input } ) [ 0 ]
addr = ctypes . cast ( detector_result . data_ptr ( ) , ctypes . POINTER ( ctypes . c_float ) )
2024-01-26 09:30:01 +01:00
tensor_output = np . ctypeslib . as_array ( addr , shape = detector_result . get_shape ( ) . lens ( ) )
2024-01-27 21:35:32 +01:00
return yolo_utils . yolov8_postprocess ( model_input_shape , tensor_output )
2024-01-25 19:14:22 +01:00