about summary refs log tree commit diff
path: root/pkgs/development/cuda-modules/generic-builders/multiplex.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/cuda-modules/generic-builders/multiplex.nix')
-rw-r--r--pkgs/development/cuda-modules/generic-builders/multiplex.nix25
1 files changed, 16 insertions, 9 deletions
diff --git a/pkgs/development/cuda-modules/generic-builders/multiplex.nix b/pkgs/development/cuda-modules/generic-builders/multiplex.nix
index b8ac84bda9133..abe8ad242a3a3 100644
--- a/pkgs/development/cuda-modules/generic-builders/multiplex.nix
+++ b/pkgs/development/cuda-modules/generic-builders/multiplex.nix
@@ -52,7 +52,7 @@ let
   # - Package: ../modules/${pname}/releases/package.nix
 
   # FIXME: do this at the module system level
-  propagatePlatforms = lib.mapAttrs (platform: subset: map (r: r // { inherit platform; }) subset);
+  propagatePlatforms = lib.mapAttrs (redistArch: packages: map (p: { inherit redistArch; } // p) packages);
 
   # All releases across all platforms
   # See ../modules/${pname}/releases/releases.nix
@@ -67,8 +67,7 @@ let
   # isSupported :: Package -> Bool
   isSupported =
     package:
-    # The `platform` attribute of the package is NVIDIA's name for a redistributable architecture.
-    redistArch == package.platform
+    redistArch == package.redistArch
     && strings.versionAtLeast cudaVersion package.minCudaVersion
     && strings.versionAtLeast package.maxCudaVersion cudaVersion;
 
@@ -77,14 +76,22 @@ let
   # Value is `"unsupported"` if the platform is not supported.
   redistArch = flags.getRedistArch hostPlatform.system;
 
-  allReleases = lists.flatten (builtins.attrValues releaseSets);
+  preferable =
+    p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionAtLeast p1.version p2.version);
 
   # All the supported packages we can build for our platform.
   # perSystemReleases :: List Package
-  perSystemReleases = releaseSets.${redistArch} or [ ];
+  allReleases = lib.pipe releaseSets
+    [
+      (builtins.attrValues)
+      (lists.flatten)
+      (builtins.groupBy (p: lib.versions.majorMinor p.version))
+      (builtins.mapAttrs (_: builtins.sort preferable))
+      (builtins.mapAttrs (_: lib.take 1))
+      (builtins.attrValues)
+      (builtins.concatMap lib.trivial.id)
+    ];
 
-  preferable =
-    p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionAtLeast p1.version p2.version);
   newest = builtins.head (builtins.sort preferable allReleases);
 
   # A function which takes the `final` overlay and the `package` being built and returns
@@ -108,7 +115,7 @@ let
       buildPackage =
         package:
         let
-          shims = final.callPackage shimsFn {inherit package redistArch;};
+          shims = final.callPackage shimsFn {inherit package; inherit (package) redistArch; };
           name = computeName package;
           drv = final.callPackage ./manifest.nix {
             inherit pname;
@@ -120,7 +127,7 @@ let
         attrsets.nameValuePair name fixedDrv;
 
       # versionedDerivations :: AttrSet Derivation
-      versionedDerivations = builtins.listToAttrs (lists.map buildPackage perSystemReleases);
+      versionedDerivations = builtins.listToAttrs (lists.map buildPackage allReleases);
 
       defaultDerivation = { ${pname} = (buildPackage newest).value; };
     in