about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix40
1 files changed, 15 insertions, 25 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index cfca1f170ea4c..8854d7927ea6c 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -13,7 +13,6 @@
 , curl
 , cython
 , fetchFromGitHub
-, fetchpatch
 , git
 , IOKit
 , jsoncpp
@@ -45,22 +44,22 @@
 , config
   # CUDA flags:
 , cudaSupport ? config.cudaSupport
-, cudaPackagesGoogle
+, cudaPackages
 
   # MKL:
 , mklSupport ? true
 }@inputs:
 
 let
-  inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl;
+  inherit (cudaPackages) cudaFlags cudaVersion cudnn nccl;
 
   pname = "jaxlib";
-  version = "0.4.24";
+  version = "0.4.28";
 
   # It's necessary to consistently use backendStdenv when building with CUDA
   # support, otherwise we get libstdc++ errors downstream
   stdenv = throw "Use effectiveStdenv instead";
-  effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv;
+  effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
 
   meta = with lib; {
     description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -78,7 +77,7 @@ let
   # These are necessary at build time and run time.
   cuda_libs_joined = symlinkJoin {
     name = "cuda-joined";
-    paths = with cudaPackagesGoogle; [
+    paths = with cudaPackages; [
       cuda_cudart.lib # libcudart.so
       cuda_cudart.static # libcudart_static.a
       cuda_cupti.lib # libcupti.so
@@ -92,11 +91,11 @@ let
   # These are only necessary at build time.
   cuda_build_deps_joined = symlinkJoin {
     name = "cuda-build-deps-joined";
-    paths = with cudaPackagesGoogle; [
+    paths = with cudaPackages; [
       cuda_libs_joined
 
       # Binaries
-      cudaPackagesGoogle.cuda_nvcc.bin # nvcc
+      cudaPackages.cuda_nvcc.bin # nvcc
 
       # Headers
       cuda_cccl.dev # block_load.cuh
@@ -181,19 +180,10 @@ let
       owner = "openxla";
       repo = "xla";
       # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
-      rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5";
-      hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90=";
+      rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
+      hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
     };
 
-    patches = [
-      # Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to
-      # ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259.
-      (fetchpatch {
-        url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch";
-        hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM=";
-      })
-    ];
-
     dontBuild = true;
 
     # This is necessary for patchShebangs to know the right path to use.
@@ -220,7 +210,7 @@ let
       repo = "jax";
       # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
       rev = "refs/tags/${pname}-v${version}";
-      hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs=";
+      hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
     };
 
     nativeBuildInputs = [
@@ -364,10 +354,10 @@ let
       ];
 
       sha256 = (if cudaSupport then {
-        x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM=";
+        x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k=";
       } else {
-        x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk=";
-        aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY==";
+        x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
+        aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
       }).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
     };
 
@@ -414,7 +404,7 @@ buildPythonPackage {
   # for more info.
   postInstall = lib.optionalString cudaSupport ''
     mkdir -p $out/bin
-    ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
+    ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
 
     find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
       patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib"
@@ -423,7 +413,7 @@ buildPythonPackage {
 
   nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];
 
-  propagatedBuildInputs = [
+  dependencies = [
     absl-py
     curl
     double-conversion