diff options
Diffstat (limited to 'pkgs/development/python-modules/torch/gpu-checks.nix')
-rw-r--r-- | pkgs/development/python-modules/torch/gpu-checks.nix | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/torch/gpu-checks.nix b/pkgs/development/python-modules/torch/gpu-checks.nix new file mode 100644 index 0000000000000..55a4b45f71522 --- /dev/null +++ b/pkgs/development/python-modules/torch/gpu-checks.nix @@ -0,0 +1,40 @@ +{ + lib, + torchWithCuda, + torchWithRocm, + callPackage, +}: + +let + accelAvailable = + { + feature, + versionAttr, + torch, + cudaPackages, + }: + cudaPackages.writeGpuTestPython + { + inherit feature; + libraries = [ torch ]; + name = "${feature}Available"; + } + '' + import torch + message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}" + assert torch.cuda.is_available() and torch.version.${versionAttr}, message + print(message) + ''; +in +{ + tester-cudaAvailable = callPackage accelAvailable { + feature = "cuda"; + versionAttr = "cuda"; + torch = torchWithCuda; + }; + tester-rocmAvailable = callPackage accelAvailable { + feature = "rocm"; + versionAttr = "hip"; + torch = torchWithRocm; + }; +} |