about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torch/tests.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/torch/tests.nix')
-rw-r--r--pkgs/development/python-modules/torch/tests.nix30
1 files changed, 29 insertions, 1 deletions
diff --git a/pkgs/development/python-modules/torch/tests.nix b/pkgs/development/python-modules/torch/tests.nix
index 5a46d0886868c..e3f2ca44ba5a9 100644
--- a/pkgs/development/python-modules/torch/tests.nix
+++ b/pkgs/development/python-modules/torch/tests.nix
@@ -1,3 +1,31 @@
 { callPackage }:
 
-callPackage ./gpu-checks.nix { }
+rec {
+  # To perform the runtime check use either
+  # `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or
+  # `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox)
+  tester-cudaAvailable = callPackage ./mk-runtime-check.nix {
+    feature = "cuda";
+    versionAttr = "cuda";
+    libraries = ps: [ ps.torchWithCuda ];
+  };
+  tester-rocmAvailable = callPackage ./mk-runtime-check.nix {
+    feature = "rocm";
+    versionAttr = "hip";
+    libraries = ps: [ ps.torchWithRocm ];
+  };
+
+  compileCpu = tester-compileCpu.gpuCheck;
+  tester-compileCpu = callPackage ./mk-torch-compile-check.nix {
+    feature = null;
+    libraries = ps: [ ps.torch ];
+  };
+  tester-compileCuda = callPackage ./mk-torch-compile-check.nix {
+    feature = "cuda";
+    libraries = ps: [ ps.torchWithCuda ];
+  };
+  tester-compileRocm = callPackage ./mk-torch-compile-check.nix {
+    feature = "rocm";
+    libraries = ps: [ ps.torchWithRocm ];
+  };
+}