Try to limit nvidia GPU queries to included GPUs (#5356)

* Try to limit nvidia GPU queries to included GPUs

* ignore non digit GPU indexes

* formatting

* Formatting

* Remove trailing spaces

---------

Co-authored-by: Nicolas Mowen <nickmowen213@gmail.com>
This commit is contained in:
jvrobert 2023-02-03 18:34:07 -07:00 committed by GitHub
parent db131d4971
commit 7083a5c9b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -926,6 +926,17 @@ def get_nvidia_gpu_stats() -> dict[str, str]:
"--format=csv", "--format=csv",
] ]
if (
"CUDA_VISIBLE_DEVICES" in os.environ
and os.environ["CUDA_VISIBLE_DEVICES"].isdigit()
):
nvidia_smi_command.extend(["--id", os.environ["CUDA_VISIBLE_DEVICES"]])
elif (
"NVIDIA_VISIBLE_DEVICES" in os.environ
and os.environ["NVIDIA_VISIBLE_DEVICES"].isdigit()
):
nvidia_smi_command.extend(["--id", os.environ["NVIDIA_VISIBLE_DEVICES"]])
p = sp.run( p = sp.run(
nvidia_smi_command, nvidia_smi_command,
encoding="ascii", encoding="ascii",