about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jax/default.nix')
-rw-r--r--pkgs/development/python-modules/jax/default.nix7
1 files changed, 5 insertions, 2 deletions
diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix
index 574341f216abd..c80629960d33f 100644
--- a/pkgs/development/python-modules/jax/default.nix
+++ b/pkgs/development/python-modules/jax/default.nix
@@ -6,6 +6,7 @@
 , numpy
 , opt-einsum
 , pytestCheckHook
+, pytest-xdist
 , pythonOlder
 , scipy
 , typing-extensions
@@ -13,7 +14,7 @@
 
 buildPythonPackage rec {
   pname = "jax";
-  version = "0.2.26";
+  version = "0.3.0";
   format = "setuptools";
 
   disabled = pythonOlder "3.7";
@@ -22,7 +23,7 @@ buildPythonPackage rec {
     owner = "google";
     repo = pname;
     rev = "${pname}-v${version}";
-    sha256 = "155hhwgq6axdrj4x4hw72322qv1wc068n4cv4z2vf5jpl05fg93g";
+    sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72";
   };
 
   patches = [
@@ -45,6 +46,7 @@ buildPythonPackage rec {
   checkInputs = [
     jaxlib
     pytestCheckHook
+    pytest-xdist
   ];
 
   # NOTE: Don't run the tests in the expiremental directory as they require flax
@@ -52,6 +54,7 @@ buildPythonPackage rec {
   # Not a big deal, this is how the JAX docs suggest running the test suite
   # anyhow.
   pytestFlagsArray = [
+    "-n auto"
     "-W ignore::DeprecationWarning"
     "tests/"
   ];