about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torch/gpu-checks.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/torch/gpu-checks.nix')
-rw-r--r--pkgs/development/python-modules/torch/gpu-checks.nix40
1 files changed, 40 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/torch/gpu-checks.nix b/pkgs/development/python-modules/torch/gpu-checks.nix
new file mode 100644
index 0000000000000..55a4b45f71522
--- /dev/null
+++ b/pkgs/development/python-modules/torch/gpu-checks.nix
@@ -0,0 +1,40 @@
+{
+  lib,
+  torchWithCuda,
+  torchWithRocm,
+  callPackage,
+}:
+
+let
+  accelAvailable =
+    {
+      feature,
+      versionAttr,
+      torch,
+      cudaPackages,
+    }:
+    cudaPackages.writeGpuTestPython
+      {
+        inherit feature;
+        libraries = [ torch ];
+        name = "${feature}Available";
+      }
+      ''
+        import torch
+        message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
+        assert torch.cuda.is_available() and torch.version.${versionAttr}, message
+        print(message)
+      '';
+in
+{
+  tester-cudaAvailable = callPackage accelAvailable {
+    feature = "cuda";
+    versionAttr = "cuda";
+    torch = torchWithCuda;
+  };
+  tester-rocmAvailable = callPackage accelAvailable {
+    feature = "rocm";
+    versionAttr = "hip";
+    torch = torchWithRocm;
+  };
+}