Add ability to configure model input dtype (#14659)

* Add input type for dtype

* Add ability to manually enable TRT execution provider

* Formatting
This commit is contained in:
Nicolas Mowen 2024-10-29 09:28:05 -06:00 committed by GitHub
parent abd22d2566
commit 4e25bebdd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 7 deletions

View File

@ -27,6 +27,11 @@ class InputTensorEnum(str, Enum):
nhwc = "nhwc" nhwc = "nhwc"
class InputDTypeEnum(str, Enum):
float = "float"
int = "int"
class ModelTypeEnum(str, Enum): class ModelTypeEnum(str, Enum):
ssd = "ssd" ssd = "ssd"
yolox = "yolox" yolox = "yolox"
@ -53,6 +58,9 @@ class ModelConfig(BaseModel):
input_pixel_format: PixelFormatEnum = Field( input_pixel_format: PixelFormatEnum = Field(
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format" default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
) )
input_dtype: InputDTypeEnum = Field(
default=InputDTypeEnum.int, title="Model Input D Type"
)
model_type: ModelTypeEnum = Field( model_type: ModelTypeEnum = Field(
default=ModelTypeEnum.ssd, title="Object Detection Model Type" default=ModelTypeEnum.ssd, title="Object Detection Model Type"
) )

View File

@ -54,7 +54,7 @@ class ONNXDetector(DetectionApi):
logger.info(f"ONNX: {path} loaded") logger.info(f"ONNX: {path} loaded")
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input: np.ndarray):
model_input_name = self.model.get_inputs()[0].name model_input_name = self.model.get_inputs()[0].name
tensor_output = self.model.run(None, {model_input_name: tensor_input}) tensor_output = self.model.run(None, {model_input_name: tensor_input})

View File

@ -12,7 +12,11 @@ from setproctitle import setproctitle
import frigate.util as util import frigate.util as util
from frigate.detectors import create_detector from frigate.detectors import create_detector
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum from frigate.detectors.detector_config import (
BaseDetectorConfig,
InputDTypeEnum,
InputTensorEnum,
)
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.builtin import EventsPerSecond, load_labels
from frigate.util.image import SharedMemoryFrameManager from frigate.util.image import SharedMemoryFrameManager
@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector):
self.input_transform = tensor_transform( self.input_transform = tensor_transform(
detector_config.model.input_tensor detector_config.model.input_tensor
) )
self.dtype = detector_config.model.input_dtype
else: else:
self.input_transform = None self.input_transform = None
self.dtype = InputDTypeEnum.int
self.detect_api = create_detector(detector_config) self.detect_api = create_detector(detector_config)
def detect(self, tensor_input, threshold=0.4): def detect(self, tensor_input: np.ndarray, threshold=0.4):
detections = [] detections = []
raw_detections = self.detect_raw(tensor_input) raw_detections = self.detect_raw(tensor_input)
@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector):
self.fps.update() self.fps.update()
return detections return detections
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input: np.ndarray):
if self.input_transform: if self.input_transform:
tensor_input = np.transpose(tensor_input, self.input_transform) tensor_input = np.transpose(tensor_input, self.input_transform)
if self.dtype == InputDTypeEnum.float:
tensor_input = tensor_input.astype(np.float32)
return self.detect_api.detect_raw(tensor_input=tensor_input) return self.detect_api.detect_raw(tensor_input=tensor_input)

View File

@ -13,7 +13,7 @@ except ImportError:
def get_ort_providers( def get_ort_providers(
force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False force_cpu: bool = False, device: str = "AUTO", requires_fp16: bool = False
) -> tuple[list[str], list[dict[str, any]]]: ) -> tuple[list[str], list[dict[str, any]]]:
if force_cpu: if force_cpu:
return ( return (
@ -38,7 +38,25 @@ def get_ort_providers(
) )
elif provider == "TensorrtExecutionProvider": elif provider == "TensorrtExecutionProvider":
# TensorrtExecutionProvider uses too much memory without options to control it # TensorrtExecutionProvider uses too much memory without options to control it
pass # so it is not enabled by default
if device == "Tensorrt":
os.makedirs(
"/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True
)
providers.append(provider)
options.append(
{
"arena_extend_strategy": "kSameAsRequested",
"trt_fp16_enable": requires_fp16
and os.environ.get("USE_FP_16", "True") != "False",
"trt_timing_cache_enable": True,
"trt_engine_cache_enable": True,
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
}
)
else:
continue
elif provider == "OpenVINOExecutionProvider": elif provider == "OpenVINOExecutionProvider":
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True) os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
providers.append(provider) providers.append(provider)
@ -46,7 +64,7 @@ def get_ort_providers(
{ {
"arena_extend_strategy": "kSameAsRequested", "arena_extend_strategy": "kSameAsRequested",
"cache_dir": "/config/model_cache/openvino/ort", "cache_dir": "/config/model_cache/openvino/ort",
"device_type": openvino_device, "device_type": device,
} }
) )
elif provider == "CPUExecutionProvider": elif provider == "CPUExecutionProvider":