Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor divergence check to each sampler #579

Merged
merged 2 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def hmc_proposal(
"""
build_trajectory = trajectory.static_integration(integrator)
init_proposal, generate_proposal = proposal.proposal_generator(
hmc_energy(kinetic_energy), divergence_threshold
hmc_energy(kinetic_energy)
)

def generate(
Expand All @@ -286,7 +286,8 @@ def generate(
end_state = build_trajectory(state, step_size, num_integration_steps)
end_state = flip_momentum(end_state)
proposal = init_proposal(state)
new_proposal, is_diverging = generate_proposal(proposal.energy, end_state)
new_proposal = generate_proposal(proposal.energy, end_state)
is_diverging = -new_proposal.weight > divergence_threshold
sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal)
do_accept, p_accept = info

Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def transition_energy(state, new_state, step_size):
return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot

init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
transition_energy, divergence_threshold=jnp.inf
transition_energy
)
sample_proposal = proposal.static_binomial_sampling

Expand All @@ -107,7 +107,7 @@ def kernel(
new_state = MALAState(*new_state)

proposal = init_proposal(state)
new_proposal, _ = generate_proposal(state, new_state, step_size=step_size)
new_proposal = generate_proposal(state, new_state, step_size=step_size)
sampled_proposal, do_accept, p_accept = sample_proposal(
key_rmh, proposal, new_proposal
)
Expand Down
39 changes: 10 additions & 29 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,13 @@ class Proposal(NamedTuple):
sum_log_p_accept: float


def proposal_generator(
energy: Callable, divergence_threshold: float
) -> tuple[Callable, Callable]:
def proposal_generator(energy: Callable) -> tuple[Callable, Callable]:
"""

Parameters
----------
energy
A function that computes the energy associated to a given state
divergence_threshold
max value allowed for the difference in energies not to be considered a divergence

Returns
-------
Expand All @@ -61,7 +57,7 @@ def proposal_generator(
def new(state: TrajectoryState) -> Proposal:
return Proposal(state, energy(state), 0.0, -jnp.inf)

def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, bool]:
def update(initial_energy: float, state: TrajectoryState) -> Proposal:
"""Generate a new proposal from a trajectory state.

The trajectory state records information about the position in the state
Expand All @@ -83,32 +79,24 @@ def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, boo

"""
new_energy = energy(state)
return proposal_from_energy_diff(
initial_energy, new_energy, divergence_threshold, state
)
return proposal_from_energy_diff(initial_energy, new_energy, state)

return new, update


def proposal_from_energy_diff(
initial_energy: float,
new_energy: float,
divergence_threshold: float,
state: TrajectoryState,
) -> tuple[Proposal, bool]:
) -> Proposal:
"""Computes a new proposal from the energy difference between two states.

It also verifies whether this difference is a divergence, if the
energy diff is above divergence_threshold.

Parameters
----------
initial_energy
the energy from the initial state
new_energy
the energy at the proposed state
divergence_threshold
max value allowed for an increase in energies not to be considered a divergence
state
the proposed state

Expand All @@ -118,28 +106,23 @@ def proposal_from_energy_diff(
"""
delta_energy = initial_energy - new_energy
delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy)
is_transition_divergent = -delta_energy > divergence_threshold

# The weight of the new proposal is equal to H0 - H(z_new)
weight = delta_energy

# Acceptance statistic min(e^{H0 - H(z_new)}, 1)
sum_log_p_accept = jnp.minimum(delta_energy, 0.0)

