Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
Breaking change: Predictive will try to avoid recomputing "deterministic" sites if it is provided in posterior_samples
. Those deterministic sites are excluded in the previous releases.
jax.vmap
and validate_args=True
by @pierreglaser in https://github.com/pyro-ppl/numpyro/pull/1686
setup.cfg
by @juanitorduz in https://github.com/pyro-ppl/numpyro/pull/1704
NotImplementedError
rather than return it. by @tillahoffmann in https://github.com/pyro-ppl/numpyro/pull/1732
Full Changelog: https://github.com/pyro-ppl/numpyro/compare/0.13.2...0.14.0
A minor release to fix numpyro 0.13.1 broken on jax 0.4.14
Drops support for python 3.8 and requires jax version >= 0.4.14
This release is composed of great contributions and feedback from the Pyro community. Thank you!
This release includes a fix for jax 0.4.11 (#1595).
Drop Python 3.7 support and require the minimal jax version 0.4.
betainc
(which supports taking gradient w.r.t. parameters) in StudentT.cdf (thanks to @colehaus)contrib.render
functionality to infer.inspect
(thanks to @mattja)obs
argument is not respected when sample
primitive is not executed under any handler (thanks to @gcskoenig)potential_fn
in BarkerMHscan
's transition functionhistory=0
as platerender_params
argument for render_model
SineBivariateVonMises
distribution__repr__
method for constraintsThanks, @cstoafer, @hesenp, @tcbegley, @themrzmaster, @karm-patel, @theorashid