about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/bin.nix169
1 files changed, 92 insertions, 77 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix
index 54abdfe48c345..5d4943a97ced4 100644
--- a/pkgs/development/python-modules/jaxlib/bin.nix
+++ b/pkgs/development/python-modules/jaxlib/bin.nix
@@ -4,117 +4,132 @@
 
 # See `python3Packages.jax.passthru` for CUDA tests.
 
-{ absl-py
-, autoAddDriverRunpath
-, autoPatchelfHook
-, buildPythonPackage
-, config
-, fetchPypi
-, fetchurl
-, flatbuffers
-, jaxlib-build
-, lib
-, ml-dtypes
-, python
-, scipy
-, stdenv
+{
+  absl-py,
+  autoAddDriverRunpath,
+  autoPatchelfHook,
+  buildPythonPackage,
+  config,
+  fetchPypi,
+  fetchurl,
+  flatbuffers,
+  jaxlib-build,
+  lib,
+  ml-dtypes,
+  python,
+  scipy,
+  stdenv,
   # Options:
-, cudaSupport ? config.cudaSupport
-, cudaPackagesGoogle
+  cudaSupport ? config.cudaSupport,
+  cudaPackages,
 }:
 
 let
-  inherit (cudaPackagesGoogle) cudaVersion;
+  inherit (cudaPackages) cudaVersion;
 
-  version = "0.4.24";
+  version = "0.4.28";
 
   inherit (python) pythonVersion;
 
-  cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [
-    cuda_cudart.lib # libcudart.so
-    cuda_cupti.lib # libcupti.so
-    cudnn.lib # libcudnn.so
-    libcufft.lib # libcufft.so
-    libcusolver.lib # libcusolver.so
-    libcusparse.lib # libcusparse.so
-  ]);
+  cudaLibPath = lib.makeLibraryPath (
+    with cudaPackages;
+    [
+      cuda_cudart.lib # libcudart.so
+      cuda_cupti.lib # libcupti.so
+      cudnn.lib # libcudnn.so
+      libcufft.lib # libcufft.so
+      libcusolver.lib # libcusolver.so
+      libcusparse.lib # libcusparse.so
+    ]
+  );
 
   # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
   # official instructions recommend installing CPU-only versions via PyPI.
   cpuSrcs =
     let
-      getSrcFromPypi = { platform, dist, hash }: fetchPypi {
-        inherit version platform dist hash;
-        pname = "jaxlib";
-        format = "wheel";
-        # See the `disabled` attr comment below.
-        python = dist;
-        abi = dist;
-      };
+      getSrcFromPypi =
+        {
+          platform,
+          dist,
+          hash,
+        }:
+        fetchPypi {
+          inherit
+            version
+            platform
+            dist
+            hash
+            ;
+          pname = "jaxlib";
+          format = "wheel";
+          # See the `disabled` attr comment below.
+          python = dist;
+          abi = dist;
+        };
     in
     {
       "3.9-x86_64-linux" = getSrcFromPypi {
         platform = "manylinux2014_x86_64";
         dist = "cp39";
-        hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE=";
+        hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw=";
       };
       "3.9-aarch64-darwin" = getSrcFromPypi {
         platform = "macosx_11_0_arm64";
         dist = "cp39";
-        hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU=";
+        hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw=";
       };
       "3.9-x86_64-darwin" = getSrcFromPypi {
         platform = "macosx_10_14_x86_64";
         dist = "cp39";
-        hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik=";
+        hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c=";
       };
 
       "3.10-x86_64-linux" = getSrcFromPypi {
         platform = "manylinux2014_x86_64";
         dist = "cp310";
-        hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY=";
+        hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps=";
       };
       "3.10-aarch64-darwin" = getSrcFromPypi {
         platform = "macosx_11_0_arm64";
         dist = "cp310";
-        hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw=";
+        hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk=";
       };
       "3.10-x86_64-darwin" = getSrcFromPypi {
         platform = "macosx_10_14_x86_64";
         dist = "cp310";
-        hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ=";
+        hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY=";
       };
 
       "3.11-x86_64-linux" = getSrcFromPypi {
         platform = "manylinux2014_x86_64";
         dist = "cp311";
-        hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8=";
+        hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU=";
       };
       "3.11-aarch64-darwin" = getSrcFromPypi {
         platform = "macosx_11_0_arm64";
         dist = "cp311";
-        hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE=";
+        hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck=";
       };
       "3.11-x86_64-darwin" = getSrcFromPypi {
         platform = "macosx_10_14_x86_64";
         dist = "cp311";
-        hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ=";
+        hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU=";
       };
 
       "3.12-x86_64-linux" = getSrcFromPypi {
         platform = "manylinux2014_x86_64";
         dist = "cp312";
-        hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo=";
+        hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40=";
       };
       "3.12-aarch64-darwin" = getSrcFromPypi {
         platform = "macosx_11_0_arm64";
         dist = "cp312";
-        hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0=";
+        hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10=";
       };
       "3.12-x86_64-darwin" = getSrcFromPypi {
         platform = "macosx_10_14_x86_64";
         dist = "cp312";
-        hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE=";
+        hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A=";
       };
     };
 