return (
Proposal(
state,
new_energy,
weight,
sum_log_p_accept,
),
is_transition_divergent,
return Proposal(
state,
new_energy,
weight,
sum_log_p_accept,
)


def asymmetric_proposal_generator(
transition_energy_fn: Callable,
divergence_threshold: float,
proposal_factory: Callable = proposal_from_energy_diff,
) -> tuple[Callable, Callable]:
"""A proposal generator that takes into account the transition between
Expand All @@ -153,8 +136,6 @@ def asymmetric_proposal_generator(
transition_energy_fn
A function that computes the energy of a transition from an initial state
to a new state, given some optional keyword arguments.
divergence_threshold
The maximum value allowed for the difference in energies not to be considered a divergence.
proposal_factory
A function that builds a proposal from the transition energies.

Expand All @@ -174,7 +155,7 @@ def update(
) -> tuple[Proposal, bool]:
new_energy = transition_energy_fn(initial_state, state, **energy_params)
prev_energy = transition_energy_fn(state, initial_state, **energy_params)
return proposal_factory(prev_energy, new_energy, divergence_threshold, state)
return proposal_factory(prev_energy, new_energy, state)

return new, update

Expand Down
5 changes: 2 additions & 3 deletions blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from typing import Callable, NamedTuple, Optional

import jax
import numpy as np
from jax import numpy as jnp

from blackjax.base import SamplingAlgorithm
Expand Down Expand Up @@ -391,7 +390,7 @@ def kernel(
transition_energy = build_rmh_transition_energy(proposal_logdensity_fn)

init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
transition_energy, np.inf
transition_energy
)

proposal_generator = rmh_proposal(
Expand Down Expand Up @@ -496,7 +495,7 @@ def build_trajectory(rng_key, initial_state: RWState) -> RWState:
def generate(rng_key, state: RWState) -> tuple[RWState, bool, float]:
key_proposal, key_accept = jax.random.split(rng_key, 2)
end_state = build_trajectory(key_proposal, state)
new_proposal, _ = generate_proposal(state, end_state)
new_proposal = generate_proposal(state, end_state)
previous_proposal = init_proposal(state)
sampled_proposal, do_accept, p_accept = sample_proposal(
key_accept, previous_proposal, new_proposal
Expand Down
16 changes: 7 additions & 9 deletions blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def dynamic_progressive_integration(
which we say a transition is divergent.

"""
_, generate_proposal = proposal_generator(
hmc_energy(kinetic_energy), divergence_threshold
)
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
sample_proposal = progressive_uniform_sampling

def integrate(
Expand Down Expand Up @@ -215,7 +213,8 @@ def add_one_state(loop_state):
rng_key, proposal_key = jax.random.split(rng_key)

new_state = integrator(trajectory.rightmost_state, direction * step_size)
new_proposal, is_diverging = generate_proposal(initial_energy, new_state)
new_proposal = generate_proposal(initial_energy, new_state)
is_diverging = -new_proposal.weight > divergence_threshold

# At step 0, we always accept the proposal, since we
# take one step to get the leftmost state of the tree.
Expand Down Expand Up @@ -248,7 +247,7 @@ def add_one_state(loop_state):

return (rng_key, new_integration_state, (is_diverging, has_terminated))

proposal_placeholder, _ = generate_proposal(initial_energy, initial_state)
proposal_placeholder = generate_proposal(initial_energy, initial_state)
trajectory_placeholder = Trajectory(
initial_state, initial_state, initial_state.momentum, 0
)
Expand Down Expand Up @@ -319,9 +318,7 @@ def dynamic_recursive_integration(
Bool to indicate whether to perform additional U turn check between two trajectory.

"""
_, generate_proposal = proposal_generator(
hmc_energy(kinetic_energy), divergence_threshold
)
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
sample_proposal = progressive_uniform_sampling

def buildtree_integrate(
Expand Down Expand Up @@ -357,7 +354,8 @@ def buildtree_integrate(
if tree_depth == 0:
# Base case - take one leapfrog step in the direction v.
next_state = integrator(initial_state, direction * step_size)
new_proposal, is_diverging = generate_proposal(initial_energy, next_state)
new_proposal = generate_proposal(initial_energy, next_state)
is_diverging = -new_proposal.weight > divergence_threshold
trajectory = Trajectory(next_state, next_state, next_state.momentum, 1)
return (
rng_key,
Expand Down
20 changes: 9 additions & 11 deletions tests/mcmc/test_proposal_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class TestAsymmetricProposalGenerator(unittest.TestCase):
def test_new(self):
state = MagicMock()
new, _ = asymmetric_proposal_generator(None, None, None)
new, _ = asymmetric_proposal_generator(None, None)
assert new(state) == Proposal(state, 0.0, 0.0, -np.inf)

def test_update(self):
Expand All @@ -23,16 +23,13 @@ def transition_energy(prev, next):

new_proposal = MagicMock()

def proposal_factory(prev_energy, new_energy, divergence_threshold, new_state):
def proposal_factory(prev_energy, new_energy, new_state):
assert prev_energy == -20
assert new_energy == 20
assert divergence_threshold == 50
assert new_state == 50
return new_proposal

_, update = asymmetric_proposal_generator(
transition_energy, 50, proposal_factory
)
_, update = asymmetric_proposal_generator(transition_energy, proposal_factory)
proposed = update(30, 50)
assert proposed == new_proposal

Expand All @@ -52,25 +49,26 @@ class TestProposalFromEnergyDiff(parameterized.TestCase):
)
def test_divergence_threshold(self, before, after, threshold, is_divergent):
state = MagicMock()
proposal, divergence = proposal_from_energy_diff(5, 10, threshold, state)
proposal = proposal_from_energy_diff(5, 10, state)
divergence = -proposal.weight > threshold
assert divergence == is_divergent

def test_sum_log_paccept(self):
state = MagicMock()
proposal, _ = proposal_from_energy_diff(5, 10, 0, state)
proposal = proposal_from_energy_diff(5, 10, state)
np.testing.assert_allclose(proposal.sum_log_p_accept, -5.0)

proposal, _ = proposal_from_energy_diff(10, 5, 0, state)
proposal = proposal_from_energy_diff(10, 5, state)
np.testing.assert_allclose(proposal.sum_log_p_accept, 0.0)

def test_delta_energy_is_nan(self):
state = MagicMock()
proposal, _ = proposal_from_energy_diff(np.nan, np.nan, 0, state)
proposal = proposal_from_energy_diff(np.nan, np.nan, state)
assert np.isneginf(proposal.weight)

def test_weight(self):
state = MagicMock()
proposal, _ = proposal_from_energy_diff(5, 10, 0, state)
proposal = proposal_from_energy_diff(5, 10, state)

assert proposal.state == state
np.testing.assert_allclose(proposal.weight, -5)
Expand Down
2 changes: 1 addition & 1 deletion tests/mcmc/test_random_walk_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def init_proposal(self, state):
return Proposal(state, 0, 0, 0)

def generate_proposal(self, prev, new):
return Proposal(new, 0, 0, 0), False
return Proposal(new, 0, 0, 0)

def test_generate_reject(self):
"""
Expand Down