about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNick Cao <nickcao@nichi.co>2023-08-01 21:23:27 -0600
committerGitHub <noreply@github.com>2023-08-01 21:23:27 -0600
commit8423edb179dea2f55e013a985a806dcb6bca60ea (patch)
tree2e281969831bb6c0bfa2732e9fe0b15eab08dfad
parentcf6c3918387253e938d6fbdccf67694fe96bf733 (diff)
Revert "Update JAX"
-rw-r--r--pkgs/build-support/build-bazel-package/default.nix39
-rw-r--r--pkgs/development/python-modules/jax/default.nix22
-rw-r--r--pkgs/development/python-modules/jaxlib/bin.nix84
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix68
-rwxr-xr-xpkgs/development/python-modules/jaxlib/prefetch.sh22
-rw-r--r--pkgs/development/python-modules/ml-dtypes/default.nix38
-rw-r--r--pkgs/top-level/python-packages.nix3
7 files changed, 103 insertions, 173 deletions
diff --git a/pkgs/build-support/build-bazel-package/default.nix b/pkgs/build-support/build-bazel-package/default.nix
index 3ffff74f70e23..f9de0ad468b2b 100644
--- a/pkgs/build-support/build-bazel-package/default.nix
+++ b/pkgs/build-support/build-bazel-package/default.nix
@@ -10,12 +10,9 @@ args@{
 , bazelFlags ? []
 , bazelBuildFlags ? []
 , bazelTestFlags ? []
-, bazelRunFlags ? []
-, runTargetFlags ? []
 , bazelFetchFlags ? []
-, bazelTargets ? []
+, bazelTargets
 , bazelTestTargets ? []
-, bazelRunTarget ? null
 , buildAttrs
 , fetchAttrs
 
@@ -49,23 +46,17 @@ args@{
 
 let
   fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // {
-    inherit
-      name
-      bazelFlags
-      bazelBuildFlags
-      bazelTestFlags
-      bazelRunFlags
-      runTargetFlags
-      bazelFetchFlags
-      bazelTargets
-      bazelTestTargets
-      bazelRunTarget
-      dontAddBazelOpts
-      ;
+    name = name;
+    bazelFlags = bazelFlags;
+    bazelBuildFlags = bazelBuildFlags;
+    bazelTestFlags = bazelTestFlags;
+    bazelFetchFlags = bazelFetchFlags;
+    bazelTestTargets = bazelTestTargets;
+    dontAddBazelOpts = dontAddBazelOpts;
   };
   fBuildAttrs = fArgs // buildAttrs;
   fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ];
-  bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }:
+  bazelCmd = { cmd, additionalFlags, targets }:
     lib.optionalString (targets != [ ]) ''
       # See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables]
       BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \
@@ -82,8 +73,7 @@ let
         "''${host_linkopts[@]}" \
         $bazelFlags \
         ${lib.strings.concatStringsSep " " additionalFlags} \
-        ${lib.strings.concatStringsSep " " targets} \
-        ${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags}
+        ${lib.strings.concatStringsSep " " targets}
     '';
   # we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so:
   # chmod: cannot operate on dangling symlink '$symlink'
@@ -272,15 +262,6 @@ stdenv.mkDerivation (fBuildAttrs // {
         targets = fBuildAttrs.bazelTargets;
       }
     }
-    ${
-      bazelCmd {
-        cmd = "run";
-        additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ];
-        # Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list.
-        targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ];
-        targetRunFlags = fBuildAttrs.runTargetFlags;
-      }
-    }
     runHook postBuild
   '';
 })
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix
index 84b7d5c303b2d..4901467262f38 100644
--- a/pkgs/development/python-modules/jax/default.nix
+++ b/pkgs/development/python-modules/jax/default.nix
@@ -8,7 +8,6 @@
 , jaxlib-bin
 , lapack
 , matplotlib
-, ml-dtypes
 , numpy
 , opt-einsum
 , pytestCheckHook
