diff options
Diffstat (limited to 'pkgs/development/python-modules/numpyro/default.nix')
-rw-r--r-- | pkgs/development/python-modules/numpyro/default.nix | 83 |
1 files changed, 67 insertions, 16 deletions
diff --git a/pkgs/development/python-modules/numpyro/default.nix b/pkgs/development/python-modules/numpyro/default.nix index b26e61b945ce3..a1aed20a52686 100644 --- a/pkgs/development/python-modules/numpyro/default.nix +++ b/pkgs/development/python-modules/numpyro/default.nix @@ -1,30 +1,43 @@ { lib, buildPythonPackage, - pythonOlder, - fetchPypi, + fetchFromGitHub, + + # build-system setuptools, + + # dependencies jax, jaxlib, multipledispatch, numpy, tqdm, + + # tests + # Our current version of tensorflow (2.13.0) is too old and doesn't support python>=3.12 + # We remove optional test dependencies that require tensorflow and skip the corresponding tests to + # avoid introducing a useless incompatibility with python 3.12: + # dm-haiku, + # flax, + # tensorflow-probability, funsor, + graphviz, + optax, + pyro-api, pytestCheckHook, -# TODO: uncomment when tensorflow-probability gets fixed. -# , tensorflow-probability + scikit-learn, }: buildPythonPackage rec { pname = "numpyro"; - version = "0.15.0"; + version = "0.15.3"; pyproject = true; - disabled = pythonOlder "3.9"; - - src = fetchPypi { - inherit version pname; - hash = "sha256-4WyfR8wx4qollYSgtslEMSCB0zypJAYCJjKtWEsOYA0="; + src = fetchFromGitHub { + owner = "pyro-ppl"; + repo = "numpyro"; + rev = "refs/tags/${version}"; + hash = "sha256-g+ep221hhLbCjQasKpiEAXkygI5A3Hglqo1tV8lv5eg="; }; build-system = [ setuptools ]; @@ -38,9 +51,14 @@ buildPythonPackage rec { ]; nativeCheckInputs = [ + # dm-haiku + # flax funsor + graphviz + optax + pyro-api pytestCheckHook - # TODO: uncomment when tensorflow-probability gets fixed. + scikit-learn # tensorflow-probability ]; @@ -57,23 +75,56 @@ buildPythonPackage rec { "test_kl_dirichlet_dirichlet" "test_kl_univariate" "test_mean_var" + # Tests want to download data "data_load" "test_jsb_chorales" + # RuntimeWarning: overflow encountered in cast "test_zero_inflated_logits_probs_agree" + # NameError: unbound axis name: _provenance "test_model_transformation" + + # require dm-haiku + "test_flax_state_dropout_smoke" + "test_flax_module" + "test_random_module_mcmc" + + # require flax + "test_haiku_state_dropout_smoke" + "test_haiku_module" + "test_random_module_mcmc" + + # require tensorflow-probability + "test_modified_bessel_first_kind_vect" + "test_diag_spectral_density_periodic" + "test_kernel_approx_periodic" + "test_modified_bessel_first_kind_one_dim" + "test_modified_bessel_first_kind_vect" + "test_periodic_gp_one_dim_model" + "test_no_tracer_leak_at_lazy_property_sample" + + # flaky on darwin + # TODO: uncomment at next release (0.15.4) as it has been fixed: + # https://github.com/pyro-ppl/numpyro/pull/1863 + "test_change_point_x64" ]; - # TODO: remove when tensorflow-probability gets fixed. - disabledTestPaths = [ "test/test_distributions.py" ]; + disabledTestPaths = [ + # require jaxns (unpackaged) + "test/contrib/test_nested_sampling.py" + + # requires tensorflow-probability + "test/contrib/test_tfp.py" + "test/test_distributions.py" + ]; - meta = with lib; { + meta = { description = "Library for probabilistic programming with NumPy"; homepage = "https://num.pyro.ai/"; changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}"; - license = licenses.asl20; - maintainers = with maintainers; [ fab ]; + license = lib.licenses.asl20; + maintainers = with lib.maintainers; [ fab ]; }; } |