Gaussian processes in JAX.
Full Changelog: https://github.com/JaxGaussianProcesses/GPJax/compare/v0.6.3...v0.6.4
Full Changelog: https://github.com/JaxGaussianProcesses/GPJax/compare/v0.6.2...v0.6.3
Full Changelog: https://github.com/JaxGaussianProcesses/GPJax/compare/v0.6.1...v0.6.2
examples/regression.pct.py
by @patel-zeel in https://github.com/JaxGaussianProcesses/GPJax/pull/186
pre-commit
setup & code changes by @st-- in https://github.com/JaxGaussianProcesses/GPJax/pull/241
Full Changelog: https://github.com/JaxGaussianProcesses/GPJax/compare/v0.5.9...v0.6
Full Changelog: https://github.com/JaxGaussianProcesses/GPJax/compare/v0.5.8...v0.5.9
Address bug fix StochasticVI
variational expectation. (@Daniel-Dodd).
Chex
as a direct GPJax dependancy (@Daniel-Dodd).gpjax.parameters
to JaxUtils.parameters
(@thomaspinder).gpjax.config
to JaxUtils.config
(@Daniel-Dodd).gpjax.utils
to JaxUtils.dict
(@Daniel-Dodd).Chex
import and dataclassesIssue: #157
Fixed: via #176, @Daniel-Dodd
Details:
Chex
has been removed in as a direct dependancy of GPJax
.
Note Distrax depends on Chex, so you still need the relevant version of Chex installed to use GPJax.
This means it is recommend that users no longer use Chex
's @dataclass
decorator, to define objects. Instead, we advise users inherit from abstract types provided in GPJax (e.g., AbstractGP
) or should inherit the JaxUtils.PyTree
module, to ensure their object is registered as a JAX PyTree.
This decision to remove Chex
was given to provide great flexibility of defining new classes, and mitigating pain points with class inheritance issues (resorting use to use mixin classes).
⚠️ WARNING: Given GPJax
's objects are no longer Chex.dataclasses
note the keyword only argument convention for initialising objects has been removed - users should be careful on the order of their inputs.
Issue: #172 Fixed: #178, @Daniel-Dodd Details:
To initialise default parameters you should call/define init_param
instead of _initialise_params
. The latter is being depreciated and will be removed in v0.6
.
import jaxkern as jk
import gpjax as gpx
import jax.random as jr
prior = gpx.Prior(kernel = jk.RBF())
# OLD DO NOT USE:
key = jr.PRNGKey(123)
params = prior._initialise_params(key)
# NEW WAY:
key = jr.PRNGKey(123)
params = prior.init_params(key)
from gpjax.likelihoods import AbstractLikelihood
from jax.random import KeyArray
from typing import Dict
class Poisson(AbstractLikelihood):
# Define __init__, ect as usual.
def __init__(self, ...) -> None:
....
# Define your default params via the `init_params` method instead of the old `_initialise_params`
def init_params(self, key: KeyArray) -> Dict:
....
Issue: #170, #169 Fixed: #171 (@thomaspinder). Details:
The docs now builds and checks are in place to ensure successful builds on new PRs made to the master branch.
Note the docs are currently built using a GitHub workflow, while unit tests are run via CircleCI workflows.
FillScaleTriL
and update config behaviour by @patel-zeel in https://github.com/JaxGaussianProcesses/GPJax/pull/163
v0.4.x
compatibility, incorporate CircleCI testing workflows, incorporate versioneer. by @thomaspinder in https://github.com/JaxGaussianProcesses/GPJax/pull/164
Full Changelog: https://github.com/JaxGaussianProcesses/GPJax/compare/v0.5.5...v0.5.6
Depreciate gpjax.Dataset
and gpjax.kernels
. These objects will be removed in v0.6.0
. @thomaspinder
Implementation of new kernels and a backend move to using a PyTree
in place of a Chex dataclass.
Full Changelog: https://github.com/JaxGaussianProcesses/GPJax/compare/v0.5.2...v0.5.4