about summary refs log tree commit diff
path: root/pkgs/development/python-modules/flax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/flax/default.nix')
-rw-r--r--pkgs/development/python-modules/flax/default.nix67
1 files changed, 33 insertions, 34 deletions
diff --git a/pkgs/development/python-modules/flax/default.nix b/pkgs/development/python-modules/flax/default.nix
index ce41f8e561394..7c443368bd41f 100644
--- a/pkgs/development/python-modules/flax/default.nix
+++ b/pkgs/development/python-modules/flax/default.nix
@@ -1,31 +1,32 @@
-{ lib
-, buildPythonPackage
-, cloudpickle
-, einops
-, fetchFromGitHub
-, jax
-, jaxlib
-, keras
-, matplotlib
-, msgpack
-, numpy
-, optax
-, orbax-checkpoint
-, pytest-xdist
-, pytestCheckHook
-, pythonOlder
-, pythonRelaxDepsHook
-, pyyaml
-, rich
-, setuptools-scm
-, tensorflow
-, tensorstore
-, typing-extensions
+{
+  lib,
+  buildPythonPackage,
+  cloudpickle,
+  einops,
+  fetchFromGitHub,
+  jax,
+  jaxlib,
+  keras,
+  matplotlib,
+  msgpack,
+  numpy,
+  optax,
+  orbax-checkpoint,
+  pytest-xdist,
+  pytestCheckHook,
+  pythonOlder,
+  pythonRelaxDepsHook,
+  pyyaml,
+  rich,
+  setuptools-scm,
+  tensorflow,
+  tensorstore,
+  typing-extensions,
 }:
 
 buildPythonPackage rec {
   pname = "flax";
-  version = "0.8.2";
+  version = "0.8.4";
   pyproject = true;
 
   disabled = pythonOlder "3.9";
@@ -34,16 +35,16 @@ buildPythonPackage rec {
     owner = "google";
     repo = "flax";
     rev = "refs/tags/v${version}";
-    hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
+    hash = "sha256-ZwqKZdJ9LOfWTav5nE9xMsMw/DbryqQUuu5fqeugBzY=";
   };
 
-  nativeBuildInputs = [
+  build-system = [
     jaxlib
     pythonRelaxDepsHook
     setuptools-scm
   ];
 
-  propagatedBuildInputs = [
+  dependencies = [
     jax
     msgpack
     numpy
@@ -59,9 +60,7 @@ buildPythonPackage rec {
     all = [ matplotlib ];
   };
 
-  pythonImportsCheck = [
-    "flax"
-  ];
+  pythonImportsCheck = [ "flax" ];
 
   nativeCheckInputs = [
     cloudpickle
@@ -87,7 +86,7 @@ buildPythonPackage rec {
     # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
     # would be limited anyway.
     "examples/*"
-    "flax/experimental/nnx/examples/*"
+    "flax/nnx/examples/*"
     # See https://github.com/google/flax/issues/3232.
     "tests/jax_utils_test.py"
     # Requires tree
@@ -99,11 +98,11 @@ buildPythonPackage rec {
     "test_overwrite_checkpoints0"
   ];
 
-  meta = with lib; {
+  meta = {
     description = "Neural network library for JAX";
     homepage = "https://github.com/google/flax";
     changelog = "https://github.com/google/flax/releases/tag/v${version}";
-    license = licenses.asl20;
-    maintainers = with maintainers; [ ndl ];
+    license = lib.licenses.asl20;
+    maintainers = with lib.maintainers; [ ndl ];
   };
 }