diff options
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/default.nix | 341 |
1 files changed, 182 insertions, 159 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index cfca1f170ea4c..8366f11d8a268 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -1,66 +1,71 @@ -{ lib -, pkgs -, stdenv +{ + lib, + pkgs, + stdenv, # Build-time dependencies: -, addOpenGLRunpath -, autoAddDriverRunpath -, bazel_6 -, binutils -, buildBazelPackage -, buildPythonPackage -, cctools -, curl -, cython -, fetchFromGitHub -, fetchpatch -, git -, IOKit -, jsoncpp -, nsync -, openssl -, pybind11 -, setuptools -, symlinkJoin -, wheel -, build -, which + addOpenGLRunpath, + autoAddDriverRunpath, + bazel_6, + binutils, + buildBazelPackage, + buildPythonPackage, + cctools, + curl, + cython, + fetchFromGitHub, + git, + IOKit, + jsoncpp, + nsync, + openssl, + pybind11, + setuptools, + symlinkJoin, + wheel, + build, + which, # Python dependencies: -, absl-py -, flatbuffers -, ml-dtypes -, numpy -, scipy -, six + absl-py, + flatbuffers, + ml-dtypes, + numpy, + scipy, + six, # Runtime dependencies: -, double-conversion -, giflib -, libjpeg_turbo -, python -, snappy -, zlib - -, config + double-conversion, + giflib, + libjpeg_turbo, + python, + snappy, + zlib, + + config, # CUDA flags: -, cudaSupport ? config.cudaSupport -, cudaPackagesGoogle + cudaSupport ? config.cudaSupport, + cudaPackages, # MKL: -, mklSupport ? true + mklSupport ? true, }@inputs: let - inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl; + inherit (cudaPackages) + cudaFlags + cudaVersion + cudnn + nccl + ; pname = "jaxlib"; - version = "0.4.24"; + version = "0.4.28"; # It's necessary to consistently use backendStdenv when building with CUDA # support, otherwise we get libstdc++ errors downstream stdenv = throw "Use effectiveStdenv instead"; - effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv; + effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; @@ -78,7 +83,7 @@ let # These are necessary at build time and run time. cuda_libs_joined = symlinkJoin { name = "cuda-joined"; - paths = with cudaPackagesGoogle; [ + paths = with cudaPackages; [ cuda_cudart.lib # libcudart.so cuda_cudart.static # libcudart_static.a cuda_cupti.lib # libcupti.so @@ -92,11 +97,11 @@ let # These are only necessary at build time. cuda_build_deps_joined = symlinkJoin { name = "cuda-build-deps-joined"; - paths = with cudaPackagesGoogle; [ + paths = with cudaPackages; [ cuda_libs_joined # Binaries - cudaPackagesGoogle.cuda_nvcc.bin # nvcc + cudaPackages.cuda_nvcc.bin # nvcc # Headers cuda_cccl.dev # block_load.cuh @@ -170,8 +175,10 @@ let arch = # KeyError: ('Linux', 'arm64') - if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then "aarch64" - else effectiveStdenv.hostPlatform.linuxArch; + if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then + "aarch64" + else + effectiveStdenv.hostPlatform.linuxArch; xla = effectiveStdenv.mkDerivation { pname = "xla-src"; @@ -181,19 +188,10 @@ let owner = "openxla"; repo = "xla"; # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. - rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5"; - hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90="; + rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4"; + hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E="; }; - patches = [ - # Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to - # ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259. - (fetchpatch { - url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch"; - hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM="; - }) - ]; - dontBuild = true; # This is necessary for patchShebangs to know the right path to use. @@ -220,7 +218,7 @@ let repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! rev = "refs/tags/${pname}-v${version}"; - hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs="; + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; }; nativeBuildInputs = [ @@ -231,30 +229,27 @@ let wheel build which - ] ++ lib.optionals effectiveStdenv.isDarwin [ - cctools - ]; - - buildInputs = [ - curl - double-conversion - giflib - jsoncpp - libjpeg_turbo - numpy - openssl - pkgs.flatbuffers - pkgs.protobuf - pybind11 - scipy - six - snappy - zlib - ] ++ lib.optionals effectiveStdenv.isDarwin [ - IOKit - ] ++ lib.optionals (!effectiveStdenv.isDarwin) [ - nsync - ]; + ] ++ lib.optionals effectiveStdenv.isDarwin [ cctools ]; + + buildInputs = + [ + curl + double-conversion + giflib + jsoncpp + libjpeg_turbo + numpy + openssl + pkgs.flatbuffers + pkgs.protobuf + pybind11 + scipy + six + snappy + zlib + ] + ++ lib.optionals effectiveStdenv.isDarwin [ IOKit ] + ++ lib.optionals (!effectiveStdenv.isDarwin) [ nsync ]; # We don't want to be quite so picky regarding bazel version postPatch = '' @@ -285,30 +280,32 @@ let echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig chmod +x dummy-ldconfig/ldconfig export PATH="$PWD/dummy-ldconfig:$PATH" - '' + - - # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345 - # for more info. We assume - # * `cpu = None` - # * `enable_nccl = True` - # * `target_cpu_features = "release"` - # * `rocm_amdgpu_targets = None` - # * `enable_rocm = False` - # * `build_gpu_plugin = False` - # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?) - # - # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266 - # instead of duplicating the logic here. Perhaps we can leverage the - # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)? '' - cat <<CFG > ./.jax_configure.bazelrc - build --strategy=Genrule=standalone - build --repo_env PYTHON_BIN_PATH="${python}/bin/python" - build --action_env=PYENV_ROOT - build --python_path="${python}/bin/python" - build --distinct_host_configuration=false - build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" - '' + lib.optionalString cudaSupport '' + + + + # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345 + # for more info. We assume + # * `cpu = None` + # * `enable_nccl = True` + # * `target_cpu_features = "release"` + # * `rocm_amdgpu_targets = None` + # * `enable_rocm = False` + # * `build_gpu_plugin = False` + # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?) + # + # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266 + # instead of duplicating the logic here. Perhaps we can leverage the + # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)? + '' + cat <<CFG > ./.jax_configure.bazelrc + build --strategy=Genrule=standalone + build --repo_env PYTHON_BIN_PATH="${python}/bin/python" + build --action_env=PYENV_ROOT + build --python_path="${python}/bin/python" + build --distinct_host_configuration=false + build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" + '' + + lib.optionalString cudaSupport '' build --config=cuda build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" @@ -316,67 +313,85 @@ let build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}" build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}" build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}" - '' + - # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just - # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so - # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322 - # for upstream's version. - lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix) '' - build --config=avx_posix - '' + lib.optionalString mklSupport '' + '' + + + # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just + # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so + # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322 + # for upstream's version. + lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix) + '' + build --config=avx_posix + '' + + lib.optionalString mklSupport '' build --config=mkl_open_source_only - '' + '' + + '' CFG ''; # Make sure Bazel knows about our configuration flags during fetching so that the # relevant dependencies can be downloaded. - bazelFlags = [ - "-c opt" - # See https://bazel.build/external/advanced#overriding-repositories for - # information on --override_repository flag. - "--override_repository=xla=${xla}" - ] ++ lib.optionals effectiveStdenv.cc.isClang [ - # bazel depends on the compiler frontend automatically selecting these flags based on file - # extension but our clang doesn't. - # https://github.com/NixOS/nixpkgs/issues/150655 - "--cxxopt=-x" - "--cxxopt=c++" - "--host_cxxopt=-x" - "--host_cxxopt=c++" - ]; + bazelFlags = + [ + "-c opt" + # See https://bazel.build/external/advanced#overriding-repositories for + # information on --override_repository flag. + "--override_repository=xla=${xla}" + ] + ++ lib.optionals effectiveStdenv.cc.isClang [ + # bazel depends on the compiler frontend automatically selecting these flags based on file + # extension but our clang doesn't. + # https://github.com/NixOS/nixpkgs/issues/150655 + "--cxxopt=-x" + "--cxxopt=c++" + "--host_cxxopt=-x" + "--host_cxxopt=c++" + ]; # We intentionally overfetch so we can share the fetch derivation across all the different configurations fetchAttrs = { TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; # we have to force @mkl_dnn_v1 since it's not needed on darwin - bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; - bazelFlags = bazelFlags ++ [ - "--config=avx_posix" - "--config=mkl_open_source_only" - ] ++ lib.optionals cudaSupport [ - # ideally we'd add this unconditionally too, but it doesn't work on darwin - # we make this conditional on `cudaSupport` instead of the system, so that the hash for both - # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't - # have access to darwin machines - "--config=cuda" + bazelTargets = [ + bazelRunTarget + "@mkl_dnn_v1//:mkl_dnn" ]; - - sha256 = (if cudaSupport then { - x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM="; - } else { - x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk="; - aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY=="; - }).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); + bazelFlags = + bazelFlags + ++ [ + "--config=avx_posix" + "--config=mkl_open_source_only" + ] + ++ lib.optionals cudaSupport [ + # ideally we'd add this unconditionally too, but it doesn't work on darwin + # we make this conditional on `cudaSupport` instead of the system, so that the hash for both + # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't + # have access to darwin machines + "--config=cuda" + ]; + + sha256 = + ( + if cudaSupport then + { x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k="; } + else + { + x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ="; + aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA="; + } + ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); }; buildAttrs = { outputs = [ "out" ]; - TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!effectiveStdenv.isDarwin) [ - "nsync" # fails to build on darwin - ]); + TF_SYSTEM_LIBS = lib.concatStringsSep "," ( + tf_system_libs + ++ lib.optionals (!effectiveStdenv.isDarwin) [ + "nsync" # fails to build on darwin + ] + ); # Note: we cannot do most of this patching at `patch` phase as the deps # are not available yet. Framework search paths aren't added by bintools @@ -399,31 +414,39 @@ let "macosx_10_9_${arch}" else if effectiveStdenv.system == "aarch64-darwin" then "macosx_11_0_${arch}" - else throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}"; - + else + throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}"; in buildPythonPackage { inherit meta pname version; format = "wheel"; src = - let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}"; - in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; + let + cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}"; + in + "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 # for more info. postInstall = lib.optionalString cudaSupport '' mkdir -p $out/bin - ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas + ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do - patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib" + patchelf --add-rpath "${ + lib.makeLibraryPath [ + cuda_libs_joined + cudnn + nccl + ] + }" "$lib" done ''; nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ]; - propagatedBuildInputs = [ + dependencies = [ absl-py curl double-conversion |