diff options
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/bin.nix | 169 |
1 files changed, 92 insertions, 77 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index 54abdfe48c345..5d4943a97ced4 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -4,117 +4,132 @@ # See `python3Packages.jax.passthru` for CUDA tests. -{ absl-py -, autoAddDriverRunpath -, autoPatchelfHook -, buildPythonPackage -, config -, fetchPypi -, fetchurl -, flatbuffers -, jaxlib-build -, lib -, ml-dtypes -, python -, scipy -, stdenv +{ + absl-py, + autoAddDriverRunpath, + autoPatchelfHook, + buildPythonPackage, + config, + fetchPypi, + fetchurl, + flatbuffers, + jaxlib-build, + lib, + ml-dtypes, + python, + scipy, + stdenv, # Options: -, cudaSupport ? config.cudaSupport -, cudaPackagesGoogle + cudaSupport ? config.cudaSupport, + cudaPackages, }: let - inherit (cudaPackagesGoogle) cudaVersion; + inherit (cudaPackages) cudaVersion; - version = "0.4.24"; + version = "0.4.28"; inherit (python) pythonVersion; - cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [ - cuda_cudart.lib # libcudart.so - cuda_cupti.lib # libcupti.so - cudnn.lib # libcudnn.so - libcufft.lib # libcufft.so - libcusolver.lib # libcusolver.so - libcusparse.lib # libcusparse.so - ]); + cudaLibPath = lib.makeLibraryPath ( + with cudaPackages; + [ + cuda_cudart.lib # libcudart.so + cuda_cupti.lib # libcupti.so + cudnn.lib # libcudnn.so + libcufft.lib # libcufft.so + libcusolver.lib # libcusolver.so + libcusparse.lib # libcusparse.so + ] + ); # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the # official instructions recommend installing CPU-only versions via PyPI. cpuSrcs = let - getSrcFromPypi = { platform, dist, hash }: fetchPypi { - inherit version platform dist hash; - pname = "jaxlib"; - format = "wheel"; - # See the `disabled` attr comment below. - python = dist; - abi = dist; - }; + getSrcFromPypi = + { + platform, + dist, + hash, + }: + fetchPypi { + inherit + version + platform + dist + hash + ; + pname = "jaxlib"; + format = "wheel"; + # See the `disabled` attr comment below. + python = dist; + abi = dist; + }; in { "3.9-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp39"; - hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE="; + hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw="; }; "3.9-aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; dist = "cp39"; - hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU="; + hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw="; }; "3.9-x86_64-darwin" = getSrcFromPypi { platform = "macosx_10_14_x86_64"; dist = "cp39"; - hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik="; + hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c="; }; "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; - hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY="; + hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps="; }; "3.10-aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; dist = "cp310"; - hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw="; + hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk="; }; "3.10-x86_64-darwin" = getSrcFromPypi { platform = "macosx_10_14_x86_64"; dist = "cp310"; - hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ="; + hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; - hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8="; + hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU="; }; "3.11-aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; dist = "cp311"; - hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE="; + hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck="; }; "3.11-x86_64-darwin" = getSrcFromPypi { platform = "macosx_10_14_x86_64"; dist = "cp311"; - hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ="; + hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; - hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo="; + hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40="; }; "3.12-aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; dist = "cp312"; - hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0="; + hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10="; }; "3.12-x86_64-darwin" = getSrcFromPypi { platform = "macosx_10_14_x86_64"; dist = "cp312"; - hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE="; + hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A="; }; }; @@ -130,57 +145,48 @@ let gpuSrcs = { "cuda12.2-3.9" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl"; - hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM="; + hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw="; }; "cuda12.2-3.10" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE="; + hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ="; }; "cuda12.2-3.11" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl"; - hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ="; + hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o="; }; "cuda12.2-3.12" = fetchurl { url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl"; - hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q="; - }; - "cuda11.8-3.9" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl"; - hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU="; - }; - "cuda11.8-3.10" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl"; - hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk="; - }; - "cuda11.8-3.11" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl"; - hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw="; - }; - "cuda11.8-3.12" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl"; - hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00="; + hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU="; }; }; - in buildPythonPackage { pname = "jaxlib"; inherit version; format = "wheel"; - disabled = !(pythonVersion == "3.9" || pythonVersion == "3.10" || pythonVersion == "3.11" || pythonVersion == "3.12"); + disabled = + !( + pythonVersion == "3.9" + || pythonVersion == "3.10" + || pythonVersion == "3.11" + || pythonVersion == "3.12" + ); # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. src = if !cudaSupport then - ( - cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}" - or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") - ) else gpuSrcs."${gpuSrcVersionString}"; + (cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}" + or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") + ) + else + gpuSrcs."${gpuSrcVersionString}"; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. - nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ] + nativeBuildInputs = + lib.optionals stdenv.isLinux [ autoPatchelfHook ] ++ lib.optionals cudaSupport [ autoAddDriverRunpath ]; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc.lib ]; @@ -213,7 +219,7 @@ buildPythonPackage { # for more info. postInstall = lib.optional cudaSupport '' mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin - ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas ''; inherit (jaxlib-build) pythonImportsCheck; @@ -224,11 +230,20 @@ buildPythonPackage { sourceProvenance = with sourceTypes; [ binaryNativeCode ]; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; - platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ]; + platforms = [ + "aarch64-darwin" + "x86_64-linux" + "x86_64-darwin" + ]; broken = !(cudaSupport -> lib.versionAtLeast cudaVersion "11.1") - || !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2") + || !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2") || !(cudaSupport -> stdenv.isLinux) - || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}")); + || !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}")) + # Fails at pythonImportsCheckPhase: + # ...-python-imports-check-hook.sh/nix-support/setup-hook: line 10: 28017 Illegal instruction: 4 + # /nix/store/5qpssbvkzfh73xih07xgmpkj5r565975-python3-3.11.9/bin/python3.11 -c + # 'import os; import importlib; list(map(lambda mod: importlib.import_module(mod), os.environ["pythonImportsCheck"].split()))' + || (stdenv.isDarwin && stdenv.isx86_64); }; } |