Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Oct 29, 2023
1 parent 9b4193a commit 48acc23
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 23 deletions.
22 changes: 9 additions & 13 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ class Proposal(NamedTuple):
sum_log_p_accept: float


def proposal_generator(
energy: Callable
) -> tuple[Callable, Callable]:
def proposal_generator(energy: Callable) -> tuple[Callable, Callable]:
"""
Parameters
Expand All @@ -59,7 +57,7 @@ def proposal_generator(
def new(state: TrajectoryState) -> Proposal:
return Proposal(state, energy(state), 0.0, -jnp.inf)

def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, bool]:
def update(initial_energy: float, state: TrajectoryState) -> Proposal:
"""Generate a new proposal from a trajectory state.
The trajectory state records information about the position in the state
Expand All @@ -81,9 +79,7 @@ def update(initial_energy: float, state: TrajectoryState) -> tuple[Proposal, boo
"""
new_energy = energy(state)
return proposal_from_energy_diff(
initial_energy, new_energy, state
)
return proposal_from_energy_diff(initial_energy, new_energy, state)

return new, update

Expand All @@ -92,7 +88,7 @@ def proposal_from_energy_diff(
initial_energy: float,
new_energy: float,
state: TrajectoryState,
) -> tuple[Proposal, bool]:
) -> Proposal:
"""Computes a new proposal from the energy difference between two states.
Parameters
Expand All @@ -118,11 +114,11 @@ def proposal_from_energy_diff(
sum_log_p_accept = jnp.minimum(delta_energy, 0.0)

return Proposal(
state,
new_energy,
weight,
sum_log_p_accept,
)
state,
new_energy,
weight,
sum_log_p_accept,
)


def asymmetric_proposal_generator(
Expand Down
1 change: 0 additions & 1 deletion blackjax/mcmc/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from typing import Callable, NamedTuple, Optional

import jax
import numpy as np
from jax import numpy as jnp

from blackjax.base import SamplingAlgorithm
Expand Down
8 changes: 2 additions & 6 deletions blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def dynamic_progressive_integration(
which we say a transition is divergent.
"""
_, generate_proposal = proposal_generator(
hmc_energy(kinetic_energy)
)
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
sample_proposal = progressive_uniform_sampling

def integrate(
Expand Down Expand Up @@ -320,9 +318,7 @@ def dynamic_recursive_integration(
Bool to indicate whether to perform additional U turn check between two trajectory.
"""
_, generate_proposal = proposal_generator(
hmc_energy(kinetic_energy)
)
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
sample_proposal = progressive_uniform_sampling

def buildtree_integrate(
Expand Down
4 changes: 1 addition & 3 deletions tests/mcmc/test_proposal_without_chex.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def proposal_factory(prev_energy, new_energy, new_state):
assert new_state == 50
return new_proposal

_, update = asymmetric_proposal_generator(
transition_energy, proposal_factory
)
_, update = asymmetric_proposal_generator(transition_energy, proposal_factory)
proposed = update(30, 50)
assert proposed == new_proposal

Expand Down

0 comments on commit 48acc23

Please sign in to comment.