about summary refs log tree commit diff
path: root/pkgs/top-level/cuda-packages.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/top-level/cuda-packages.nix')
-rw-r--r--pkgs/top-level/cuda-packages.nix15
1 files changed, 14 insertions, 1 deletions
diff --git a/pkgs/top-level/cuda-packages.nix b/pkgs/top-level/cuda-packages.nix
index a2f49a98ccd53..3912422785bc4 100644
--- a/pkgs/top-level/cuda-packages.nix
+++ b/pkgs/top-level/cuda-packages.nix
@@ -24,6 +24,7 @@ let
 
     buildCuTensorPackage = final.callPackage ../development/libraries/science/math/cutensor/generic.nix;
 
+    # FIXME: Include non-x86_64 platforms
     cuTensorVersions = {
       "1.2.2.5" = {
         hash = "sha256-lU7iK4DWuC/U3s1Ct/rq2Gr3w4F2U7RYYgpmF05bibY=";
@@ -31,12 +32,24 @@ let
       "1.5.0.3" = {
         hash = "sha256-T96+lPC6OTOkIs/z3QWg73oYVSyidN0SVkBWmT9VRx0=";
       };
+      "2.0.0.7" = {
+        hash = "sha256-32M4rtGOW2rgxJUhBT0WBtKkHhh9f17M+RgK9rvE72g=";
+      };
     };
 
     inherit (final) cudaMajorMinorVersion cudaMajorVersion;
 
+    cudaToCutensor = {
+      "10" = "1.2.25";
+      "11" = "1.5.0.3";
+      "12" = "2.0.0.7";
+    };
+
+    versionNewer = lib.flip lib.versionOlder;
+    latestVersion = (builtins.head (lib.sort versionNewer (builtins.attrNames cuTensorVersions)));
+
     cutensor = buildCuTensorPackage rec {
-      version = if cudaMajorMinorVersion == "10.1" then "1.2.2.5" else "1.5.0.3";
+      version = cudaToCutensor.${cudaMajorVersion} or latestVersion;
       inherit (cuTensorVersions.${version}) hash;
       # This can go into generic.nix
       libPath = "lib/${if cudaMajorVersion == "10" then cudaMajorMinorVersion else cudaMajorVersion}";