about summary refs log tree commit diff
path: root/pkgs/development/cuda-modules/generic-builders
diff options
context:
space:
mode:
authorSomeone Serge <sergei.kozlukov@aalto.fi>2024-01-20 00:48:14 +0000
committerSomeone Serge <sergei.kozlukov@aalto.fi>2024-01-20 01:27:24 +0000
commit5cf0f04f95f73888da861e0c1249dd9b581f30ab (patch)
treef8a101479ce8cb5be6c5f9c823fd24346a527037 /pkgs/development/cuda-modules/generic-builders
parent9a33f8ce5ba7329600b296996d944ec9ddaa7d38 (diff)
cudaPackagesGoogle.cudnn_8_6: ensure present on all platforms
...including jetson

Previously:
❯ nix eval -f . --argstr system aarch64-linux --arg config '{ cudaCapabilities = [ "6.2" ]; cudaEnableForwardCompat = false; cudaSupport = true; allowUnfree = true; }' -L cudaPackagesGoogle.cudnn_8_6.outPath
attribute ... in selection path .... not found

Now:
❯ nix eval -f . --argstr system aarch64-linux --arg config '{ cudaCapabilities = [ "6.2" ]; cudaEnableForwardCompat = false; cudaSupport = true; allowUnfree = true; }' -L cudaPackagesGoogle.cudnn_8_6.outPath
       error: Package ‘cudnn-8.6.0.163’ in ... is not available on the requested hostPlatform:
Diffstat (limited to 'pkgs/development/cuda-modules/generic-builders')
-rw-r--r--pkgs/development/cuda-modules/generic-builders/manifest.nix6
-rw-r--r--pkgs/development/cuda-modules/generic-builders/multiplex.nix18
2 files changed, 14 insertions, 10 deletions
diff --git a/pkgs/development/cuda-modules/generic-builders/manifest.nix b/pkgs/development/cuda-modules/generic-builders/manifest.nix
index 38b80ff29f77f..5e837fa36b5e6 100644
--- a/pkgs/development/cuda-modules/generic-builders/manifest.nix
+++ b/pkgs/development/cuda-modules/generic-builders/manifest.nix
@@ -47,6 +47,8 @@ let
   # The redistArch is the name of the architecture for which the redistributable is built.
   # It is `"unsupported"` if the redistributable is not supported on the target platform.
   redistArch = flags.getRedistArch hostPlatform.system;
+
+  sourceMatchesHost = flags.getNixSystem redistArch == stdenv.hostPlatform.system;
 in
 backendStdenv.mkDerivation (
   finalAttrs: {
@@ -136,7 +138,9 @@ backendStdenv.mkDerivation (
     # badPlatformsConditions :: AttrSet Bool
     # Sets `meta.badPlatforms = meta.platforms` if any of the conditions are true.
     # Example: Broken on a specific architecture when some condition is met (like targeting Jetson).
-    badPlatformsConditions = { };
+    badPlatformsConditions = {
+      "No source" = !sourceMatchesHost;
+    };
 
     # src :: Optional Derivation
     src = trivial.pipe redistArch [
diff --git a/pkgs/development/cuda-modules/generic-builders/multiplex.nix b/pkgs/development/cuda-modules/generic-builders/multiplex.nix
index c1838ac0cf576..abe8ad242a3a3 100644
--- a/pkgs/development/cuda-modules/generic-builders/multiplex.nix
+++ b/pkgs/development/cuda-modules/generic-builders/multiplex.nix
@@ -52,7 +52,7 @@ let
   # - Package: ../modules/${pname}/releases/package.nix
 
   # FIXME: do this at the module system level
-  propagatePlatforms = lib.mapAttrs (platform: subset: map (r: r // { inherit platform; }) subset);
+  propagatePlatforms = lib.mapAttrs (redistArch: packages: map (p: { inherit redistArch; } // p) packages);
 
   # All releases across all platforms
   # See ../modules/${pname}/releases/releases.nix
@@ -67,8 +67,7 @@ let
   # isSupported :: Package -> Bool
   isSupported =
     package:
-    # The `platform` attribute of the package is NVIDIA's name for a redistributable architecture.
-    redistArch == package.platform
+    redistArch == package.redistArch
     && strings.versionAtLeast cudaVersion package.minCudaVersion
     && strings.versionAtLeast package.maxCudaVersion cudaVersion;
 
@@ -77,12 +76,15 @@ let
   # Value is `"unsupported"` if the platform is not supported.
   redistArch = flags.getRedistArch hostPlatform.system;
 
-  allReleases = lists.flatten (builtins.attrValues releaseSets);
+  preferable =
+    p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionAtLeast p1.version p2.version);
 
   # All the supported packages we can build for our platform.
   # perSystemReleases :: List Package
-  perSystemReleases = lib.pipe (releaseSets.${redistArch} or [ ])
+  allReleases = lib.pipe releaseSets
     [
+      (builtins.attrValues)
+      (lists.flatten)
       (builtins.groupBy (p: lib.versions.majorMinor p.version))
       (builtins.mapAttrs (_: builtins.sort preferable))
       (builtins.mapAttrs (_: lib.take 1))
@@ -90,8 +92,6 @@ let
       (builtins.concatMap lib.trivial.id)
     ];
 
-  preferable =
-    p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionAtLeast p1.version p2.version);
   newest = builtins.head (builtins.sort preferable allReleases);
 
   # A function which takes the `final` overlay and the `package` being built and returns
@@ -115,7 +115,7 @@ let
       buildPackage =
         package:
         let
-          shims = final.callPackage shimsFn {inherit package redistArch;};
+          shims = final.callPackage shimsFn {inherit package; inherit (package) redistArch; };
           name = computeName package;
           drv = final.callPackage ./manifest.nix {
             inherit pname;
@@ -127,7 +127,7 @@ let
         attrsets.nameValuePair name fixedDrv;
 
       # versionedDerivations :: AttrSet Derivation
-      versionedDerivations = builtins.listToAttrs (lists.map buildPackage perSystemReleases);
+      versionedDerivations = builtins.listToAttrs (lists.map buildPackage allReleases);
 
       defaultDerivation = { ${pname} = (buildPackage newest).value; };
     in