about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxopt/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxopt/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxopt/default.nix65
1 files changed, 40 insertions, 25 deletions
diff --git a/pkgs/development/python-modules/jaxopt/default.nix b/pkgs/development/python-modules/jaxopt/default.nix
index af924cea5ab22..83a847b5a2f45 100644
--- a/pkgs/development/python-modules/jaxopt/default.nix
+++ b/pkgs/development/python-modules/jaxopt/default.nix
@@ -1,26 +1,28 @@
-{ lib
-, stdenv
-, buildPythonPackage
-, pythonOlder
-, fetchFromGitHub
-, fetchpatch
-, pytest-xdist
-, pytestCheckHook
-, absl-py
-, cvxpy
-, jax
-, jaxlib
-, matplotlib
-, numpy
-, optax
-, scipy
-, scikit-learn
+{
+  lib,
+  stdenv,
+  buildPythonPackage,
+  pythonOlder,
+  fetchFromGitHub,
+  fetchpatch,
+  pytest-xdist,
+  pytestCheckHook,
+  setuptools,
+  absl-py,
+  cvxpy,
+  jax,
+  jaxlib,
+  matplotlib,
+  numpy,
+  optax,
+  scipy,
+  scikit-learn,
 }:
 
 buildPythonPackage rec {
   pname = "jaxopt";
   version = "0.8.3";
-  format = "setuptools";
+  pyproject = true;
 
   disabled = pythonOlder "3.8";
 
@@ -41,7 +43,9 @@ buildPythonPackage rec {
     })
   ];
 
-  propagatedBuildInputs = [
+  build-system = [ setuptools ];
+
+  dependencies = [
     absl-py
     jax
     jaxlib
@@ -66,12 +70,23 @@ buildPythonPackage rec {
     "jaxopt.tree_util"
   ];
 
-  disabledTests = lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
-    # https://github.com/google/jaxopt/issues/577
-    "test_binary_logit_log_likelihood"
-    "test_solve_sparse"
-    "test_logreg_with_intercept_manual_loop3"
-  ];
+  disabledTests =
+    [
+      # https://github.com/google/jaxopt/issues/592
+      "test_solve_sparse"
+    ]
+    ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
+      # https://github.com/google/jaxopt/issues/577
+      "test_binary_logit_log_likelihood"
+      "test_solve_sparse"
+      "test_logreg_with_intercept_manual_loop3"
+
+      # https://github.com/google/jaxopt/issues/593
+      # Makes the test suite crash
+      "test_dtype_consistency"
+      # AssertionError: Array(0.01411963, dtype=float32) not less than or equal to 0.01
+      "test_multiclass_logreg6"
+    ];
 
   meta = with lib; {
     homepage = "https://jaxopt.github.io";