Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMC Inner kernel tuning #595

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f979f43
inner kernel tuning, tests, and some common strategies
ciguaran Nov 30, 2023
699a385
Merge branch 'main' into carlosiguaran_smc_inner_kernel_tuning
ciguaran Nov 30, 2023
8e60ec7
Adding imports
ciguaran Nov 30, 2023
1177372
pre-commit
ciguaran Nov 30, 2023
565c2ed
Merge branch 'main' into carlosiguaran_smc_inner_kernel_tuning
ciguaran Nov 30, 2023
ca5f781
code review updates
ciguaran Nov 30, 2023
6693cf4
Merge branch 'main' into carlosiguaran_smc_inner_kernel_tuning
junpenglao Dec 2, 2023
2126486
Adding Chex tests
ciguaran Dec 4, 2023
494ee99
Merge branch 'carlosiguaran_smc_inner_kernel_tuning' of github.com:ci…
ciguaran Dec 4, 2023
38f1557
Merge branch 'main' into carlosiguaran_smc_inner_kernel_tuning
ciguaran Dec 4, 2023
6d7eb40
line alignment comment
ciguaran Dec 5, 2023
cf3219e
Merge branch 'carlosiguaran_smc_inner_kernel_tuning' of github.com:ci…
ciguaran Dec 5, 2023
f14eddf
Adding particles_as_rows test
ciguaran Dec 5, 2023
2af0179
Modifying implementation of particles_as_rows
ciguaran Dec 5, 2023
5bdeede
pre-commit
ciguaran Dec 5, 2023
ea7e7af
Merge branch 'main' into carlosiguaran_smc_inner_kernel_tuning
ciguaran Dec 5, 2023
f5314c8
change in inverse_mass_matrix from particles implementation
ciguaran Dec 5, 2023
2f4e67c
replacing particles_as_rows_test
ciguaran Dec 6, 2023
b1a2d0f
Merge branch 'carlosiguaran_smc_inner_kernel_tuning' of github.com:ci…
ciguaran Dec 6, 2023
ab46961
Merge branch 'main' into carlosiguaran_smc_inner_kernel_tuning
ciguaran Dec 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .sgmcmc.sgld import sgld
from .sgmcmc.sgnht import sgnht
from .smc.adaptive_tempered import adaptive_tempered_smc
from .smc.inner_kernel_tuning import inner_kernel_tuning
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
Expand Down Expand Up @@ -57,6 +58,7 @@
"mclmc_find_L_and_step_size", # mclmc adaptation
"adaptive_tempered_smc", # smc
"tempered_smc",
"inner_kernel_tuning",
"meanfield_vi", # variational inference
"pathfinder",
"schrodinger_follmer",
Expand Down
4 changes: 2 additions & 2 deletions blackjax/smc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import adaptive_tempered, tempered
from . import adaptive_tempered, inner_kernel_tuning, tempered

__all__ = ["adaptive_tempered", "tempered"]
__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning"]
150 changes: 150 additions & 0 deletions blackjax/smc/inner_kernel_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from typing import Callable, Dict, NamedTuple, Tuple, Union

from blackjax.base import SamplingAlgorithm
from blackjax.smc.adaptive_tempered import adaptive_tempered_smc
from blackjax.smc.base import SMCInfo, SMCState
from blackjax.smc.tempered import tempered_smc
from blackjax.types import ArrayTree, PRNGKey


class StateWithParameterOverride(NamedTuple):
sampler_state: ArrayTree
parameter_override: ArrayTree


def init(alg_init_fn, position, initial_parameter_value):
return StateWithParameterOverride(alg_init_fn(position), initial_parameter_value)


def build_kernel(
smc_algorithm,
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_factory: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree],
num_mcmc_steps: int = 10,
**extra_parameters,
) -> Callable:
"""In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner
MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC,
based on particles. The parameter type must be a valid JAX type.

Parameters
----------
smc_algorithm
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of
a sampling algorithm that returns an SMCState and SMCInfo pair).
logprior_fn
A function that computes the log density of the prior distribution
loglikelihood_fn
A function that returns the probability at a given position.
mcmc_factory
A callable that can construct an inner kernel out of the newly-computed parameter
mcmc_init_fn
A callable that initializes the inner kernel
mcmc_parameters
Other (fixed across SMC iterations) parameters for the inner kernel
mcmc_parameter_update_fn
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.
extra_parameters:
parameters to be used for the creation of the smc_algorithm.
"""

def kernel(
rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters
) -> Tuple[StateWithParameterOverride, SMCInfo]:
step_fn = smc_algorithm(
logprior_fn=logprior_fn,
loglikelihood_fn=loglikelihood_fn,
mcmc_step_fn=mcmc_factory(state.parameter_override),
mcmc_init_fn=mcmc_init_fn,
mcmc_parameters=mcmc_parameters,
resampling_fn=resampling_fn,
num_mcmc_steps=num_mcmc_steps,
**extra_parameters,
).step
new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters)
new_parameter_override = mcmc_parameter_update_fn(new_state, info)
return StateWithParameterOverride(new_state, new_parameter_override), info

return kernel


