diff options
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/default.nix | 84 |
1 files changed, 54 insertions, 30 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index b77a7de7b3575..7410400ed05a5 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -55,7 +55,6 @@ let inherit (cudaPackages) cudaFlags cudaVersion - cudnn nccl ; @@ -80,18 +79,26 @@ let broken = effectiveStdenv.isDarwin || nccl.meta.unsupported; }; + # Bazel wants a merged cudnn at configuration time + cudnnMerged = symlinkJoin { + name = "cudnn-merged"; + paths = with cudaPackages; [ + (lib.getDev cudnn) + (lib.getLib cudnn) + ]; + }; + # These are necessary at build time and run time. cuda_libs_joined = symlinkJoin { name = "cuda-joined"; paths = with cudaPackages; [ - cuda_cudart.lib # libcudart.so - cuda_cudart.static # libcudart_static.a - cuda_cupti.lib # libcupti.so - libcublas.lib # libcublas.so - libcufft.lib # libcufft.so - libcurand.lib # libcurand.so - libcusolver.lib # libcusolver.so - libcusparse.lib # libcusparse.so + (lib.getLib cuda_cudart) # libcudart.so + (lib.getLib cuda_cupti) # libcupti.so + (lib.getLib libcublas) # libcublas.so + (lib.getLib libcufft) # libcufft.so + (lib.getLib libcurand) # libcurand.so + (lib.getLib libcusolver) # libcusolver.so + (lib.getLib libcusparse) # libcusparse.so ]; }; # These are only necessary at build time. @@ -101,20 +108,23 @@ let cuda_libs_joined # Binaries - cudaPackages.cuda_nvcc.bin # nvcc + (lib.getBin cuda_nvcc) # nvcc + + # Archives + (lib.getOutput "static" cuda_cudart) # libcudart_static.a # Headers - cuda_cccl.dev # block_load.cuh - cuda_cudart.dev # cuda.h - cuda_cupti.dev # cupti.h - cuda_nvcc.dev # See https://github.com/google/jax/issues/19811 - cuda_nvml_dev # nvml.h - cuda_nvtx.dev # nvToolsExt.h - libcublas.dev # cublas_api.h - libcufft.dev # cufft.h - libcurand.dev # curand.h - libcusolver.dev # cusolver_common.h - libcusparse.dev # cusparse.h + (lib.getDev cuda_cccl) # block_load.cuh + (lib.getDev cuda_cudart) # cuda.h + (lib.getDev cuda_cupti) # cupti.h + (lib.getDev cuda_nvcc) # See https://github.com/google/jax/issues/19811 + (lib.getDev cuda_nvml_dev) # nvml.h + (lib.getDev cuda_nvtx) # nvToolsExt.h + (lib.getDev libcublas) # cublas_api.h + (lib.getDev libcufft) # cufft.h + (lib.getDev libcurand) # curand.h + (lib.getDev libcusolver) # cusolver_common.h + (lib.getDev libcusparse) # cusparse.h ]; }; @@ -308,10 +318,10 @@ let + lib.optionalString cudaSupport '' build --config=cuda build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}" - build --action_env CUDNN_INSTALL_PATH="${cudnn}" - build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnn},${nccl}" + build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}" + build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}" build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}" - build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}" + build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}" build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}" '' + @@ -374,13 +384,20 @@ let sha256 = ( if cudaSupport then - { x86_64-linux = "sha256-vUoAPkYKEnHkV4fw6BI0mCeuP2e8BMCJnVuZMm9LwSA="; } + { x86_64-linux = "sha256-Uf0VMRE0jgaWEYiuphWkWloZ5jMeqaWBl3lSvk2y1HI="; } else { - x86_64-linux = "sha256-R1TIIyyyLlDqAlUkuhJhtyTxZMra2q5S/jX0OCInsEQ="; - aarch64-linux = "sha256-P5JEmJljN1DeRA0dNkzyosKzRnJH+5SD2aWdV5JsoiY="; + x86_64-linux = "sha256-NzJJg6NlrPGMiR8Fn8u4+fu0m+AulfmN5Xqk63Um6sw="; + aarch64-linux = "sha256-Ro3qzrUxSR+3TH6ROoJTq+dLSufrDN/9oEo2MRkx7wM="; } ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); + + # Non-reproducible fetch https://github.com/NixOS/nixpkgs/issues/321920#issuecomment-2184940546 + preInstall = '' + cat << \EOF > "$bazelOut/external/go_sdk/versions.json" + [] + EOF + ''; }; buildAttrs = { @@ -418,7 +435,7 @@ let throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}"; in buildPythonPackage { - inherit meta pname version; + inherit pname version; format = "wheel"; src = @@ -431,13 +448,13 @@ buildPythonPackage { # for more info. postInstall = lib.optionalString cudaSupport '' mkdir -p $out/bin - ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/bin/ptxas find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do patchelf --add-rpath "${ lib.makeLibraryPath [ cuda_libs_joined - cudnn + (lib.getLib cudaPackages.cudnn) nccl ] }" "$lib" @@ -471,4 +488,11 @@ buildPythonPackage { # Without it there are complaints about libcudart.so.11.0 not being found # because RPATH path entries added above are stripped. dontPatchELF = cudaSupport; + + passthru = { + # Note "bazel.*.tar.gz" can be accessed as `jaxlib.bazel-build.deps` + inherit bazel-build; + }; + + inherit meta; } |