about summary refs log tree commit diff
path: root/pkgs/development/python-modules/diffusers/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/diffusers/default.nix')
-rw-r--r--pkgs/development/python-modules/diffusers/default.nix174
1 files changed, 89 insertions, 85 deletions
diff --git a/pkgs/development/python-modules/diffusers/default.nix b/pkgs/development/python-modules/diffusers/default.nix
index 39464efe47fdb..23580b69ced68 100644
--- a/pkgs/development/python-modules/diffusers/default.nix
+++ b/pkgs/development/python-modules/diffusers/default.nix
@@ -1,45 +1,46 @@
-{ lib
-, stdenv
-, buildPythonPackage
-, fetchFromGitHub
-, pythonOlder
-, writeText
-, setuptools
-, wheel
-, filelock
-, huggingface-hub
-, importlib-metadata
-, numpy
-, pillow
-, regex
-, requests
-, safetensors
-# optional dependencies
-, accelerate
-, datasets
-, flax
-, jax
-, jaxlib
-, jinja2
-, peft
-, protobuf
-, tensorboard
-, torch
-# test dependencies
-, parameterized
-, pytest-timeout
-, pytest-xdist
-, pytestCheckHook
-, requests-mock
-, scipy
-, sentencepiece
-, torchsde
-, transformers
+{
+  lib,
+  buildPythonPackage,
+  pythonOlder,
+  fetchFromGitHub,
+  writeText,
+  setuptools,
+  filelock,
+  huggingface-hub,
+  importlib-metadata,
+  numpy,
+  pillow,
+  regex,
+  requests,
+  safetensors,
+  # optional dependencies
+  accelerate,
+  datasets,
+  flax,
+  jax,
+  jaxlib,
+  jinja2,
+  peft,
+  protobuf,
+  tensorboard,
+  torch,
+  # test dependencies
+  parameterized,
+  pytest-timeout,
+  pytest-xdist,
+  pytestCheckHook,
+  requests-mock,
+  scipy,
+  sentencepiece,
+  torchsde,
+  transformers,
+  pythonAtLeast,
+  diffusers,
 }:
 
 buildPythonPackage rec {
   pname = "diffusers";
-  version = "0.27.2";
+  version = "0.28.2";
   pyproject = true;
 
   disabled = pythonOlder "3.8";
@@ -48,15 +49,12 @@ buildPythonPackage rec {
     owner = "huggingface";
     repo = "diffusers";
     rev = "refs/tags/v${version}";
-    hash = "sha256-aRnbU3jN40xaCsoMFyRt1XB+hyIYMJP2b/T1yZho90c=";
+    hash = "sha256-q1Y7YJSTVkPZF7KeHdOwO7XgTDBvFGioLR57adc1P+o=";
   };
 
-  nativeBuildInputs = [
-    setuptools
-    wheel
-  ];
+  build-system = [ setuptools ];
 
-  propagatedBuildInputs = [
+  dependencies = [
     filelock
     huggingface-hub
     importlib-metadata
@@ -87,12 +85,12 @@ buildPythonPackage rec {
     ];
   };
 
-  pythonImportsCheck = [
-    "diffusers"
-  ];
+  pythonImportsCheck = [ "diffusers" ];
 
-  # tests crash due to torch segmentation fault
-  doCheck = !(stdenv.isLinux && stdenv.isAarch64);
+  # it takes a few hours
+  doCheck = false;
+
+  passthru.tests.pytest = diffusers.overridePythonAttrs { doCheck = true; };
 
   nativeCheckInputs = [
     parameterized
@@ -106,44 +104,50 @@ buildPythonPackage rec {
     transformers
   ] ++ passthru.optional-dependencies.torch;
 
-  preCheck = let
-    # This pytest hook mocks and catches attempts at accessing the network
-    # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
-    # cf. python3Packages.shap
-    conftestSkipNetworkErrors = writeText "conftest.py" ''
-      from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
-      import urllib3
-
-      class NetworkAccessDeniedError(RuntimeError): pass
-      def deny_network_access(*a, **kw):
-        raise NetworkAccessDeniedError
-
-      urllib3.connection.HTTPSConnection._new_conn = deny_network_access
-
-      def pytest_runtest_makereport(item, call):
-        tr = orig_pytest_runtest_makereport(item, call)
-        if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
-            tr.outcome = 'skipped'
-            tr.wasxfail = "reason: Requires network access."
-        return tr
+  preCheck =
+    let
+      # This pytest hook mocks and catches attempts at accessing the network
+      # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
+      # cf. python3Packages.shap
+      conftestSkipNetworkErrors = writeText "conftest.py" ''
+        from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
+        import urllib3
+
+        class NetworkAccessDeniedError(RuntimeError): pass
+        def deny_network_access(*a, **kw):
+          raise NetworkAccessDeniedError
+
+        urllib3.connection.HTTPSConnection._new_conn = deny_network_access
+
+        def pytest_runtest_makereport(item, call):
+          tr = orig_pytest_runtest_makereport(item, call)
+          if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
+              tr.outcome = 'skipped'
+              tr.wasxfail = "reason: Requires network access."
+          return tr
+      '';
+    in
+    ''
+      export HOME=$TMPDIR
+      cat ${conftestSkipNetworkErrors} >> tests/conftest.py
     '';
-  in ''
-    export HOME=$TMPDIR
-    cat ${conftestSkipNetworkErrors} >> tests/conftest.py
-  '';
 
-  pytestFlagsArray = [
-    "tests/"
-  ];
-
-  disabledTests = [
-    # depends on current working directory
-    "test_deprecate_stacklevel"
-    # fails due to precision of floating point numbers
-    "test_model_cpu_offload_forward_pass"
-    # tries to run ruff which we have intentionally removed from nativeCheckInputs
-    "test_is_copy_consistent"
-  ];
+  pytestFlagsArray = [ "tests/" ];
+
+  disabledTests =
+    [
+      # depends on current working directory
+      "test_deprecate_stacklevel"
+      # fails due to precision of floating point numbers
+      "test_model_cpu_offload_forward_pass"
+      # tries to run ruff which we have intentionally removed from nativeCheckInputs
+      "test_is_copy_consistent"
+    ]
+    ++ lib.optionals (pythonAtLeast "3.12") [
+
+      # RuntimeError: Dynamo is not supported on Python 3.12+
+      "test_from_save_pretrained_dynamo"
+    ];
 
   meta = with lib; {
     description = "State-of-the-art diffusion models for image and audio generation in PyTorch";