Numpyro Versions Save

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.

0.14.0

2 months ago

Breaking change: Predictive will try to avoid recomputing "deterministic" sites if it is provided in posterior_samples. Those deterministic sites are excluded in the previous releases.

What's Changed

New Contributors

Full Changelog: https://github.com/pyro-ppl/numpyro/compare/0.13.2...0.14.0

0.13.2

7 months ago

A minor release to fix numpyro 0.13.1 broken on jax 0.4.14

0.13.1

7 months ago

Enhancements and Bug Fixes

  • Add promote_batch_shape rule for Independent (#1630) by @deoxyribose
  • Support custom prng key (#1642) by @fehiepsi

0.13.0

8 months ago

Breaking changes

Drops support for python 3.8 and requires jax version >= 0.4.14

New Features

  • Distributions are now vmap-able (#1529) (a great contribution by @pierreglaser)

Enhancements and Bug Fixes

  • Enhance LocScaleReparam's documentation (#1599) by @Madhav-Kanda
  • Fix incorrect unflattenning of inverse transforms (#1600) by @pierreglaser
  • Update Stein mixture (#1601 and #1612) by @OlaRonning
  • Support model without global variables in AutoSemiDAIS (#1610 and #1619) by @fehiepsi
  • Fix mixture assert message: string shoulf be f-string (#1617) by @adrn
  • Add support for local variables in RenyiELBO (#1608) by @fehiepsi
  • Fix quantile computation of mvn autoguides (#1622) by @fehiepsi
  • Respect log_density in kl of delta (#1625) by @fehiepsi
  • Add vectorized_particles to ELBO (#1624) by @fehiepsi
  • Fix bug in SineBivariateVonMises sampler (#1628) by @deoxyribose

This release is composed of great contributions and feedback from the Pyro community. Thank you!

0.12.1

11 months ago

This release includes a fix for jax 0.4.11 (#1595).

0.12.0

11 months ago

New Features

Enhancements and Bug Fixes

  • Make transforms jittable (#1575)
  • Fixed typo in surrogate posterior of beta (#1591)
  • Do not scale mnist label (#1589)
  • Do not mutate shapes of ExpandedDistribution for map-free ops (#1574)
  • Add support for JAX custom PRNG (#1587)
  • Include deterministic variables in AutoDelta's sample_posterior (#1584)
  • Fix forward shape of SimplexToOrderTransform (#1583)
  • Fix inf's in TruncatedNormal log_prob & sample (#1581)
  • Allow users to specify total_count_max in Multinomial (#1557)
  • Allow pickled mcmc object to run post warmup phase (#1558)
  • Add init_params argument to svi.init() and svi.run() (#1561)
  • Support pickling MCMC objects with enumeration (#1577)
  • Raise error when reparameterize lognormal (#1548)
  • Avoid initializing model params when already specified in guide (#1553)
  • Respect init params if provided to mcmc.run (#1547)
  • Fix provenance for jax 0.4.4 (#1543)
  • Use analytic kl divergence in TraceEnum_ELBO (#1533)
  • Properly handle contraction of guide plates in TraceEnum_ELBO (#1537)
  • Raise an error if there is no common scale when model enumerated (#1536)
  • Optimize reduction of enumerated guide sites (#1531)
  • Guess max_plate_nesting in TraceEnum_ELBO (#1528)

0.11.0

1 year ago

Breaking changes

Drop Python 3.7 support and require the minimal jax version 0.4.

New Features

Enhancements and Bug Fixes

  • #1507 Made constraints robust to pickling (thanks to @pierreglaser)
  • #1515 Fixed log_prob for negative correlation in SineBivariateVonMises distribution (thanks to @OlaRonning)
  • #1509 Fixed a bug at flatten/unflatten distributions which mixed the order of arguments of the distributions (thanks to @hessammehr)
  • #1494 Fixed Mixture distribution with unnormalized logits giving incorrect log_prob
  • #1486 Returned a positive ordered vector when applies ExpTransform for an ordered vector
  • #1491 Fixed Mixture intermediate values
  • #1480 Fixed some computations in Bayesian Hierarchical Stacking tutorial (thanks to @cpieringer)
  • #1478 Added icdf methods for Beta, Gamma, StudentT
  • #1477 Allowed multiple arguments to initialize flax/haiku modules (thanks to @fehiepsi)
  • #1475 Used TFP's betainc (which supports taking gradient w.r.t. parameters) in StudentT.cdf (thanks to @colehaus)

0.10.1

1 year ago

New Features

  • #1458 New Mixture distribution to allow component distributions come from different families (thanks to @dfm)

Enhancements and Bug Fixes

  • #1469 Fix memory leak when using scan
  • #1465 Fix validation logic for transformed distribution (thanks to @xidulu)
  • #1462 Avoid infinite loop in vmapped Binomial (thanks to @tbenthompson)
  • #1449 Move contrib.render functionality to infer.inspect (thanks to @mattja)

0.10.0

1 year ago

New Features

Enhancements and Bug Fixes

  • #1401 Fix obs argument is not respected when sample primitive is not executed under any handler (thanks to @gcskoenig)
  • #1412 TraceGraph_ELBO implementation using provenance tracking
  • #1418 Fix SA sampler cannot be run in parallel chains
  • #1419 Fix categorical sampler occasionally generate out-of-support samples
  • #1436 Allow to use potential_fn in BarkerMH
  • #1437 Fix for AutoMultivariateNormal.get_posterior method return incorrect distribution (thanks to @xidulu)
  • #1444 Promote shapes for observed variables inside scan's transition function
  • #1443 Consider the time dimension of markov models with history=0 as plate
  • #1441 More stable check for corr_cholesky constraint
  • #1400 Fix SineBivariateVonMises sampler

0.9.2

2 years ago

New Features

Enhancements and Bug Fixes

  • #1386 Support JAX 0.3.5
  • #1388 Update jaxns to 1.0.0
  • #1372 Fix batch shapes of SineBivariateVonMises distribution
  • #1375 Add __repr__ method for constraints
  • #1358 Force validate_args to be keyword argument
  • #1350 Enhance the AR2 example

Thanks, @cstoafer, @hesenp, @tcbegley, @themrzmaster, @karm-patel, @theorashid