Numpyro Versions Save

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.

0.9.2

2 years ago

New Features

Enhancements and Bug Fixes

  • #1386 Support JAX 0.3.5
  • #1388 Update jaxns to 1.0.0
  • #1372 Fix batch shapes of SineBivariateVonMises distribution
  • #1375 Add __repr__ method for constraints
  • #1358 Force validate_args to be keyword argument
  • #1350 Enhance the AR2 example

Thanks, @cstoafer, @hesenp, @tcbegley, @themrzmaster, @karm-patel, @theorashid

0.9.1

2 years ago

New Features

Enhancements and Bug Fixes

  • #1320 Fix provenance logic for numpyro.render_model
  • #1330 Fix vectorize sampling for sites with size=0
  • #1331 Allow nested params in SteinVI
  • #1333 Reduce memory consumption for SteinVI
  • #1325 Fix compat MCMC signature

Thanks, @Vinnie-Palazeti, @wataruhashimoto52, @hessammehr, @OlaRonning, @d-diaz!

0.9.0

2 years ago

New Features

Enhancements and Bug Fixes

  • #1305 Fixes HMCECS bug for likelihoods with multiple plates
  • #1304 Improves warning mechanism when plates are missing.
  • #1301 Fixes sparse Poisson density sometimes returns int output.
  • #1289 Make HMC Gibbs algorithms work with improper distributions
  • #1284 Adds various KL divergences for Gamma/Beta families
  • #1281 Raises error if there are duplicated deterministic sites
  • #1271 Better warning mechanism with stacklevel
  • #1270 Incorporate kl divergences of Tensorflow Probability distributions
  • #1259 #1266 Allow TruncatedNormal/Cauchy to take both low and high
  • #1254 numpyro.contrib.indexing is moved to numpyro.ops.indexing
  • #1252 Use multipledispatch for kl_registry
  • #1250 Added cdf methods for gamma, inverse gamma, log normal densities
  • #1248 Add ProvenanceArray to infer relational structure in a model
  • #1244 Raise warning for the automatic enumeration behavior
  • #1237 Enhance warnings for invalid parameters of BetaProportion distribution
  • #1227 Allow prior to be callable in random_flax_module and random_haiku_module
  • #1226 Allow init_to_sample work with scalar values
  • #1225 Add color for divergences in Neal's example
  • #1196 Allow custom precision function in laplace approximation autoguide
  • #1194 Option to specify init state for SVI run
  • #1185 #1189 Avoid -inf/nan samples in truncated distributions
  • #1182 Extend scope handler for plate stack frames
  • #1179 Support enumerate support for zero inflated distributions
  • #1169 Allow pickle autoguides

This release is composed of great contributions and feedback from the Pyro community: @amalvaidya @MarcoGorelli @omarfsosa @maw501 @bjeffrey92 @hessammehr @OlaRonning @dykim29 @Carlosbogo @wataruhashimoto52 @Vedranh13 @ahmadsalim @austereantelope and many others. Thank you!

0.8.0

2 years ago

Breaking changes

Switch to softplus transforms for autoguide scales (thanks to experiments performed by @vitkl).

New Features

Enhancements and Bug Fixes

  • #1108 Avoid numerical problems when using BernoulliProbs
  • #1118 Recommend AutoNormal guide when hessian in AutoLaplace is singular
  • #1126 Smarter warning about discrete inference in SVI models
  • #1136 Support to use SA sampler with arviz
  • #1139 Document Poisson is_sparse argument
  • #1140 Make Sigmoid and StickBreakingTransform more stable
  • #1149 Raise value error if num_steps bad in svi.run
  • #1162 Use black[jupyter] in notebooks

This release is composed of great contributions and feedback from the Pyro community: @MarcoGorelli @OlaRonning @d-diaz @quattro @svilupp @peterroelants @prashjet @freddyaboulton @tcbegley @julianstastny @alexlyttle and many others. Thank you!

0.7.2

2 years ago

This is a patch release with the following new feature and fixes:

0.7.1

2 years ago

In 0.7.0 release, the wheel file uploaded to PyPI had some files not updated. This release fixes that issue.

0.7.0

2 years ago

Since this release, NumPyro can be installed along with the latest jax and jaxlib releases (their version restrictions have been relaxed). In addition, NumPyro will use the default JAX platform so if you installed JAX with GPU/TPU support, their devices will be used by default.

New Features

Enhancements and Bug Fixes

  • Documentation and examples are greatly enhanced to make features more accessible
  • Fix chain detection for various CPU device strings #1077
  • Fix AutoNormal's quantiles method for models with non-scalar latent sites #1066
  • Fix LocScaleReparam with center=1 #1059
  • Enhance auto guides to support models with deterministic sites #1022
  • Support for mutable states in Flax and Haiku modules #1016
  • Fix a bug in auto guides that happens when using the guide in Predictive #1013
  • Support decorator syntax for effect handlers #1009
  • Implement sparse Poisson log probability #1003
  • Support total_count=0 in Multinomial distribution #1000
  • Add a flag to control regularize mass matrix behavior in mass matrix adaptation #998
  • Add experimental Dockerfiles #996
  • Allow setting max tree depth of NUTS sampler during warmup phase #984
  • Fix dimensions mixed up in ExpandedDistribution.sample method #972
  • MCMC objects can be pickled now #968

