Add ability to configure model input dtype (#14659)

* Add input type for dtype

* Add ability to manually enable TRT execution provider

* Formatting
This commit is contained in:
Nicolas Mowen
2024-10-29 09:28:05 -06:00
committed by GitHub
parent abd22d2566
commit 4e25bebdd0
4 changed files with 44 additions and 7 deletions

View File

@@ -13,7 +13,7 @@ except ImportError:
def get_ort_providers(
force_cpu: bool = False, openvino_device: str = "AUTO", requires_fp16: bool = False
force_cpu: bool = False, device: str = "AUTO", requires_fp16: bool = False
) -> tuple[list[str], list[dict[str, any]]]:
if force_cpu:
return (
@@ -38,7 +38,25 @@ def get_ort_providers(
)
elif provider == "TensorrtExecutionProvider":
# TensorrtExecutionProvider uses too much memory without options to control it
pass
# so it is not enabled by default
if device == "Tensorrt":
os.makedirs(
"/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True
)
providers.append(provider)
options.append(
{
"arena_extend_strategy": "kSameAsRequested",
"trt_fp16_enable": requires_fp16
and os.environ.get("USE_FP_16", "True") != "False",
"trt_timing_cache_enable": True,
"trt_engine_cache_enable": True,
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
}
)
else:
continue
elif provider == "OpenVINOExecutionProvider":
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
providers.append(provider)
@@ -46,7 +64,7 @@ def get_ort_providers(
{
"arena_extend_strategy": "kSameAsRequested",
"cache_dir": "/config/model_cache/openvino/ort",
"device_type": openvino_device,
"device_type": device,
}
)
elif provider == "CPUExecutionProvider":