Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Typing to follow Jax best practice #543

Merged
merged 4 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions blackjax/adaptation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
from typing import NamedTuple

from blackjax.types import PyTree
from blackjax.types import ArrayTree


class AdaptationResults(NamedTuple):
state: PyTree
state: ArrayTree
parameters: dict


Expand Down
12 changes: 7 additions & 5 deletions blackjax/adaptation/mass_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax
import jax.numpy as jnp

from blackjax.types import Array
from blackjax.types import Array, ArrayLike

__all__ = [
"WelfordAlgorithmState",
Expand Down Expand Up @@ -111,7 +111,7 @@ def init(n_dims: int) -> MassMatrixAdaptationState:
return MassMatrixAdaptationState(inverse_mass_matrix, wc_state)

def update(
mm_state: MassMatrixAdaptationState, position: Array
mm_state: MassMatrixAdaptationState, position: ArrayLike
) -> MassMatrixAdaptationState:
"""Update the algorithm's state.

Expand Down Expand Up @@ -203,14 +203,16 @@ def init(n_dims: int) -> WelfordAlgorithmState:
m2 = jnp.zeros((n_dims, n_dims))
return WelfordAlgorithmState(mean, m2, sample_size)

def update(wa_state: WelfordAlgorithmState, value: Array) -> WelfordAlgorithmState:
def update(
wa_state: WelfordAlgorithmState, value: ArrayLike
) -> WelfordAlgorithmState:
"""Update the M2 matrix using the new value.

Parameters
----------
state:
wa_state:
The current state of the Welford Algorithm
position: Array, shape (1,)
value: Array, shape (1,)
The new sample (typically position of the chain) used to update m2

"""
Expand Down
20 changes: 11 additions & 9 deletions blackjax/adaptation/meads_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import blackjax.mcmc as mcmc
from blackjax.adaptation.base import AdaptationInfo, AdaptationResults
from blackjax.base import AdaptationAlgorithm
from blackjax.types import PRNGKey, PyTree
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["MEADSAdaptationState", "base", "maximum_eigenvalue", "meads_adaptation"]

Expand All @@ -42,7 +42,7 @@ class MEADSAdaptationState(NamedTuple):

current_iteration: int
step_size: float
position_sigma: PyTree
position_sigma: ArrayTree
alpha: float
delta: float

Expand Down Expand Up @@ -72,7 +72,7 @@ def base():
"""

def compute_parameters(
positions: PyTree, logdensity_grad: PyTree, current_iteration: int
positions: ArrayLikeTree, logdensity_grad: ArrayLikeTree, current_iteration: int
):
"""Compute values for the parameters based on statistics collected from
multiple chains.
Expand Down Expand Up @@ -117,15 +117,17 @@ def compute_parameters(
delta = alpha / 2
return epsilon, sd_position, alpha, delta

def init(positions: PyTree, logdensity_grad: PyTree):
def init(
positions: ArrayLikeTree, logdensity_grad: ArrayLikeTree
) -> MEADSAdaptationState:
parameters = compute_parameters(positions, logdensity_grad, 0)
return MEADSAdaptationState(0, *parameters)

def update(
adaptation_state: MEADSAdaptationState,
positions: PyTree,
logdensity_grad: PyTree,
):
positions: ArrayLikeTree,
logdensity_grad: ArrayLikeTree,
) -> MEADSAdaptationState:
"""Update the adaptation state and parameter values.

We find new optimal values for the parameters of the generalized HMC
Expand Down Expand Up @@ -231,7 +233,7 @@ def one_step(carry, rng_key):
new_adaptation_state,
)

def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000):
def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000):
key_init, key_adapt = jax.random.split(rng_key)

rng_keys = jax.random.split(key_init, num_chains)
Expand All @@ -255,7 +257,7 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000):
return AdaptationAlgorithm(run) # type: ignore[arg-type]