This release is made of great contributions and feedbacks from the Pyro community: @ahoho, @kpj, @gustavehug, @AndrewCSQ, @jatentaki, @tcbegley, @dominikstrb, @justinrporter, @dirmeier, @irustandi, @MarcoGorelli, @lumip, and many others. Thank you!

0.6.0

3 years ago

New Features

Enhancements and Bug Fixes

  • Improve precision for Dirichlet distributions with small concentration #943
  • Make it easy to use softplus transforms in autoguides #941
  • Improving compiling time in MCMC samplers - compiling time is 2x faster than previously #924
  • Reduce memory requirement for AutoLowRankMultivariateNormal.quantiles #921
  • Example of how to use Distribution.mask #917
  • Add goodness of fit helpers for testing distributions #916
  • Enabling sampling with intermediates for ExpandedDistribution #909
  • Fix DiscreteHMCGibbs to work with multiple chains #908
  • Fix missing infer key in handlers.lift #892

Thanks @loopylangur, Dominik Straub @dominikstrb, Jeremie Coullon @jeremiecoullon, Ola Rønning @OlaRonning, Lukas Prediger @lumip, Raúl Peralta Lozada @RaulPL, Vitalii Kleshchevnikov @vitkl, Matt Ludkin @ludkinm, and many others for your contributions and feedback!

0.5.0

3 years ago

New documentation page with galleries of tutorials and examples num.pyro.ai.

New Features

  • New primitive: prng_key to draw a random key under seed handler.
  • New autoguide: AutoDelta
  • New samplers:
    • HMCGibbs: a general HMC/NUTS-within-Gibbs interface.
    • DiscreteHMCGibbs: HMC/NUTS-within-Gibbs for models with discrete latent variables.
    • HMCECS: HMC/NUTS with energy conserving subsampling.
  • New example:
  • New kernels module in numpyro.contrib.einstein, in preparing for (Ein)Stein VI inference in future releases.
  • New user-friendly SVI.run method to simplify the training phase of SVI inference.
  • New feasible_like method in constraints.
  • New methods forward_shape and inverse_shape in Transform to infer output shape given input shape.
  • Transform.inv now returns an inversed transform, hence enables many new (inversed) transforms.
  • Support thinning in MCMC.
  • Add post_warmup_state and last_state to allow sequential sampling strategy in MCMC: allow to keep calling .run method to get more samples.
  • New history argument to support for Markov models with history > 1 in scan.
  • New forward_model_differentiation argument in HMC/NUTS kernels to allow to use forward mode differentiation.

Enhancements and Bug Fixes

  • #886 Make TransformReparam compatible with .to_event()
  • #883 Improve gradient computation of Euclidean kinetic energy.
  • #872 Enhance masked distribution to allow gradient propagate properly when using mask handler for invalid data.
  • #865 Make subsample faster in CPU.
  • #860 Fix for memory leak in MCMC.
  • #849 Expose logits attribute to some discrete distributions
  • #848 Add has_rsample and rsample attribute to distributions
  • #832 Allow a callable to return an init value in param primitive
  • #824 Fix for cannot using sample method of TFP distributions in sample primitive.
  • #823 Demo on how to use various init strategies in Gaussian Process example.
  • #822 Allow haiku/flax modules to take general args/kwargs in init.
  • #821 Better error messages when rng_key is missing.
  • #818 Better error messages when an error happens in the middle of inference.
  • #805 Display correct progress bar message after running MCMC.warmup.
  • #801 Raise an error early if missing plates for models with discrete latent variables.
  • #797 MCMC vectorized chain method works for models with deterministic sites.
  • #796 Bernoulli distribution returns an int instead of a boolean.
  • #795 Reveal signature for help(Distribution).

Thanks Ola Ronning @OlaRonning, Armin Stepanjan @ab-10, @cerbelaut, Xi Wang @xidulu, Wouter van Amsterdam @vanAmsterdam, @loopylangur, and many others for your contributions and helpful feedback!

0.4.1

3 years ago

New Features

Enhancements and Bug Fixes

  • #764 Make exception chaining more user-friendly. Thanks, @akihironitta!
  • #766 Relax interval constraint.
  • #776 Fix bugs in methods log_prob and sample of VonMises distribution.
  • #775 Make validation mechanism compatible with omnistaging since JAX 0.2.
  • #780 Fix name dimensions of sample sites under contrib.funsor's plate handler.