diff options
author | Nick Cao <nickcao@nichi.co> | 2023-08-01 21:23:27 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-01 21:23:27 -0600 |
commit | 8423edb179dea2f55e013a985a806dcb6bca60ea (patch) | |
tree | 2e281969831bb6c0bfa2732e9fe0b15eab08dfad | |
parent | cf6c3918387253e938d6fbdccf67694fe96bf733 (diff) |
Revert "Update JAX"
-rw-r--r-- | pkgs/build-support/build-bazel-package/default.nix | 39 | ||||
-rw-r--r-- | pkgs/development/python-modules/jax/default.nix | 22 | ||||
-rw-r--r-- | pkgs/development/python-modules/jaxlib/bin.nix | 84 | ||||
-rw-r--r-- | pkgs/development/python-modules/jaxlib/default.nix | 68 | ||||
-rwxr-xr-x | pkgs/development/python-modules/jaxlib/prefetch.sh | 22 | ||||
-rw-r--r-- | pkgs/development/python-modules/ml-dtypes/default.nix | 38 | ||||
-rw-r--r-- | pkgs/top-level/python-packages.nix | 3 |
7 files changed, 103 insertions, 173 deletions
diff --git a/pkgs/build-support/build-bazel-package/default.nix b/pkgs/build-support/build-bazel-package/default.nix index 3ffff74f70e23..f9de0ad468b2b 100644 --- a/pkgs/build-support/build-bazel-package/default.nix +++ b/pkgs/build-support/build-bazel-package/default.nix @@ -10,12 +10,9 @@ args@{ , bazelFlags ? [] , bazelBuildFlags ? [] , bazelTestFlags ? [] -, bazelRunFlags ? [] -, runTargetFlags ? [] , bazelFetchFlags ? [] -, bazelTargets ? [] +, bazelTargets , bazelTestTargets ? [] -, bazelRunTarget ? null , buildAttrs , fetchAttrs @@ -49,23 +46,17 @@ args@{ let fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // { - inherit - name - bazelFlags - bazelBuildFlags - bazelTestFlags - bazelRunFlags - runTargetFlags - bazelFetchFlags - bazelTargets - bazelTestTargets - bazelRunTarget - dontAddBazelOpts - ; + name = name; + bazelFlags = bazelFlags; + bazelBuildFlags = bazelBuildFlags; + bazelTestFlags = bazelTestFlags; + bazelFetchFlags = bazelFetchFlags; + bazelTestTargets = bazelTestTargets; + dontAddBazelOpts = dontAddBazelOpts; }; fBuildAttrs = fArgs // buildAttrs; fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ]; - bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }: + bazelCmd = { cmd, additionalFlags, targets }: lib.optionalString (targets != [ ]) '' # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables] BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \ @@ -82,8 +73,7 @@ let "''${host_linkopts[@]}" \ $bazelFlags \ ${lib.strings.concatStringsSep " " additionalFlags} \ - ${lib.strings.concatStringsSep " " targets} \ - ${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags} + ${lib.strings.concatStringsSep " " targets} ''; # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so: # chmod: cannot operate on dangling symlink '$symlink' @@ -272,15 +262,6 @@ stdenv.mkDerivation (fBuildAttrs // { targets = fBuildAttrs.bazelTargets; } } - ${ - bazelCmd { - cmd = "run"; - additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ]; - # Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list. - targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ]; - targetRunFlags = fBuildAttrs.runTargetFlags; - } - } runHook postBuild ''; }) diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index 84b7d5c303b2d..4901467262f38 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -8,7 +8,6 @@ , jaxlib-bin , lapack , matplotlib -, ml-dtypes , numpy , opt-einsum , pytestCheckHook @@ -28,7 +27,7 @@ let in buildPythonPackage rec { pname = "jax"; - version = "0.4.12"; + version = "0.4.5"; format = "setuptools"; disabled = pythonOlder "3.7"; @@ -38,7 +37,7 @@ buildPythonPackage rec { repo = pname; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/${pname}-v${version}"; - hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; + hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are @@ -47,7 +46,6 @@ buildPythonPackage rec { propagatedBuildInputs = [ absl-py etils - ml-dtypes numpy opt-einsum scipy @@ -98,12 +96,24 @@ buildPythonPackage rec { "testScanGrad_jit_scan" ]; - disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ + # See https://github.com/google/jax/issues/11722. This is a temporary fix in + # order to unblock etils, and upgrading jax/jaxlib to the latest version. See + # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. + disabledTestPaths = [ + "tests/api_test.py" + "tests/core_test.py" + "tests/lax_numpy_indexing_test.py" + "tests/lax_numpy_test.py" + "tests/nn_test.py" + "tests/random_test.py" + "tests/sparse_test.py" + ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; - pythonImportsCheck = [ "jax" ]; + # As of 0.3.22, `import jax` does not work without jaxlib being installed. + pythonImportsCheck = [ ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code"; diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index c92e7117028d8..b3d3138ab4431 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -18,12 +18,11 @@ , autoPatchelfHook , buildPythonPackage , config -, fetchPypi +, cudnn ? cudaPackages.cudnn , fetchurl , flatbuffers -, jaxlib +, isPy39 , lib -, ml-dtypes , python , scipy , stdenv @@ -36,57 +35,46 @@ let inherit (cudaPackages) cudatoolkit cudnn; in -assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux; +assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1"; +assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2"; let - version = "0.4.12"; - - inherit (python) pythonVersion; - - # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the - # official instructions recommend installing CPU-only versions via PyPI. - cpuSrcs = - let - getSrcFromPypi = { platform, hash }: fetchPypi { - inherit version platform hash; - pname = "jaxlib"; - format = "wheel"; - # See the `disabled` attr comment below. - dist = "cp310"; - python = "cp310"; - abi = "cp310"; - }; - in - { - "x86_64-linux" = getSrcFromPypi { - platform = "manylinux2014_x86_64"; - hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc="; - }; - "aarch64-darwin" = getSrcFromPypi { - platform = "macosx_11_0_arm64"; - hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo="; - }; - "x86_64-darwin" = getSrcFromPypi { - platform = "macosx_10_14_x86_64"; - hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c="; - }; - }; + version = "0.4.4"; + pythonVersion = python.pythonVersion; # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html. # When upgrading, you can get these hashes from prefetch.sh. See - # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. - gpuSrc = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo="; + # https://github.com/google/jax/issues/12879 as to why this specific URL is + # the correct index. + cpuSrcs = { + "x86_64-linux" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ="; + }; + "aarch64-darwin" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl"; + hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U="; + }; + "x86_64-darwin" = fetchurl { + url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl"; + hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok="; + }; }; + gpuSrc = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl"; + hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk="; + }; in -buildPythonPackage { +buildPythonPackage rec { pname = "jaxlib"; inherit version; format = "wheel"; + # At the time of writing (2022-10-19), there are releases for <=3.10. + # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs + # python version. disabled = !(pythonVersion == "3.10"); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. @@ -99,10 +87,9 @@ buildPythonPackage { # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. - nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] - ++ lib.optionals cudaSupport [ addOpenGLRunpath ]; + nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ]; # Dynamic link dependencies - buildInputs = [ stdenv.cc.cc.lib ]; + buildInputs = [ stdenv.cc.cc ]; # jaxlib contains shared libraries that open other shared libraries via dlopen # and these implicit dependencies are not recognized by ldd or @@ -126,12 +113,7 @@ buildPythonPackage { done ''; - propagatedBuildInputs = [ - absl-py - flatbuffers - ml-dtypes - scipy - ]; + propagatedBuildInputs = [ absl-py flatbuffers scipy ]; # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH. # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for @@ -141,7 +123,7 @@ buildPythonPackage { ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas ''; - inherit (jaxlib) pythonImportsCheck; + pythonImportsCheck = [ "jaxlib" ]; meta = with lib; { description = "XLA library for JAX"; diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index 2fa754af8c6a3..bf93bf1a5a26e 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -4,7 +4,7 @@ # Build-time dependencies: , addOpenGLRunpath -, bazel_6 +, bazel_5 , binutils , buildBazelPackage , buildPythonPackage @@ -26,7 +26,6 @@ # Python dependencies: , absl-py , flatbuffers -, ml-dtypes , numpy , scipy , six @@ -36,6 +35,7 @@ , giflib , grpc , libjpeg_turbo +, protobuf , python , snappy , zlib @@ -53,7 +53,7 @@ let inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl; pname = "jaxlib"; - version = "0.4.12"; + version = "0.4.4"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; @@ -138,15 +138,14 @@ let bazel-build = buildBazelPackage rec { name = "bazel-build-${pname}-${version}"; - # See https://github.com/google/jax/blob/main/.bazelversion for the latest. - bazel = bazel_6; + bazel = bazel_5; src = fetchFromGitHub { owner = "google"; repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! rev = "refs/tags/${pname}-v${version}"; - hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y="; + hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo="; }; nativeBuildInputs = [ @@ -170,7 +169,7 @@ let numpy openssl pkgs.flatbuffers - pkgs.protobuf + protobuf pybind11 scipy six @@ -189,8 +188,7 @@ let rm -f .bazelversion ''; - bazelRunTarget = "//build:build_wheel"; - runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ]; + bazelTargets = [ "//build:build_wheel" ]; removeRulesCC = false; @@ -209,11 +207,7 @@ let build --action_env=PYENV_ROOT build --python_path="${python}/bin/python" build --distinct_host_configuration=false - build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" - '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) '' - build --config=avx_posix - '' + lib.optionalString mklSupport '' - build --config=mkl_open_source_only + build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include" '' + lib.optionalString cudaSupport '' build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}" build --action_env CUDNN_INSTALL_PATH="${cudnn}" @@ -240,7 +234,7 @@ let fetchAttrs = { TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; # we have to force @mkl_dnn_v1 since it's not needed on darwin - bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ]; + bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ]; bazelFlags = bazelFlags ++ [ "--config=avx_posix" ] ++ lib.optionals cudaSupport [ @@ -255,9 +249,9 @@ let sha256 = if cudaSupport then - "sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4=" + "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk=" else - "sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc="; + "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI="; }; buildAttrs = { @@ -267,13 +261,25 @@ let "nsync" # fails to build on darwin ]); + bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [ + "--config=avx_posix" + ] ++ lib.optionals cudaSupport [ + "--config=cuda" + ] ++ lib.optionals mklSupport [ + "--config=mkl_open_source_only" + ]; # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet. - # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on + # 1) Fix pybind11 include paths. + # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on # loading multiple extensions in the same python program due to duplicate protobuf DBs. - # 2) Patch python path in the compiler driver. - preBuild = lib.optionalString cudaSupport '' + # 3) Patch python path in the compiler driver. + preBuild = '' + for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do + sed -i 's@include/pybind11@pybind11@g' $src + done + '' + lib.optionalString cudaSupport '' export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib" - patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl + patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '' + lib.optionalString stdenv.isDarwin '' # Framework search paths aren't added by bintools hook # https://github.com/NixOS/nixpkgs/pull/41914 @@ -283,12 +289,16 @@ let substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ --replace "/usr/bin/libtool" "${cctools}/bin/libtool" '' + (if stdenv.cc.isGNU then '' - sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD - sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD '' else if stdenv.cc.isClang then '' - sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD - sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD '' else throw "Unsupported stdenv.cc: ${stdenv.cc}"); + + installPhase = '' + ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch} + ''; }; inherit meta; @@ -335,19 +345,13 @@ buildPythonPackage { grpc jsoncpp libjpeg_turbo - ml-dtypes numpy scipy six snappy ]; - pythonImportsCheck = [ - "jaxlib" - # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. - "jaxlib.cpu_feature_guard" - "jaxlib.xla_client" - ]; + pythonImportsCheck = [ "jaxlib" ]; # Without it there are complaints about libcudart.so.11.0 not being found # because RPATH path entries added above are stripped. diff --git a/pkgs/development/python-modules/jaxlib/prefetch.sh b/pkgs/development/python-modules/jaxlib/prefetch.sh index 3362e2d0b7813..31db6530639fc 100755 --- a/pkgs/development/python-modules/jaxlib/prefetch.sh +++ b/pkgs/development/python-modules/jaxlib/prefetch.sh @@ -1,15 +1,7 @@ -#!/usr/bin/env bash - -prefetch () { - expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url" - url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r) - echo "$url" - sha256=$(nix-prefetch-url "$url") - nix hash to-sri --type sha256 "$sha256" - echo -} - -prefetch "x86_64-linux" "false" -prefetch "aarch64-darwin" "false" -prefetch "x86_64-darwin" "false" -prefetch "x86_64-linux" "true" +version="$1" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)" +nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)" diff --git a/pkgs/development/python-modules/ml-dtypes/default.nix b/pkgs/development/python-modules/ml-dtypes/default.nix deleted file mode 100644 index e0d1df9c13c70..0000000000000 --- a/pkgs/development/python-modules/ml-dtypes/default.nix +++ /dev/null @@ -1,38 +0,0 @@ -{ lib -, buildPythonPackage -, fetchFromGitHub -, numpy -, pybind11 -, pythonOlder -}: - -buildPythonPackage rec { - pname = "ml-dtypes"; - version = "0.2.0"; - - disabled = pythonOlder "3.7"; - - src = fetchFromGitHub { - owner = "jax-ml"; - repo = "ml_dtypes"; - rev = "refs/tags/v${version}"; - hash = "sha256-eqajWUwylIYsS8gzEaCZLLr+1+34LXWhfKBjuwsEhhI="; - # Since this upstream patch (https://github.com/jax-ml/ml_dtypes/commit/1bfd097e794413b0d465fa34f2eff0f3828ff521), - # the attempts to use the nixpkgs packaged eigen dependency have failed. - # Hence, we rely on the bundled eigen library. - fetchSubmodules = true; - }; - - nativeBuildInputs = [ pybind11 ]; - - propagatedBuildInputs = [ numpy ]; - - pythonImportsCheck = [ "ml_dtypes" ]; - - meta = with lib; { - description = "A stand-alone implementation of several NumPy dtype extensions used in machine learning libraries"; - homepage = "https://github.com/jax-ml/ml_dtypes"; - license = licenses.asl20; - maintainers = with maintainers; [ GaetanLepage samuela ]; - }; -} diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index af7fa4a4693f1..5c498ff9519aa 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -5310,6 +5310,7 @@ self: super: with self; { # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'. inherit (pkgs.config) cudaSupport; IOKit = pkgs.darwin.apple_sdk_11_0.IOKit; + protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21 }; jaxlib = self.jaxlib-build; @@ -6563,8 +6564,6 @@ self: super: with self; { ml-collections = callPackage ../development/python-modules/ml-collections { }; - ml-dtypes = callPackage ../development/python-modules/ml-dtypes { }; - mlflow = callPackage ../development/python-modules/mlflow { }; mlrose = callPackage ../development/python-modules/mlrose { }; |