diff options
Diffstat (limited to 'pkgs/development/python-modules/jax/default.nix')
-rw-r--r-- | pkgs/development/python-modules/jax/default.nix | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index 95e85bf8e2b20..ba9a621b1eab7 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -8,6 +8,7 @@ fetchFromGitHub, jaxlib, jaxlib-bin, + jaxlib-build, hypothesis, lapack, matplotlib, @@ -23,10 +24,6 @@ let usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; - # jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work - # fine. jaxlib is only used in the checkPhase, so switching backends does not - # impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*. - jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib; in buildPythonPackage rec { pname = "jax"; @@ -61,7 +58,7 @@ buildPythonPackage rec { nativeCheckInputs = [ hypothesis - jaxlib' + jaxlib matplotlib pytestCheckHook pytest-xdist @@ -84,7 +81,7 @@ buildPythonPackage rec { # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 # NOTE: this doesn't seem to be an issue on linux - preCheck = lib.optionalString stdenv.isDarwin '' + preCheck = lib.optionalString stdenv.hostPlatform.isDarwin '' export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) ''; @@ -111,7 +108,7 @@ buildPythonPackage rec { "test_custom_root_with_aux" "testEigvalsGrad_shape" ] - ++ lib.optionals stdenv.isAarch64 [ + ++ lib.optionals stdenv.hostPlatform.isAarch64 [ # See https://github.com/google/jax/issues/14793. "test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" "testQdwhWithRandomMatrix3" @@ -130,7 +127,11 @@ buildPythonPackage rec { "testQdwhWithOnRankDeficientInput5" ]; - disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ + disabledTestPaths = [ + # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba + "tests/linalg_test.py" + ] + ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64) [ # RuntimeWarning: invalid value encountered in cast "tests/lax_test.py" ]; @@ -147,7 +148,7 @@ buildPythonPackage rec { # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin passthru.tests = { test_cuda_jaxlibSource = callPackage ./test-cuda.nix { - jaxlib = jaxlib.override { cudaSupport = true; }; + jaxlib = jaxlib-build.override { cudaSupport = true; }; }; test_cuda_jaxlibBin = callPackage ./test-cuda.nix { jaxlib = jaxlib-bin.override { cudaSupport = true; }; @@ -158,7 +159,11 @@ buildPythonPackage rec { passthru.skipBulkUpdate = true; meta = with lib; { - description = "Differentiate, compile, and transform Numpy code"; + description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; + longDescription = '' + This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations, + e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`. + ''; homepage = "https://github.com/google/jax"; license = licenses.asl20; maintainers = with maintainers; [ samuela ]; |