about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFabian Affolter <fabian@affolter-engineering.ch>2024-02-25 22:11:29 +0100
committerGitHub <noreply@github.com>2024-02-25 22:11:29 +0100
commit9b640555cf4c8ca756775312e612da431e09bfb6 (patch)
tree70b4645a5aaaf8402f93f58f21a66bf72fa16300
parenta54887d53cec025eebd064dff67910fe55ee1282 (diff)
parent3ec2275f08950e9e51fe1b26224bfb2932df5ee4 (diff)
Merge pull request #291171 from fabaff/orbax-checkpoint
python311Packages.orbax-checkpoint: init at 0.5.3
-rw-r--r--pkgs/development/python-modules/orbax-checkpoint/default.nix78
-rw-r--r--pkgs/top-level/python-packages.nix2
2 files changed, 80 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/orbax-checkpoint/default.nix b/pkgs/development/python-modules/orbax-checkpoint/default.nix
new file mode 100644
index 0000000000000..0f9d467335ce2
--- /dev/null
+++ b/pkgs/development/python-modules/orbax-checkpoint/default.nix
@@ -0,0 +1,78 @@
+{ lib
+, absl-py
+, buildPythonPackage
+, cached-property
+, etils
+, fetchPypi
+, flit-core
+, importlib-resources
+, jax
+, jaxlib
+, msgpack
+, nest-asyncio
+, numpy
+, protobuf
+, pytest-xdist
+, pytestCheckHook
+, pythonOlder
+, pyyaml
+, tensorstore
+, typing-extensions
+}:
+
+buildPythonPackage rec {
+  pname = "orbax-checkpoint";
+  version = "0.5.3";
+  pyproject = true;
+
+  disabled = pythonOlder "3.9";
+
+  src = fetchPypi {
+    pname = "orbax_checkpoint";
+    inherit version;
+    hash = "sha256-FXKQTLv+hROSfg2A+AtzDg7y9oAzLTwoENhENTKTi0U=";
+  };
+
+  nativeBuildInputs = [
+    flit-core
+  ];
+
+  propagatedBuildInputs = [
+    absl-py
+    cached-property
+    etils
+    importlib-resources
+    jax
+    jaxlib
+    msgpack
+    nest-asyncio
+    numpy
+    protobuf
+    pyyaml
+    tensorstore
+    typing-extensions
+  ];
+
+  nativeCheckInputs = [
+    pytest-xdist
+    pytestCheckHook
+  ];
+
+  pythonImportsCheck = [
+    "orbax"
+  ];
+
+  disabledTestPaths = [
+    # Circular dependency flax
+    "orbax/checkpoint/transform_utils_test.py"
+    "orbax/checkpoint/utils_test.py"
+  ];
+
+  meta = with lib; {
+    description = "Orbax provides common utility libraries for JAX users";
+    homepage = "https://github.com/google/orbax/tree/main/checkpoint";
+    changelog = "https://github.com/google/orbax/blob/${version}/CHANGELOG.md";
+    license = licenses.asl20;
+    maintainers = with maintainers; [fab ];
+  };
+}
diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix
index 214fd82a327e8..3bd7e9174ad67 100644
--- a/pkgs/top-level/python-packages.nix
+++ b/pkgs/top-level/python-packages.nix
@@ -8952,6 +8952,8 @@ self: super: with self; {
 
   oras = callPackage ../development/python-modules/oras { };
 
+  orbax-checkpoint = callPackage ../development/python-modules/orbax-checkpoint { };
+
   orderedmultidict = callPackage ../development/python-modules/orderedmultidict { };
 
   ordered-set = callPackage ../development/python-modules/ordered-set { };