diff --git a/frigate/util/model.py b/frigate/util/model.py index 091bb0833..ce2c9538c 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -33,10 +33,12 @@ def get_ort_providers( for provider in ort.get_available_providers(): if provider == "CUDAExecutionProvider": + device_id = 0 if not device.isdigit() else int(device) providers.append(provider) options.append( { "arena_extend_strategy": "kSameAsRequested", + "device_id": device_id, } ) elif provider == "TensorrtExecutionProvider": @@ -46,10 +48,11 @@ def get_ort_providers( os.makedirs( "/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True ) + device_id = 0 if not device.isdigit() else int(device) providers.append(provider) options.append( { - "arena_extend_strategy": "kSameAsRequested", + "device_id": device_id, "trt_fp16_enable": requires_fp16 and os.environ.get("USE_FP_16", "True") != "False", "trt_timing_cache_enable": True,