about summary refs log tree commit diff
path: root/pkgs/development/python-modules/pot/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/pot/default.nix')
-rw-r--r--pkgs/development/python-modules/pot/default.nix78
1 files changed, 51 insertions, 27 deletions
diff --git a/pkgs/development/python-modules/pot/default.nix b/pkgs/development/python-modules/pot/default.nix
index 381bd2bb2193c..9d361a3ff11c0 100644
--- a/pkgs/development/python-modules/pot/default.nix
+++ b/pkgs/development/python-modules/pot/default.nix
@@ -3,25 +3,25 @@
   autograd,
   buildPythonPackage,
   fetchFromGitHub,
-  cupy,
   cvxopt,
   cython,
-  oldest-supported-numpy,
+  jax,
+  jaxlib,
   matplotlib,
   numpy,
-  tensorflow,
   pymanopt,
   pytestCheckHook,
   pythonOlder,
   scikit-learn,
   scipy,
-  enableDimensionalityReduction ? false,
-  enableGPU ? false,
+  setuptools,
+  tensorflow,
+  torch,
 }:
 
 buildPythonPackage rec {
   pname = "pot";
-  version = "0.9.3";
+  version = "0.9.4";
   pyproject = true;
 
   disabled = pythonOlder "3.6";
@@ -30,33 +30,56 @@ buildPythonPackage rec {
     owner = "PythonOT";
     repo = "POT";
     rev = "refs/tags/${version}";
-    hash = "sha256-fdqDM0V6zTFe1lcqi53ZZNHAfmuR2I7fdX4SN9qeNn8=";
+    hash = "sha256-Yx9hjniXebn7ZZeqou0JEsn2Yf9hyJSu/acDlM4kCCI=";
   };
 
-  nativeBuildInputs = [
+  build-system = [
+    setuptools
     cython
-    oldest-supported-numpy
+    numpy
   ];
 
-  propagatedBuildInputs =
-    [
-      numpy
-      scipy
-    ]
-    ++ lib.optionals enableGPU [ cupy ]
-    ++ lib.optionals enableDimensionalityReduction [
-      autograd
+  dependencies = [
+    numpy
+    scipy
+  ];
+
+  optional-dependencies = {
+    backend-numpy = [ ];
+    backend-jax = [
+      jax
+      jaxlib
+    ];
+    backend-cupy = [ ];
+    backend-tf = [ tensorflow ];
+    backend-torch = [ torch ];
+    cvxopt = [ cvxopt ];
+    dr = [
+      scikit-learn
       pymanopt
+      autograd
+    ];
+    gnn = [
+      torch
+      # torch-geometric
     ];
+    plot = [ matplotlib ];
+    all =
+      with optional-dependencies;
+      (
+        backend-numpy
+        ++ backend-jax
+        ++ backend-cupy
+        ++ backend-tf
+        ++ backend-torch
+        ++ optional-dependencies.cvxopt
+        ++ dr
+        ++ gnn
+        ++ plot
+      );
+  };
 
-  nativeCheckInputs = [
-    cvxopt
-    matplotlib
-    numpy
-    tensorflow
-    scikit-learn
-    pytestCheckHook
-  ];
+  nativeCheckInputs = [ pytestCheckHook ];
 
   postPatch = ''
     substituteInPlace setup.cfg \
@@ -64,6 +87,9 @@ buildPythonPackage rec {
       --replace " --durations=20" "" \
       --replace " --junit-xml=junit-results.xml" ""
 
+    substituteInPlace pyproject.toml \
+      --replace-fail "numpy>=2.0.0" "numpy"
+
     # we don't need setup.py to find the macos sdk for us
     sed -i '/sdk_path/d' setup.py
   '';
@@ -106,8 +132,6 @@ buildPythonPackage rec {
     "test_emd1d_device_tf"
   ];
 
-  disabledTestPaths = lib.optionals (!enableDimensionalityReduction) [ "test/test_dr.py" ];
-
   pythonImportsCheck = [
     "ot"
     "ot.lp"