diff options
Diffstat (limited to 'pkgs/development/python-modules/jax/default.nix')
-rw-r--r-- | pkgs/development/python-modules/jax/default.nix | 159 |
1 files changed, 87 insertions, 72 deletions
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index e160eec612cf..ba9a621b1eab 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -1,35 +1,33 @@ -{ lib -, blas -, buildPythonPackage -, callPackage -, setuptools -, importlib-metadata -, fetchFromGitHub -, jaxlib -, jaxlib-bin -, hypothesis -, lapack -, matplotlib -, ml-dtypes -, numpy -, opt-einsum -, pytestCheckHook -, pytest-xdist -, pythonOlder -, scipy -, stdenv +{ + lib, + blas, + buildPythonPackage, + callPackage, + setuptools, + importlib-metadata, + fetchFromGitHub, + jaxlib, + jaxlib-bin, + jaxlib-build, + hypothesis, + lapack, + matplotlib, + ml-dtypes, + numpy, + opt-einsum, + pytestCheckHook, + pytest-xdist, + pythonOlder, + scipy, + stdenv, }: let usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; - # jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work - # fine. jaxlib is only used in the checkPhase, so switching backends does not - # impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*. - jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib; in buildPythonPackage rec { pname = "jax"; - version = "0.4.25"; + version = "0.4.28"; pyproject = true; disabled = pythonOlder "3.9"; @@ -39,12 +37,10 @@ buildPythonPackage rec { repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jax tags! rev = "refs/tags/jax-v${version}"; - hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok="; + hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; }; - nativeBuildInputs = [ - setuptools - ]; + nativeBuildInputs = [ setuptools ]; # The version is automatically set to ".dev" if this variable is not set. # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 @@ -62,7 +58,7 @@ buildPythonPackage rec { nativeCheckInputs = [ hypothesis - jaxlib' + jaxlib matplotlib pytestCheckHook pytest-xdist @@ -81,46 +77,61 @@ buildPythonPackage rec { "tests/" ]; - disabledTests = [ - # Exceeds tolerance when the machine is busy - "test_custom_linear_solve_aux" - # UserWarning: Explicitly requested dtype <class 'numpy.float64'> - # requested in astype is not available, and will be truncated to - # dtype float32. (With numpy 1.24) - "testKde3" - "testKde5" - "testKde6" - # Invokes python manually in a subprocess, which does not have the correct dependencies - # ImportError: This version of jax requires jaxlib version >= 0.4.19. - "test_no_log_spam" - ] ++ lib.optionals usingMKL [ - # See - # * https://github.com/google/jax/issues/9705 - # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 - # * https://github.com/NixOS/nixpkgs/issues/161960 - "test_custom_linear_solve_cholesky" - "test_custom_root_with_aux" - "testEigvalsGrad_shape" - ] ++ lib.optionals stdenv.isAarch64 [ - # See https://github.com/google/jax/issues/14793. - "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" - "testQdwhWithRandomMatrix3" - "testScanGrad_jit_scan" - - # See https://github.com/google/jax/issues/17867. - "test_array" - "test_async" - "test_copy0" - "test_device_put" - "test_make_array_from_callback" - "test_make_array_from_single_device_arrays" - - # Fails on some hardware due to some numerical error - # See https://github.com/google/jax/issues/18535 - "testQdwhWithOnRankDeficientInput5" - ]; - - disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ + # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with + # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' + # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 + # NOTE: this doesn't seem to be an issue on linux + preCheck = lib.optionalString stdenv.hostPlatform.isDarwin '' + export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) + ''; + + disabledTests = + [ + # Exceeds tolerance when the machine is busy + "test_custom_linear_solve_aux" + # UserWarning: Explicitly requested dtype <class 'numpy.float64'> + # requested in astype is not available, and will be truncated to + # dtype float32. (With numpy 1.24) + "testKde3" + "testKde5" + "testKde6" + # Invokes python manually in a subprocess, which does not have the correct dependencies + # ImportError: This version of jax requires jaxlib version >= 0.4.19. + "test_no_log_spam" + ] + ++ lib.optionals usingMKL [ + # See + # * https://github.com/google/jax/issues/9705 + # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 + # * https://github.com/NixOS/nixpkgs/issues/161960 + "test_custom_linear_solve_cholesky" + "test_custom_root_with_aux" + "testEigvalsGrad_shape" + ] + ++ lib.optionals stdenv.hostPlatform.isAarch64 [ + # See https://github.com/google/jax/issues/14793. + "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" + "testQdwhWithRandomMatrix3" + "testScanGrad_jit_scan" + + # See https://github.com/google/jax/issues/17867. + "test_array" + "test_async" + "test_copy0" + "test_device_put" + "test_make_array_from_callback" + "test_make_array_from_single_device_arrays" + + # Fails on some hardware due to some numerical error + # See https://github.com/google/jax/issues/18535 + "testQdwhWithOnRankDeficientInput5" + ]; + + disabledTestPaths = [ + # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba + "tests/linalg_test.py" + ] + ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; @@ -137,7 +148,7 @@ buildPythonPackage rec { # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin passthru.tests = { test_cuda_jaxlibSource = callPackage ./test-cuda.nix { - jaxlib = jaxlib.override { cudaSupport = true; }; + jaxlib = jaxlib-build.override { cudaSupport = true; }; }; test_cuda_jaxlibBin = callPackage ./test-cuda.nix { jaxlib = jaxlib-bin.override { cudaSupport = true; }; @@ -148,7 +159,11 @@ buildPythonPackage rec { passthru.skipBulkUpdate = true; meta = with lib; { - description = "Differentiate, compile, and transform Numpy code"; + description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; + longDescription = '' + This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, + e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. + ''; homepage = "https://github.com/google/jax"; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; |