about summary refs log tree commit diff
path: root/pkgs/development/cuda-modules/nccl
diff options
context:
space:
mode:
authorConnor Baker <connor.baker@tweag.io>2024-04-03 22:27:03 +0000
committerConnor Baker <connor.baker@tweag.io>2024-04-18 16:31:18 +0000
commit0494330fad2dde171bbb3f09795e4e6347f50ed8 (patch)
tree1d1f155810f79435ae77b73079508c130d8a002e /pkgs/development/cuda-modules/nccl
parent5ed9f23d218223ce5ea280e43bdcf6739d8ace07 (diff)
cudaPackages.nccl: switch to cudaAtLeast, cudaOlder, and __structuredAttrs
Diffstat (limited to 'pkgs/development/cuda-modules/nccl')
-rw-r--r--pkgs/development/cuda-modules/nccl/default.nix30
1 files changed, 16 insertions, 14 deletions
diff --git a/pkgs/development/cuda-modules/nccl/default.nix b/pkgs/development/cuda-modules/nccl/default.nix
index 9db08c722acd7..ec84b8dfb9062 100644
--- a/pkgs/development/cuda-modules/nccl/default.nix
+++ b/pkgs/development/cuda-modules/nccl/default.nix
@@ -17,9 +17,10 @@ let
     cuda_cccl
     cuda_cudart
     cuda_nvcc
+    cudaAtLeast
     cudaFlags
+    cudaOlder
     cudatoolkit
-    cudaVersion
     ;
 in
 backendStdenv.mkDerivation (finalAttrs: {
@@ -33,6 +34,7 @@ backendStdenv.mkDerivation (finalAttrs: {
     hash = "sha256-ModIjD6RaRD/57a/PA1oTgYhZsAQPrrvhl5sNVXnO6c=";
   };
 
+  __structuredAttrs = true;
   strictDeps = true;
 
   outputs = [
@@ -46,12 +48,12 @@ backendStdenv.mkDerivation (finalAttrs: {
       autoAddDriverRunpath
       python3
     ]
-    ++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ]
-    ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ cuda_nvcc ];
+    ++ lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
+    ++ lib.optionals (cudaAtLeast "11.4") [ cuda_nvcc ];
 
   buildInputs =
-    lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ]
-    ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [
+    lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
+    ++ lib.optionals (cudaAtLeast "11.4") [
       cuda_nvcc.dev # crt/host_config.h
       cuda_cudart
     ]
@@ -59,25 +61,25 @@ backendStdenv.mkDerivation (finalAttrs: {
     # against other version, like below, it's important that we use the same format. Otherwise,
     # we'll get incorrect results.
     # For example, lib.versionAtLeast "12.0" "12.0.0" == false.
-    ++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [ cuda_cccl ];
+    ++ lib.optionals (cudaAtLeast "12.0") [ cuda_cccl ];
 
   env.NIX_CFLAGS_COMPILE = toString [ "-Wno-unused-function" ];
 
-  preConfigure = ''
+  postPatch = ''
     patchShebangs ./src/device/generate.py
-    makeFlagsArray+=(
-      "NVCC_GENCODE=${lib.concatStringsSep " " cudaFlags.gencode}"
-    )
   '';
 
-  makeFlags =
-    [ "PREFIX=$(out)" ]
-    ++ lib.optionals (lib.versionOlder cudaVersion "11.4") [
+  makeFlagsArray =
+    [
+      "PREFIX=$(out)"
+      "NVCC_GENCODE=${cudaFlags.gencodeString}"
+    ]
+    ++ lib.optionals (cudaOlder "11.4") [
       "CUDA_HOME=${cudatoolkit}"
       "CUDA_LIB=${lib.getLib cudatoolkit}/lib"
       "CUDA_INC=${lib.getDev cudatoolkit}/include"
     ]
-    ++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [
+    ++ lib.optionals (cudaAtLeast "11.4") [
       "CUDA_HOME=${cuda_nvcc}"
       "CUDA_LIB=${lib.getLib cuda_cudart}/lib"
       "CUDA_INC=${lib.getDev cuda_cudart}/include"