about summary refs log tree commit diff
path: root/pkgs/development/python-modules/diffusers
diff options
context:
space:
mode:
authornatsukium <tomoya.otabi@gmail.com>2023-12-09 18:15:46 +0900
committernatsukium <tomoya.otabi@gmail.com>2023-12-10 11:42:09 +0900
commitbe762d1df86e5615ccd74bbb78f08707bbb1d4cf (patch)
tree57d10ebab5a1e9e2a6d83960ed37869ed2a7159c /pkgs/development/python-modules/diffusers
parentcbef97d927154c450621249eab0bcef4e1bf440c (diff)
python311Packages.diffusers: init at 0.24.0
Diffstat (limited to 'pkgs/development/python-modules/diffusers')
-rw-r--r--pkgs/development/python-modules/diffusers/default.nix153
1 files changed, 153 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/diffusers/default.nix b/pkgs/development/python-modules/diffusers/default.nix
new file mode 100644
index 0000000000000..3485f9e3351d8
--- /dev/null
+++ b/pkgs/development/python-modules/diffusers/default.nix
@@ -0,0 +1,153 @@
+{ lib
+, stdenv
+, buildPythonPackage
+, fetchFromGitHub
+, pythonOlder
+, writeText
+, setuptools
+, wheel
+, filelock
+, huggingface-hub
+, importlib-metadata
+, numpy
+, pillow
+, regex
+, requests
+, safetensors
+# optional dependencies
+, accelerate
+, datasets
+, flax
+, jax
+, jaxlib
+, jinja2
+, protobuf
+, tensorboard
+, torch
+# test dependencies
+, parameterized
+, pytest-timeout
+, pytest-xdist
+, pytestCheckHook
+, requests-mock
+, ruff
+, scipy
+, sentencepiece
+, torchsde
+, transformers
+}:
+
+buildPythonPackage rec {
+  pname = "diffusers";
+  version = "0.24.0";
+  pyproject = true;
+
+  disabled = pythonOlder "3.8";
+
+  src = fetchFromGitHub {
+    owner = "huggingface";
+    repo = "diffusers";
+    rev = "refs/tags/v${version}";
+    hash = "sha256-ccWF8hQzPhFY/kqRum2tbanI+cQiT25MmvPZN+hGadc=";
+  };
+
+  nativeBuildInputs = [
+    setuptools
+    wheel
+  ];
+
+  propagatedBuildInputs = [
+    filelock
+    huggingface-hub
+    importlib-metadata
+    numpy
+    pillow
+    regex
+    requests
+    safetensors
+  ];
+
+  passthru.optional-dependencies = {
+    flax = [
+      flax
+      jax
+      jaxlib
+    ];
+    torch = [
+      accelerate
+      torch
+    ];
+    training = [
+      accelerate
+      datasets
+      jinja2
+      protobuf
+      tensorboard
+    ];
+  };
+
+  pythonImportsCheck = [
+    "diffusers"
+  ];
+
+  # tests crash due to torch segmentation fault
+  doCheck = !(stdenv.isLinux && stdenv.isAarch64);
+
+  nativeCheckInputs = [
+    parameterized
+    pytest-timeout
+    pytest-xdist
+    pytestCheckHook
+    requests-mock
+    ruff
+    scipy
+    sentencepiece
+    torchsde
+    transformers
+  ] ++ passthru.optional-dependencies.torch;
+
+  preCheck = let
+    # This pytest hook mocks and catches attempts at accessing the network
+    # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
+    # cf. python3Packages.shap
+    conftestSkipNetworkErrors = writeText "conftest.py" ''
+      from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
+      import urllib3
+
+      class NetworkAccessDeniedError(RuntimeError): pass
+      def deny_network_access(*a, **kw):
+        raise NetworkAccessDeniedError
+
+      urllib3.connection.HTTPSConnection._new_conn = deny_network_access
+
+      def pytest_runtest_makereport(item, call):
+        tr = orig_pytest_runtest_makereport(item, call)
+        if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
+            tr.outcome = 'skipped'
+            tr.wasxfail = "reason: Requires network access."
+        return tr
+    '';
+  in ''
+    export HOME=$TMPDIR
+    cat ${conftestSkipNetworkErrors} >> tests/conftest.py
+  '';
+
+  pytestFlagsArray = [
+    "tests/"
+  ];
+
+  disabledTests = [
+    # depends on current working directory
+    "test_deprecate_stacklevel"
+    # fails due to precision of floating point numbers
+    "test_model_cpu_offload_forward_pass"
+  ];
+
+  meta = with lib; {
+    description = "State-of-the-art diffusion models for image and audio generation in PyTorch";
+    homepage = "https://github.com/huggingface/diffusers";
+    changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.rev}";
+    license = licenses.asl20;
+    maintainers = with maintainers; [ natsukium ];
+  };
+}