Skip to content

Commit

Permalink
refactor MALA so that it uses the MH component in proposals.py
Browse files Browse the repository at this point in the history
  • Loading branch information
albcab committed Apr 14, 2023
1 parent fb14353 commit 3e8d85c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 42 deletions.
35 changes: 15 additions & 20 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax.numpy as jnp

import blackjax.mcmc.diffusions as diffusions
import blackjax.mcmc.proposal as proposal
from blackjax.types import PRNGKey, PyTree

__all__ = ["MALAState", "MALAInfo", "init", "kernel"]
Expand Down Expand Up @@ -74,8 +75,8 @@ def kernel():
"""

def transition_probability(state, new_state, step_size):
"""Transition probability to go from `state` to `new_state`"""
def transition_energy(state, new_state, step_size):
"""Transition energy to go from `state` to `new_state`"""
theta = jax.tree_util.tree_map(
lambda new_x, x, g: new_x - x - step_size * g,
new_state.position,
Expand All @@ -85,7 +86,12 @@ def transition_probability(state, new_state, step_size):
theta_dot = jax.tree_util.tree_reduce(
operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta)
)
return -0.25 * (1.0 / step_size) * theta_dot
return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot

init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
transition_energy, divergence_threshold=jnp.inf
)
sample_proposal = proposal.static_binomial_sampling

def one_step(
rng_key: PRNGKey, state: MALAState, logdensity_fn: Callable, step_size: float
Expand All @@ -97,26 +103,15 @@ def one_step(
key_integrator, key_rmh = jax.random.split(rng_key)

new_state = integrator(key_integrator, state, step_size)
new_state = MALAState(*new_state)

delta = (
new_state.logdensity
- state.logdensity
+ transition_probability(new_state, state, step_size)
- transition_probability(state, new_state, step_size)
)
delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta)
p_accept = jnp.clip(jnp.exp(delta), a_max=1)

do_accept = jax.random.bernoulli(key_rmh, p_accept)
proposal = init_proposal(state)
new_proposal, _ = generate_proposal(state, new_state, step_size=step_size)
sampled_proposal, *info = sample_proposal(key_rmh, proposal, new_proposal)
do_accept, p_accept = info

new_state = MALAState(*new_state)
info = MALAInfo(p_accept, do_accept)

return jax.lax.cond(
do_accept,
lambda _: (new_state, info),
lambda _: (state, info),
operand=None,
)
return sampled_proposal.state, info

return one_step
44 changes: 22 additions & 22 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import jax
import jax.numpy as jnp
import numpy as np

TrajectoryState = NamedTuple

Expand Down Expand Up @@ -49,18 +48,18 @@ def proposal_generator(
Parameters
----------
energy
A callable that computes the energy associated to a given state
A function that computes the energy associated to a given state
divergence_threshold
max value allowed for the difference in energies not to be considered a divergence
max value allowed for the difference in energies not to be considered a divergence
Returns
-------
Two callables, to generate an initial proposal when no step has been taken,
and to generate proposals after each step.
Two functions, one to generate an initial proposal when no step has been taken,
another to generate proposals after each step.
"""

def new(state: TrajectoryState) -> Proposal:
return Proposal(state, energy(state), 0.0, -np.inf)
return Proposal(state, energy(state), 0.0, -jnp.inf)

def update(initial_energy: float, state: TrajectoryState) -> Tuple[Proposal, bool]:
"""Generate a new proposal from a trajectory state.
Expand Down Expand Up @@ -103,13 +102,13 @@ def proposal_from_energy_diff(
Parameters
----------
initial_energy
the energy from the previous state
the energy from the initial state
new_energy
the energy at the new state
the energy at the proposed state
divergence_threshold
max value allowed for the difference in energies not to be considered a divergence
max value allowed for the difference in energies not to be considered a divergence
state
the state to propose
the proposed state
Returns
-------
Expand Down Expand Up @@ -139,36 +138,37 @@ def proposal_from_energy_diff(
def asymmetric_proposal_generator(
transition_energy_fn: Callable,
divergence_threshold: float,
proposal_factory=proposal_from_energy_diff,
proposal_factory: Callable = proposal_from_energy_diff,
) -> Tuple[Callable, Callable]:
"""A proposal generator that takes into account the transition between
two states to compute a new proposal. In particular, both states are
used to compute the energies to consider in weighting the proposal,
to account for asymmetries.
----------
transition_energy_fn
A Callable that computes the energy of a associated with a transition
from one state to another
A function that computes the energy of a transition from an initial state
to a new state, given some optional keyword arguments.
divergence_threshold
A max number to will be used by the proposal_factory to flag a Proposal
as a divergence.
The maximum value allowed for the difference in energies not to be considered a divergence.
proposal_factory
A callable that builds a proposal from the transitions energies
A function that builds a proposal from the transition energies.
Returns
-------
Two callables, to generate an initial proposal when no step has been taken,
and to generate proposals after each step.
Two functions, one to generate an initial proposal when no step has been taken,
another to generate proposals after each step.
"""

def new(state: TrajectoryState) -> Proposal:
return Proposal(state, 0.0, 0.0, -np.inf)
return Proposal(state, 0.0, 0.0, -jnp.inf)

def update(
initial_state: TrajectoryState, state: TrajectoryState
initial_state: TrajectoryState,
state: TrajectoryState,
**energy_params,
) -> Tuple[Proposal, bool]:
new_energy = transition_energy_fn(initial_state, state)
prev_energy = transition_energy_fn(state, initial_state)
new_energy = transition_energy_fn(initial_state, state, **energy_params)
prev_energy = transition_energy_fn(state, initial_state, **energy_params)
return proposal_factory(prev_energy, new_energy, divergence_threshold, state)

return new, update
Expand Down

0 comments on commit 3e8d85c

Please sign in to comment.