diff options
Diffstat (limited to 'pkgs/development/python-modules/flax/default.nix')
-rw-r--r-- | pkgs/development/python-modules/flax/default.nix | 34 |
1 files changed, 12 insertions, 22 deletions
diff --git a/pkgs/development/python-modules/flax/default.nix b/pkgs/development/python-modules/flax/default.nix index b68f56a11f412..fa0f053f86de1 100644 --- a/pkgs/development/python-modules/flax/default.nix +++ b/pkgs/development/python-modules/flax/default.nix @@ -4,14 +4,17 @@ , jaxlib , pythonRelaxDepsHook , setuptools-scm -, cloudpickle , jax -, matplotlib , msgpack , numpy , optax +, pyyaml , rich , tensorstore +, typing-extensions +, matplotlib +, cloudpickle +, einops , keras , pytest-xdist , pytestCheckHook @@ -37,24 +40,27 @@ buildPythonPackage rec { ]; propagatedBuildInputs = [ - cloudpickle jax - matplotlib msgpack numpy optax + pyyaml rich tensorstore + typing-extensions ]; - # See https://github.com/google/flax/pull/2882. - pythonRemoveDeps = [ "orbax" ]; + passthru.optional-dependencies = { + all = [ matplotlib ]; + }; pythonImportsCheck = [ "flax" ]; nativeCheckInputs = [ + cloudpickle + einops keras pytest-xdist pytestCheckHook @@ -85,22 +91,6 @@ buildPythonPackage rec { "tests/checkpoints_test.py" ]; - disabledTests = [ - # See https://github.com/google/flax/issues/2554. - "test_async_save_checkpoints" - "test_jax_array0" - "test_jax_array1" - "test_keep0" - "test_keep1" - "test_optimized_lstm_cell_matches_regular" - "test_overwrite_checkpoints" - "test_save_restore_checkpoints_target_empty" - "test_save_restore_checkpoints_target_none" - "test_save_restore_checkpoints_target_singular" - "test_save_restore_checkpoints_w_float_steps" - "test_save_restore_checkpoints" - ]; - meta = with lib; { description = "Neural network library for JAX"; homepage = "https://github.com/google/flax"; |