about summary refs log tree commit diff
path: root/pkgs/development/python-modules/diffusers/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/diffusers/default.nix')
-rw-r--r--pkgs/development/python-modules/diffusers/default.nix179
1 files changed, 101 insertions, 78 deletions
diff --git a/pkgs/development/python-modules/diffusers/default.nix b/pkgs/development/python-modules/diffusers/default.nix
index 39464efe47fdb..8762022b06511 100644
--- a/pkgs/development/python-modules/diffusers/default.nix
+++ b/pkgs/development/python-modules/diffusers/default.nix
@@ -1,40 +1,43 @@
-{ lib
-, stdenv
-, buildPythonPackage
-, fetchFromGitHub
-, pythonOlder
-, writeText
-, setuptools
-, wheel
-, filelock
-, huggingface-hub
-, importlib-metadata
-, numpy
-, pillow
-, regex
-, requests
-, safetensors
-# optional dependencies
-, accelerate
-, datasets
-, flax
-, jax
-, jaxlib
-, jinja2
-, peft
-, protobuf
-, tensorboard
-, torch
-# test dependencies
-, parameterized
-, pytest-timeout
-, pytest-xdist
-, pytestCheckHook
-, requests-mock
-, scipy
-, sentencepiece
-, torchsde
-, transformers
+{
+  lib,
+  stdenv,
+  buildPythonPackage,
+  pythonOlder,
+  fetchFromGitHub,
+  fetchpatch,
+  writeText,
+  setuptools,
+  wheel,
+  filelock,
+  huggingface-hub,
+  importlib-metadata,
+  numpy,
+  pillow,
+  regex,
+  requests,
+  safetensors,
+  # optional dependencies
+  accelerate,
+  datasets,
+  flax,
+  jax,
+  jaxlib,
+  jinja2,
+  peft,
+  protobuf,
+  tensorboard,
+  torch,
+  # test dependencies
+  parameterized,
+  pytest-timeout,
+  pytest-xdist,
+  pytestCheckHook,
+  requests-mock,
+  scipy,
+  sentencepiece,
+  torchsde,
+  transformers,
+  pythonAtLeast,
 }:
 
 buildPythonPackage rec {
@@ -51,12 +54,28 @@ buildPythonPackage rec {
     hash = "sha256-aRnbU3jN40xaCsoMFyRt1XB+hyIYMJP2b/T1yZho90c=";
   };
 
-  nativeBuildInputs = [
+  patches = [
+    # fix python3.12 build
+    (fetchpatch {
+      # https://github.com/huggingface/diffusers/pull/7455
+      name = "001-remove-distutils.patch";
+      url = "https://github.com/huggingface/diffusers/compare/363699044e365ef977a7646b500402fa585e1b6b...3c67864c5acb30413911730b1ed4a9ad47c0a15c.patch";
+      hash = "sha256-Qyvyp1GyTVXN+A+lA1r2hf887ubTtaUknbKd4r46NZQ=";
+    })
+    (fetchpatch {
+      # https://github.com/huggingface/diffusers/pull/7461
+      name = "002-fix-removed-distutils.patch";
+      url = "https://github.com/huggingface/diffusers/commit/efbbbc38e436a1abb1df41a6eccfd6f9f0333f97.patch";
+      hash = "sha256-scdtpX1RYFFEDHcaMb+gDZSsPafkvnIO/wQlpzrQhLA=";
+    })
+  ];
+
+  build-system = [
     setuptools
     wheel
   ];
 
-  propagatedBuildInputs = [
+  dependencies = [
     filelock
     huggingface-hub
     importlib-metadata
@@ -87,9 +106,7 @@ buildPythonPackage rec {
     ];
   };
 
-  pythonImportsCheck = [
-    "diffusers"
-  ];
+  pythonImportsCheck = [ "diffusers" ];
 
   # tests crash due to torch segmentation fault
   doCheck = !(stdenv.isLinux && stdenv.isAarch64);
@@ -106,44 +123,50 @@ buildPythonPackage rec {
     transformers
   ] ++ passthru.optional-dependencies.torch;
 
-  preCheck = let
-    # This pytest hook mocks and catches attempts at accessing the network
-    # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
-    # cf. python3Packages.shap
-    conftestSkipNetworkErrors = writeText "conftest.py" ''
-      from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
-      import urllib3
-
-      class NetworkAccessDeniedError(RuntimeError): pass
-      def deny_network_access(*a, **kw):
-        raise NetworkAccessDeniedError
-
-      urllib3.connection.HTTPSConnection._new_conn = deny_network_access
-
-      def pytest_runtest_makereport(item, call):
-        tr = orig_pytest_runtest_makereport(item, call)
-        if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
-            tr.outcome = 'skipped'
-            tr.wasxfail = "reason: Requires network access."
-        return tr
+  preCheck =
+    let
+      # This pytest hook mocks and catches attempts at accessing the network
+      # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
+      # cf. python3Packages.shap
+      conftestSkipNetworkErrors = writeText "conftest.py" ''
+        from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
+        import urllib3
+
+        class NetworkAccessDeniedError(RuntimeError): pass
+        def deny_network_access(*a, **kw):
+          raise NetworkAccessDeniedError
+
+        urllib3.connection.HTTPSConnection._new_conn = deny_network_access
+
+        def pytest_runtest_makereport(item, call):
+          tr = orig_pytest_runtest_makereport(item, call)
+          if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
+              tr.outcome = 'skipped'
+              tr.wasxfail = "reason: Requires network access."
+          return tr
+      '';
+    in
+    ''
+      export HOME=$TMPDIR
+      cat ${conftestSkipNetworkErrors} >> tests/conftest.py
     '';
-  in ''
-    export HOME=$TMPDIR
-    cat ${conftestSkipNetworkErrors} >> tests/conftest.py
-  '';
 
-  pytestFlagsArray = [
-    "tests/"
-  ];
-
-  disabledTests = [
-    # depends on current working directory
-    "test_deprecate_stacklevel"
-    # fails due to precision of floating point numbers
-    "test_model_cpu_offload_forward_pass"
-    # tries to run ruff which we have intentionally removed from nativeCheckInputs
-    "test_is_copy_consistent"
-  ];
+  pytestFlagsArray = [ "tests/" ];
+
+  disabledTests =
+    [
+      # depends on current working directory
+      "test_deprecate_stacklevel"
+      # fails due to precision of floating point numbers
+      "test_model_cpu_offload_forward_pass"
+      # tries to run ruff which we have intentionally removed from nativeCheckInputs
+      "test_is_copy_consistent"
+    ]
+    ++ lib.optionals (pythonAtLeast "3.12") [
+
+      # RuntimeError: Dynamo is not supported on Python 3.12+
+      "test_from_save_pretrained_dynamo"
+    ];
 
   meta = with lib; {
     description = "State-of-the-art diffusion models for image and audio generation in PyTorch";