summary refs log tree commit diff
path: root/pkgs/development/python-modules/flax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/flax/default.nix')
-rw-r--r--pkgs/development/python-modules/flax/default.nix34
1 files changed, 12 insertions, 22 deletions
diff --git a/pkgs/development/python-modules/flax/default.nix b/pkgs/development/python-modules/flax/default.nix
index b68f56a11f412..fa0f053f86de1 100644
--- a/pkgs/development/python-modules/flax/default.nix
+++ b/pkgs/development/python-modules/flax/default.nix
@@ -4,14 +4,17 @@
 , jaxlib
 , pythonRelaxDepsHook
 , setuptools-scm
-, cloudpickle
 , jax
-, matplotlib
 , msgpack
 , numpy
 , optax
+, pyyaml
 , rich
 , tensorstore
+, typing-extensions
+, matplotlib
+, cloudpickle
+, einops
 , keras
 , pytest-xdist
 , pytestCheckHook
@@ -37,24 +40,27 @@ buildPythonPackage rec {
   ];
 
   propagatedBuildInputs = [
-    cloudpickle
     jax
-    matplotlib
     msgpack
     numpy
     optax
+    pyyaml
     rich
     tensorstore
+    typing-extensions
   ];
 
-  # See https://github.com/google/flax/pull/2882.
-  pythonRemoveDeps = [ "orbax" ];
+  passthru.optional-dependencies = {
+    all = [ matplotlib ];
+  };
 
   pythonImportsCheck = [
     "flax"
   ];
 
   nativeCheckInputs = [
+    cloudpickle
+    einops
     keras
     pytest-xdist
     pytestCheckHook
@@ -85,22 +91,6 @@ buildPythonPackage rec {
     "tests/checkpoints_test.py"
   ];
 
-  disabledTests = [
-    # See https://github.com/google/flax/issues/2554.
-    "test_async_save_checkpoints"
-    "test_jax_array0"
-    "test_jax_array1"
-    "test_keep0"
-    "test_keep1"
-    "test_optimized_lstm_cell_matches_regular"
-    "test_overwrite_checkpoints"
-    "test_save_restore_checkpoints_target_empty"
-    "test_save_restore_checkpoints_target_none"
-    "test_save_restore_checkpoints_target_singular"
-    "test_save_restore_checkpoints_w_float_steps"
-    "test_save_restore_checkpoints"
-  ];
-
   meta = with lib; {
     description = "Neural network library for JAX";
     homepage = "https://github.com/google/flax";