Skip to content

Commit

Permalink
Refactor the potential fun flip in HMC
Browse files Browse the repository at this point in the history
Fixes #284
  • Loading branch information
junpenglao authored and rlouf committed Jan 16, 2023
1 parent 25eee35 commit feb810f
Show file tree
Hide file tree
Showing 15 changed files with 190 additions and 214 deletions.
23 changes: 11 additions & 12 deletions blackjax/adaptation/meads.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def base():
"""

def compute_parameters(
positions: PyTree, potential_energy_grad: PyTree, current_iteration: int
positions: PyTree, logdensity_grad: PyTree, current_iteration: int
):
"""Compute values for the parameters based on statistics collected from
multiple chains.
Expand All @@ -82,8 +82,8 @@ def compute_parameters(
----------
positions:
A PyTree that contains the current position of every chains.
potential_energy_grad:
A PyTree that contains the gradients of the potential energy
logdensity_grad:
A PyTree that contains the gradients of the logdensity
function evaluated at the current position of every chains.
current_iteration:
The current iteration index in the adaptation process.
Expand All @@ -103,9 +103,8 @@ def compute_parameters(
sd_position,
)

batch_grad = jax.tree_map(lambda x: -x, potential_energy_grad)
batch_grad_scaled = jax.tree_map(
lambda grad, sd: grad * sd, batch_grad, sd_position
lambda grad, sd: grad * sd, logdensity_grad, sd_position
)

epsilon = jnp.minimum(
Expand All @@ -119,14 +118,14 @@ def compute_parameters(
delta = alpha / 2
return epsilon, sd_position, alpha, delta

def init(positions: PyTree, potential_energy_grad: PyTree):
parameters = compute_parameters(positions, potential_energy_grad, 0)
def init(positions: PyTree, logdensity_grad: PyTree):
parameters = compute_parameters(positions, logdensity_grad, 0)
return MEADSAdaptationState(0, *parameters)

def update(
adaptation_state: MEADSAdaptationState,
positions: PyTree,
potential_energy_grad: PyTree,
logdensity_grad: PyTree,
):
"""Update the adaptation state and parameter values.
Expand All @@ -140,9 +139,9 @@ def update(
The current state of the adaptation algorithm
positions
The current position of every chain.
potential_energy_grad
The gradients of the potential energy function
evaluated at the current position of every chain.
logdensity_grad
The gradients of the logdensity function evaluated at the
current position of every chain.
Returns
-------
Expand All @@ -152,7 +151,7 @@ def update(
"""
current_iteration = adaptation_state.current_iteration
step_size, position_sigma, alpha, delta = compute_parameters(
positions, potential_energy_grad, current_iteration
positions, logdensity_grad, current_iteration
)

return MEADSAdaptationState(
Expand Down
19 changes: 7 additions & 12 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ class hmc:
Parameters
----------
logdensity_fn
The log-density function we wish to draw samples from. This
is minus the potential function.
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
inverse_mass_matrix
Expand Down Expand Up @@ -289,8 +288,7 @@ class mala:
Parameters
----------
logdensity_fn
The log-density function we wish to draw samples from. This
is minus the potential function.
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
Expand Down Expand Up @@ -353,8 +351,7 @@ class nuts:
Parameters
----------
logdensity_fn
The log-density function we wish to draw samples from. This
is minus the potential function.
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
inverse_mass_matrix
Expand Down Expand Up @@ -861,7 +858,7 @@ def kernel(rng_key, state):
keys = jax.random.split(rng_key, num_chains)
new_states, info = jax.vmap(kernel)(keys, states)
new_adaptation_state = update(
adaptation_state, new_states.position, new_states.potential_energy_grad
adaptation_state, new_states.position, new_states.logdensity_grad
)

return (new_states, new_adaptation_state), (
Expand All @@ -876,7 +873,7 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000):

rng_keys = jax.random.split(key_init, num_chains)
init_states = batch_init(rng_keys, positions)
init_adaptation_state = init(positions, init_states.potential_energy_grad)
init_adaptation_state = init(positions, init_states.logdensity_grad)

keys = jax.random.split(key_adapt, num_steps)
(last_states, last_adaptation_state), _ = jax.lax.scan(
Expand Down Expand Up @@ -1045,8 +1042,7 @@ class orbital_hmc:
Parameters
----------
logdensity_fn
The logarithm of the probability density function we wish to draw samples from. This
is minus the potential energy function.
The logarithm of the probability density function we wish to draw samples from.
step_size
The value to use for the step size in for the symplectic integrator to buid the orbit.
inverse_mass_matrix
Expand Down Expand Up @@ -1189,8 +1185,7 @@ class ghmc:
Parameters
----------
logdensity_fn
The log-density function we wish to draw samples from. This
is minus the potential function.
The log-density function we wish to draw samples from.
step_size
A PyTree of the same structure as the target PyTree (position) with the
values used for as a step size for each dimension of the target space in
Expand Down
36 changes: 18 additions & 18 deletions blackjax/mcmc/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ class EllipSliceState(NamedTuple):
position
Current position of the chain.
loglikelihood
Current value of the log likelihood only.
logdensity
Current value of the logdensity (evaluated at current position).
"""

position: PyTree
loglikelihood: PyTree
logdensity: PyTree


class EllipSliceInfo(NamedTuple):
Expand All @@ -61,9 +61,9 @@ class EllipSliceInfo(NamedTuple):
subiter: int


def init(position: PyTree, loglikelihood_fn: Callable):
loglikelihood = loglikelihood_fn(position)
return EllipSliceState(position, loglikelihood)
def init(position: PyTree, logdensity_fn: Callable):
logdensity = logdensity_fn(position)
return EllipSliceState(position, logdensity)


def kernel(cov_matrix: Array, mean: Array):
Expand Down Expand Up @@ -108,18 +108,18 @@ def momentum_generator(rng_key, position):
def one_step(
rng_key: PRNGKey,
state: EllipSliceState,
loglikelihood_fn: Callable,
logdensity_fn: Callable,
) -> Tuple[EllipSliceState, EllipSliceInfo]:
proposal_generator = elliptical_proposal(
loglikelihood_fn, momentum_generator, mean
logdensity_fn, momentum_generator, mean
)
return proposal_generator(rng_key, state)

return one_step


def elliptical_proposal(
loglikelihood_fn: Callable,
logdensity_fn: Callable,
momentum_generator: Callable,
mean: Array,
) -> Callable:
Expand All @@ -131,7 +131,7 @@ def elliptical_proposal(
Parameters
----------
loglikelihood_fn
logdensity_fn
A function that returns the log-likelihood at a given position.
momentum_generator
A function that generates a new latent momentum variable.
Expand All @@ -147,20 +147,20 @@ def elliptical_proposal(
def generate(
rng_key: PRNGKey, state: EllipSliceState
) -> Tuple[EllipSliceState, EllipSliceInfo]:
position, loglikelihood = state
position, logdensity = state
key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 3)
# step 1: sample momentum
momentum = momentum_generator(key_momentum, position)
# step 2: get slice (y)
logy = loglikelihood + jnp.log(jax.random.uniform(key_uniform))
logy = logdensity + jnp.log(jax.random.uniform(key_uniform))
# step 3: get theta (ellipsis move), set inital interval
theta = 2 * jnp.pi * jax.random.uniform(key_theta)
theta_min = theta - 2 * jnp.pi
theta_max = theta
# step 4: proposal
p, m = ellipsis(position, momentum, theta, mean)
# step 5: acceptance
loglikelihood = loglikelihood_fn(p)
logdensity = logdensity_fn(p)

def slice_fn(vals):
"""Perform slice sampling around the ellipsis.
Expand All @@ -179,19 +179,19 @@ def slice_fn(vals):
rng, thetak = jax.random.split(rng)
theta = jax.random.uniform(thetak, minval=theta_min, maxval=theta_max)
p, m = ellipsis(position, momentum, theta, mean)
loglikelihood = loglikelihood_fn(p)
logdensity = logdensity_fn(p)
theta_min = jnp.where(theta < 0, theta, theta_min)
theta_max = jnp.where(theta > 0, theta, theta_max)
subiter += 1
return rng, loglikelihood, subiter, theta, theta_min, theta_max, p, m
return rng, logdensity, subiter, theta, theta_min, theta_max, p, m

_, loglikelihood, subiter, theta, *_, position, momentum = jax.lax.while_loop(
_, logdensity, subiter, theta, *_, position, momentum = jax.lax.while_loop(
lambda vals: vals[1] <= logy,
slice_fn,
(rng_key, loglikelihood, 1, theta, theta_min, theta_max, p, m),
(rng_key, logdensity, 1, theta, theta_min, theta_max, p, m),
)
return (
EllipSliceState(position, loglikelihood),
EllipSliceState(position, logdensity),
EllipSliceInfo(momentum, theta, subiter),
)

Expand Down
27 changes: 11 additions & 16 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ class GHMCState(NamedTuple):
to perform a non-reversible Metropolis Hastings update, thus we also
store the current slice variable and return its updated version after
each iteration. To make computations more efficient, we also store
the current potential energy as well as the current gradient of the
potential energy.
the current logdensity as well as the current gradient of the
logdensity.
"""

position: PyTree
momentum: PyTree
potential_energy: float
potential_energy_grad: PyTree
logdensity: float
logdensity_grad: PyTree
slice: float


Expand All @@ -53,16 +53,14 @@ def init(
position: PyTree,
logdensity_fn: Callable,
):
def potential_fn(x):
return -logdensity_fn(x)

potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(position)
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)

key_mometum, key_slice = jax.random.split(rng_key)
momentum = generate_gaussian_noise(key_mometum, position)
slice = jax.random.uniform(key_slice, minval=-1.0, maxval=1.0)

return GHMCState(position, momentum, potential_energy, potential_energy_grad, slice)
return GHMCState(position, momentum, logdensity, logdensity_grad, slice)


def kernel(
Expand Down Expand Up @@ -132,14 +130,11 @@ def one_step(
"""

def potential_fn(x):
return -logdensity_fn(x)

flat_inverse_scale = jax.flatten_util.ravel_pytree(momentum_inverse_scale)[0]
_, kinetic_energy_fn, _ = metrics.gaussian_euclidean(flat_inverse_scale**2)

symplectic_integrator = integrators.velocity_verlet(
potential_fn, kinetic_energy_fn
logdensity_fn, kinetic_energy_fn
)
proposal_generator = hmc.hmc_proposal(
symplectic_integrator,
Expand All @@ -150,23 +145,23 @@ def potential_fn(x):
)

key_momentum, key_noise = jax.random.split(rng_key)
position, momentum, potential_energy, potential_energy_grad, slice = state
position, momentum, logdensity, logdensity_grad, slice = state
# New momentum is persistent
momentum = update_momentum(key_momentum, state, alpha)
momentum = jax.tree_map(lambda m, s: m / s, momentum, momentum_inverse_scale)
# Slice is non-reversible
slice = ((slice + 1.0 + delta + noise_fn(key_noise)) % 2) - 1.0

integrator_state = integrators.IntegratorState(
position, momentum, potential_energy, potential_energy_grad
position, momentum, logdensity, logdensity_grad
)
proposal, info = proposal_generator(slice, integrator_state)
proposal = hmc.flip_momentum(proposal)
state = GHMCState(
proposal.position,
jax.tree_map(lambda m, s: m * s, proposal.momentum, momentum_inverse_scale),
proposal.potential_energy,
proposal.potential_energy_grad,
proposal.logdensity,
proposal.logdensity_grad,
info.acceptance_rate,
)

Expand Down
Loading

0 comments on commit feb810f

Please sign in to comment.