summary refs log tree commit diff
path: root/pkgs/development/python-modules/dalle-mini/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/dalle-mini/default.nix')
-rw-r--r--pkgs/development/python-modules/dalle-mini/default.nix16
1 files changed, 10 insertions, 6 deletions
diff --git a/pkgs/development/python-modules/dalle-mini/default.nix b/pkgs/development/python-modules/dalle-mini/default.nix
index e50249dc7dd9a..0c768ba5dbe14 100644
--- a/pkgs/development/python-modules/dalle-mini/default.nix
+++ b/pkgs/development/python-modules/dalle-mini/default.nix
@@ -1,6 +1,7 @@
 { lib
 , buildPythonPackage
 , fetchPypi
+, fetchpatch
 , einops
 , emoji
 , flax
@@ -16,16 +17,20 @@
 buildPythonPackage rec {
   pname = "dalle-mini";
   version = "0.1.5";
+  format = "setuptools";
 
   src = fetchPypi {
     inherit pname version;
     hash = "sha256-k4XILjNNz0FPcAzwPEeqe5Lj24S2Y139uc9o/1IUS1c=";
   };
 
-  format = "setuptools";
-
-  buildInputs = [
-    jaxlib
+  # Fix incompatibility with the latest JAX versions
+  # See https://github.com/borisdayma/dalle-mini/pull/338
+  patches = [
+    (fetchpatch {
+      url = "https://github.com/borisdayma/dalle-mini/pull/338/commits/22ffccf03f3e207731a481e3e42bdb564ceebb69.patch";
+      hash = "sha256-LIOyfeq/oVYukG+1rfy5PjjsJcjADCjn18x/hVmLkPY=";
+    })
   ];
 
   propagatedBuildInputs = [
@@ -34,6 +39,7 @@ buildPythonPackage rec {
     flax
     ftfy
     jax
+    jaxlib
     pillow
     transformers
     unidecode
@@ -49,7 +55,5 @@ buildPythonPackage rec {
     homepage = "https://github.com/borisdayma/dalle-mini";
     license = licenses.asl20;
     maintainers = with maintainers; [ r-burns ];
-    # incompatible with recent versions of JAX
-    broken = true;
   };
 }