def maximum_eigenvalue(matrix: PyTree):
def maximum_eigenvalue(matrix: ArrayLikeTree) -> Array:
"""Estimate the largest eigenvalues of a matrix.

We calculate an unbiased estimate of the ratio between the sum of the
Expand Down
6 changes: 3 additions & 3 deletions blackjax/adaptation/pathfinder_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from blackjax.base import AdaptationAlgorithm
from blackjax.optimizers.lbfgs import lbfgs_inverse_hessian_formula_1
from blackjax.types import Array, PRNGKey, PyTree
from blackjax.types import Array, ArrayLikeTree, PRNGKey

__all__ = ["PathfinderAdaptationState", "base", "pathfinder_adaptation"]

Expand Down Expand Up @@ -99,7 +99,7 @@ def init(

def update(
adaptation_state: PathfinderAdaptationState,
position: PyTree,
position: ArrayLikeTree,
acceptance_rate: float,
) -> PathfinderAdaptationState:
"""Update the adaptation state and parameter values.
Expand Down Expand Up @@ -192,7 +192,7 @@ def one_step(carry, rng_key):
AdaptationInfo(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400):
def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400):
init_key, sample_key, rng_key = jax.random.split(rng_key, 3)

pathfinder_state, _ = vi.pathfinder.approximate(
Expand Down
12 changes: 6 additions & 6 deletions blackjax/adaptation/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from blackjax.mcmc.hmc import HMCState
from blackjax.optimizers.dual_averaging import dual_averaging
from blackjax.types import PRNGKey

__all__ = [
"DualAveragingAdaptationState",
Expand All @@ -30,7 +31,6 @@
# -------------------------------------------------------------------
# DUAL AVERAGING
# -------------------------------------------------------------------
from blackjax.types import PRNGKey


class DualAveragingAdaptationState(NamedTuple):
Expand Down Expand Up @@ -99,7 +99,7 @@ def dual_averaging_adaptation(
gamma:
Controls the speed of convergence of the scheme. The authors of :cite:p:`hoffman2014no` recommend
a value of 0.05.
kappa: float in ]0.5, 1]
kappa: float in [0.5, 1]
Controls the weights of past steps in the current update. The scheme will
quickly forget earlier step for a small value of `kappa`. Introduced
in :cite:p:`hoffman2014no`, with a recommended value of .75
Expand Down Expand Up @@ -131,10 +131,10 @@ def update(

Parameters
----------
p_accept: float in [0, 1]
The current metropolis acceptance rate.
state:
da_state:
The current state of the dual averaging algorithm.
acceptance_rate: float in [0, 1]
The current metropolis acceptance rate.

Returns
-------
Expand Down Expand Up @@ -241,7 +241,7 @@ def do_continue(rss_state: ReasonableStepSizeState) -> bool:
)
return is_step_size_not_extreme & has_acceptance_rate_not_crossed_threshold

def update(rss_state: ReasonableStepSizeState) -> Tuple:
def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState:
"""Perform one step of the step size search."""
rng_key, direction, _, step_size = rss_state
rng_key, subkey = jax.random.split(rng_key)
Expand Down
16 changes: 10 additions & 6 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from blackjax.base import AdaptationAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.types import Array, PRNGKey, PyTree
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size

__all__ = ["WindowAdaptationState", "base", "build_schedule", "window_adaptation"]
Expand Down Expand Up @@ -102,7 +102,9 @@ def base(
mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal)
da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate)

def init(position: PyTree, initial_step_size: float) -> Tuple:
def init(
position: ArrayLikeTree, initial_step_size: float
) -> WindowAdaptationState:
"""Initialze the adaptation state and parameter values.

Unlike the original Stan window adaptation we do not use the
Expand All @@ -123,7 +125,7 @@ def init(position: PyTree, initial_step_size: float) -> Tuple:
)

def fast_update(
position: PyTree,
position: ArrayLikeTree,
acceptance_rate: float,
warmup_state: WindowAdaptationState,
) -> WindowAdaptationState:
Expand All @@ -134,6 +136,8 @@ def fast_update(
compared to the covariance estimation with Welford's algorithm

"""
del position

new_ss_state = da_update(warmup_state.ss_state, acceptance_rate)
new_step_size = jnp.exp(new_ss_state.log_step_size)

Expand All @@ -145,7 +149,7 @@ def fast_update(
)

