diff options
Diffstat (limited to 'pkgs/development/python-modules/optax/default.nix')
-rw-r--r-- | pkgs/development/python-modules/optax/default.nix | 29 |
1 files changed, 19 insertions, 10 deletions
diff --git a/pkgs/development/python-modules/optax/default.nix b/pkgs/development/python-modules/optax/default.nix index 14082067893a3..345b02ec26472 100644 --- a/pkgs/development/python-modules/optax/default.nix +++ b/pkgs/development/python-modules/optax/default.nix @@ -1,19 +1,27 @@ { lib, - absl-py, buildPythonPackage, + pythonOlder, + fetchFromGitHub, + + # build-system flit-core, + + # dependencies + absl-py, chex, - fetchFromGitHub, + jax, jaxlib, numpy, + etils, + + # checks callPackage, - pythonOlder, }: buildPythonPackage rec { pname = "optax"; - version = "0.2.2"; + version = "0.2.3"; pyproject = true; disabled = pythonOlder "3.9"; @@ -22,7 +30,7 @@ buildPythonPackage rec { owner = "deepmind"; repo = "optax"; rev = "refs/tags/v${version}"; - hash = "sha256-sBiKUuQR89mttc9Njrh1aeUJOYdlcF7Nlj3/+Y7OMb4="; + hash = "sha256-D1qKei3IjDP9fC62hf6fNtvHlnn09O/dKuzTBdLwW64="; }; outputs = [ @@ -30,15 +38,16 @@ buildPythonPackage rec { "testsout" ]; - nativeBuildInputs = [ flit-core ]; + build-system = [ flit-core ]; - buildInputs = [ jaxlib ]; - - propagatedBuildInputs = [ + dependencies = [ absl-py chex + etils + jax + jaxlib numpy - ]; + ] ++ etils.optional-dependencies.epy; postInstall = '' mkdir $testsout |