about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix341
1 files changed, 182 insertions, 159 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index cfca1f170ea4c..8366f11d8a268 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -1,66 +1,71 @@
-{ lib
-, pkgs
-, stdenv
+{
+  lib,
+  pkgs,
+  stdenv,
 
   # Build-time dependencies:
-, addOpenGLRunpath
-, autoAddDriverRunpath
-, bazel_6
-, binutils
-, buildBazelPackage
-, buildPythonPackage
-, cctools
-, curl
-, cython
-, fetchFromGitHub
-, fetchpatch
-, git
-, IOKit
-, jsoncpp
-, nsync
-, openssl
-, pybind11
-, setuptools
-, symlinkJoin
-, wheel
-, build
-, which
+  addOpenGLRunpath,
+  autoAddDriverRunpath,
+  bazel_6,
+  binutils,
+  buildBazelPackage,
+  buildPythonPackage,
+  cctools,
+  curl,
+  cython,
+  fetchFromGitHub,
+  git,
+  IOKit,
+  jsoncpp,
+  nsync,
+  openssl,
+  pybind11,
+  setuptools,
+  symlinkJoin,
+  wheel,
+  build,
+  which,
 
   # Python dependencies:
-, absl-py
-, flatbuffers
-, ml-dtypes
-, numpy
-, scipy
-, six
+  absl-py,
+  flatbuffers,
+  ml-dtypes,
+  numpy,
+  scipy,
+  six,
 
   # Runtime dependencies:
-, double-conversion
-, giflib
-, libjpeg_turbo
-, python
-, snappy
-, zlib
-
-, config
+  double-conversion,
+  giflib,
+  libjpeg_turbo,
+  python,
+  snappy,
+  zlib,
+
+  config,
   # CUDA flags:
-, cudaSupport ? config.cudaSupport
-, cudaPackagesGoogle
+  cudaSupport ? config.cudaSupport,
+  cudaPackages,
 
   # MKL:
-, mklSupport ? true
+  mklSupport ? true,
 }@inputs:
 
 let
-  inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl;
+  inherit (cudaPackages)
+    cudaFlags
+    cudaVersion
+    cudnn
+    nccl
+    ;
 
   pname = "jaxlib";
-  version = "0.4.24";
+  version = "0.4.28";
 
   # It's necessary to consistently use backendStdenv when building with CUDA
   # support, otherwise we get libstdc++ errors downstream
   stdenv = throw "Use effectiveStdenv instead";
-  effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv;
+  effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
 
   meta = with lib; {
     description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -78,7 +83,7 @@ let
   # These are necessary at build time and run time.
   cuda_libs_joined = symlinkJoin {
     name = "cuda-joined";
-    paths = with cudaPackagesGoogle; [
+    paths = with cudaPackages; [
       cuda_cudart.lib # libcudart.so
       cuda_cudart.static # libcudart_static.a
       cuda_cupti.lib # libcupti.so
@@ -92,11 +97,11 @@ let
   # These are only necessary at build time.
   cuda_build_deps_joined = symlinkJoin {
     name = "cuda-build-deps-joined";
-    paths = with cudaPackagesGoogle; [
+    paths = with cudaPackages; [
       cuda_libs_joined
 
       # Binaries
-      cudaPackagesGoogle.cuda_nvcc.bin # nvcc
+      cudaPackages.cuda_nvcc.bin # nvcc
 
       # Headers
       cuda_cccl.dev # block_load.cuh
@@ -170,8 +175,10 @@ let
 
   arch =
     # KeyError: ('Linux', 'arm64')
-    if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then "aarch64"
-    else effectiveStdenv.hostPlatform.linuxArch;
+    if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then
+      "aarch64"
+    else
+      effectiveStdenv.hostPlatform.linuxArch;
 
   xla = effectiveStdenv.mkDerivation {
     pname = "xla-src";
@@ -181,19 +188,10 @@ let
       owner = "openxla";
       repo = "xla";
       # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
-      rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5";
-      hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90=";
+      rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
+      hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
     };
 
-    patches = [
-      # Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to
-      # ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259.
-      (fetchpatch {
-        url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch";
-        hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM=";
-      })
-    ];
-
     dontBuild = true;
 
     # This is necessary for patchShebangs to know the right path to use.
@@ -220,7 +218,7 @@ let
       repo = "jax";
       # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
       rev = "refs/tags/${pname}-v${version}";
-      hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs=";
+      hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
     };
 
     nativeBuildInputs = [
@@ -231,30 +229,27 @@ let
       wheel
       build
       which
-    ] ++ lib.optionals effectiveStdenv.isDarwin [
-      cctools
-    ];
-
-    buildInputs = [
-      curl
-      double-conversion
-      giflib
-      jsoncpp
-      libjpeg_turbo
-      numpy
-      openssl
-      pkgs.flatbuffers
-      pkgs.protobuf
-      pybind11
-      scipy
-      six
-      snappy
-      zlib
-    ] ++ lib.optionals effectiveStdenv.isDarwin [
-      IOKit
-    ] ++ lib.optionals (!effectiveStdenv.isDarwin) [
-      nsync
-    ];
+    ] ++ lib.optionals effectiveStdenv.isDarwin [ cctools ];
+
+    buildInputs =
+      [
+        curl
+        double-conversion
+        giflib
+        jsoncpp
+        libjpeg_turbo
+        numpy
+        openssl
+        pkgs.flatbuffers
+        pkgs.protobuf
+        pybind11
+        scipy
+        six
+        snappy
+        zlib
+      ]
+      ++ lib.optionals effectiveStdenv.isDarwin [ IOKit ]
+      ++ lib.optionals (!effectiveStdenv.isDarwin) [ nsync ];
 
     # We don't want to be quite so picky regarding bazel version
     postPatch = ''
@@ -285,30 +280,32 @@ let
         echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig
         chmod +x dummy-ldconfig/ldconfig
         export PATH="$PWD/dummy-ldconfig:$PATH"
-      '' +
-
-      # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
-      # for more info. We assume
-      # * `cpu = None`
-      # * `enable_nccl = True`
-      # * `target_cpu_features = "release"`
-      # * `rocm_amdgpu_targets = None`
-      # * `enable_rocm = False`
-      # * `build_gpu_plugin = False`
-      # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
-      #
-      # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
-      # instead of duplicating the logic here. Perhaps we can leverage the
-      # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
       ''
-        cat <<CFG > ./.jax_configure.bazelrc
-        build --strategy=Genrule=standalone
-        build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
-        build --action_env=PYENV_ROOT
-        build --python_path="${python}/bin/python"
-        build --distinct_host_configuration=false
-        build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
-      '' + lib.optionalString cudaSupport ''
+      +
+
+        # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
+        # for more info. We assume
+        # * `cpu = None`
+        # * `enable_nccl = True`
+        # * `target_cpu_features = "release"`
+        # * `rocm_amdgpu_targets = None`
+        # * `enable_rocm = False`
+        # * `build_gpu_plugin = False`
+        # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
+        #
+        # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
+        # instead of duplicating the logic here. Perhaps we can leverage the
+        # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
+        ''
+          cat <<CFG > ./.jax_configure.bazelrc
+          build --strategy=Genrule=standalone
+          build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
+          build --action_env=PYENV_ROOT
+          build --python_path="${python}/bin/python"
+          build --distinct_host_configuration=false
+          build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
+        ''
+      + lib.optionalString cudaSupport ''
         build --config=cuda
         build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
         build --action_env CUDNN_INSTALL_PATH="${cudnn}"
@@ -316,67 +313,85 @@ let
         build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}"
         build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
         build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
-      '' +
-      # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
-      # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
-      # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
-      # for upstream's version.
-      lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix) ''
-        build --config=avx_posix
-      '' + lib.optionalString mklSupport ''
+      ''
+      +
+        # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
+        # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
+        # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
+        # for upstream's version.
+        lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix)
+          ''
+            build --config=avx_posix
+          ''
+      + lib.optionalString mklSupport ''
         build --config=mkl_open_source_only
-      '' +
       ''
+      + ''
         CFG
       '';
 
     # Make sure Bazel knows about our configuration flags during fetching so that the
     # relevant dependencies can be downloaded.
-    bazelFlags = [
-      "-c opt"
-      # See https://bazel.build/external/advanced#overriding-repositories for
-      # information on --override_repository flag.
-      "--override_repository=xla=${xla}"
-    ] ++ lib.optionals effectiveStdenv.cc.isClang [
-      # bazel depends on the compiler frontend automatically selecting these flags based on file
-      # extension but our clang doesn't.
-      # https://github.com/NixOS/nixpkgs/issues/150655
-      "--cxxopt=-x"
-      "--cxxopt=c++"
-      "--host_cxxopt=-x"
-      "--host_cxxopt=c++"
-    ];
+    bazelFlags =
+      [
+        "-c opt"
+        # See https://bazel.build/external/advanced#overriding-repositories for
+        # information on --override_repository flag.
+        "--override_repository=xla=${xla}"
+      ]
+      ++ lib.optionals effectiveStdenv.cc.isClang [
+        # bazel depends on the compiler frontend automatically selecting these flags based on file
+        # extension but our clang doesn't.
+        # https://github.com/NixOS/nixpkgs/issues/150655
+        "--cxxopt=-x"
+        "--cxxopt=c++"
+        "--host_cxxopt=-x"
+        "--host_cxxopt=c++"
+      ];
 
     # We intentionally overfetch so we can share the fetch derivation across all the different configurations
     fetchAttrs = {
       TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
       # we have to force @mkl_dnn_v1 since it's not needed on darwin
-      bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
-      bazelFlags = bazelFlags ++ [
-        "--config=avx_posix"
-        "--config=mkl_open_source_only"
-      ] ++ lib.optionals cudaSupport [
-        # ideally we'd add this unconditionally too, but it doesn't work on darwin
-        # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
-        # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
-        # have access to darwin machines
-        "--config=cuda"
+      bazelTargets = [
+        bazelRunTarget
+        "@mkl_dnn_v1//:mkl_dnn"
       ];
-
-      sha256 = (if cudaSupport then {
-        x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM=";
-      } else {
-        x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk=";
-        aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY==";
-      }).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
+      bazelFlags =
+        bazelFlags
+        ++ [
+          "--config=avx_posix"
+          "--config=mkl_open_source_only"
+        ]
+        ++ lib.optionals cudaSupport [
+          # ideally we'd add this unconditionally too, but it doesn't work on darwin
+          # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
+          # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
+          # have access to darwin machines
+          "--config=cuda"
+        ];
+
+      sha256 =
+        (
+          if cudaSupport then
+            { x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k="; }
+          else
+            {
+              x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
+              aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
+            }
+        ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
     };
 
     buildAttrs = {
       outputs = [ "out" ];
 
-      TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!effectiveStdenv.isDarwin) [
-        "nsync" # fails to build on darwin
-      ]);
+      TF_SYSTEM_LIBS = lib.concatStringsSep "," (
+        tf_system_libs
+        ++ lib.optionals (!effectiveStdenv.isDarwin) [
+          "nsync" # fails to build on darwin
+        ]
+      );
 
       # Note: we cannot do most of this patching at `patch` phase as the deps
       # are not available yet. Framework search paths aren't added by bintools
@@ -399,31 +414,39 @@ let
       "macosx_10_9_${arch}"
     else if effectiveStdenv.system == "aarch64-darwin" then
       "macosx_11_0_${arch}"
-    else throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
-
+    else
+      throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
 in
 buildPythonPackage {
   inherit meta pname version;
   format = "wheel";
 
   src =
-    let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
-    in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
+    let
+      cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}";
+    in
+    "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
 
   # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
   # for more info.
   postInstall = lib.optionalString cudaSupport ''
     mkdir -p $out/bin
-    ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
+    ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
 
     find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
-      patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib"
+      patchelf --add-rpath "${
+        lib.makeLibraryPath [
+          cuda_libs_joined
+          cudnn
+          nccl
+        ]
+      }" "$lib"
     done
   '';
 
   nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];
 
-  propagatedBuildInputs = [
+  dependencies = [
     absl-py
     curl
     double-conversion