about summary refs log tree commit diff
path: root/pkgs/development/python-modules/equinox/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/equinox/default.nix')
-rw-r--r--pkgs/development/python-modules/equinox/default.nix50
1 files changed, 31 insertions, 19 deletions
diff --git a/pkgs/development/python-modules/equinox/default.nix b/pkgs/development/python-modules/equinox/default.nix
index 976406b399b1e..d89a041be864f 100644
--- a/pkgs/development/python-modules/equinox/default.nix
+++ b/pkgs/development/python-modules/equinox/default.nix
@@ -1,16 +1,17 @@
-{ lib
-, buildPythonPackage
-, pythonOlder
-, fetchFromGitHub
-, hatchling
-, jax
-, jaxlib
-, jaxtyping
-, typing-extensions
-, beartype
-, optax
-, pytest-xdist
-, pytestCheckHook
+{
+  lib,
+  buildPythonPackage,
+  pythonOlder,
+  fetchFromGitHub,
+  hatchling,
+  jax,
+  jaxlib,
+  jaxtyping,
+  typing-extensions,
+  beartype,
+  optax,
+  pytest-xdist,
+  pytestCheckHook,
 }:
 
 buildPythonPackage rec {
@@ -27,9 +28,7 @@ buildPythonPackage rec {
     hash = "sha256-3OwHND1YEdg/SppqiB7pCdp6v+lYwTbtX07tmyEMWDo=";
   };
 
-  nativeBuildInputs = [
-    hatchling
-  ];
+  nativeBuildInputs = [ hatchling ];
 
   propagatedBuildInputs = [
     jax
@@ -48,12 +47,25 @@ buildPythonPackage rec {
   pythonImportsCheck = [ "equinox" ];
 
   disabledTests = [
-    # Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
-    "test_tracetime"
+    # For simplicity, JAX has removed its internal frames from the traceback of the following exception.
+    # https://github.com/patrick-kidger/equinox/issues/716
+    "test_abstract"
+    "test_complicated"
+    "test_grad"
+    "test_jvp"
+    "test_mlp"
+    "test_num_traces"
+    "test_pytree_in"
+    "test_simple"
+    "test_vmap"
+
+    # AssertionError: assert 'foo:\n   pri...pe=float32)\n' == 'foo:\n   pri...pe=float32)\n'
+    # Also reported in patrick-kidger/equinox#716
+    "test_backward_nan"
   ];
 
   meta = with lib; {
-    description = "A JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
+    description = "JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
     changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}";
     homepage = "https://github.com/patrick-kidger/equinox";
     license = licenses.asl20;