class inner_kernel_tuning:
"""In the context of an SMC sampler (whose step_fn returning state
has a .particles attribute), there's an inner MCMC that is used
to perturbate/update each of the particles. This adaptation tunes some
parameter of that MCMC, based on particles.
The parameter type must be a valid JAX type.

Parameters
----------
smc_algorithm
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of
a sampling algorithm that returns an SMCState and SMCInfo pair).
logprior_fn
A function that computes the log density of the prior distribution
loglikelihood_fn
A function that returns the probability at a given position.
mcmc_factory
A callable that can construct an inner kernel out of the newly-computed parameter
mcmc_init_fn
A callable that initializes the inner kernel
mcmc_parameters
Other (fixed across SMC iterations) parameters for the inner kernel step
mcmc_parameter_update_fn
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the
inner kernel in i+1 iteration.
initial_parameter_value
Paramter to be used by the mcmc_factory before the first iteration.
extra_parameters:
parameters to be used for the creation of the smc_algorithm.

Returns
-------
A ``SamplingAlgorithm``.

"""

init = staticmethod(init)
build_kernel = staticmethod(build_kernel)

def __new__( # type: ignore[misc]
cls,
smc_algorithm: Union[adaptive_tempered_smc, tempered_smc],
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_factory: Callable,
mcmc_init_fn: Callable,
mcmc_parameters: Dict,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree],
initial_parameter_value,
num_mcmc_steps: int = 10,
**extra_parameters,
) -> SamplingAlgorithm:
kernel = cls.build_kernel(
smc_algorithm,
logprior_fn,
loglikelihood_fn,
mcmc_factory,
mcmc_init_fn,
mcmc_parameters,
resampling_fn,
mcmc_parameter_update_fn,
num_mcmc_steps,
**extra_parameters,
)

def init_fn(position):
return cls.init(smc_algorithm.init, position, initial_parameter_value)

def step_fn(
rng_key: PRNGKey, state, **extra_step_parameters
) -> Tuple[StateWithParameterOverride, SMCInfo]:
return kernel(rng_key, state, **extra_step_parameters)

return SamplingAlgorithm(init_fn, step_fn)
Empty file added blackjax/smc/tuning/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions blackjax/smc/tuning/from_kernel_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
strategies to tune the parameters of mcmc kernels
used within smc, based on MCMC states
"""
import jax
import jax.numpy as jnp

__all__ = ["update_scale_from_acceptance_rate"]


def update_scale_from_acceptance_rate(
scales: jax.Array,
acceptance_rates: jax.Array,
target_acceptance_rate: float = 0.234,
) -> jax.Array:
"""
Given N chains from some MCMC algorithm like Random Walk Metropolis
and N scale factors, each associated to a different chain.
Updates the scale factors taking into account acceptance rates and
the average acceptance rate.

Under certain assumptions it is known that the optimal acceptance rate
of Metropolis Hastings is 0.4 for 1 dimension and converges to
0.234 in infinite dimensions. In practice, 0.234 is a reasonable
assumption for 5 or more dimensions.

If certain chain is below optimal acceptance rate, its scale will decrease
and if its above, its scale will increase,
-------

Parameters
----------
scales
(n_chains) array consisting of N scale factors, associated to N markov chains
acceptance_rates
(n_chains) acceptance rate of the N markov chains
target_acceptance_rate
a float with a desirable acceptance rate for the chains.

Returns
-------
(n_chains) new scales, with the aim of getting acceptance rates closer to target
if the chains were to be run again.
"""
chain_scales = jnp.exp(jnp.log(scales) + acceptance_rates - target_acceptance_rate)
return 0.5 * (chain_scales + chain_scales.mean())
49 changes: 49 additions & 0 deletions blackjax/smc/tuning/from_particles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
strategies to tune the parameters of mcmc kernels
used within SMC, based on particles.
"""
import jax
import jax.numpy as jnp
from jax._src.flatten_util import ravel_pytree

from blackjax.types import Array

__all__ = [
"particles_means",
"particles_stds",
"particles_covariance_matrix",
"mass_matrix_from_particles",
]


def particles_stds(particles):
return jnp.std(particles_as_rows(particles), axis=0)


def particles_means(particles):
return jnp.mean(particles_as_rows(particles), axis=0)


def particles_covariance_matrix(particles):
return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False)


def mass_matrix_from_particles(particles) -> Array:
"""
Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf
Computing a mass matrix to be used in HMC from particles.
Given the particles covariance matrix, set all non-diagonal elements as zero,
take the inverse, and keep the diagonal.
Returns
-------
A mass Matrix
"""
return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0))


def particles_as_rows(particles):
ciguaran marked this conversation as resolved.
Show resolved Hide resolved
"""
Adds end dimension for single-dimension variables, and then represents multivariables
as a matrix where each column is a variable, each row a particle.
"""
return jax.vmap(lambda x: ravel_pytree(x)[0])(particles)
37 changes: 37 additions & 0 deletions tests/smc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import chex
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np


class SMCLinearRegressionTestCase(chex.TestCase):
def logdensity_fn(self, log_scale, coefs, preds, x):
"""Linear regression"""
scale = jnp.exp(log_scale)
y = jnp.dot(x, coefs)
logpdf = stats.norm.logpdf(preds, y, scale)
return jnp.sum(logpdf)

def particles_prior_loglikelihood(self):
num_particles = 100

x_data = np.random.normal(0, 1, size=(1000, 1))
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
observations = {"x": x_data, "preds": y_data}

logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf(
x["coefs"]
)
loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations)

log_scale_init = np.random.randn(num_particles)
coeffs_init = np.random.randn(num_particles)
init_particles = {"log_scale": log_scale_init, "coefs": coeffs_init}

return init_particles, logprior_fn, loglikelihood_fn

def assert_linear_regression_test_case(self, result):
np.testing.assert_allclose(
np.mean(np.exp(result.particles["log_scale"])), 1.0, rtol=1e-1
)
np.testing.assert_allclose(np.mean(result.particles["coefs"]), 3.0, rtol=1e-1)
Loading