about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jax/default.nix')
-rw-r--r--pkgs/development/python-modules/jax/default.nix21
1 files changed, 13 insertions, 8 deletions
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix
index 95e85bf8e2b20..ec8c53daab464 100644
--- a/pkgs/development/python-modules/jax/default.nix
+++ b/pkgs/development/python-modules/jax/default.nix
@@ -8,6 +8,7 @@
   fetchFromGitHub,
   jaxlib,
   jaxlib-bin,
+  jaxlib-build,
   hypothesis,
   lapack,
   matplotlib,
@@ -23,10 +24,6 @@
 
 let
   usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
-  # jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
-  # fine. jaxlib is only used in the checkPhase, so switching backends does not
-  # impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
-  jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
 in
 buildPythonPackage rec {
   pname = "jax";
@@ -61,7 +58,7 @@ buildPythonPackage rec {
 
   nativeCheckInputs = [
     hypothesis
-    jaxlib'
+    jaxlib
     matplotlib
     pytestCheckHook
     pytest-xdist
@@ -130,7 +127,11 @@ buildPythonPackage rec {
       "testQdwhWithOnRankDeficientInput5"
     ];
 
-  disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
+  disabledTestPaths = [
+    # Segmentation fault. See https://gist.github.com/zimbatm/e9b61891f3bcf5e4aaefd13f94344fba
+    "tests/linalg_test.py"
+  ]
+  ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
     # RuntimeWarning: invalid value encountered in cast
     "tests/lax_test.py"
   ];
@@ -147,7 +148,7 @@ buildPythonPackage rec {
   #   NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin
   passthru.tests = {
     test_cuda_jaxlibSource = callPackage ./test-cuda.nix {
-      jaxlib = jaxlib.override { cudaSupport = true; };
+      jaxlib = jaxlib-build.override { cudaSupport = true; };
     };
     test_cuda_jaxlibBin = callPackage ./test-cuda.nix {
       jaxlib = jaxlib-bin.override { cudaSupport = true; };
@@ -158,7 +159,11 @@ buildPythonPackage rec {
   passthru.skipBulkUpdate = true;
 
   meta = with lib; {
-    description = "Differentiate, compile, and transform Numpy code";
+    description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
+    longDescription = ''
+      This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
+      e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
+    '';
     homepage = "https://github.com/google/jax";
     license = licenses.asl20;
     maintainers = with maintainers; [ samuela ];