Add support for specifying tensorrt device (#14898)

This commit is contained in:
Nicolas Mowen 2024-11-10 07:43:24 -07:00 committed by GitHub
parent a68c7f4ef8
commit 96c0c43dc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,10 +33,12 @@ def get_ort_providers(
for provider in ort.get_available_providers(): for provider in ort.get_available_providers():
if provider == "CUDAExecutionProvider": if provider == "CUDAExecutionProvider":
device_id = 0 if not device.isdigit() else int(device)
providers.append(provider) providers.append(provider)
options.append( options.append(
{ {
"arena_extend_strategy": "kSameAsRequested", "arena_extend_strategy": "kSameAsRequested",
"device_id": device_id,
} }
) )
elif provider == "TensorrtExecutionProvider": elif provider == "TensorrtExecutionProvider":
@ -46,10 +48,11 @@ def get_ort_providers(
os.makedirs( os.makedirs(
"/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True "/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True
) )
device_id = 0 if not device.isdigit() else int(device)
providers.append(provider) providers.append(provider)
options.append( options.append(
{ {
"arena_extend_strategy": "kSameAsRequested", "device_id": device_id,
"trt_fp16_enable": requires_fp16 "trt_fp16_enable": requires_fp16
and os.environ.get("USE_FP_16", "True") != "False", and os.environ.get("USE_FP_16", "True") != "False",
"trt_timing_cache_enable": True, "trt_timing_cache_enable": True,