@@ -28,7 +27,7 @@ let
 in
 buildPythonPackage rec {
   pname = "jax";
-  version = "0.4.12";
+  version = "0.4.5";
   format = "setuptools";
 
   disabled = pythonOlder "3.7";
@@ -38,7 +37,7 @@ buildPythonPackage rec {
     repo = pname;
     # google/jax contains tags for jax and jaxlib. Only use jax tags!
     rev = "refs/tags/${pname}-v${version}";
-    hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y=";
+    hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
   };
 
   # jaxlib is _not_ included in propagatedBuildInputs because there are
@@ -47,7 +46,6 @@ buildPythonPackage rec {
   propagatedBuildInputs = [
     absl-py
     etils
-    ml-dtypes
     numpy
     opt-einsum
     scipy
@@ -98,12 +96,24 @@ buildPythonPackage rec {
     "testScanGrad_jit_scan"
   ];
 
-  disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
+  # See https://github.com/google/jax/issues/11722. This is a temporary fix in
+  # order to unblock etils, and upgrading jax/jaxlib to the latest version. See
+  # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
+  disabledTestPaths = [
+    "tests/api_test.py"
+    "tests/core_test.py"
+    "tests/lax_numpy_indexing_test.py"
+    "tests/lax_numpy_test.py"
+    "tests/nn_test.py"
+    "tests/random_test.py"
+    "tests/sparse_test.py"
+  ] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
     # RuntimeWarning: invalid value encountered in cast
     "tests/lax_test.py"
   ];
 
-  pythonImportsCheck = [ "jax" ];
+  # As of 0.3.22, `import jax` does not work without jaxlib being installed.
+  pythonImportsCheck = [ ];
 
   meta = with lib; {
     description = "Differentiate, compile, and transform Numpy code";
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix
index c92e7117028d8..b3d3138ab4431 100644
--- a/pkgs/development/python-modules/jaxlib/bin.nix
+++ b/pkgs/development/python-modules/jaxlib/bin.nix
@@ -18,12 +18,11 @@
 , autoPatchelfHook
 , buildPythonPackage
 , config
-, fetchPypi
+, cudnn ? cudaPackages.cudnn
 , fetchurl
 , flatbuffers
-, jaxlib
+, isPy39
 , lib
-, ml-dtypes
 , python
 , scipy
 , stdenv
@@ -36,57 +35,46 @@ let
   inherit (cudaPackages) cudatoolkit cudnn;
 in
 
-assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux;
+assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
+assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
 
 let
-  version = "0.4.12";
-
-  inherit (python) pythonVersion;
-
-  # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
-  # official instructions recommend installing CPU-only versions via PyPI.
-  cpuSrcs =
-    let
-      getSrcFromPypi = { platform, hash }: fetchPypi {
-        inherit version platform hash;
-        pname = "jaxlib";
-        format = "wheel";
-        # See the `disabled` attr comment below.
-        dist = "cp310";
-        python = "cp310";
-        abi = "cp310";
-      };
-    in
-    {
-      "x86_64-linux" = getSrcFromPypi {
-        platform = "manylinux2014_x86_64";
-        hash = "sha256-8ef5aMP7M3/FetSqfdz2OCaVCt6CLHRSMMsVtV2bCLc=";
-      };
-      "aarch64-darwin" = getSrcFromPypi {
-        platform = "macosx_11_0_arm64";
-        hash = "sha256-Opg/DB4wAVSm5L3+G470HiBPDoR/BO4qP0OX9HSbeSo=";
-      };
-      "x86_64-darwin" = getSrcFromPypi {
-        platform = "macosx_10_14_x86_64";
-        hash = "sha256-I4zX1vv4L5Ik9eWrJ8fKd0EIt5C9XTN4JlfB8hH+l5c=";
-      };
-    };
+  version = "0.4.4";
 
+  pythonVersion = python.pythonVersion;
 
   # Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
   # When upgrading, you can get these hashes from prefetch.sh. See
-  # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
-  gpuSrc = fetchurl {
-    url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
-    hash = "sha256-xc6Nje0WHtMC5nV75zvdN53xSuNTbFSsz1FzHKd8Muo=";
+  # https://github.com/google/jax/issues/12879 as to why this specific URL is
+  # the correct index.
+  cpuSrcs = {
+    "x86_64-linux" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
+      hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
+    };
+    "aarch64-darwin" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
+      hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
+    };
+    "x86_64-darwin" = fetchurl {
+      url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
+      hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
+    };
   };
 
+  gpuSrc = fetchurl {
+    url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
+    hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
+  };
 in
-buildPythonPackage {
+buildPythonPackage rec {
   pname = "jaxlib";
   inherit version;
   format = "wheel";
 
+  # At the time of writing (2022-10-19), there are releases for <=3.10.
+  # Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
+  # python version.
   disabled = !(pythonVersion == "3.10");
 
   # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
@@ -99,10 +87,9 @@ buildPythonPackage {
 
   # Prebuilt wheels are dynamically linked against things that nix can't find.
   # Run `autoPatchelfHook` to automagically fix them.
-  nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ]
-    ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
+  nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
   # Dynamic link dependencies
-  buildInputs = [ stdenv.cc.cc.lib ];
+  buildInputs = [ stdenv.cc.cc ];
 
   # jaxlib contains shared libraries that open other shared libraries via dlopen
   # and these implicit dependencies are not recognized by ldd or
@@ -126,12 +113,7 @@ buildPythonPackage {
     done
   '';
 
-  propagatedBuildInputs = [
-    absl-py
-    flatbuffers
-    ml-dtypes
-    scipy
-  ];
+  propagatedBuildInputs = [ absl-py flatbuffers scipy ];
 
   # Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
   # See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
@@ -141,7 +123,7 @@ buildPythonPackage {
     ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
   '';
 
-  inherit (jaxlib) pythonImportsCheck;
+  pythonImportsCheck = [ "jaxlib" ];
 
   meta = with lib; {
     description = "XLA library for JAX";
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index 2fa754af8c6a3..bf93bf1a5a26e 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -4,7 +4,7 @@
 
   # Build-time dependencies:
 , addOpenGLRunpath
-, bazel_6
+, bazel_5
 , binutils
 , buildBazelPackage
 , buildPythonPackage
@@ -26,7 +26,6 @@
   # Python dependencies:
 , absl-py
 , flatbuffers
-, ml-dtypes
 , numpy
 , scipy
 , six
@@ -36,6 +35,7 @@
 , giflib
 , grpc
 , libjpeg_turbo
+, protobuf
 , python
 , snappy
 , zlib
@@ -53,7 +53,7 @@ let
   inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
 
   pname = "jaxlib";
-  version = "0.4.12";
+  version = "0.4.4";
 
   meta = with lib; {
     description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -138,15 +138,14 @@ let
   bazel-build = buildBazelPackage rec {
     name = "bazel-build-${pname}-${version}";
 
-    # See https://github.com/google/jax/blob/main/.bazelversion for the latest.
-    bazel = bazel_6;
+    bazel = bazel_5;
 
     src = fetchFromGitHub {
       owner = "google";
       repo = "jax";
       # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
       rev = "refs/tags/${pname}-v${version}";
-      hash = "sha256-2JwEpzB5RwmBjGktppKhCpiaBM0AR20wfsRoQ33lh8Y=";
+      hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
     };
 
     nativeBuildInputs = [
@@ -170,7 +169,7 @@ let
       numpy
       openssl
       pkgs.flatbuffers
-      pkgs.protobuf
+      protobuf
       pybind11
       scipy
       six
@@ -189,8 +188,7 @@ let
       rm -f .bazelversion
     '';
 
-    bazelRunTarget = "//build:build_wheel";
-    runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];
+    bazelTargets = [ "//build:build_wheel" ];
 
     removeRulesCC = false;
 
@@ -209,11 +207,7 @@ let
       build --action_env=PYENV_ROOT
       build --python_path="${python}/bin/python"
       build --distinct_host_configuration=false
-      build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
-    '' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
-      build --config=avx_posix
-    '' + lib.optionalString mklSupport ''
-      build --config=mkl_open_source_only
+      build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
     '' + lib.optionalString cudaSupport ''
       build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
       build --action_env CUDNN_INSTALL_PATH="${cudnn}"
@@ -240,7 +234,7 @@ let
     fetchAttrs = {
       TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
       # we have to force @mkl_dnn_v1 since it's not needed on darwin
-      bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
+      bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
       bazelFlags = bazelFlags ++ [
         "--config=avx_posix"
       ] ++ lib.optionals cudaSupport [
@@ -255,9 +249,9 @@ let
 
       sha256 =
         if cudaSupport then
-          "sha256-wpucplv03HQHZ2gWhVq4R798ouPH99T3X4hbu7IRxj4="
+          "sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
         else
-          "sha256-v2tCFifMBJbqweZQ2rsw707Zxehu+B+YtxFk1iHdDgc=";
+          "sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
     };
 
     buildAttrs = {
@@ -267,13 +261,25 @@ let
         "nsync" # fails to build on darwin
       ]);
 
+      bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
+        "--config=avx_posix"
+      ] ++ lib.optionals cudaSupport [
+        "--config=cuda"
+      ] ++ lib.optionals mklSupport [
+        "--config=mkl_open_source_only"
+      ];
       # Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
-      # 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
+      # 1) Fix pybind11 include paths.
+      # 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
       #    loading multiple extensions in the same python program due to duplicate protobuf DBs.
-      # 2) Patch python path in the compiler driver.
-      preBuild = lib.optionalString cudaSupport ''
+      # 3) Patch python path in the compiler driver.
+      preBuild = ''
+        for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
+          sed -i 's@include/pybind11@pybind11@g' $src
+        done
+      '' + lib.optionalString cudaSupport ''
         export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
-        patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
+        patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
       '' + lib.optionalString stdenv.isDarwin ''
         # Framework search paths aren't added by bintools hook
         # https://github.com/NixOS/nixpkgs/pull/41914
@@ -283,12 +289,16 @@ let
         substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
           --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
       '' + (if stdenv.cc.isGNU then ''
-        sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
-        sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
+        sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
+        sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
       '' else if stdenv.cc.isClang then ''
-        sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
-        sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
+        sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
+        sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
       '' else throw "Unsupported stdenv.cc: ${stdenv.cc}");
+
+      installPhase = ''
+        ./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
+      '';
     };
 
     inherit meta;
@@ -335,19 +345,13 @@ buildPythonPackage {
     grpc
     jsoncpp
     libjpeg_turbo
-    ml-dtypes
     numpy
     scipy
     six
     snappy
   ];
 
-  pythonImportsCheck = [
-    "jaxlib"
-    # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
-    "jaxlib.cpu_feature_guard"
-    "jaxlib.xla_client"
-  ];
+  pythonImportsCheck = [ "jaxlib" ];
 
   # Without it there are complaints about libcudart.so.11.0 not being found
   # because RPATH path entries added above are stripped.
diff --git a/pkgs/development/python-modules/jaxlib/prefetch.sh b/pkgs/development/python-modules/jaxlib/prefetch.sh
index 3362e2d0b7813..31db6530639fc 100755
--- a/pkgs/development/python-modules/jaxlib/prefetch.sh
+++ b/pkgs/development/python-modules/jaxlib/prefetch.sh
@@ -1,15 +1,7 @@
-#!/usr/bin/env bash
-
-prefetch () {
-    expr="(import <nixpkgs> { system = \"$1\"; config.cudaSupport = $2; }).python3.pkgs.jaxlib-bin.src.url"
-    url=$(NIX_PATH=.. nix-instantiate --eval -E "$expr" | jq -r)
-    echo "$url"
-    sha256=$(nix-prefetch-url "$url")
-    nix hash to-sri --type sha256 "$sha256"
-    echo
-}
-
-prefetch "x86_64-linux" "false"
-prefetch "aarch64-darwin" "false"
-prefetch "x86_64-darwin" "false"
-prefetch "x86_64-linux" "true"
+version="$1"
+nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl)"
+nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-none-manylinux2010_x86_64.whl)"
+nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl)"
+nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl)"
+nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp310-none-manylinux2010_x86_64.whl)"
+nix hash to-sri --type sha256 "$(nix-prefetch-url https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-none-manylinux2010_x86_64.whl)"
diff --git a/pkgs/development/python-modules/ml-dtypes/default.nix b/pkgs/development/python-modules/ml-dtypes/default.nix
deleted file mode 100644
index e0d1df9c13c70..0000000000000
--- a/pkgs/development/python-modules/ml-dtypes/default.nix
+++ /dev/null
@@ -1,38 +0,0 @@
-{ lib
-, buildPythonPackage
-, fetchFromGitHub
-, numpy
-, pybind11
-, pythonOlder
-}:
-
-buildPythonPackage rec {
-  pname = "ml-dtypes";
-  version = "0.2.0";
-
-  disabled = pythonOlder "3.7";
-
-  src = fetchFromGitHub {
-    owner = "jax-ml";
-    repo = "ml_dtypes";
-    rev = "refs/tags/v${version}";
-    hash = "sha256-eqajWUwylIYsS8gzEaCZLLr+1+34LXWhfKBjuwsEhhI=";
-    # Since this upstream patch (https://github.com/jax-ml/ml_dtypes/commit/1bfd097e794413b0d465fa34f2eff0f3828ff521),
-    # the attempts to use the nixpkgs packaged eigen dependency have failed.
-    # Hence, we rely on the bundled eigen library.
-    fetchSubmodules = true;
-  };
-
-  nativeBuildInputs = [ pybind11 ];
-
-  propagatedBuildInputs = [ numpy ];
-
-  pythonImportsCheck = [ "ml_dtypes" ];
-
-  meta = with lib; {
-    description = "A stand-alone implementation of several NumPy dtype extensions used in machine learning libraries";
-    homepage = "https://github.com/jax-ml/ml_dtypes";
-    license = licenses.asl20;
-    maintainers = with maintainers; [ GaetanLepage samuela ];
-  };
-}
diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix
index af7fa4a4693f1..5c498ff9519aa 100644
--- a/pkgs/top-level/python-packages.nix
+++ b/pkgs/top-level/python-packages.nix
@@ -5310,6 +5310,7 @@ self: super: with self; {
     # Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
     inherit (pkgs.config) cudaSupport;
     IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
+    protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
   };
 
   jaxlib = self.jaxlib-build;
@@ -6563,8 +6564,6 @@ self: super: with self; {
 
   ml-collections = callPackage ../development/python-modules/ml-collections { };
 
-  ml-dtypes = callPackage ../development/python-modules/ml-dtypes { };
-
   mlflow = callPackage ../development/python-modules/mlflow { };
 
   mlrose = callPackage ../development/python-modules/mlrose { };