def slow_update(
position: PyTree,
position: ArrayLikeTree,
acceptance_rate: float,
warmup_state: WindowAdaptationState,
) -> WindowAdaptationState:
Expand Down Expand Up @@ -188,7 +192,7 @@ def slow_final(warmup_state: WindowAdaptationState) -> WindowAdaptationState:
def update(
adaptation_state: WindowAdaptationState,
adaptation_stage: Tuple,
position: PyTree,
position: ArrayLikeTree,
acceptance_rate: float,
) -> WindowAdaptationState:
"""Update the adaptation state and parameter values.
Expand Down Expand Up @@ -316,7 +320,7 @@ def one_step(carry, xs):
AdaptationInfo(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000):
def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):
init_state = algorithm.init(position, logdensity_fn)
init_adaptation_state = adapt_init(position, initial_step_size)

Expand Down
6 changes: 3 additions & 3 deletions blackjax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from typing_extensions import Protocol

from .types import PRNGKey, PyTree
from .types import ArrayLikeTree, PRNGKey

Position = PyTree
Position = ArrayLikeTree
State = NamedTuple
Info = NamedTuple

Expand Down Expand Up @@ -139,7 +139,7 @@ class VIAlgorithm(NamedTuple):
class RunFn(Protocol):
"""A `Callable` used to run the adaptation procedure."""

def __call__(self, rng_key: PRNGKey, position: PyTree):
def __call__(self, rng_key: PRNGKey, position: ArrayLikeTree):
"""Run the compiled algorithm."""


Expand Down
10 changes: 5 additions & 5 deletions blackjax/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import numpy as np
from scipy.fftpack import next_fast_len # type: ignore

from blackjax.types import Array
from blackjax.types import Array, ArrayLike

__all__ = ["potential_scale_reduction", "effective_sample_size"]


def potential_scale_reduction(
input_array: Array, chain_axis: int = 0, sample_axis: int = 1
):
input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1
) -> Array:
"""Gelman and Rubin (1992)'s potential scale reduction for computing multiple MCMC chain convergence.

Parameters
Expand Down Expand Up @@ -76,8 +76,8 @@ def potential_scale_reduction(


def effective_sample_size(
input_array: Array, chain_axis: int = 0, sample_axis: int = 1
):
input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1
) -> Array:
"""Compute estimate of the effective sample size (ess).

Parameters
Expand Down
6 changes: 3 additions & 3 deletions blackjax/mcmc/diffusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import jax
import jax.numpy as jnp

from blackjax.types import PyTree
from blackjax.types import ArrayTree
from blackjax.util import generate_gaussian_noise

__all__ = ["overdamped_langevin"]


class DiffusionState(NamedTuple):
position: PyTree
position: ArrayTree
logdensity: float
logdensity_grad: PyTree
logdensity_grad: ArrayTree


def overdamped_langevin(logdensity_grad_fn):
Expand Down
12 changes: 6 additions & 6 deletions blackjax/mcmc/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp

from blackjax.base import MCMCSamplingAlgorithm
from blackjax.types import Array, PRNGKey, PyTree
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise

__all__ = [
Expand All @@ -40,8 +40,8 @@ class EllipSliceState(NamedTuple):

"""

position: PyTree
logdensity: PyTree
position: ArrayTree
logdensity: ArrayTree


class EllipSliceInfo(NamedTuple):
Expand All @@ -63,12 +63,12 @@ class EllipSliceInfo(NamedTuple):

"""

momentum: PyTree
momentum: ArrayTree
theta: float
subiter: int


def init(position: PyTree, logdensity_fn: Callable):
def init(position: ArrayLikeTree, logdensity_fn: Callable):
logdensity = logdensity_fn(position)
return EllipSliceState(position, logdensity)

Expand Down Expand Up @@ -164,7 +164,7 @@ def __new__( # type: ignore[misc]
) -> MCMCSamplingAlgorithm:
kernel = cls.build_kernel(cov, mean)

def init_fn(position: PyTree):
def init_fn(position: ArrayLikeTree):
return cls.init(position, loglikelihood_fn)

def step_fn(rng_key: PRNGKey, state):
Expand Down
Loading