about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib
diff options
context:
space:
mode:
authorAlexander Tsvyashchenko <ndl@endl.ch>2021-12-28 01:19:10 +0100
committerGitHub <noreply@github.com>2021-12-27 16:19:10 -0800
commitbe5272250926e352427b3c62c6066a95c6592375 (patch)
tree826d9be930dc2c701209d84eb6abbda59cff853c /pkgs/development/python-modules/jaxlib
parent8efd318b108e44673cfcb0643ddd1fd224e25dc1 (diff)
python3Packages.jaxlib: refactor to support Nix-based builds (#151909)
* python3Packages.jaxlib: rename to `jaxlib-bin`

Refactoring `jaxlib` to have a similar structure to `tensorflow` with the 'bin' and 'build' options.

* python3Packages.jaxlib: init the 'build' variant at 0.1.75

Similar to `tensorflow-build`, now there's an option to build `jaxlib` using Nix-provided environment and dependencies.

* python3Packages.jax: 0.2.24 -> 0.2.26

* Addressed review comments.

* Fixed `cudaSupport` missing property on some arches.

* Unified the versions of CUDA-related packages with TF.

Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Diffstat (limited to 'pkgs/development/python-modules/jaxlib')
-rw-r--r--pkgs/development/python-modules/jaxlib/bin.nix90
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix345
2 files changed, 360 insertions, 75 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix
new file mode 100644
index 0000000000000..f597eeacfced4
--- /dev/null
+++ b/pkgs/development/python-modules/jaxlib/bin.nix
@@ -0,0 +1,90 @@
+# For the moment we only support the CPU and GPU backends of jaxlib. The TPU
+# backend will require some additional work. Those wheels are located here:
+# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
+
+# For future reference, the easiest way to test the GPU backend is to run
+#   NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
+#   export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
+#   python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
+#   python -c "from jax import random; random.PRNGKey(0)"
+#   python -c "from jax import random; x = random.normal(random.PRNGKey(0), (100, 100)); x @ x"
+# There's no convenient way to test the GPU backend in the derivation since the
+# nix build environment blocks access to the GPU. See also:
+#   * https://github.com/google/jax/issues/971#issuecomment-508216439
+#   * https://github.com/google/jax/issues/5723#issuecomment-913038780
+
+{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config
+, fetchurl, isPy39, lib, stdenv
+# propagatedBuildInputs
+, absl-py, flatbuffers, scipy, cudatoolkit_11
+# Options:
+, cudaSupport ? config.cudaSupport or false
+}:
+
+assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1";
+
+let
+  device = if cudaSupport then "gpu" else "cpu";
+in
+buildPythonPackage rec {
+  pname = "jaxlib";
+  version = "0.1.71";
+  format = "wheel";
+
+  # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
+  # all of them is a pain, so we focus on 3.9, the current nixpkgs python3
+  # version.
+  disabled = !isPy39;
+
+  src = {
+    cpu = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
+      sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6";
+    };
+    gpu = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl";
+      sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89";
+    };
+  }.${device};
+
+  # Prebuilt wheels are dynamically linked against things that nix can't find.
+  # Run `autoPatchelfHook` to automagically fix them.
+  nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath;
+  # Dynamic link dependencies
+  buildInputs = [ stdenv.cc.cc ];
+
+  # jaxlib contains shared libraries that open other shared libraries via dlopen
+  # and these implicit dependencies are not recognized by ldd or
+  # autoPatchelfHook. That means we need to sneak them into rpath. This step
+  # must be done after autoPatchelfHook and the automatic stripping of
+  # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
+  # patchPhase. Dependencies:
+  #   * libcudart.so.11.0 -> cudatoolkit_11.lib
+  #   * libcublas.so.11   -> cudatoolkit_11
+  #   * libcuda.so.1      -> opengl driver in /run/opengl-driver/lib
+  preInstallCheck = lib.optional cudaSupport ''
+    shopt -s globstar
+
+    addOpenGLRunpath $out/**/*.so
+
+    for file in $out/**/*.so; do
+      rpath=$(patchelf --print-rpath $file)
+      # For some reason `makeLibraryPath` on `cudatoolkit_11` maps to
+      # <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib.
+      patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file
+    done
+  '';
+
+  # pip dependencies and optionally cudatoolkit.
+  propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11;
+
+  pythonImportsCheck = [ "jaxlib" ];
+
+  meta = with lib; {
+    description = "XLA library for JAX";
+    homepage    = "https://github.com/google/jax";
+    license     = licenses.asl20;
+    maintainers = with maintainers; [ samuela ];
+    platforms = [ "x86_64-linux" ];
+  };
+}
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index f597eeacfced4..bfb7f494ce1a3 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -1,90 +1,285 @@
-# For the moment we only support the CPU and GPU backends of jaxlib. The TPU
-# backend will require some additional work. Those wheels are located here:
-# https://storage.googleapis.com/jax-releases/libtpu_releases.html.
-
-# For future reference, the easiest way to test the GPU backend is to run
-#   NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
-#   export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
-#   python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
-#   python -c "from jax import random; random.PRNGKey(0)"
-#   python -c "from jax import random; x = random.normal(random.PRNGKey(0), (100, 100)); x @ x"
-# There's no convenient way to test the GPU backend in the derivation since the
-# nix build environment blocks access to the GPU. See also:
-#   * https://github.com/google/jax/issues/971#issuecomment-508216439
-#   * https://github.com/google/jax/issues/5723#issuecomment-913038780
-
-{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config
-, fetchurl, isPy39, lib, stdenv
-# propagatedBuildInputs
-, absl-py, flatbuffers, scipy, cudatoolkit_11
-# Options:
-, cudaSupport ? config.cudaSupport or false
-}:
+{ lib
+, pkgs
+, stdenv
+
+  # Build-time dependencies:
+, addOpenGLRunpath
+, bazel_4
+, binutils
+, buildBazelPackage
+, buildPythonPackage
+, cython
+, fetchFromGitHub
+, git
+, jsoncpp
+, pybind11
+, setuptools
+, symlinkJoin
+, wheel
+, which
+
+  # Build-time and runtime CUDA dependencies:
+, cudatoolkit ? null
+, cudnn ? null
+, nccl ? null
 
-assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1";
+  # Python dependencies:
+, absl-py
+, flatbuffers
+, numpy
+, scipy
+, six
+
+  # Runtime dependencies:
+, double-conversion
+, giflib
+, grpc
+, libjpeg_turbo
+, python
+, snappy
+, zlib
+
+  # CUDA flags:
+, cudaCapabilities ? [ "sm_35" "sm_50" "sm_60" "sm_70" "sm_75" "compute_80" ]
+, cudaSupport ? false
+
+  # MKL:
+, mklSupport ? true
+}:
 
 let
