mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Add ability to configure model input dtype (#14659)
* Add input type for dtype * Add ability to manually enable TRT execution provider * Formatting
This commit is contained in:
parent
abd22d2566
commit
4e25bebdd0
@ -27,6 +27,11 @@ class InputTensorEnum(str, Enum):
|
||||
nhwc = "nhwc"
|
||||
|
||||
|
||||
class InputDTypeEnum(str, Enum):
|
||||
float = "float"
|
||||
int = "int"
|
||||
|
||||
|
||||
class ModelTypeEnum(str, Enum):
|
||||
ssd = "ssd"
|
||||
yolox = "yolox"
|
||||
@ -53,6 +58,9 @@ class ModelConfig(BaseModel):
|
||||
input_pixel_format: PixelFormatEnum = Field(
|
||||
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
|
||||
)
|
||||
input_dtype: InputDTypeEnum = Field(
|
||||
default=InputDTypeEnum.int, title="Model Input D Type"
|
||||
)
|
||||
model_type: ModelTypeEnum = Field(
|
||||
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
|
||||
)
|
||||
|
@ -54,7 +54,7 @@ class ONNXDetector(DetectionApi):
|
||||
|
||||
logger.info(f"ONNX: {path} loaded")
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
def detect_raw(self, tensor_input: np.ndarray):
|
||||
model_input_name = self.model.get_inputs()[0].name
|
||||
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
||||
|
||||
|
@ -12,7 +12,11 @@ from setproctitle import setproctitle
|
||||
|
||||
import frigate.util as util
|
||||
from frigate.detectors import create_detector
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
|
||||
from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
InputDTypeEnum,
|
||||
InputTensorEnum,
|
||||
)
|
||||
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
|
||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||
from frigate.util.image import SharedMemoryFrameManager
|
||||
@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector):
|
||||
self.input_transform = tensor_transform(
|
||||
detector_config.model.input_tensor
|
||||
)
|
||||
|
||||
self.dtype = detector_config.model.input_dtype
|
||||
else:
|
||||
self.input_transform = None
|
||||
self.dtype = InputDTypeEnum.int
|
||||
|
||||
self.detect_api = create_detector(detector_config)
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
def detect(self, tensor_input: np.ndarray, threshold=0.4):
|
||||
detections = []
|
||||
|
||||
raw_detections = self.detect_raw(tensor_input)
|
||||
@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector):
|
||||
self.fps.update()
|
||||
return detections
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
def detect_raw(self, tensor_input: np.ndarray):
|
||||
if self.input_transform:
|
||||
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||
|
||||
if self.dtype == InputDTypeEnum.float:
|
||||
tensor_input = tensor_input.astype(np.float32)
|
||||
|
||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@ except ImportError:
|
||||
|
||||
|
||||
def get_ort_providers(
|
||||
force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False
|
||||
force_cpu: bool = False, device: str = "AUTO", requires_fp16: bool = False
|
||||
) -> tuple[list[str], list[dict[str, any]]]:
|
||||
if force_cpu:
|
||||
return (
|
||||
@ -38,7 +38,25 @@ def get_ort_providers(
|
||||
)
|
||||
elif provider == "TensorrtExecutionProvider":
|
||||
# TensorrtExecutionProvider uses too much memory without options to control it
|
||||
pass
|
||||
# so it is not enabled by default
|
||||
if device == "Tensorrt":
|
||||
os.makedirs(
|
||||
"/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True
|
||||
)
|
||||
providers.append(provider)
|
||||
options.append(
|
||||
{
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
"trt_fp16_enable": requires_fp16
|
||||
and os.environ.get("USE_FP_16", "True") != "False",
|
||||
"trt_timing_cache_enable": True,
|
||||
"trt_engine_cache_enable": True,
|
||||
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
|
||||
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
|
||||
}
|
||||
)
|
||||
else:
|
||||
continue
|
||||
elif provider == "OpenVINOExecutionProvider":
|
||||
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
|
||||
providers.append(provider)
|
||||
@ -46,7 +64,7 @@ def get_ort_providers(
|
||||
{
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
"cache_dir": "/config/model_cache/openvino/ort",
|
||||
"device_type": openvino_device,
|
||||
"device_type": device,
|
||||
}
|
||||
)
|
||||
elif provider == "CPUExecutionProvider":
|
||||
|
Loading…
Reference in New Issue
Block a user