From 9b4193ad2a04214242b9a3d9e7ed486a8cb13438 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Sun, 29 Oct 2023 12:19:16 +0100 Subject: [PATCH 1/2] partially fix #391 --- blackjax/mcmc/hmc.py | 5 +++-- blackjax/mcmc/mala.py | 4 ++-- blackjax/mcmc/proposal.py | 25 +++++---------------- blackjax/mcmc/random_walk.py | 4 ++-- blackjax/mcmc/trajectory.py | 12 +++++----- tests/mcmc/test_proposal_without_chex.py | 18 +++++++-------- tests/mcmc/test_random_walk_without_chex.py | 2 +- 7 files changed, 29 insertions(+), 41 deletions(-) diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 228fd0b51..56e1c1790 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -276,7 +276,7 @@ def hmc_proposal( """ build_trajectory = trajectory.static_integration(integrator) init_proposal, generate_proposal = proposal.proposal_generator( - hmc_energy(kinetic_energy), divergence_threshold + hmc_energy(kinetic_energy) ) def generate( @@ -286,7 +286,8 @@ def generate( end_state = build_trajectory(state, step_size, num_integration_steps) end_state = flip_momentum(end_state) proposal = init_proposal(state) - new_proposal, is_diverging = generate_proposal(proposal.energy, end_state) + new_proposal = generate_proposal(proposal.energy, end_state) + is_diverging = -new_proposal.weight > divergence_threshold sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal) do_accept, p_accept = info diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 0f4295a0e..f4aad2fb6 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -90,7 +90,7 @@ def transition_energy(state, new_state, step_size): return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot init_proposal, generate_proposal = proposal.asymmetric_proposal_generator( - transition_energy, divergence_threshold=jnp.inf + transition_energy ) sample_proposal = proposal.static_binomial_sampling @@ -107,7 +107,7 @@ def kernel( new_state = MALAState(*new_state) proposal = init_proposal(state) - new_proposal, _ = generate_proposal(state, new_state, step_size=step_size) + new_proposal = generate_proposal(state, new_state, step_size=step_size) sampled_proposal, do_accept, p_accept = sample_proposal( key_rmh, proposal, new_proposal ) diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 2ed3eca87..259a5df0b 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -41,7 +41,7 @@ class Proposal(NamedTuple): def proposal_generator( - energy: Callable, divergence_threshold: float + energy: Callable ) -> tuple[Callable, Callable]: """ @@ -49,8 +49,6 @@ def proposal_generator( ---------- energy A function that computes the energy associated to a given state - divergence_threshold - max value allowed for the difference in energies not to be considered a divergence Returns ------- @@ -84,7 +82,7 @@ def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, boo """ new_energy = energy(state) return proposal_from_energy_diff( - initial_energy, new_energy, divergence_threshold, state + initial_energy, new_energy, state ) return new, update @@ -93,22 +91,16 @@ def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, boo def proposal_from_energy_diff( initial_energy: float, new_energy: float, - divergence_threshold: float, state: TrajectoryState, ) -> tuple[Proposal, bool]: """Computes a new proposal from the energy difference between two states. - It also verifies whether this difference is a divergence, if the - energy diff is above divergence_threshold. - Parameters ---------- initial_energy the energy from the initial state new_energy the energy at the proposed state - divergence_threshold - max value allowed for an increase in energies not to be considered a divergence state the proposed state @@ -118,7 +110,6 @@ def proposal_from_energy_diff( """ delta_energy = initial_energy - new_energy delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) - is_transition_divergent = -delta_energy > divergence_threshold # The weight of the new proposal is equal to H0 - H(z_new) weight = delta_energy @@ -126,20 +117,16 @@ def proposal_from_energy_diff( # Acceptance statistic min(e^{H0 - H(z_new)}, 1) sum_log_p_accept = jnp.minimum(delta_energy, 0.0) - return ( - Proposal( + return Proposal( state, new_energy, weight, sum_log_p_accept, - ), - is_transition_divergent, - ) + ) def asymmetric_proposal_generator( transition_energy_fn: Callable, - divergence_threshold: float, proposal_factory: Callable = proposal_from_energy_diff, ) -> tuple[Callable, Callable]: """A proposal generator that takes into account the transition between @@ -153,8 +140,6 @@ def asymmetric_proposal_generator( transition_energy_fn A function that computes the energy of a transition from an initial state to a new state, given some optional keyword arguments. - divergence_threshold - The maximum value allowed for the difference in energies not to be considered a divergence. proposal_factory A function that builds a proposal from the transition energies. @@ -174,7 +159,7 @@ def update( ) -> tuple[Proposal, bool]: new_energy = transition_energy_fn(initial_state, state, **energy_params) prev_energy = transition_energy_fn(state, initial_state, **energy_params) - return proposal_factory(prev_energy, new_energy, divergence_threshold, state) + return proposal_factory(prev_energy, new_energy, state) return new, update diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index 6d97c7c08..e8a4b9659 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -391,7 +391,7 @@ def kernel( transition_energy = build_rmh_transition_energy(proposal_logdensity_fn) init_proposal, generate_proposal = proposal.asymmetric_proposal_generator( - transition_energy, np.inf + transition_energy ) proposal_generator = rmh_proposal( @@ -496,7 +496,7 @@ def build_trajectory(rng_key, initial_state: RWState) -> RWState: def generate(rng_key, state: RWState) -> tuple[RWState, bool, float]: key_proposal, key_accept = jax.random.split(rng_key, 2) end_state = build_trajectory(key_proposal, state) - new_proposal, _ = generate_proposal(state, end_state) + new_proposal = generate_proposal(state, end_state) previous_proposal = init_proposal(state) sampled_proposal, do_accept, p_accept = sample_proposal( key_accept, previous_proposal, new_proposal diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 81d369c0b..b06d36930 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -160,7 +160,7 @@ def dynamic_progressive_integration( """ _, generate_proposal = proposal_generator( - hmc_energy(kinetic_energy), divergence_threshold + hmc_energy(kinetic_energy) ) sample_proposal = progressive_uniform_sampling @@ -215,7 +215,8 @@ def add_one_state(loop_state): rng_key, proposal_key = jax.random.split(rng_key) new_state = integrator(trajectory.rightmost_state, direction * step_size) - new_proposal, is_diverging = generate_proposal(initial_energy, new_state) + new_proposal = generate_proposal(initial_energy, new_state) + is_diverging = -new_proposal.weight > divergence_threshold # At step 0, we always accept the proposal, since we # take one step to get the leftmost state of the tree. @@ -248,7 +249,7 @@ def add_one_state(loop_state): return (rng_key, new_integration_state, (is_diverging, has_terminated)) - proposal_placeholder, _ = generate_proposal(initial_energy, initial_state) + proposal_placeholder = generate_proposal(initial_energy, initial_state) trajectory_placeholder = Trajectory( initial_state, initial_state, initial_state.momentum, 0 ) @@ -320,7 +321,7 @@ def dynamic_recursive_integration( """ _, generate_proposal = proposal_generator( - hmc_energy(kinetic_energy), divergence_threshold + hmc_energy(kinetic_energy) ) sample_proposal = progressive_uniform_sampling @@ -357,7 +358,8 @@ def buildtree_integrate( if tree_depth == 0: # Base case - take one leapfrog step in the direction v. next_state = integrator(initial_state, direction * step_size) - new_proposal, is_diverging = generate_proposal(initial_energy, next_state) + new_proposal = generate_proposal(initial_energy, next_state) + is_diverging = -new_proposal.weight > divergence_threshold trajectory = Trajectory(next_state, next_state, next_state.momentum, 1) return ( rng_key, diff --git a/tests/mcmc/test_proposal_without_chex.py b/tests/mcmc/test_proposal_without_chex.py index b34d758ef..9741c6737 100644 --- a/tests/mcmc/test_proposal_without_chex.py +++ b/tests/mcmc/test_proposal_without_chex.py @@ -14,7 +14,7 @@ class TestAsymmetricProposalGenerator(unittest.TestCase): def test_new(self): state = MagicMock() - new, _ = asymmetric_proposal_generator(None, None, None) + new, _ = asymmetric_proposal_generator(None, None) assert new(state) == Proposal(state, 0.0, 0.0, -np.inf) def test_update(self): @@ -23,15 +23,14 @@ def transition_energy(prev, next): new_proposal = MagicMock() - def proposal_factory(prev_energy, new_energy, divergence_threshold, new_state): + def proposal_factory(prev_energy, new_energy, new_state): assert prev_energy == -20 assert new_energy == 20 - assert divergence_threshold == 50 assert new_state == 50 return new_proposal _, update = asymmetric_proposal_generator( - transition_energy, 50, proposal_factory + transition_energy, proposal_factory ) proposed = update(30, 50) assert proposed == new_proposal @@ -52,25 +51,26 @@ class TestProposalFromEnergyDiff(parameterized.TestCase): ) def test_divergence_threshold(self, before, after, threshold, is_divergent): state = MagicMock() - proposal, divergence = proposal_from_energy_diff(5, 10, threshold, state) + proposal = proposal_from_energy_diff(5, 10, state) + divergence = -proposal.weight > threshold assert divergence == is_divergent def test_sum_log_paccept(self): state = MagicMock() - proposal, _ = proposal_from_energy_diff(5, 10, 0, state) + proposal = proposal_from_energy_diff(5, 10, state) np.testing.assert_allclose(proposal.sum_log_p_accept, -5.0) - proposal, _ = proposal_from_energy_diff(10, 5, 0, state) + proposal = proposal_from_energy_diff(10, 5, state) np.testing.assert_allclose(proposal.sum_log_p_accept, 0.0) def test_delta_energy_is_nan(self): state = MagicMock() - proposal, _ = proposal_from_energy_diff(np.nan, np.nan, 0, state) + proposal = proposal_from_energy_diff(np.nan, np.nan, state) assert np.isneginf(proposal.weight) def test_weight(self): state = MagicMock() - proposal, _ = proposal_from_energy_diff(5, 10, 0, state) + proposal = proposal_from_energy_diff(5, 10, state) assert proposal.state == state np.testing.assert_allclose(proposal.weight, -5) diff --git a/tests/mcmc/test_random_walk_without_chex.py b/tests/mcmc/test_random_walk_without_chex.py index 6e4e7afe1..8bbcd578e 100644 --- a/tests/mcmc/test_random_walk_without_chex.py +++ b/tests/mcmc/test_random_walk_without_chex.py @@ -90,7 +90,7 @@ def init_proposal(self, state): return Proposal(state, 0, 0, 0) def generate_proposal(self, prev, new): - return Proposal(new, 0, 0, 0), False + return Proposal(new, 0, 0, 0) def test_generate_reject(self): """ From 48acc2375cd98a99f2fc6ee5dfa9d6c890475268 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Sun, 29 Oct 2023 12:33:02 +0100 Subject: [PATCH 2/2] fix formatting --- blackjax/mcmc/proposal.py | 22 +++++++++------------- blackjax/mcmc/random_walk.py | 1 - blackjax/mcmc/trajectory.py | 8 ++------ tests/mcmc/test_proposal_without_chex.py | 4 +--- 4 files changed, 12 insertions(+), 23 deletions(-) diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 259a5df0b..9415438b0 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -40,9 +40,7 @@ class Proposal(NamedTuple): sum_log_p_accept: float -def proposal_generator( - energy: Callable -) -> tuple[Callable, Callable]: +def proposal_generator(energy: Callable) -> tuple[Callable, Callable]: """ Parameters @@ -59,7 +57,7 @@ def proposal_generator( def new(state: TrajectoryState) -> Proposal: return Proposal(state, energy(state), 0.0, -jnp.inf) - def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, bool]: + def update(initial_energy: float, state: TrajectoryState) -> Proposal: """Generate a new proposal from a trajectory state. The trajectory state records information about the position in the state @@ -81,9 +79,7 @@ def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, boo """ new_energy = energy(state) - return proposal_from_energy_diff( - initial_energy, new_energy, state - ) + return proposal_from_energy_diff(initial_energy, new_energy, state) return new, update @@ -92,7 +88,7 @@ def proposal_from_energy_diff( initial_energy: float, new_energy: float, state: TrajectoryState, -) -> tuple[Proposal, bool]: +) -> Proposal: """Computes a new proposal from the energy difference between two states. Parameters @@ -118,11 +114,11 @@ def proposal_from_energy_diff( sum_log_p_accept = jnp.minimum(delta_energy, 0.0) return Proposal( - state, - new_energy, - weight, - sum_log_p_accept, - ) + state, + new_energy, + weight, + sum_log_p_accept, + ) def asymmetric_proposal_generator( diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index e8a4b9659..9d7a0abee 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -63,7 +63,6 @@ from typing import Callable, NamedTuple, Optional import jax -import numpy as np from jax import numpy as jnp from blackjax.base import SamplingAlgorithm diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index b06d36930..00f25989d 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -159,9 +159,7 @@ def dynamic_progressive_integration( which we say a transition is divergent. """ - _, generate_proposal = proposal_generator( - hmc_energy(kinetic_energy) - ) + _, generate_proposal = proposal_generator(hmc_energy(kinetic_energy)) sample_proposal = progressive_uniform_sampling def integrate( @@ -320,9 +318,7 @@ def dynamic_recursive_integration( Bool to indicate whether to perform additional U turn check between two trajectory. """ - _, generate_proposal = proposal_generator( - hmc_energy(kinetic_energy) - ) + _, generate_proposal = proposal_generator(hmc_energy(kinetic_energy)) sample_proposal = progressive_uniform_sampling def buildtree_integrate( diff --git a/tests/mcmc/test_proposal_without_chex.py b/tests/mcmc/test_proposal_without_chex.py index 9741c6737..82097f53f 100644 --- a/tests/mcmc/test_proposal_without_chex.py +++ b/tests/mcmc/test_proposal_without_chex.py @@ -29,9 +29,7 @@ def proposal_factory(prev_energy, new_energy, new_state): assert new_state == 50 return new_proposal - _, update = asymmetric_proposal_generator( - transition_energy, proposal_factory - ) + _, update = asymmetric_proposal_generator(transition_energy, proposal_factory) proposed = update(30, 50) assert proposed == new_proposal