diff options
Diffstat (limited to 'pkgs/development/python-modules/flax/default.nix')
-rw-r--r-- | pkgs/development/python-modules/flax/default.nix | 67 |
1 files changed, 33 insertions, 34 deletions
diff --git a/pkgs/development/python-modules/flax/default.nix b/pkgs/development/python-modules/flax/default.nix index ce41f8e561394..7c443368bd41f 100644 --- a/pkgs/development/python-modules/flax/default.nix +++ b/pkgs/development/python-modules/flax/default.nix @@ -1,31 +1,32 @@ -{ lib -, buildPythonPackage -, cloudpickle -, einops -, fetchFromGitHub -, jax -, jaxlib -, keras -, matplotlib -, msgpack -, numpy -, optax -, orbax-checkpoint -, pytest-xdist -, pytestCheckHook -, pythonOlder -, pythonRelaxDepsHook -, pyyaml -, rich -, setuptools-scm -, tensorflow -, tensorstore -, typing-extensions +{ + lib, + buildPythonPackage, + cloudpickle, + einops, + fetchFromGitHub, + jax, + jaxlib, + keras, + matplotlib, + msgpack, + numpy, + optax, + orbax-checkpoint, + pytest-xdist, + pytestCheckHook, + pythonOlder, + pythonRelaxDepsHook, + pyyaml, + rich, + setuptools-scm, + tensorflow, + tensorstore, + typing-extensions, }: buildPythonPackage rec { pname = "flax"; - version = "0.8.2"; + version = "0.8.4"; pyproject = true; disabled = pythonOlder "3.9"; @@ -34,16 +35,16 @@ buildPythonPackage rec { owner = "google"; repo = "flax"; rev = "refs/tags/v${version}"; - hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g="; + hash = "sha256-ZwqKZdJ9LOfWTav5nE9xMsMw/DbryqQUuu5fqeugBzY="; }; - nativeBuildInputs = [ + build-system = [ jaxlib pythonRelaxDepsHook setuptools-scm ]; - propagatedBuildInputs = [ + dependencies = [ jax msgpack numpy @@ -59,9 +60,7 @@ buildPythonPackage rec { all = [ matplotlib ]; }; - pythonImportsCheck = [ - "flax" - ]; + pythonImportsCheck = [ "flax" ]; nativeCheckInputs = [ cloudpickle @@ -87,7 +86,7 @@ buildPythonPackage rec { # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them # would be limited anyway. "examples/*" - "flax/experimental/nnx/examples/*" + "flax/nnx/examples/*" # See https://github.com/google/flax/issues/3232. "tests/jax_utils_test.py" # Requires tree @@ -99,11 +98,11 @@ buildPythonPackage rec { "test_overwrite_checkpoints0" ]; - meta = with lib; { + meta = { description = "Neural network library for JAX"; homepage = "https://github.com/google/flax"; changelog = "https://github.com/google/flax/releases/tag/v${version}"; - license = licenses.asl20; - maintainers = with maintainers; [ ndl ]; + license = lib.licenses.asl20; + maintainers = with lib.maintainers; [ ndl ]; }; } |