about summary refs log tree commit diff
path: root/pkgs/development/cuda-modules/flags.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/cuda-modules/flags.nix')
-rw-r--r--pkgs/development/cuda-modules/flags.nix49
1 files changed, 26 insertions, 23 deletions
diff --git a/pkgs/development/cuda-modules/flags.nix b/pkgs/development/cuda-modules/flags.nix
index 7b7922085ff66..139952bc9dfd9 100644
--- a/pkgs/development/cuda-modules/flags.nix
+++ b/pkgs/development/cuda-modules/flags.nix
@@ -53,11 +53,11 @@ let
   isDefault =
     gpu:
     let
-      inherit (gpu) dontDefaultAfter;
+      inherit (gpu) dontDefaultAfter isJetson;
       newGpu = dontDefaultAfter == null;
       recentGpu = newGpu || strings.versionAtLeast dontDefaultAfter cudaVersion;
     in
-    recentGpu;
+    recentGpu && !isJetson;
 
   # supportedGpus :: List Gpu
   # GPUs which are supported by the provided CUDA version.
@@ -100,11 +100,11 @@ let
   ];
 
   # Find the intersection with the user-specified list of cudaCapabilities.
-  # NOTE: Jetson devices are never built by default because they cannot be targeted along
+  # NOTE: Jetson devices are never built by default because they cannot be targeted along with
   # non-Jetson devices and require an aarch64 host platform. As such, if they're present anywhere,
   # they must be in the user-specified cudaCapabilities.
   # NOTE: We don't need to worry about mixes of Jetson and non-Jetson devices here -- there's
-  # sanity-checking for all that in cudaFlags.
+  # sanity-checking for all that in below.
   jetsonTargets = lists.intersectLists jetsonComputeCapabilities cudaCapabilities;
 
   # dropDot :: String -> String
@@ -146,14 +146,15 @@ let
       builtins.throw "Unsupported Nix system: ${nixSystem}";
 
   # Maps NVIDIA redist arch to Nix system.
+  # It is imperative that we include the boolean condition based on jetsonTargets to ensure
+  # we don't advertise availability of packages only available on server-grade ARM
+  # as being available for the Jetson, since both `linux-sbsa` and `linux-aarch64` are
+  # mapped to the Nix system `aarch64-linux`.
   getNixSystem =
     redistArch:
-    if
-      lists.elem redistArch [
-        "linux-aarch64"
-        "linux-sbsa"
-      ]
-    then
+    if redistArch == "linux-sbsa" && jetsonTargets == [] then
+      "aarch64-linux"
+    else if redistArch == "linux-aarch64" && jetsonTargets != [] then
       "aarch64-linux"
     else if redistArch == "linux-x86_64" then
       "x86_64-linux"
@@ -217,26 +218,28 @@ let
       # isJetsonBuild :: Boolean
       isJetsonBuild =
         let
-          # List of booleans representing whether any of the currently targeted capabilities are
-          # Jetson devices.
-          # isJetsons :: List Boolean
-          isJetsons =
-            lists.map (trivial.flip builtins.getAttr cudaComputeCapabilityToIsJetson)
+          requestedJetsonDevices =
+            lists.filter (cap: cudaComputeCapabilityToIsJetson.${cap})
+              cudaCapabilities;
+          requestedNonJetsonDevices =
+            lists.filter (cap: !(builtins.elem cap requestedJetsonDevices))
               cudaCapabilities;
-          anyJetsons = lists.any (trivial.id) isJetsons;
-          allJetsons = lists.all (trivial.id) isJetsons;
-          hostIsAarch64 = hostPlatform.isAarch64;
+          jetsonBuildSufficientCondition = requestedJetsonDevices != [];
+          jetsonBuildNecessaryCondition = requestedNonJetsonDevices == [] && hostPlatform.isAarch64;
         in
-        trivial.throwIfNot (anyJetsons -> (allJetsons && hostIsAarch64))
+        trivial.throwIf (jetsonBuildSufficientCondition && !jetsonBuildNecessaryCondition)
           ''
             Jetson devices cannot be targeted with non-Jetson devices. Additionally, they require hostPlatform to be aarch64.
             You requested ${builtins.toJSON cudaCapabilities} for host platform ${hostPlatform.system}.
+            Requested Jetson devices: ${builtins.toJSON requestedJetsonDevices}.
+            Requested non-Jetson devices: ${builtins.toJSON requestedNonJetsonDevices}.
             Exactly one of the following must be true:
-            - All CUDA capabilities belong to Jetson devices (${trivial.boolToString allJetsons}) and the hostPlatform is aarch64 (${trivial.boolToString hostIsAarch64}).
-            - No CUDA capabilities belong to Jetson devices (${trivial.boolToString (!anyJetsons)}).
+            - All CUDA capabilities belong to Jetson devices and hostPlatform is aarch64.
+            - No CUDA capabilities belong to Jetson devices.
             See ${./gpus.nix} for a list of architectures supported by this version of Nixpkgs.
           ''
-          allJetsons;
+          jetsonBuildSufficientCondition
+        && jetsonBuildNecessaryCondition;
     };
 in
 # When changing names or formats: pause, validate, and update the assert
@@ -283,7 +286,7 @@ assert let
   };
   actualWrapped = (builtins.tryEval (builtins.deepSeq actual actual)).value;
 in
-asserts.assertMsg (expected == actualWrapped) ''
+asserts.assertMsg ((strings.versionAtLeast cudaVersion "11.2") -> (expected == actualWrapped)) ''
   This test should only fail when using a version of CUDA older than 11.2, the first to support
   8.6.
   Expected: ${builtins.toJSON expected}