diff options
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/default.nix | 40 |
1 files changed, 15 insertions, 25 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index cfca1f170ea4c..8854d7927ea6c 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -13,7 +13,6 @@ , curl , cython , fetchFromGitHub -, fetchpatch , git , IOKit , jsoncpp @@ -45,22 +44,22 @@ , config # CUDA flags: , cudaSupport ? config.cudaSupport -, cudaPackagesGoogle +, cudaPackages # MKL: , mklSupport ? true }@inputs: let - inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl; + inherit (cudaPackages) cudaFlags cudaVersion cudnn nccl; pname = "jaxlib"; - version = "0.4.24"; + version = "0.4.28"; # It's necessary to consistently use backendStdenv when building with CUDA # support, otherwise we get libstdc++ errors downstream stdenv = throw "Use effectiveStdenv instead"; - effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv; + effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; @@ -78,7 +77,7 @@ let # These are necessary at build time and run time. cuda_libs_joined = symlinkJoin { name = "cuda-joined"; - paths = with cudaPackagesGoogle; [ + paths = with cudaPackages; [ cuda_cudart.lib # libcudart.so cuda_cudart.static # libcudart_static.a cuda_cupti.lib # libcupti.so @@ -92,11 +91,11 @@ let # These are only necessary at build time. cuda_build_deps_joined = symlinkJoin { name = "cuda-build-deps-joined"; - paths = with cudaPackagesGoogle; [ + paths = with cudaPackages; [ cuda_libs_joined # Binaries - cudaPackagesGoogle.cuda_nvcc.bin # nvcc + cudaPackages.cuda_nvcc.bin # nvcc # Headers cuda_cccl.dev # block_load.cuh @@ -181,19 +180,10 @@ let owner = "openxla"; repo = "xla"; # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. - rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5"; - hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90="; + rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4"; + hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E="; }; - patches = [ - # Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to - # ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259. - (fetchpatch { - url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch"; - hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM="; - }) - ]; - dontBuild = true; # This is necessary for patchShebangs to know the right path to use. @@ -220,7 +210,7 @@ let repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! rev = "refs/tags/${pname}-v${version}"; - hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs="; + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; }; nativeBuildInputs = [ @@ -364,10 +354,10 @@ let ]; sha256 = (if cudaSupport then { - x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM="; + x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k="; } else { - x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk="; - aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY=="; + x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ="; + aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA="; }).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); }; @@ -414,7 +404,7 @@ buildPythonPackage { # for more info. postInstall = lib.optionalString cudaSupport '' mkdir -p $out/bin - ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas + ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib" @@ -423,7 +413,7 @@ buildPythonPackage { nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ]; - propagatedBuildInputs = [ + dependencies = [ absl-py curl double-conversion |