diff --git a/blackjax/__init__.py b/blackjax/__init__.py index a0dcaecd9..b1d16e517 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -23,6 +23,7 @@ from .sgmcmc.sgld import sgld from .sgmcmc.sgnht import sgnht from .smc.adaptive_tempered import adaptive_tempered_smc +from .smc.inner_kernel_tuning import inner_kernel_tuning from .smc.tempered import tempered_smc from .vi.meanfield_vi import meanfield_vi from .vi.pathfinder import pathfinder @@ -57,6 +58,7 @@ "mclmc_find_L_and_step_size", # mclmc adaptation "adaptive_tempered_smc", # smc "tempered_smc", + "inner_kernel_tuning", "meanfield_vi", # variational inference "pathfinder", "schrodinger_follmer", diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index 59fdbbdd4..180cd8259 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -1,3 +1,3 @@ -from . import adaptive_tempered, tempered +from . import adaptive_tempered, inner_kernel_tuning, tempered -__all__ = ["adaptive_tempered", "tempered"] +__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning"] diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py new file mode 100644 index 000000000..83bfdb50d --- /dev/null +++ b/blackjax/smc/inner_kernel_tuning.py @@ -0,0 +1,150 @@ +from typing import Callable, Dict, NamedTuple, Tuple, Union + +from blackjax.base import SamplingAlgorithm +from blackjax.smc.adaptive_tempered import adaptive_tempered_smc +from blackjax.smc.base import SMCInfo, SMCState +from blackjax.smc.tempered import tempered_smc +from blackjax.types import ArrayTree, PRNGKey + + +class StateWithParameterOverride(NamedTuple): + sampler_state: ArrayTree + parameter_override: ArrayTree + + +def init(alg_init_fn, position, initial_parameter_value): + return StateWithParameterOverride(alg_init_fn(position), initial_parameter_value) + + +def build_kernel( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_factory: Callable, + mcmc_init_fn: Callable, + mcmc_parameters: Dict, + resampling_fn: Callable, + mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], + num_mcmc_steps: int = 10, + **extra_parameters, +) -> Callable: + """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner + MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, + based on particles. The parameter type must be a valid JAX type. + + Parameters + ---------- + smc_algorithm + Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of + a sampling algorithm that returns an SMCState and SMCInfo pair). + logprior_fn + A function that computes the log density of the prior distribution + loglikelihood_fn + A function that returns the probability at a given position. + mcmc_factory + A callable that can construct an inner kernel out of the newly-computed parameter + mcmc_init_fn + A callable that initializes the inner kernel + mcmc_parameters + Other (fixed across SMC iterations) parameters for the inner kernel + mcmc_parameter_update_fn + A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. + extra_parameters: + parameters to be used for the creation of the smc_algorithm. + """ + + def kernel( + rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + step_fn = smc_algorithm( + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + mcmc_step_fn=mcmc_factory(state.parameter_override), + mcmc_init_fn=mcmc_init_fn, + mcmc_parameters=mcmc_parameters, + resampling_fn=resampling_fn, + num_mcmc_steps=num_mcmc_steps, + **extra_parameters, + ).step + new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) + new_parameter_override = mcmc_parameter_update_fn(new_state, info) + return StateWithParameterOverride(new_state, new_parameter_override), info + + return kernel + + +class inner_kernel_tuning: + """In the context of an SMC sampler (whose step_fn returning state + has a .particles attribute), there's an inner MCMC that is used + to perturbate/update each of the particles. This adaptation tunes some + parameter of that MCMC, based on particles. + The parameter type must be a valid JAX type. + + Parameters + ---------- + smc_algorithm + Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of + a sampling algorithm that returns an SMCState and SMCInfo pair). + logprior_fn + A function that computes the log density of the prior distribution + loglikelihood_fn + A function that returns the probability at a given position. + mcmc_factory + A callable that can construct an inner kernel out of the newly-computed parameter + mcmc_init_fn + A callable that initializes the inner kernel + mcmc_parameters + Other (fixed across SMC iterations) parameters for the inner kernel step + mcmc_parameter_update_fn + A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the + inner kernel in i+1 iteration. + initial_parameter_value + Paramter to be used by the mcmc_factory before the first iteration. + extra_parameters: + parameters to be used for the creation of the smc_algorithm. + + Returns + ------- + A ``SamplingAlgorithm``. + + """ + + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + + def __new__( # type: ignore[misc] + cls, + smc_algorithm: Union[adaptive_tempered_smc, tempered_smc], + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_factory: Callable, + mcmc_init_fn: Callable, + mcmc_parameters: Dict, + resampling_fn: Callable, + mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], + initial_parameter_value, + num_mcmc_steps: int = 10, + **extra_parameters, + ) -> SamplingAlgorithm: + kernel = cls.build_kernel( + smc_algorithm, + logprior_fn, + loglikelihood_fn, + mcmc_factory, + mcmc_init_fn, + mcmc_parameters, + resampling_fn, + mcmc_parameter_update_fn, + num_mcmc_steps, + **extra_parameters, + ) + + def init_fn(position): + return cls.init(smc_algorithm.init, position, initial_parameter_value) + + def step_fn( + rng_key: PRNGKey, state, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + return kernel(rng_key, state, **extra_step_parameters) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/tuning/__init__.py b/blackjax/smc/tuning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/blackjax/smc/tuning/from_kernel_info.py b/blackjax/smc/tuning/from_kernel_info.py new file mode 100644 index 000000000..a039e66c1 --- /dev/null +++ b/blackjax/smc/tuning/from_kernel_info.py @@ -0,0 +1,46 @@ +""" +strategies to tune the parameters of mcmc kernels +used within smc, based on MCMC states +""" +import jax +import jax.numpy as jnp + +__all__ = ["update_scale_from_acceptance_rate"] + + +def update_scale_from_acceptance_rate( + scales: jax.Array, + acceptance_rates: jax.Array, + target_acceptance_rate: float = 0.234, +) -> jax.Array: + """ + Given N chains from some MCMC algorithm like Random Walk Metropolis + and N scale factors, each associated to a different chain. + Updates the scale factors taking into account acceptance rates and + the average acceptance rate. + + Under certain assumptions it is known that the optimal acceptance rate + of Metropolis Hastings is 0.4 for 1 dimension and converges to + 0.234 in infinite dimensions. In practice, 0.234 is a reasonable + assumption for 5 or more dimensions. + + If certain chain is below optimal acceptance rate, its scale will decrease + and if its above, its scale will increase, + ------- + + Parameters + ---------- + scales + (n_chains) array consisting of N scale factors, associated to N markov chains + acceptance_rates + (n_chains) acceptance rate of the N markov chains + target_acceptance_rate + a float with a desirable acceptance rate for the chains. + + Returns + ------- + (n_chains) new scales, with the aim of getting acceptance rates closer to target + if the chains were to be run again. + """ + chain_scales = jnp.exp(jnp.log(scales) + acceptance_rates - target_acceptance_rate) + return 0.5 * (chain_scales + chain_scales.mean()) diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py new file mode 100755 index 000000000..1df19ac26 --- /dev/null +++ b/blackjax/smc/tuning/from_particles.py @@ -0,0 +1,49 @@ +""" +strategies to tune the parameters of mcmc kernels +used within SMC, based on particles. +""" +import jax +import jax.numpy as jnp +from jax._src.flatten_util import ravel_pytree + +from blackjax.types import Array + +__all__ = [ + "particles_means", + "particles_stds", + "particles_covariance_matrix", + "mass_matrix_from_particles", +] + + +def particles_stds(particles): + return jnp.std(particles_as_rows(particles), axis=0) + + +def particles_means(particles): + return jnp.mean(particles_as_rows(particles), axis=0) + + +def particles_covariance_matrix(particles): + return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False) + + +def mass_matrix_from_particles(particles) -> Array: + """ + Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf + Computing a mass matrix to be used in HMC from particles. + Given the particles covariance matrix, set all non-diagonal elements as zero, + take the inverse, and keep the diagonal. + Returns + ------- + A mass Matrix + """ + return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0)) + + +def particles_as_rows(particles): + """ + Adds end dimension for single-dimension variables, and then represents multivariables + as a matrix where each column is a variable, each row a particle. + """ + return jax.vmap(lambda x: ravel_pytree(x)[0])(particles) diff --git a/tests/smc/__init__.py b/tests/smc/__init__.py index e69de29bb..7a4e5c029 100644 --- a/tests/smc/__init__.py +++ b/tests/smc/__init__.py @@ -0,0 +1,37 @@ +import chex +import jax.numpy as jnp +import jax.scipy.stats as stats +import numpy as np + + +class SMCLinearRegressionTestCase(chex.TestCase): + def logdensity_fn(self, log_scale, coefs, preds, x): + """Linear regression""" + scale = jnp.exp(log_scale) + y = jnp.dot(x, coefs) + logpdf = stats.norm.logpdf(preds, y, scale) + return jnp.sum(logpdf) + + def particles_prior_loglikelihood(self): + num_particles = 100 + + x_data = np.random.normal(0, 1, size=(1000, 1)) + y_data = 3 * x_data + np.random.normal(size=x_data.shape) + observations = {"x": x_data, "preds": y_data} + + logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( + x["coefs"] + ) + loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations) + + log_scale_init = np.random.randn(num_particles) + coeffs_init = np.random.randn(num_particles) + init_particles = {"log_scale": log_scale_init, "coefs": coeffs_init} + + return init_particles, logprior_fn, loglikelihood_fn + + def assert_linear_regression_test_case(self, result): + np.testing.assert_allclose( + np.mean(np.exp(result.particles["log_scale"])), 1.0, rtol=1e-1 + ) + np.testing.assert_allclose(np.mean(result.particles["coefs"]), 3.0, rtol=1e-1) diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py new file mode 100644 index 000000000..038f27e0a --- /dev/null +++ b/tests/smc/test_inner_kernel_tuning.py @@ -0,0 +1,380 @@ +import functools +import unittest +from unittest.mock import MagicMock + +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +import numpy as np +from absl.testing import absltest + +import blackjax +import blackjax.smc.resampling as resampling +from blackjax import adaptive_tempered_smc, tempered_smc +from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning +from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate +from blackjax.smc.tuning.from_particles import ( + mass_matrix_from_particles, + particles_as_rows, + particles_covariance_matrix, + particles_means, + particles_stds, +) +from tests.mcmc.test_sampling import irmh_proposal_distribution +from tests.smc import SMCLinearRegressionTestCase + + +class MultivariableParticlesDistribution: + """ + Builds particles for tests belonging to a posterior with more than one variable. + sample from P(x,y) x ~ N(mean, cov) y ~ N(mean, cov) + """ + + def __init__(self, n_particles, mean_x=None, mean_y=None, cov_x=None, cov_y=None): + self.n_particles = n_particles + self.mean_x = mean_x if mean_x is not None else [10.0, 5.0] + self.mean_y = mean_y if mean_y is not None else [0.0, 0.0] + self.cov_x = cov_x if cov_x is not None else [[1.0, 0.0], [0.0, 1.0]] + self.cov_y = cov_y if cov_y is not None else [[1.0, 0.0], [0.0, 1.0]] + + def get_particles(self): + return [ + np.random.multivariate_normal( + mean=self.mean_x, cov=self.cov_x, size=self.n_particles + ), + np.random.multivariate_normal( + mean=self.mean_y, cov=self.cov_y, size=self.n_particles + ), + ] + + +def kernel_logprob_fn(position): + return jnp.sum(stats.norm.logpdf(position)) + + +def log_weights_fn(x, y): + return jnp.sum(stats.norm.logpdf(y - x)) + + +class SMCParameterTuningTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(42) + + def logdensity_fn(self, log_scale, coefs, preds, x): + """Linear regression""" + scale = jnp.exp(log_scale) + y = jnp.dot(x, coefs) + logpdf = stats.norm.logpdf(preds, y, scale) + return jnp.sum(logpdf) + + def test_smc_inner_kernel_adaptive_tempered(self): + self.smc_inner_kernel_tuning_test_case( + blackjax.adaptive_tempered_smc, + smc_parameters={"target_ess": 0.5}, + step_parameters={}, + ) + + def test_smc_inner_kernel_tempered(self): + self.smc_inner_kernel_tuning_test_case( + blackjax.tempered_smc, smc_parameters={}, step_parameters={"lmbda": 0.75} + ) + + def smc_inner_kernel_tuning_test_case( + self, smc_algorithm, smc_parameters, step_parameters + ): + specialized_log_weights_fn = lambda tree: log_weights_fn(tree, 1.0) + # Don't use exactly the invariant distribution for the MCMC kernel + init_particles = 0.25 + np.random.randn(1000) * 50 + + proposal_factory = MagicMock() + proposal_factory.return_value = 100 + + def mcmc_parameter_update_fn(state, info): + return 100 + + mcmc_factory = MagicMock() + sampling_algorithm = MagicMock() + mcmc_factory.return_value = sampling_algorithm + prior = lambda x: stats.norm.logpdf(x) + + def kernel_factory(proposal_distribution): + kernel = blackjax.irmh.build_kernel() + + def wrapped_kernel(rng_key, state, logdensity): + return kernel(rng_key, state, logdensity, proposal_distribution) + + return wrapped_kernel + + kernel = inner_kernel_tuning( + logprior_fn=prior, + loglikelihood_fn=specialized_log_weights_fn, + mcmc_factory=kernel_factory, + mcmc_init_fn=blackjax.irmh.init, + resampling_fn=resampling.systematic, + smc_algorithm=smc_algorithm, + mcmc_parameters={}, + mcmc_parameter_update_fn=mcmc_parameter_update_fn, + initial_parameter_value=irmh_proposal_distribution, + **smc_parameters, + ) + + new_state, new_info = kernel.step( + self.key, state=kernel.init(init_particles), **step_parameters + ) + assert new_state.parameter_override == 100 + + +class MeanAndStdFromParticlesTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(42) + + def test_mean_and_std(self): + particles = np.array( + [ + jnp.array([10]) + jax.random.normal(key) * jnp.array([0.5]) + for key in jax.random.split(self.key, 1000) + ] + ) + mean = particles_means(particles) + std = particles_stds(particles) + cov = particles_covariance_matrix(particles) + np.testing.assert_allclose(mean, 10.0, rtol=1e-1) + np.testing.assert_allclose(std, 0.5, rtol=1e-1) + np.testing.assert_allclose(cov, 0.24, rtol=1e-1) + + def test_mean_and_std_multivariate_particles(self): + particles = np.array( + [ + jnp.array([10.0, 15.0]) + jax.random.normal(key) * jnp.array([0.5, 0.7]) + for key in jax.random.split(self.key, 1000) + ] + ) + + mean = particles_means(particles) + std = particles_stds(particles) + cov = particles_covariance_matrix(particles) + np.testing.assert_allclose(mean, np.array([10.0, 15.0]), rtol=1e-1) + np.testing.assert_allclose(std, np.array([0.5, 0.7]), rtol=1e-1) + np.testing.assert_allclose( + cov, np.array([[0.249529, 0.34934], [0.34934, 0.489076]]), atol=1e-1 + ) + + def test_mean_and_std_multivariable_particles(self): + var1 = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + var2 = np.array([jnp.array([10.0]), jnp.array([3.0])]) + particles = {"var1": var1, "var2": var2} + mean = particles_means(particles) + std = particles_stds(particles) + cov = particles_covariance_matrix(particles) + np.testing.assert_allclose(mean, np.array([6.5, 9.5, 6.5])) + np.testing.assert_allclose(std, np.array([3.5, 5.5, 3.5])) + np.testing.assert_allclose( + cov, + np.array( + [[12.25, 19.25, 12.25], [19.25, 30.25, 19.25], [12.25, 19.25, 12.25]] + ), + ) + + +class InverseMassMatrixFromParticles(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(42) + + def test_inverse_mass_matrix_from_particles(self): + inverse_mass_matrix = mass_matrix_from_particles( + np.array([np.array(10.0), np.array(3.0)]) + ) + np.testing.assert_allclose( + inverse_mass_matrix, np.diag(np.array([0.08163])), rtol=1e-4 + ) + + def test_inverse_mass_matrix_from_multivariate_particles(self): + inverse_mass_matrix = mass_matrix_from_particles( + np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + ) + np.testing.assert_allclose( + inverse_mass_matrix, np.diag(np.array([0.081633, 0.033058])), rtol=1e-4 + ) + + def test_inverse_mass_matrix_from_multivariable_particles(self): + var1 = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + var2 = np.array([jnp.array([10.0]), jnp.array([3.0])]) + init_particles = {"var1": var1, "var2": var2} + mass_matrix = mass_matrix_from_particles(init_particles) + assert mass_matrix.shape == (3, 3) + np.testing.assert_allclose( + np.diag(mass_matrix), + np.array([0.081633, 0.033058, 0.081633], dtype="float32"), + rtol=1e-4, + ) + + def test_inverse_mass_matrix_from_multivariable_univariate_particles(self): + var1 = np.array([3.0, 2.0]) + var2 = np.array([10.0, 3.0]) + init_particles = {"var1": var1, "var2": var2} + mass_matrix = mass_matrix_from_particles(init_particles) + assert mass_matrix.shape == (2, 2) + np.testing.assert_allclose( + np.diag(mass_matrix), np.array([4, 0.081633], dtype="float32"), rtol=1e-4 + ) + + +class ScaleCovarianceFromAcceptanceRates(chex.TestCase): + def test_scale_when_aceptance_below_optimal(self): + """ + Given that the acceptance rate is below optimal, + the scale gets reduced. + """ + np.testing.assert_allclose( + update_scale_from_acceptance_rate( + scales=jnp.array([0.5]), acceptance_rates=jnp.array([0.2]) + ), + jnp.array([0.483286]), + rtol=1e-4, + ) + + def test_scale_when_aceptance_above_optimal(self): + """ + Given that the acceptance rate is above optimal + the scale increases + ------- + """ + np.testing.assert_allclose( + update_scale_from_acceptance_rate( + scales=jnp.array([0.5]), acceptance_rates=jnp.array([0.3]) + ), + jnp.array([0.534113]), + rtol=1e-4, + ) + + def test_scale_mean_smoothes(self): + """ + The end result depends on the mean acceptance rate, + smoothing the results + """ + np.testing.assert_allclose( + update_scale_from_acceptance_rate( + scales=jnp.array([0.5, 0.5]), acceptance_rates=jnp.array([0.3, 0.2]) + ), + jnp.array([0.521406, 0.495993]), + rtol=1e-4, + ) + + +class InnerKernelTuningJitTest(SMCLinearRegressionTestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def mcmc_factory(self, mass_matrix): + return functools.partial( + blackjax.hmc.build_kernel(), + inverse_mass_matrix=mass_matrix, + step_size=10e-2, + num_integration_steps=50, + ) + + @chex.all_variants(with_pmap=False) + def test_with_adaptive_tempered(self): + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + init, step = blackjax.inner_kernel_tuning( + adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + self.mcmc_factory, + blackjax.hmc.init, + {}, + resampling.systematic, + mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles( + state.particles + ), + initial_parameter_value=jnp.eye(2), + num_mcmc_steps=10, + target_ess=0.5, + ) + init_state = init(init_particles) + smc_kernel = self.variant(step) + + def inference_loop(kernel, rng_key, initial_state): + def cond(carry): + state, key = carry + return state.sampler_state.lmbda < 1 + + def body(carry): + state, op_key = carry + op_key, subkey = jax.random.split(op_key, 2) + state, _ = kernel(subkey, state) + return state, op_key + + return jax.lax.while_loop(cond, body, (initial_state, rng_key)) + + state, _ = inference_loop(smc_kernel, self.key, init_state) + + assert state.parameter_override.shape == (2, 2) + self.assert_linear_regression_test_case(state.sampler_state) + + @chex.all_variants(with_pmap=False) + def test_with_tempered_smc(self): + num_tempering_steps = 10 + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + init, step = blackjax.inner_kernel_tuning( + tempered_smc, + logprior_fn, + loglikelihood_fn, + self.mcmc_factory, + blackjax.hmc.init, + {}, + resampling.systematic, + mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles( + state.particles + ), + initial_parameter_value=jnp.eye(2), + num_mcmc_steps=10, + ) + + init_state = init(init_particles) + smc_kernel = self.variant(step) + + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + + def body_fn(carry, lmbda): + rng_key, state = carry + rng_key, subkey = jax.random.split(rng_key) + new_state, info = smc_kernel(subkey, state, lmbda=lmbda) + return (rng_key, new_state), (new_state, info) + + (_, result), _ = jax.lax.scan(body_fn, (self.key, init_state), lambda_schedule) + self.assert_linear_regression_test_case(result.sampler_state) + + +class ParticlesAsRowsTest(unittest.TestCase): + def test_particles_as_rows(self): + n_particles = 1000 + test_particles = { + "a": np.zeros(n_particles), + "b": np.ones([n_particles, 1]), + "c": np.repeat( + (np.arange(3 * 5) + 2).reshape(3, 5)[None, ...], n_particles, axis=0 + ), + } + flatten_particles = particles_as_rows(test_particles) + assert flatten_particles.shape == (n_particles, 3 * 5 + 2) + np.testing.assert_array_equal(np.arange(3 * 5 + 2), flatten_particles[0]) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index 4aa453869..f4234d117 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -12,6 +12,7 @@ import blackjax.smc.resampling as resampling import blackjax.smc.solver as solver from blackjax import adaptive_tempered_smc, tempered_smc +from tests.smc import SMCLinearRegressionTestCase def inference_loop(kernel, rng_key, initial_state): @@ -32,20 +33,13 @@ def body(carry): return total_iter, final_state, log_likelihood -class TemperedSMCTest(chex.TestCase): +class TemperedSMCTest(SMCLinearRegressionTestCase): """Test posterior mean estimate.""" def setUp(self): super().setUp() self.key = jax.random.key(42) - def logdensity_fn(self, log_scale, coefs, preds, x): - """Linear regression""" - scale = jnp.exp(log_scale) - y = jnp.dot(x, coefs) - logpdf = stats.norm.logpdf(preds, y, scale) - return jnp.sum(logpdf) - @chex.variants(with_jit=True) def test_adaptive_tempered_smc(self): num_particles = 100 @@ -105,21 +99,13 @@ def logprior_fn(x): @chex.variants(with_jit=True) def test_fixed_schedule_tempered_smc(self): - num_particles = 100 - num_tempering_steps = 10 - - x_data = np.random.normal(0, 1, size=(1000, 1)) - y_data = 3 * x_data + np.random.normal(size=x_data.shape) - observations = {"x": x_data, "preds": y_data} - - logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( - x["coefs"] - ) - loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations) + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() - log_scale_init = np.random.randn(num_particles) - coeffs_init = np.random.randn(num_particles) - init_particles = {"log_scale": log_scale_init, "coefs": coeffs_init} + num_tempering_steps = 10 lambda_schedule = np.logspace(-5, 0, num_tempering_steps) hmc_init = blackjax.hmc.init @@ -149,10 +135,7 @@ def body_fn(carry, lmbda): return (rng_key, new_state), (new_state, info) (_, result), _ = jax.lax.scan(body_fn, (self.key, init_state), lambda_schedule) - np.testing.assert_allclose( - np.mean(np.exp(result.particles["log_scale"])), 1.0, rtol=1e-1 - ) - np.testing.assert_allclose(np.mean(result.particles["coefs"]), 3.0, rtol=1e-1) + self.assert_linear_regression_test_case(result) def normal_logdensity_fn(x, chol_cov):