mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-21 19:07:46 +01:00
Add ability to configure model input dtype (#14659)
* Add input type for dtype * Add ability to manually enable TRT execution provider * Formatting
This commit is contained in:
parent
abd22d2566
commit
4e25bebdd0
@ -27,6 +27,11 @@ class InputTensorEnum(str, Enum):
|
|||||||
nhwc = "nhwc"
|
nhwc = "nhwc"
|
||||||
|
|
||||||
|
|
||||||
|
class InputDTypeEnum(str, Enum):
|
||||||
|
float = "float"
|
||||||
|
int = "int"
|
||||||
|
|
||||||
|
|
||||||
class ModelTypeEnum(str, Enum):
|
class ModelTypeEnum(str, Enum):
|
||||||
ssd = "ssd"
|
ssd = "ssd"
|
||||||
yolox = "yolox"
|
yolox = "yolox"
|
||||||
@ -53,6 +58,9 @@ class ModelConfig(BaseModel):
|
|||||||
input_pixel_format: PixelFormatEnum = Field(
|
input_pixel_format: PixelFormatEnum = Field(
|
||||||
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
|
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
|
||||||
)
|
)
|
||||||
|
input_dtype: InputDTypeEnum = Field(
|
||||||
|
default=InputDTypeEnum.int, title="Model Input D Type"
|
||||||
|
)
|
||||||
model_type: ModelTypeEnum = Field(
|
model_type: ModelTypeEnum = Field(
|
||||||
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
|
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
|
||||||
)
|
)
|
||||||
|
@ -54,7 +54,7 @@ class ONNXDetector(DetectionApi):
|
|||||||
|
|
||||||
logger.info(f"ONNX: {path} loaded")
|
logger.info(f"ONNX: {path} loaded")
|
||||||
|
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input: np.ndarray):
|
||||||
model_input_name = self.model.get_inputs()[0].name
|
model_input_name = self.model.get_inputs()[0].name
|
||||||
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
||||||
|
|
||||||
|
@ -12,7 +12,11 @@ from setproctitle import setproctitle
|
|||||||
|
|
||||||
import frigate.util as util
|
import frigate.util as util
|
||||||
from frigate.detectors import create_detector
|
from frigate.detectors import create_detector
|
||||||
from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum
|
from frigate.detectors.detector_config import (
|
||||||
|
BaseDetectorConfig,
|
||||||
|
InputDTypeEnum,
|
||||||
|
InputTensorEnum,
|
||||||
|
)
|
||||||
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
|
from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY
|
||||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||||
from frigate.util.image import SharedMemoryFrameManager
|
from frigate.util.image import SharedMemoryFrameManager
|
||||||
@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
self.input_transform = tensor_transform(
|
self.input_transform = tensor_transform(
|
||||||
detector_config.model.input_tensor
|
detector_config.model.input_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.dtype = detector_config.model.input_dtype
|
||||||
else:
|
else:
|
||||||
self.input_transform = None
|
self.input_transform = None
|
||||||
|
self.dtype = InputDTypeEnum.int
|
||||||
|
|
||||||
self.detect_api = create_detector(detector_config)
|
self.detect_api = create_detector(detector_config)
|
||||||
|
|
||||||
def detect(self, tensor_input, threshold=0.4):
|
def detect(self, tensor_input: np.ndarray, threshold=0.4):
|
||||||
detections = []
|
detections = []
|
||||||
|
|
||||||
raw_detections = self.detect_raw(tensor_input)
|
raw_detections = self.detect_raw(tensor_input)
|
||||||
@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
self.fps.update()
|
self.fps.update()
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input: np.ndarray):
|
||||||
if self.input_transform:
|
if self.input_transform:
|
||||||
tensor_input = np.transpose(tensor_input, self.input_transform)
|
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||||
|
|
||||||
|
if self.dtype == InputDTypeEnum.float:
|
||||||
|
tensor_input = tensor_input.astype(np.float32)
|
||||||
|
|
||||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
def get_ort_providers(
|
def get_ort_providers(
|
||||||
force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False
|
force_cpu: bool = False, device: str = "AUTO", requires_fp16: bool = False
|
||||||
) -> tuple[list[str], list[dict[str, any]]]:
|
) -> tuple[list[str], list[dict[str, any]]]:
|
||||||
if force_cpu:
|
if force_cpu:
|
||||||
return (
|
return (
|
||||||
@ -38,7 +38,25 @@ def get_ort_providers(
|
|||||||
)
|
)
|
||||||
elif provider == "TensorrtExecutionProvider":
|
elif provider == "TensorrtExecutionProvider":
|
||||||
# TensorrtExecutionProvider uses too much memory without options to control it
|
# TensorrtExecutionProvider uses too much memory without options to control it
|
||||||
pass
|
# so it is not enabled by default
|
||||||
|
if device == "Tensorrt":
|
||||||
|
os.makedirs(
|
||||||
|
"/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True
|
||||||
|
)
|
||||||
|
providers.append(provider)
|
||||||
|
options.append(
|
||||||
|
{
|
||||||
|
"arena_extend_strategy": "kSameAsRequested",
|
||||||
|
"trt_fp16_enable": requires_fp16
|
||||||
|
and os.environ.get("USE_FP_16", "True") != "False",
|
||||||
|
"trt_timing_cache_enable": True,
|
||||||
|
"trt_engine_cache_enable": True,
|
||||||
|
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
|
||||||
|
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
elif provider == "OpenVINOExecutionProvider":
|
elif provider == "OpenVINOExecutionProvider":
|
||||||
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
|
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
@ -46,7 +64,7 @@ def get_ort_providers(
|
|||||||
{
|
{
|
||||||
"arena_extend_strategy": "kSameAsRequested",
|
"arena_extend_strategy": "kSameAsRequested",
|
||||||
"cache_dir": "/config/model_cache/openvino/ort",
|
"cache_dir": "/config/model_cache/openvino/ort",
|
||||||
"device_type": openvino_device,
|
"device_type": device,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif provider == "CPUExecutionProvider":
|
elif provider == "CPUExecutionProvider":
|
||||||
|
Loading…
Reference in New Issue
Block a user