Fix CUDA graph config (#20135)

This commit is contained in:
Nicolas Mowen 2025-09-19 04:59:42 -06:00 committed by GitHub
parent 61d3b370b1
commit b8fd0a2b31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -420,16 +420,27 @@ def get_optimized_runner(
if device != "CPU" and is_openvino_gpu_npu_available():
return OpenVINOModelRunner(model_path, device, model_type, **kwargs)
ortSession = ort.InferenceSession(
model_path,
providers=providers,
provider_options=options,
)
if (
not CudaGraphRunner.is_complex_model(model_type)
and providers[0] == "CUDAExecutionProvider"
):
return CudaGraphRunner(ortSession, options[0]["device_id"])
options[0] = {
**options[0],
"enable_cuda_graph": True,
}
return CudaGraphRunner(
ort.InferenceSession(
model_path,
providers=providers,
provider_options=options,
),
options[0]["device_id"],
)
return ONNXModelRunner(ortSession)
return ONNXModelRunner(
ort.InferenceSession(
model_path,
providers=providers,
provider_options=options,
)
)