about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxopt/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxopt/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxopt/default.nix20
1 files changed, 17 insertions, 3 deletions
diff --git a/pkgs/development/python-modules/jaxopt/default.nix b/pkgs/development/python-modules/jaxopt/default.nix
index af924cea5ab22..1216b15f83f2e 100644
--- a/pkgs/development/python-modules/jaxopt/default.nix
+++ b/pkgs/development/python-modules/jaxopt/default.nix
@@ -6,6 +6,7 @@
 , fetchpatch
 , pytest-xdist
 , pytestCheckHook
+, setuptools
 , absl-py
 , cvxpy
 , jax
@@ -20,7 +21,7 @@
 buildPythonPackage rec {
   pname = "jaxopt";
   version = "0.8.3";
-  format = "setuptools";
+  pyproject = true;
 
   disabled = pythonOlder "3.8";
 
@@ -41,7 +42,11 @@ buildPythonPackage rec {
     })
   ];
 
-  propagatedBuildInputs = [
+  build-system = [
+    setuptools
+  ];
+
+  dependencies = [
     absl-py
     jax
     jaxlib
@@ -66,11 +71,20 @@ buildPythonPackage rec {
     "jaxopt.tree_util"
   ];
 
-  disabledTests = lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
+  disabledTests = [
+    # https://github.com/google/jaxopt/issues/592
+    "test_solve_sparse"
+  ] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
     # https://github.com/google/jaxopt/issues/577
     "test_binary_logit_log_likelihood"
     "test_solve_sparse"
     "test_logreg_with_intercept_manual_loop3"
+
+    # https://github.com/google/jaxopt/issues/593
+    # Makes the test suite crash
+    "test_dtype_consistency"
+    # AssertionError: Array(0.01411963, dtype=float32) not less than or equal to 0.01
+    "test_multiclass_logreg6"
   ];
 
   meta = with lib; {