about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jax/default.nix')
-rw-r--r--pkgs/development/python-modules/jax/default.nix159
1 files changed, 87 insertions, 72 deletions
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix
index e160eec612cf..ba9a621b1eab 100644
--- a/pkgs/development/python-modules/jax/default.nix
+++ b/pkgs/development/python-modules/jax/default.nix
@@ -1,35 +1,33 @@
-{ lib
-, blas
-, buildPythonPackage
-, callPackage
-, setuptools
-, importlib-metadata
-, fetchFromGitHub
-, jaxlib
-, jaxlib-bin
-, hypothesis
-, lapack
-, matplotlib
-, ml-dtypes
-, numpy
-, opt-einsum
-, pytestCheckHook
-, pytest-xdist
-, pythonOlder
-, scipy
-, stdenv
+{
+  lib,
+  blas,
+  buildPythonPackage,
+  callPackage,
+  setuptools,
+  importlib-metadata,
+  fetchFromGitHub,
+  jaxlib,
+  jaxlib-bin,
+  jaxlib-build,
+  hypothesis,
+  lapack,
+  matplotlib,
+  ml-dtypes,
+  numpy,
+  opt-einsum,
+  pytestCheckHook,
+  pytest-xdist,
+  pythonOlder,
+  scipy,
+  stdenv,
 }:
 
 let
   usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
-  # jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
-  # fine. jaxlib is only used in the checkPhase, so switching backends does not
-  # impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
-  jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
 in
 buildPythonPackage rec {
   pname = "jax";
-  version = "0.4.25";
+  version = "0.4.28";
   pyproject = true;
 
   disabled = pythonOlder "3.9";
@@ -39,12 +37,10 @@ buildPythonPackage rec {
     repo = "jax";
     # google/jax contains tags for jax and jaxlib. Only use jax tags!
     rev = "refs/tags/jax-v${version}";
-    hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok=";
+    hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
   };
 
-  nativeBuildInputs = [
-    setuptools
-  ];
+  nativeBuildInputs = [ setuptools ];
 
   # The version is automatically set to ".dev" if this variable is not set.
   # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
@@ -62,7 +58,7 @@ buildPythonPackage rec {
 
   nativeCheckInputs = [
     hypothesis
-    jaxlib'
+    jaxlib
     matplotlib
     pytestCheckHook
     pytest-xdist
@@ -81,46 +77,61 @@ buildPythonPackage rec {
     "tests/"
   ];
 
-  disabledTests = [
-    # Exceeds tolerance when the machine is busy
-    "test_custom_linear_solve_aux"
-    # UserWarning: Explicitly requested dtype <class 'numpy.float64'>
-    #  requested in astype is not available, and will be truncated to
-    # dtype float32. (With numpy 1.24)
-    "testKde3"
-    "testKde5"
-    "testKde6"
-    # Invokes python manually in a subprocess, which does not have the correct dependencies
-    # ImportError: This version of jax requires jaxlib version >= 0.4.19.
-    "test_no_log_spam"
-  ] ++ lib.optionals usingMKL [
-    # See
-    #  * https://github.com/google/jax/issues/9705
-    #  * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921
-    #  * https://github.com/NixOS/nixpkgs/issues/161960
-    "test_custom_linear_solve_cholesky"
-    "test_custom_root_with_aux"
-    "testEigvalsGrad_shape"
-  ] ++ lib.optionals stdenv.isAarch64 [
-    # See https://github.com/google/jax/issues/14793.
-    "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
-    "testQdwhWithRandomMatrix3"
-    "testScanGrad_jit_scan"
-
-    # See https://github.com/google/jax/issues/17867.
-    "test_array"
-    "test_async"
-    "test_copy0"
-    "test_device_put"
-    "test_make_array_from_callback"
-    "test_make_array_from_single_device_arrays"
-
-    # Fails on some hardware due to some numerical error
-    # See https://github.com/google/jax/issues/18535
-    "testQdwhWithOnRankDeficientInput5"
-  ];
-
-  disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
+  # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
+  # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py'
+  # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241
+  # NOTE: this doesn't seem to be an issue on linux
+  preCheck = lib.optionalString stdenv.hostPlatform.isDarwin ''
+    export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d)
+  '';
+
+  disabledTests =
+    [
+      # Exceeds tolerance when the machine is busy
+      "test_custom_linear_solve_aux"
+      # UserWarning: Explicitly requested dtype <class 'numpy.float64'>
+      #  requested in astype is not available, and will be truncated to
+      # dtype float32. (With numpy 1.24)
+      "testKde3"
+      "testKde5"
+      "testKde6"
+      # Invokes python manually in a subprocess, which does not have the correct dependencies
+      # ImportError: This version of jax requires jaxlib version >= 0.4.19.
+      "test_no_log_spam"
+    ]
+    ++ lib.optionals usingMKL [
+      # See
+      #  * https://github.com/google/jax/issues/9705
+      #  * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921
+      #  * https://github.com/NixOS/nixpkgs/issues/161960
+      "test_custom_linear_solve_cholesky"
+      "test_custom_root_with_aux"
+      "testEigvalsGrad_shape"
+    ]
+    ++ lib.optionals stdenv.hostPlatform.isAarch64 [
+      # See https://github.com/google/jax/issues/14793.
+      "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
+      "testQdwhWithRandomMatrix3"
+      "testScanGrad_jit_scan"
+
+      # See https://github.com/google/jax/issues/17867.
+      "test_array"
+      "test_async"
+      "test_copy0"
+      "test_device_put"
+      "test_make_array_from_callback"
+      "test_make_array_from_single_device_arrays"
+
+      # Fails on some hardware due to some numerical error
+      # See https://github.com/google/jax/issues/18535
+      "testQdwhWithOnRankDeficientInput5"
+    ];
+
+  disabledTestPaths = [
+    # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba
+    "tests/linalg_test.py"
+  ]
+  ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64) [
     # RuntimeWarning: invalid value encountered in cast
     "tests/lax_test.py"
   ];
@@ -137,7 +148,7 @@ buildPythonPackage rec {
   #   NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
   passthru.tests = {
     test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
-      jaxlib = jaxlib.override { cudaSupport = true; };
+      jaxlib = jaxlib-build.override { cudaSupport = true; };
     };
     test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
       jaxlib = jaxlib-bin.override { cudaSupport = true; };
@@ -148,7 +159,11 @@ buildPythonPackage rec {
   passthru.skipBulkUpdate = true;
 
   meta = with lib; {
-    description = "Differentiate, compile, and transform Numpy code";
+    description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
+    longDescription = ''
+      This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
+      e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
+    '';
     homepage = "https://github.com/google/jax";
     license = licenses.asl20;
     maintainers = with maintainers; [ samuela ];