-  device = if cudaSupport then "gpu" else "cpu";
-in
-buildPythonPackage rec {
+
   pname = "jaxlib";
-  version = "0.1.71";
-  format = "wheel";
+  version = "0.1.75";
+
+  meta = with lib; {
+    description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
+    homepage = "https://github.com/google/jax";
+    license = licenses.asl20;
+    maintainers = with maintainers; [ ndl ];
+  };
+
+  cudatoolkit_joined = symlinkJoin {
+    name = "${cudatoolkit.name}-merged";
+    paths = [
+      cudatoolkit.lib
+      cudatoolkit.out
+    ] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
+      # for some reason some of the required libs are in the targets/x86_64-linux
+      # directory; not sure why but this works around it
+      "${cudatoolkit}/targets/${stdenv.system}"
+    ];
+  };
+
+  cudatoolkit_cc_joined = symlinkJoin {
+    name = "${cudatoolkit.cc.name}-merged";
+    paths = [
+      cudatoolkit.cc
+      binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
+    ];
+  };
 
-  # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
-  # all of them is a pain, so we focus on 3.9, the current nixpkgs python3
-  # version.
-  disabled = !isPy39;
+  bazel-build = buildBazelPackage {
+    name = "bazel-build-${pname}-${version}";
 
-  src = {
-    cpu = fetchurl {
-      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
-      sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6";
+    bazel = bazel_4;
+
+    src = fetchFromGitHub {
+      owner = "google";
+      repo = "jax";
+      rev = "${pname}-v${version}";
+      sha256 = "01ks4djbpjsxjy2zwdwv3h00sgwi4ps3jz75swddrw2f56zjdmw4";
     };
-    gpu = fetchurl {
-      url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl";
-      sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89";
+
+    nativeBuildInputs = [
+      cython
+      pkgs.flatbuffers
+      git
+      setuptools
+      wheel
+      which
+    ];
+
+    buildInputs = [
+      double-conversion
+      giflib
+      grpc
+      jsoncpp
+      libjpeg_turbo
+      numpy
+      pkgs.flatbuffers
+      pkgs.protobuf
+      pybind11
+      scipy
+      six
+      snappy
+      zlib
+    ] ++ lib.optionals cudaSupport [
+      cudatoolkit
+      cudnn
+    ];
+
+    postPatch = ''
+      rm -f .bazelversion
+    '';
+
+    bazelTarget = "//build:build_wheel";
+
+    removeRulesCC = false;
+
+    GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
+    GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
+
+    preConfigure = ''
+      # dummy ldconfig
+      mkdir dummy-ldconfig
+      echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
+      chmod +x dummy-ldconfig/ldconfig
+      export PATH="$PWD/dummy-ldconfig:$PATH"
+      cat <<CFG > ./.jax_configure.bazelrc
+      build --strategy=Genrule=standalone
+      build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
+      build --action_env=PYENV_ROOT
+      build --python_path="${python}/bin/python"
+      build --distinct_host_configuration=false
+    '' + lib.optionalString cudaSupport ''
+      build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
+      build --action_env CUDNN_INSTALL_PATH="${cudnn}"
+      build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
+      build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
+      build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
+      build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${lib.concatStringsSep "," cudaCapabilities}"
+    '' + ''
+      CFG
+    '';
+
+    # Copy-paste from TF derivation.
+    # Most of these are not really used in jaxlib compilation but it's simpler to keep it
+    # 'as is' so that it's more compatible with TF derivation.
+    TF_SYSTEM_LIBS = lib.concatStringsSep "," [
+      "absl_py"
+      "astor_archive"
+      "astunparse_archive"
+      "boringssl"
+      # Not packaged in nixpkgs
+      # "com_github_googleapis_googleapis"
+      # "com_github_googlecloudplatform_google_cloud_cpp"
+      "com_github_grpc_grpc"
+      "com_google_protobuf"
+      # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
+      # "com_googlesource_code_re2"
+      "curl"
+      "cython"
+      "dill_archive"
+      "double_conversion"
+      "enum34_archive"
+      "flatbuffers"
+      "functools32_archive"
+      "gast_archive"
+      "gif"
+      "hwloc"
+      "icu"
+      "jsoncpp_git"
+      "libjpeg_turbo"
+      "lmdb"
+      "nasm"
+      # "nsync" # not packaged in nixpkgs
+      "opt_einsum_archive"
+      "org_sqlite"
+      "pasta"
+      "pcre"
+      "png"
+      "pybind11"
+      "six_archive"
+      "snappy"
+      "tblib_archive"
+      "termcolor_archive"
+      "typing_extensions_archive"
+      "wrapt"
+      "zlib"
+    ];
+
+    # Make sure Bazel knows about our configuration flags during fetching so that the
+    # relevant dependencies can be downloaded.
+    bazelFetchFlags = bazel-build.bazelBuildFlags;
+
+    bazelBuildFlags = [
+      "-c opt"
+    ] ++ lib.optional (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
+      "--config=avx_posix"
+    ] ++ lib.optional cudaSupport [
+      "--config=cuda"
+    ] ++ lib.optional mklSupport [
+      "--config=mkl_open_source_only"
+    ];
+
+    fetchAttrs = {
+      sha256 =
+        if cudaSupport then
+          "1lyipbflqd1y5cdj4hdml5h1inbr0wwfgp6xw5p5623qv3im16lh"
+        else
+          "09kapzpfwnlr6ghmgwac232bqf2a57mm1brz4cvfx8mlg8bbaw63";
+    };
+
+    buildAttrs = {
+      outputs = [ "out" ];
+
+      # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
+      # 1) Fix pybind11 include paths.
+      # 2) Force static protobuf linkage to prevent crashes on loading multiple extensions
+      #    in the same python program due to duplicate protobuf DBs.
+      # 3) Patch python path in the compiler driver.
+      preBuild = ''
+        for src in ./jaxlib/*.{cc,h}; do
+          sed -i 's@include/pybind11@pybind11@g' $src
+        done
+        sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
+        sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
+      '' + lib.optionalString cudaSupport ''
+        patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
+      '';
+
+      installPhase = ''
+        ./bazel-bin/build/build_wheel --output_path=$out --cpu=${stdenv.targetPlatform.linuxArch}
+      '';
     };
-  }.${device};
-
-  # Prebuilt wheels are dynamically linked against things that nix can't find.
-  # Run `autoPatchelfHook` to automagically fix them.
-  nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath;
-  # Dynamic link dependencies
-  buildInputs = [ stdenv.cc.cc ];
-
-  # jaxlib contains shared libraries that open other shared libraries via dlopen
-  # and these implicit dependencies are not recognized by ldd or
-  # autoPatchelfHook. That means we need to sneak them into rpath. This step
-  # must be done after autoPatchelfHook and the automatic stripping of
-  # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
-  # patchPhase. Dependencies:
-  #   * libcudart.so.11.0 -> cudatoolkit_11.lib
-  #   * libcublas.so.11   -> cudatoolkit_11
-  #   * libcuda.so.1      -> opengl driver in /run/opengl-driver/lib
-  preInstallCheck = lib.optional cudaSupport ''
-    shopt -s globstar
-
-    addOpenGLRunpath $out/**/*.so
-
-    for file in $out/**/*.so; do
-      rpath=$(patchelf --print-rpath $file)
-      # For some reason `makeLibraryPath` on `cudatoolkit_11` maps to
-      # <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib.
-      patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file
+
+    inherit meta;
+  };
+
+in
+buildPythonPackage {
+  inherit meta pname version;
+  format = "wheel";
+
+  src = "${bazel-build}/jaxlib-${version}-cp${builtins.replaceStrings ["."] [""] python.pythonVersion}-none-manylinux2010_${stdenv.targetPlatform.linuxArch}.whl";
+
+  postInstall = lib.optionalString cudaSupport ''
+    find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
+      addOpenGLRunpath "$lib"
+      patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
     done
   '';
 
-  # pip dependencies and optionally cudatoolkit.
-  propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11;
+  nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
+
+  propagatedBuildInputs = [
+    absl-py
+    double-conversion
+    flatbuffers
+    giflib
+    grpc
+    jsoncpp
+    libjpeg_turbo
+    numpy
+    scipy
+    six
+    snappy
+  ];
 
   pythonImportsCheck = [ "jaxlib" ];
 
-  meta = with lib; {
-    description = "XLA library for JAX";
-    homepage    = "https://github.com/google/jax";
-    license     = licenses.asl20;
-    maintainers = with maintainers; [ samuela ];
-    platforms = [ "x86_64-linux" ];
-  };
+  # Without it there are complaints about libcudart.so.11.0 not being found
+  # because RPATH path entries added above are stripped.
+  dontPatchELF = cudaSupport;
 }