From 5a25352bde139751bb556886d7543ab17f5671f7 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Sat, 5 Oct 2024 01:32:43 -0300 Subject: [PATCH] Partial posteriors SMC and refactor to decouple tempering from SMC construction (#729) * extracting taking last * test passing * layering * example * more * Adding another example * tests in place * rolling back changes * Adding test for num_mcmc_steps * format * better test coverage * linter * Flake8 * black * implementation[ * partial posteriors implementation * rolling back some changes * linter * fixing test * adding reference * typo * exposing in top level api * reruning precommit * adding more steps * smaller step size * fixes on comments * small fix on formating * renaming to data mask * linter --- blackjax/__init__.py | 4 +- blackjax/smc/__init__.py | 1 + blackjax/smc/base.py | 28 +++++ blackjax/smc/from_mcmc.py | 64 ++++++++++++ blackjax/smc/partial_posteriors_path.py | 127 +++++++++++++++++++++++ blackjax/smc/tempered.py | 66 ++---------- tests/smc/__init__.py | 31 +++++- tests/smc/test_partial_posteriors_smc.py | 88 ++++++++++++++++ 8 files changed, 350 insertions(+), 59 deletions(-) create mode 100644 blackjax/smc/from_mcmc.py create mode 100644 blackjax/smc/partial_posteriors_path.py create mode 100644 tests/smc/test_partial_posteriors_smc.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index dfdcfc545..5858c34aa 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -35,6 +35,7 @@ from .sgmcmc import sgnht as _sgnht from .smc import adaptive_tempered from .smc import inner_kernel_tuning as _inner_kernel_tuning +from .smc import partial_posteriors_path as _partial_posteriors_smc from .smc import tempered from .vi import meanfield_vi as _meanfield_vi from .vi import pathfinder as _pathfinder @@ -119,8 +120,9 @@ def generate_top_level_api_from(module): adaptive_tempered_smc = generate_top_level_api_from(adaptive_tempered) tempered_smc = generate_top_level_api_from(tempered) inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) +partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc) -smc_family = [tempered_smc, adaptive_tempered_smc] +smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc] "Step_fn returning state has a .particles attribute" # stochastic gradient mcmc diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index ef10b10e6..9670fcb6e 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -6,4 +6,5 @@ "tempered", "inner_kernel_tuning", "extend_params", + "partial_posteriors_path", ] diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 5093cf06b..56df7f010 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -156,3 +156,31 @@ def extend_params(params): """ return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params) + + +def update_and_take_last( + mcmc_init_fn, + tempered_logposterior_fn, + shared_mcmc_step_fn, + num_mcmc_steps, + n_particles, +): + """Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and + returns the last values, waisting the previous num_mcmc_steps-1 + samples per chain. + """ + + def mcmc_kernel(rng_key, position, step_parameters): + state = mcmc_init_fn(position, tempered_logposterior_fn) + + def body_fn(state, rng_key): + new_state, info = shared_mcmc_step_fn( + rng_key, state, tempered_logposterior_fn, **step_parameters + ) + return new_state, info + + keys = jax.random.split(rng_key, num_mcmc_steps) + last_state, info = jax.lax.scan(body_fn, state, keys) + return last_state.position, info + + return jax.vmap(mcmc_kernel), n_particles diff --git a/blackjax/smc/from_mcmc.py b/blackjax/smc/from_mcmc.py new file mode 100644 index 000000000..0e60b5968 --- /dev/null +++ b/blackjax/smc/from_mcmc.py @@ -0,0 +1,64 @@ +from functools import partial +from typing import Callable + +import jax + +from blackjax import smc +from blackjax.smc.base import SMCState, update_and_take_last +from blackjax.types import PRNGKey + + +def build_kernel( + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + update_strategy: Callable = update_and_take_last, +): + """SMC step from MCMC kernels. + Builds MCMC kernels from the input parameters, which may change across iterations. + Moreover, it defines the way such kernels are used to update the particles. This layer + adapts an API defined in terms of kernels (mcmc_step_fn and mcmc_init_fn) into an API + that depends on an update function over the set of particles. + Returns + ------- + A callable that takes a rng_key and a state with .particles and .weights and returns a base.SMCState + and base.SMCInfo pair. + + """ + + def step( + rng_key: PRNGKey, + state, + num_mcmc_steps: int, + mcmc_parameters: dict, + logposterior_fn: Callable, + log_weights_fn: Callable, + ) -> tuple[smc.base.SMCState, smc.base.SMCInfo]: + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v + + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + + update_fn, num_resampled = update_strategy( + mcmc_init_fn, + logposterior_fn, + shared_mcmc_step_fn, + n_particles=state.weights.shape[0], + num_mcmc_steps=num_mcmc_steps, + ) + + return smc.base.step( + rng_key, + SMCState(state.particles, state.weights, unshared_mcmc_parameters), + update_fn, + jax.vmap(log_weights_fn), + resampling_fn, + num_resampled, + ) + + return step diff --git a/blackjax/smc/partial_posteriors_path.py b/blackjax/smc/partial_posteriors_path.py new file mode 100644 index 000000000..81f19716d --- /dev/null +++ b/blackjax/smc/partial_posteriors_path.py @@ -0,0 +1,127 @@ +from typing import Callable, NamedTuple, Optional, Tuple + +import jax +import jax.numpy as jnp + +from blackjax import SamplingAlgorithm, smc +from blackjax.smc.base import update_and_take_last +from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey + + +class PartialPosteriorsSMCState(NamedTuple): + """Current state for the tempered SMC algorithm. + + particles: PyTree + The particles' positions. + weights: + Weights of the particles, so that they represent a probability distribution + data_mask: + A 1D boolean array to indicate which datapoints to include + in the computation of the observed likelihood. + """ + + particles: ArrayTree + weights: Array + data_mask: Array + + +def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState: + """num_datapoints are the number of observations that could potentially be + used in a partial posterior. Since the initial data_mask is all 0s, it + means that no likelihood term will be added (only prior). + """ + num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] + weights = jnp.ones(num_particles) / num_particles + return PartialPosteriorsSMCState(particles, weights, jnp.zeros(num_datapoints)) + + +def build_kernel( + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + num_mcmc_steps: Optional[int], + mcmc_parameters: ArrayTree, + partial_logposterior_factory: Callable[[Array], Callable], + update_strategy=update_and_take_last, +) -> Callable: + """Build the Partial Posteriors (data tempering) SMC kernel. + The distribution's trajectory includes increasingly adding more + datapoints to the likelihood. See Section 2.2 of https://arxiv.org/pdf/2007.11936 + Parameters + ---------- + mcmc_step_fn + A function that computes the log density of the prior distribution + mcmc_init_fn + A function that returns the probability at a given position. + resampling_fn + A random function that resamples generated particles based of weights + num_mcmc_steps + Number of iterations in the MCMC chain. + mcmc_parameters + A dictionary of parameters to be used by the inner MCMC kernels + partial_logposterior_factory: + A callable that given an array of 0 and 1, returns a function logposterior(x). + The array represents which values to include in the logposterior calculation. The logposterior + must be jax compilable. + + Returns + ------- + A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for + the current and previous posteriors, and takes a data-tempered SMC state. + """ + delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) + + def step( + key, state: PartialPosteriorsSMCState, data_mask: Array + ) -> Tuple[PartialPosteriorsSMCState, smc.base.SMCInfo]: + logposterior_fn = partial_logposterior_factory(data_mask) + + previous_logposterior_fn = partial_logposterior_factory(state.data_mask) + + def log_weights_fn(x): + return logposterior_fn(x) - previous_logposterior_fn(x) + + state, info = delegate( + key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn + ) + + return ( + PartialPosteriorsSMCState(state.particles, state.weights, data_mask), + info, + ) + + return step + + +def as_top_level_api( + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + mcmc_parameters: dict, + resampling_fn: Callable, + num_mcmc_steps, + partial_logposterior_factory: Callable, + update_strategy=update_and_take_last, +) -> SamplingAlgorithm: + """A factory that wraps the kernel into a SamplingAlgorithm object. + See build_kernel for full documentation on the parameters. + """ + + kernel = build_kernel( + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + num_mcmc_steps, + mcmc_parameters, + partial_logposterior_factory, + update_strategy, + ) + + def init_fn(position: ArrayLikeTree, num_observations, rng_key=None): + del rng_key + return init(position, num_observations) + + def step(key: PRNGKey, state: PartialPosteriorsSMCState, data_mask: Array): + return kernel(key, state, data_mask) + + return SamplingAlgorithm(init_fn, step) # type: ignore[arg-type] diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 19de8afb7..88539deaa 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, NamedTuple, Optional import jax import jax.numpy as jnp import blackjax.smc as smc +import blackjax.smc.from_mcmc as smc_from_mcmc from blackjax.base import SamplingAlgorithm -from blackjax.smc.base import SMCState +from blackjax.smc.base import update_and_take_last from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["TemperedSMCState", "init", "build_kernel", "as_top_level_api"] @@ -48,35 +48,6 @@ def init(particles: ArrayLikeTree): return TemperedSMCState(particles, weights, 0.0) -def update_and_take_last( - mcmc_init_fn, - tempered_logposterior_fn, - shared_mcmc_step_fn, - num_mcmc_steps, - n_particles, -): - """ - Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and - returns the last values, waisting the previous num_mcmc_steps-1 - samples per chain. - """ - - def mcmc_kernel(rng_key, position, step_parameters): - state = mcmc_init_fn(position, tempered_logposterior_fn) - - def body_fn(state, rng_key): - new_state, info = shared_mcmc_step_fn( - rng_key, state, tempered_logposterior_fn, **step_parameters - ) - return new_state, info - - keys = jax.random.split(rng_key, num_mcmc_steps) - last_state, info = jax.lax.scan(body_fn, state, keys) - return last_state.position, info - - return jax.vmap(mcmc_kernel), n_particles - - def build_kernel( logprior_fn: Callable, loglikelihood_fn: Callable, @@ -121,6 +92,9 @@ def build_kernel( information about the transition. """ + delegate = smc_from_mcmc.build_kernel( + mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy + ) def kernel( rng_key: PRNGKey, @@ -153,14 +127,6 @@ def kernel( """ delta = lmbda - state.lmbda - shared_mcmc_parameters = {} - unshared_mcmc_parameters = {} - for k, v in mcmc_parameters.items(): - if v.shape[0] == 1: - shared_mcmc_parameters[k] = v[0, ...] - else: - unshared_mcmc_parameters[k] = v - def log_weights_fn(position: ArrayLikeTree) -> float: return delta * loglikelihood_fn(position) @@ -169,23 +135,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood - shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) - - update_fn, num_resampled = update_strategy( - mcmc_init_fn, - tempered_logposterior_fn, - shared_mcmc_step_fn, - n_particles=state.weights.shape[0], - num_mcmc_steps=num_mcmc_steps, - ) - - smc_state, info = smc.base.step( + smc_state, info = delegate( rng_key, - SMCState(state.particles, state.weights, unshared_mcmc_parameters), - update_fn, - jax.vmap(log_weights_fn), - resampling_fn, - num_resampled, + state, + num_mcmc_steps, + mcmc_parameters, + tempered_logposterior_fn, + log_weights_fn, ) tempered_state = TemperedSMCState( diff --git a/tests/smc/__init__.py b/tests/smc/__init__.py index 7a4e5c029..006d7ba38 100644 --- a/tests/smc/__init__.py +++ b/tests/smc/__init__.py @@ -5,19 +5,27 @@ class SMCLinearRegressionTestCase(chex.TestCase): - def logdensity_fn(self, log_scale, coefs, preds, x): - """Linear regression""" + def logdensity_by_observation(self, log_scale, coefs, preds, x): scale = jnp.exp(log_scale) y = jnp.dot(x, coefs) logpdf = stats.norm.logpdf(preds, y, scale) + return logpdf + + def logdensity_fn(self, log_scale, coefs, preds, x): + """Linear regression""" + logpdf = self.logdensity_by_observation(log_scale, coefs, preds, x) return jnp.sum(logpdf) - def particles_prior_loglikelihood(self): + def observations(self): num_particles = 100 x_data = np.random.normal(0, 1, size=(1000, 1)) y_data = 3 * x_data + np.random.normal(size=x_data.shape) observations = {"x": x_data, "preds": y_data} + return observations, num_particles + + def particles_prior_loglikelihood(self): + observations, num_particles = self.observations() logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( x["coefs"] @@ -30,6 +38,23 @@ def particles_prior_loglikelihood(self): return init_particles, logprior_fn, loglikelihood_fn + def partial_posterior_test_case(self): + num_particles = 100 + + x_data = np.random.normal(0, 1, size=(1000, 1)) + y_data = 3 * x_data + np.random.normal(size=x_data.shape) + observations = {"x": x_data, "preds": y_data} + + logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( + x["coefs"] + ) + + log_scale_init = np.random.randn(num_particles) + coeffs_init = np.random.randn(num_particles) + init_particles = {"log_scale": log_scale_init, "coefs": coeffs_init} + + return init_particles, logprior_fn, observations + def assert_linear_regression_test_case(self, result): np.testing.assert_allclose( np.mean(np.exp(result.particles["log_scale"])), 1.0, rtol=1e-1 diff --git a/tests/smc/test_partial_posteriors_smc.py b/tests/smc/test_partial_posteriors_smc.py new file mode 100644 index 000000000..78d57a934 --- /dev/null +++ b/tests/smc/test_partial_posteriors_smc.py @@ -0,0 +1,88 @@ +import chex +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest + +import blackjax +import blackjax.smc.resampling as resampling +from blackjax.smc import extend_params +from tests.smc import SMCLinearRegressionTestCase + + +class PartialPosteriorsSMCTest(SMCLinearRegressionTestCase): + """Test posterior mean estimate.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.variants(with_jit=True) + def test_partial_posteriors(self): + ( + init_particles, + logprior_fn, + observations, + ) = self.partial_posterior_test_case() + + hmc_init = blackjax.hmc.init + hmc_kernel = blackjax.hmc.build_kernel() + + hmc_parameters = extend_params( + { + "step_size": 10e-3, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) + + dataset_size = 1000 + + def partial_logposterior_factory(data_mask): + def partial_logposterior(x): + lp = logprior_fn(x) + return lp + jnp.sum( + self.logdensity_by_observation(**x, **observations) + * data_mask.reshape(-1, 1) + ) + + return jax.jit(partial_logposterior) + + init, kernel = blackjax.partial_posteriors_smc( + hmc_kernel, + hmc_init, + hmc_parameters, + resampling.systematic, + 50, + partial_logposterior_factory=partial_logposterior_factory, + ) + + init_state = init(init_particles, 1000) + smc_kernel = self.variant(kernel) + + data_masks = jnp.array( + [ + jnp.concat( + [ + jnp.ones(datapoints_chosen), + jnp.zeros(dataset_size - datapoints_chosen), + ] + ) + for datapoints_chosen in np.arange(100, 1001, 50) + ] + ) + + def body_fn(carry, data_mask): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, data_mask) + return (i + 1, new_state), (new_state, info) + + (steps, result), it = jax.lax.scan(body_fn, (0, init_state), data_masks) + assert steps == 19 + + self.assert_linear_regression_test_case(result) + + +if __name__ == "__main__": + absltest.main()