about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix84
1 files changed, 54 insertions, 30 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index b77a7de7b3575..7410400ed05a5 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -55,7 +55,6 @@ let
   inherit (cudaPackages)
     cudaFlags
     cudaVersion
-    cudnn
     nccl
     ;
 
@@ -80,18 +79,26 @@ let
     broken = effectiveStdenv.isDarwin || nccl.meta.unsupported;
   };
 
+  # Bazel wants a merged cudnn at configuration time
+  cudnnMerged = symlinkJoin {
+    name = "cudnn-merged";
+    paths = with cudaPackages; [
+      (lib.getDev cudnn)
+      (lib.getLib cudnn)
+    ];
+  };
+
   # These are necessary at build time and run time.
   cuda_libs_joined = symlinkJoin {
     name = "cuda-joined";
     paths = with cudaPackages; [
-      cuda_cudart.lib # libcudart.so
-      cuda_cudart.static # libcudart_static.a
-      cuda_cupti.lib # libcupti.so
-      libcublas.lib # libcublas.so
-      libcufft.lib # libcufft.so
-      libcurand.lib # libcurand.so
-      libcusolver.lib # libcusolver.so
-      libcusparse.lib # libcusparse.so
+      (lib.getLib cuda_cudart) # libcudart.so
+      (lib.getLib cuda_cupti) # libcupti.so
+      (lib.getLib libcublas) # libcublas.so
+      (lib.getLib libcufft) # libcufft.so
+      (lib.getLib libcurand) # libcurand.so
+      (lib.getLib libcusolver) # libcusolver.so
+      (lib.getLib libcusparse) # libcusparse.so
     ];
   };
   # These are only necessary at build time.
@@ -101,20 +108,23 @@ let
       cuda_libs_joined
 
       # Binaries
-      cudaPackages.cuda_nvcc.bin # nvcc
+      (lib.getBin cuda_nvcc) # nvcc
+
+      # Archives
+      (lib.getOutput "static" cuda_cudart) # libcudart_static.a
 
       # Headers
-      cuda_cccl.dev # block_load.cuh
-      cuda_cudart.dev # cuda.h
-      cuda_cupti.dev # cupti.h
-      cuda_nvcc.dev # See https://github.com/google/jax/issues/19811
-      cuda_nvml_dev # nvml.h
-      cuda_nvtx.dev # nvToolsExt.h
-      libcublas.dev # cublas_api.h
-      libcufft.dev # cufft.h
-      libcurand.dev # curand.h
-      libcusolver.dev # cusolver_common.h
-      libcusparse.dev # cusparse.h
+      (lib.getDev cuda_cccl) # block_load.cuh
+      (lib.getDev cuda_cudart) # cuda.h
+      (lib.getDev cuda_cupti) # cupti.h
+      (lib.getDev cuda_nvcc) # See https://github.com/google/jax/issues/19811
+      (lib.getDev cuda_nvml_dev) # nvml.h
+      (lib.getDev cuda_nvtx) # nvToolsExt.h
+      (lib.getDev libcublas) # cublas_api.h
+      (lib.getDev libcufft) # cufft.h
+      (lib.getDev libcurand) # curand.h
+      (lib.getDev libcusolver) # cusolver_common.h
+      (lib.getDev libcusparse) # cusparse.h
     ];
   };
 
@@ -308,10 +318,10 @@ let
       + lib.optionalString cudaSupport ''
         build --config=cuda
         build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
-        build --action_env CUDNN_INSTALL_PATH="${cudnn}"
-        build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnn},${nccl}"
+        build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}"
+        build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}"
         build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}"
-        build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
+        build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}"
         build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
       ''
       +
@@ -374,13 +384,20 @@ let
       sha256 =
         (
           if cudaSupport then
-            { x86_64-linux = "sha256-vUoAPkYKEnHkV4fw6BI0mCeuP2e8BMCJnVuZMm9LwSA="; }
+            { x86_64-linux = "sha256-Uf0VMRE0jgaWEYiuphWkWloZ5jMeqaWBl3lSvk2y1HI="; }
           else
             {
-              x86_64-linux = "sha256-R1TIIyyyLlDqAlUkuhJhtyTxZMra2q5S/jX0OCInsEQ=";
-              aarch64-linux = "sha256-P5JEmJljN1DeRA0dNkzyosKzRnJH+5SD2aWdV5JsoiY=";
+              x86_64-linux = "sha256-NzJJg6NlrPGMiR8Fn8u4+fu0m+AulfmN5Xqk63Um6sw=";
+              aarch64-linux = "sha256-Ro3qzrUxSR+3TH6ROoJTq+dLSufrDN/9oEo2MRkx7wM=";
             }
         ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
+
+        # Non-reproducible fetch https://github.com/NixOS/nixpkgs/issues/321920#issuecomment-2184940546
+        preInstall = ''
+          cat << \EOF > "$bazelOut/external/go_sdk/versions.json"
+          []
+          EOF
+        '';
     };
 
     buildAttrs = {
@@ -418,7 +435,7 @@ let
       throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
 in
 buildPythonPackage {
-  inherit meta pname version;
+  inherit pname version;
   format = "wheel";
 
   src =
@@ -431,13 +448,13 @@ buildPythonPackage {
   # for more info.
   postInstall = lib.optionalString cudaSupport ''
     mkdir -p $out/bin
-    ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
+    ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/bin/ptxas
 
     find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
       patchelf --add-rpath "${
         lib.makeLibraryPath [
           cuda_libs_joined
-          cudnn
+          (lib.getLib cudaPackages.cudnn)
           nccl
         ]
       }" "$lib"
@@ -471,4 +488,11 @@ buildPythonPackage {
   # Without it there are complaints about libcudart.so.11.0 not being found
   # because RPATH path entries added above are stripped.
   dontPatchELF = cudaSupport;
+
+  passthru = {
+    # Note "bazel.*.tar.gz" can be accessed as `jaxlib.bazel-build.deps`
+    inherit bazel-build;
+  };
+
+  inherit meta;
 }