about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix17
1 files changed, 11 insertions, 6 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index bfb7f494ce1a3..664e109719adf 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -4,7 +4,7 @@
 
   # Build-time dependencies:
 , addOpenGLRunpath
-, bazel_4
+, bazel_5
 , binutils
 , buildBazelPackage
 , buildPythonPackage
@@ -50,7 +50,7 @@
 let
 
   pname = "jaxlib";
-  version = "0.1.75";
+  version = "0.3.0";
 
   meta = with lib; {
     description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -82,13 +82,13 @@ let
   bazel-build = buildBazelPackage {
     name = "bazel-build-${pname}-${version}";
 
-    bazel = bazel_4;
+    bazel = bazel_5;
 
     src = fetchFromGitHub {
       owner = "google";
       repo = "jax";
       rev = "${pname}-v${version}";
-      sha256 = "01ks4djbpjsxjy2zwdwv3h00sgwi4ps3jz75swddrw2f56zjdmw4";
+      sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
     };
 
     nativeBuildInputs = [
@@ -216,9 +216,9 @@ let
     fetchAttrs = {
       sha256 =
         if cudaSupport then
-          "1lyipbflqd1y5cdj4hdml5h1inbr0wwfgp6xw5p5623qv3im16lh"
+          "1k0rjxqjm703gd9navwzx5x3874b4dxamr62m1fxhm79d271zxis"
         else
-          "09kapzpfwnlr6ghmgwac232bqf2a57mm1brz4cvfx8mlg8bbaw63";
+          "0ivah1w41jcj13jm740qzwx5h0ia8vbj71pjgd0zrfk3c92kll41";
     };
 
     buildAttrs = {
@@ -229,12 +229,17 @@ let
       # 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
       #    in the same python program due to duplicate protobuf DBs.
       # 3) Patch python path in the compiler driver.
+      # 4) Patch tensorflow sources to work with later versions of protobuf. See
+      #    https://github.com/google/jax/issues/9534. Note that this should be
+      #    removed on the next release after 0.3.0.
       preBuild = ''
         for src in ./jaxlib/*.{cc,h}; do
           sed -i 's@include/pybind11@pybind11@g' $src
         done
         sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
         sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
+        substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \
+          --replace "status.message()" "std::string{status.message()}"
       '' + lib.optionalString cudaSupport ''
         patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
       '';