about summary refs log tree commit diff
path: root/pkgs/development/python-modules/numpyro/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/numpyro/default.nix')
-rw-r--r--pkgs/development/python-modules/numpyro/default.nix83
1 files changed, 67 insertions, 16 deletions
diff --git a/pkgs/development/python-modules/numpyro/default.nix b/pkgs/development/python-modules/numpyro/default.nix
index b26e61b945ce3..a1aed20a52686 100644
--- a/pkgs/development/python-modules/numpyro/default.nix
+++ b/pkgs/development/python-modules/numpyro/default.nix
@@ -1,30 +1,43 @@
 {
   lib,
   buildPythonPackage,
-  pythonOlder,
-  fetchPypi,
+  fetchFromGitHub,
+
+  # build-system
   setuptools,
+
+  # dependencies
   jax,
   jaxlib,
   multipledispatch,
   numpy,
   tqdm,
+
+  # tests
+  # Our current version of tensorflow (2.13.0) is too old and doesn't support python>=3.12
+  # We remove optional test dependencies that require tensorflow and skip the corresponding tests to
+  # avoid introducing a useless incompatibility with python 3.12:
+  # dm-haiku,
+  # flax,
+  # tensorflow-probability,
   funsor,
+  graphviz,
+  optax,
+  pyro-api,
   pytestCheckHook,
-# TODO: uncomment when tensorflow-probability gets fixed.
-# , tensorflow-probability
+  scikit-learn,
 }:
 
 buildPythonPackage rec {
   pname = "numpyro";
-  version = "0.15.0";
+  version = "0.15.3";
   pyproject = true;
 
-  disabled = pythonOlder "3.9";
-
-  src = fetchPypi {
-    inherit version pname;
-    hash = "sha256-4WyfR8wx4qollYSgtslEMSCB0zypJAYCJjKtWEsOYA0=";
+  src = fetchFromGitHub {
+    owner = "pyro-ppl";
+    repo = "numpyro";
+    rev = "refs/tags/${version}";
+    hash = "sha256-g+ep221hhLbCjQasKpiEAXkygI5A3Hglqo1tV8lv5eg=";
   };
 
   build-system = [ setuptools ];
@@ -38,9 +51,14 @@ buildPythonPackage rec {
   ];
 
   nativeCheckInputs = [
+    # dm-haiku
+    # flax
     funsor
+    graphviz
+    optax
+    pyro-api
     pytestCheckHook
-    # TODO: uncomment when tensorflow-probability gets fixed.
+    scikit-learn
     # tensorflow-probability
   ];
 
@@ -57,23 +75,56 @@ buildPythonPackage rec {
     "test_kl_dirichlet_dirichlet"
     "test_kl_univariate"
     "test_mean_var"
+
     # Tests want to download data
     "data_load"
     "test_jsb_chorales"
+
     # RuntimeWarning: overflow encountered in cast
     "test_zero_inflated_logits_probs_agree"
+
     # NameError: unbound axis name: _provenance
     "test_model_transformation"
+
+    # require dm-haiku
+    "test_flax_state_dropout_smoke"
+    "test_flax_module"
+    "test_random_module_mcmc"
+
+    # require flax
+    "test_haiku_state_dropout_smoke"
+    "test_haiku_module"
+    "test_random_module_mcmc"
+
+    # require tensorflow-probability
+    "test_modified_bessel_first_kind_vect"
+    "test_diag_spectral_density_periodic"
+    "test_kernel_approx_periodic"
+    "test_modified_bessel_first_kind_one_dim"
+    "test_modified_bessel_first_kind_vect"
+    "test_periodic_gp_one_dim_model"
+    "test_no_tracer_leak_at_lazy_property_sample"
+
+    # flaky on darwin
+    # TODO: uncomment at next release (0.15.4) as it has been fixed:
+    # https://github.com/pyro-ppl/numpyro/pull/1863
+    "test_change_point_x64"
   ];
 
-  # TODO: remove when tensorflow-probability gets fixed.
-  disabledTestPaths = [ "test/test_distributions.py" ];
+  disabledTestPaths = [
+    # require jaxns (unpackaged)
+    "test/contrib/test_nested_sampling.py"
+
+    # requires tensorflow-probability
+    "test/contrib/test_tfp.py"
+    "test/test_distributions.py"
+  ];
 
-  meta = with lib; {
+  meta = {
     description = "Library for probabilistic programming with NumPy";
     homepage = "https://num.pyro.ai/";
     changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
-    license = licenses.asl20;
-    maintainers = with maintainers; [ fab ];
+    license = lib.licenses.asl20;
+    maintainers = with lib.maintainers; [ fab ];
   };
 }