@@ -130,57 +145,48 @@ let
   gpuSrcs = {
     "cuda12.2-3.9" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
-      hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM=";
+      hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw=";
     };
     "cuda12.2-3.10" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
-      hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE=";
+      hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ=";
     };
     "cuda12.2-3.11" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
-      hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ=";
+      hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o=";
     };
     "cuda12.2-3.12" = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
-      hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q=";
-    };
-    "cuda11.8-3.9" = fetchurl {
-      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
-      hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU=";
-    };
-    "cuda11.8-3.10" = fetchurl {
-      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
-      hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk=";
-    };
-    "cuda11.8-3.11" = fetchurl {
-      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
-      hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw=";
-    };
-    "cuda11.8-3.12" = fetchurl {
-      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl";
-      hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00=";
+      hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU=";
     };
   };
-
 in
 buildPythonPackage {
   pname = "jaxlib";
   inherit version;
   format = "wheel";
 
-  disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10" || pythonVersion == "3.11" || pythonVersion == "3.12");
+  disabled =
+    !(
+      pythonVersion == "3.9"
+      || pythonVersion == "3.10"
+      || pythonVersion == "3.11"
+      || pythonVersion == "3.12"
+    );
 
   # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
   src =
     if !cudaSupport then
-      (
-        cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}"
-          or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
-      ) else gpuSrcs."${gpuSrcVersionString}";
+      (cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}"
+        or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
+      )
+    else
+      gpuSrcs."${gpuSrcVersionString}";
 
   # Prebuilt wheels are dynamically linked against things that nix can't find.
   # Run `autoPatchelfHook` to automagically fix them.
-  nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ]
+  nativeBuildInputs =
+    lib.optionals stdenv.isLinux [ autoPatchelfHook ]
     ++ lib.optionals cudaSupport [ autoAddDriverRunpath ];
   # Dynamic link dependencies
   buildInputs = [ stdenv.cc.cc.lib ];
@@ -213,7 +219,7 @@ buildPythonPackage {
   # for more info.
   postInstall = lib.optional cudaSupport ''
     mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin
-    ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
+    ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
   '';
 
   inherit (jaxlib-build) pythonImportsCheck;
@@ -224,11 +230,20 @@ buildPythonPackage {
     sourceProvenance = with sourceTypes; [ binaryNativeCode ];
     license = licenses.asl20;
     maintainers = with maintainers; [ samuela ];
-    platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ];
+    platforms = [
+      "aarch64-darwin"
+      "x86_64-linux"
+      "x86_64-darwin"
+    ];
     broken =
       !(cudaSupport -> lib.versionAtLeast cudaVersion "11.1")
-      || !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2")
+      || !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2")
       || !(cudaSupport -> stdenv.isLinux)
-      || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"));
+      || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"))
+      # Fails at pythonImportsCheckPhase:
+      # ...-python-imports-check-hook.sh/nix-support/setup-hook: line 10: 28017 Illegal instruction: 4
+      # /nix/store/5qpssbvkzfh73xih07xgmpkj5r565975-python3-3.11.9/bin/python3.11 -c
+      # 'import os; import importlib; list(map(lambda mod: importlib.import_module(mod), os.environ["pythonImportsCheck"].split()))'
+      || (stdenv.isDarwin && stdenv.isx86_64);
   };
 }