diff --git a/blackjax/adaptation/base.py b/blackjax/adaptation/base.py index 8a612936b..e0a01e596 100644 --- a/blackjax/adaptation/base.py +++ b/blackjax/adaptation/base.py @@ -13,11 +13,11 @@ # limitations under the License. from typing import NamedTuple -from blackjax.types import PyTree +from blackjax.types import ArrayTree class AdaptationResults(NamedTuple): - state: PyTree + state: ArrayTree parameters: dict diff --git a/blackjax/adaptation/mass_matrix.py b/blackjax/adaptation/mass_matrix.py index 30ea9ec0d..4cd84492e 100644 --- a/blackjax/adaptation/mass_matrix.py +++ b/blackjax/adaptation/mass_matrix.py @@ -23,7 +23,7 @@ import jax import jax.numpy as jnp -from blackjax.types import Array +from blackjax.types import Array, ArrayLike __all__ = [ "WelfordAlgorithmState", @@ -111,7 +111,7 @@ def init(n_dims: int) -> MassMatrixAdaptationState: return MassMatrixAdaptationState(inverse_mass_matrix, wc_state) def update( - mm_state: MassMatrixAdaptationState, position: Array + mm_state: MassMatrixAdaptationState, position: ArrayLike ) -> MassMatrixAdaptationState: """Update the algorithm's state. @@ -203,14 +203,16 @@ def init(n_dims: int) -> WelfordAlgorithmState: m2 = jnp.zeros((n_dims, n_dims)) return WelfordAlgorithmState(mean, m2, sample_size) - def update(wa_state: WelfordAlgorithmState, value: Array) -> WelfordAlgorithmState: + def update( + wa_state: WelfordAlgorithmState, value: ArrayLike + ) -> WelfordAlgorithmState: """Update the M2 matrix using the new value. Parameters ---------- - state: + wa_state: The current state of the Welford Algorithm - position: Array, shape (1,) + value: Array, shape (1,) The new sample (typically position of the chain) used to update m2 """ diff --git a/blackjax/adaptation/meads_adaptation.py b/blackjax/adaptation/meads_adaptation.py index f69c63557..e50065710 100644 --- a/blackjax/adaptation/meads_adaptation.py +++ b/blackjax/adaptation/meads_adaptation.py @@ -19,7 +19,7 @@ import blackjax.mcmc as mcmc from blackjax.adaptation.base import AdaptationInfo, AdaptationResults from blackjax.base import AdaptationAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["MEADSAdaptationState", "base", "maximum_eigenvalue", "meads_adaptation"] @@ -42,7 +42,7 @@ class MEADSAdaptationState(NamedTuple): current_iteration: int step_size: float - position_sigma: PyTree + position_sigma: ArrayTree alpha: float delta: float @@ -72,7 +72,7 @@ def base(): """ def compute_parameters( - positions: PyTree, logdensity_grad: PyTree, current_iteration: int + positions: ArrayLikeTree, logdensity_grad: ArrayLikeTree, current_iteration: int ): """Compute values for the parameters based on statistics collected from multiple chains. @@ -117,15 +117,17 @@ def compute_parameters( delta = alpha / 2 return epsilon, sd_position, alpha, delta - def init(positions: PyTree, logdensity_grad: PyTree): + def init( + positions: ArrayLikeTree, logdensity_grad: ArrayLikeTree + ) -> MEADSAdaptationState: parameters = compute_parameters(positions, logdensity_grad, 0) return MEADSAdaptationState(0, *parameters) def update( adaptation_state: MEADSAdaptationState, - positions: PyTree, - logdensity_grad: PyTree, - ): + positions: ArrayLikeTree, + logdensity_grad: ArrayLikeTree, + ) -> MEADSAdaptationState: """Update the adaptation state and parameter values. We find new optimal values for the parameters of the generalized HMC @@ -231,7 +233,7 @@ def one_step(carry, rng_key): new_adaptation_state, ) - def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000): + def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000): key_init, key_adapt = jax.random.split(rng_key) rng_keys = jax.random.split(key_init, num_chains) @@ -255,7 +257,7 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000): return AdaptationAlgorithm(run) # type: ignore[arg-type] -def maximum_eigenvalue(matrix: PyTree): +def maximum_eigenvalue(matrix: ArrayLikeTree) -> Array: """Estimate the largest eigenvalues of a matrix. We calculate an unbiased estimate of the ratio between the sum of the diff --git a/blackjax/adaptation/pathfinder_adaptation.py b/blackjax/adaptation/pathfinder_adaptation.py index eddbb1633..3d05bc1d1 100644 --- a/blackjax/adaptation/pathfinder_adaptation.py +++ b/blackjax/adaptation/pathfinder_adaptation.py @@ -26,7 +26,7 @@ ) from blackjax.base import AdaptationAlgorithm from blackjax.optimizers.lbfgs import lbfgs_inverse_hessian_formula_1 -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, PRNGKey __all__ = ["PathfinderAdaptationState", "base", "pathfinder_adaptation"] @@ -99,7 +99,7 @@ def init( def update( adaptation_state: PathfinderAdaptationState, - position: PyTree, + position: ArrayLikeTree, acceptance_rate: float, ) -> PathfinderAdaptationState: """Update the adaptation state and parameter values. @@ -192,7 +192,7 @@ def one_step(carry, rng_key): AdaptationInfo(new_state, info, new_adaptation_state), ) - def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400): + def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400): init_key, sample_key, rng_key = jax.random.split(rng_key, 3) pathfinder_state, _ = vi.pathfinder.approximate( diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index a3acdc470..298f702f6 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -19,6 +19,7 @@ from blackjax.mcmc.hmc import HMCState from blackjax.optimizers.dual_averaging import dual_averaging +from blackjax.types import PRNGKey __all__ = [ "DualAveragingAdaptationState", @@ -30,7 +31,6 @@ # ------------------------------------------------------------------- # DUAL AVERAGING # ------------------------------------------------------------------- -from blackjax.types import PRNGKey class DualAveragingAdaptationState(NamedTuple): @@ -99,7 +99,7 @@ def dual_averaging_adaptation( gamma: Controls the speed of convergence of the scheme. The authors of :cite:p:`hoffman2014no` recommend a value of 0.05. - kappa: float in ]0.5, 1] + kappa: float in [0.5, 1] Controls the weights of past steps in the current update. The scheme will quickly forget earlier step for a small value of `kappa`. Introduced in :cite:p:`hoffman2014no`, with a recommended value of .75 @@ -131,10 +131,10 @@ def update( Parameters ---------- - p_accept: float in [0, 1] - The current metropolis acceptance rate. - state: + da_state: The current state of the dual averaging algorithm. + acceptance_rate: float in [0, 1] + The current metropolis acceptance rate. Returns ------- @@ -241,7 +241,7 @@ def do_continue(rss_state: ReasonableStepSizeState) -> bool: ) return is_step_size_not_extreme & has_acceptance_rate_not_crossed_threshold - def update(rss_state: ReasonableStepSizeState) -> Tuple: + def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: """Perform one step of the step size search.""" rng_key, direction, _, step_size = rss_state rng_key, subkey = jax.random.split(rng_key) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 392176895..b99d787f1 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -29,7 +29,7 @@ ) from blackjax.base import AdaptationAlgorithm from blackjax.progress_bar import progress_bar_scan -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, PRNGKey from blackjax.util import pytree_size __all__ = ["WindowAdaptationState", "base", "build_schedule", "window_adaptation"] @@ -102,7 +102,9 @@ def base( mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal) da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate) - def init(position: PyTree, initial_step_size: float) -> Tuple: + def init( + position: ArrayLikeTree, initial_step_size: float + ) -> WindowAdaptationState: """Initialze the adaptation state and parameter values. Unlike the original Stan window adaptation we do not use the @@ -123,7 +125,7 @@ def init(position: PyTree, initial_step_size: float) -> Tuple: ) def fast_update( - position: PyTree, + position: ArrayLikeTree, acceptance_rate: float, warmup_state: WindowAdaptationState, ) -> WindowAdaptationState: @@ -134,6 +136,8 @@ def fast_update( compared to the covariance estimation with Welford's algorithm """ + del position + new_ss_state = da_update(warmup_state.ss_state, acceptance_rate) new_step_size = jnp.exp(new_ss_state.log_step_size) @@ -145,7 +149,7 @@ def fast_update( ) def slow_update( - position: PyTree, + position: ArrayLikeTree, acceptance_rate: float, warmup_state: WindowAdaptationState, ) -> WindowAdaptationState: @@ -188,7 +192,7 @@ def slow_final(warmup_state: WindowAdaptationState) -> WindowAdaptationState: def update( adaptation_state: WindowAdaptationState, adaptation_stage: Tuple, - position: PyTree, + position: ArrayLikeTree, acceptance_rate: float, ) -> WindowAdaptationState: """Update the adaptation state and parameter values. @@ -316,7 +320,7 @@ def one_step(carry, xs): AdaptationInfo(new_state, info, new_adaptation_state), ) - def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000): + def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): init_state = algorithm.init(position, logdensity_fn) init_adaptation_state = adapt_init(position, initial_step_size) diff --git a/blackjax/base.py b/blackjax/base.py index 0a64c668d..922746423 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -14,9 +14,9 @@ from typing_extensions import Protocol -from .types import PRNGKey, PyTree +from .types import ArrayLikeTree, PRNGKey -Position = PyTree +Position = ArrayLikeTree State = NamedTuple Info = NamedTuple @@ -139,7 +139,7 @@ class VIAlgorithm(NamedTuple): class RunFn(Protocol): """A `Callable` used to run the adaptation procedure.""" - def __call__(self, rng_key: PRNGKey, position: PyTree): + def __call__(self, rng_key: PRNGKey, position: ArrayLikeTree): """Run the compiled algorithm.""" diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index a491b8a58..da861d9b1 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -17,14 +17,14 @@ import numpy as np from scipy.fftpack import next_fast_len # type: ignore -from blackjax.types import Array +from blackjax.types import Array, ArrayLike __all__ = ["potential_scale_reduction", "effective_sample_size"] def potential_scale_reduction( - input_array: Array, chain_axis: int = 0, sample_axis: int = 1 -): + input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1 +) -> Array: """Gelman and Rubin (1992)'s potential scale reduction for computing multiple MCMC chain convergence. Parameters @@ -76,8 +76,8 @@ def potential_scale_reduction( def effective_sample_size( - input_array: Array, chain_axis: int = 0, sample_axis: int = 1 -): + input_array: ArrayLike, chain_axis: int = 0, sample_axis: int = 1 +) -> Array: """Compute estimate of the effective sample size (ess). Parameters diff --git a/blackjax/mcmc/diffusions.py b/blackjax/mcmc/diffusions.py index 82a0e84f4..82274ab95 100644 --- a/blackjax/mcmc/diffusions.py +++ b/blackjax/mcmc/diffusions.py @@ -17,16 +17,16 @@ import jax import jax.numpy as jnp -from blackjax.types import PyTree +from blackjax.types import ArrayTree from blackjax.util import generate_gaussian_noise __all__ = ["overdamped_langevin"] class DiffusionState(NamedTuple): - position: PyTree + position: ArrayTree logdensity: float - logdensity_grad: PyTree + logdensity_grad: ArrayTree def overdamped_langevin(logdensity_grad_fn): diff --git a/blackjax/mcmc/elliptical_slice.py b/blackjax/mcmc/elliptical_slice.py index ecd3c9804..98ae74d4b 100644 --- a/blackjax/mcmc/elliptical_slice.py +++ b/blackjax/mcmc/elliptical_slice.py @@ -18,7 +18,7 @@ import jax.numpy as jnp from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise __all__ = [ @@ -40,8 +40,8 @@ class EllipSliceState(NamedTuple): """ - position: PyTree - logdensity: PyTree + position: ArrayTree + logdensity: ArrayTree class EllipSliceInfo(NamedTuple): @@ -63,12 +63,12 @@ class EllipSliceInfo(NamedTuple): """ - momentum: PyTree + momentum: ArrayTree theta: float subiter: int -def init(position: PyTree, logdensity_fn: Callable): +def init(position: ArrayLikeTree, logdensity_fn: Callable): logdensity = logdensity_fn(position) return EllipSliceState(position, logdensity) @@ -164,7 +164,7 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel(cov, mean) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, loglikelihood_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index 64556db04..53a38ab1a 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -22,7 +22,7 @@ import blackjax.mcmc.metrics as metrics import blackjax.mcmc.proposal as proposal from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise, pytree_size __all__ = ["GHMCState", "init", "build_kernel", "ghmc"] @@ -42,15 +42,15 @@ class GHMCState(NamedTuple): """ - position: PyTree - momentum: PyTree + position: ArrayTree + momentum: ArrayTree logdensity: float - logdensity_grad: PyTree + logdensity_grad: ArrayTree slice: float def init( - position: PyTree, + position: ArrayLikeTree, rng_key: PRNGKey, logdensity_fn: Callable, ) -> GHMCState: @@ -101,7 +101,7 @@ def kernel( state: GHMCState, logdensity_fn: Callable, step_size: float, - momentum_inverse_scale: PyTree, + momentum_inverse_scale: ArrayLikeTree, alpha: float, delta: float, ) -> Tuple[GHMCState, hmc.HMCInfo]: @@ -170,6 +170,30 @@ def kernel( return kernel +def update_momentum(rng_key, state, alpha): + """Persistent update of the momentum variable. + + Performs a persistent update of the momentum, taking as input the previous + momentum, a random number generating key and the parameter alpha. Outputs + an updated momentum that is a mixture of the previous momentum a new sample + from a Gaussian density (dependent on alpha). The weights of the mixture of + these two components are a function of alpha. + + """ + position, momentum, *_ = state + + m_size = pytree_size(momentum) + momentum_generator, *_ = metrics.gaussian_euclidean(1 / alpha * jnp.ones((m_size,))) + momentum = jax.tree_map( + lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha) + + shifted_momentum, + momentum, + momentum_generator(rng_key, position), + ) + + return momentum + + class ghmc: """Implements the (basic) user interface for the Generalized HMC kernel. @@ -239,7 +263,7 @@ def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, step_size: float, - momentum_inverse_scale: PyTree, + momentum_inverse_scale: ArrayLikeTree, alpha: float, delta: float, *, @@ -248,7 +272,7 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel(noise_gn, divergence_threshold) - def init_fn(position: PyTree, rng_key: PRNGKey): + def init_fn(position: ArrayLikeTree, rng_key: PRNGKey): return cls.init(position, rng_key, logdensity_fn) def step_fn(rng_key: PRNGKey, state): @@ -263,27 +287,3 @@ def step_fn(rng_key: PRNGKey, state): ) return MCMCSamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] - - -def update_momentum(rng_key, state, alpha): - """Persistent update of the momentum variable. - - Performs a persistent update of the momentum, taking as input the previous - momentum, a random number generating key and the parameter alpha. Outputs - an updated momentum that is a mixture of the previous momentum a new sample - from a Gaussian density (dependent on alpha). The weights of the mixture of - these two components are a function of alpha. - - """ - position, momentum, *_ = state - - m_size = pytree_size(momentum) - momentum_generator, *_ = metrics.gaussian_euclidean(1 / alpha * jnp.ones((m_size,))) - momentum = jax.tree_map( - lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha) - + shifted_momentum, - momentum, - momentum_generator(rng_key, position), - ) - - return momentum diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 4738e9d09..86d90c634 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -22,7 +22,7 @@ import blackjax.mcmc.trajectory as trajectory from blackjax.base import MCMCSamplingAlgorithm from blackjax.mcmc.trajectory import hmc_energy -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["HMCState", "HMCInfo", "init", "build_kernel", "hmc"] @@ -36,9 +36,9 @@ class HMCState(NamedTuple): """ - position: PyTree + position: ArrayTree logdensity: float - logdensity_grad: PyTree + logdensity_grad: ArrayTree class HMCInfo(NamedTuple): @@ -70,7 +70,7 @@ class HMCInfo(NamedTuple): """ - momentum: PyTree + momentum: ArrayTree acceptance_rate: float is_accepted: bool is_divergent: bool @@ -79,7 +79,7 @@ class HMCInfo(NamedTuple): num_integration_steps: int -def init(position: PyTree, logdensity_fn: Callable): +def init(position: ArrayLikeTree, logdensity_fn: Callable): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) return HMCState(position, logdensity, logdensity_grad) @@ -223,7 +223,7 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel(integrator, divergence_threshold) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): @@ -242,7 +242,7 @@ def step_fn(rng_key: PRNGKey, state): def hmc_proposal( integrator: Callable, kinetic_energy: Callable, - step_size: Union[float, PyTree], + step_size: Union[float, ArrayLikeTree], num_integration_steps: int = 1, divergence_threshold: float = 1000, *, diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 15f1808f6..1959dee69 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -17,7 +17,7 @@ import jax from blackjax.mcmc.metrics import EuclideanKineticEnergy -from blackjax.types import PyTree +from blackjax.types import ArrayTree __all__ = ["mclachlan", "velocity_verlet", "yoshida"] @@ -29,10 +29,10 @@ class IntegratorState(NamedTuple): to speedup computations. """ - position: PyTree - momentum: PyTree + position: ArrayTree + momentum: ArrayTree logdensity: float - logdensity_grad: PyTree + logdensity_grad: ArrayTree EuclideanIntegrator = Callable[[IntegratorState, float], IntegratorState] diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 87ec36a34..7f94c5d8b 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -21,7 +21,7 @@ import blackjax.mcmc.diffusions as diffusions import blackjax.mcmc.proposal as proposal from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["MALAState", "MALAInfo", "init", "build_kernel", "mala"] @@ -36,9 +36,9 @@ class MALAState(NamedTuple): """ - position: PyTree + position: ArrayTree logdensity: float - logdensity_grad: PyTree + logdensity_grad: ArrayTree class MALAInfo(NamedTuple): @@ -59,7 +59,7 @@ class MALAInfo(NamedTuple): is_accepted: bool -def init(position: PyTree, logdensity_fn: Callable) -> MALAState: +def init(position: ArrayLikeTree, logdensity_fn: Callable) -> MALAState: grad_fn = jax.value_and_grad(logdensity_fn) logdensity, logdensity_grad = grad_fn(position) return MALAState(position, logdensity, logdensity_grad) @@ -179,7 +179,7 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/marginal_latent_gaussian.py b/blackjax/mcmc/marginal_latent_gaussian.py index d19d9340e..82a90bde0 100644 --- a/blackjax/mcmc/marginal_latent_gaussian.py +++ b/blackjax/mcmc/marginal_latent_gaussian.py @@ -24,6 +24,7 @@ __all__ = ["MarginalState", "MarginalInfo", "init_and_kernel", "mgrad_gaussian"] +# [TODO](https://github.com/blackjax-devs/blackjax/issues/237) class MarginalState(NamedTuple): """State of the RMH chain. diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index e6f10db76..262345abd 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -33,12 +33,12 @@ import jax.scipy as jscipy from jax.flatten_util import ravel_pytree -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise __all__ = ["gaussian_euclidean"] -EuclideanKineticEnergy = Callable[[PyTree], float] +EuclideanKineticEnergy = Callable[[ArrayLikeTree], float] def gaussian_euclidean( @@ -103,17 +103,19 @@ def gaussian_euclidean( f" expected 1 or 2, got {jnp.ndim(inverse_mass_matrix)}." # type: ignore[arg-type] ) - def momentum_generator(rng_key: PRNGKey, position: PyTree) -> PyTree: + def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree: return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt) - def kinetic_energy(momentum: PyTree) -> float: + def kinetic_energy(momentum: ArrayLikeTree) -> float: momentum, _ = ravel_pytree(momentum) velocity = matmul(inverse_mass_matrix, momentum) kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum) return kinetic_energy_val def is_turning( - momentum_left: PyTree, momentum_right: PyTree, momentum_sum: PyTree + momentum_left: ArrayLikeTree, + momentum_right: ArrayLikeTree, + momentum_sum: ArrayLikeTree, ) -> bool: """Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`. diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 90dd74eef..d22b159d1 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -25,7 +25,7 @@ import blackjax.mcmc.termination as termination import blackjax.mcmc.trajectory as trajectory from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["NUTSInfo", "init", "build_kernel", "nuts"] @@ -63,7 +63,7 @@ class NUTSInfo(NamedTuple): """ - momentum: PyTree + momentum: ArrayTree is_divergent: bool is_turning: bool energy: float @@ -226,7 +226,7 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel(integrator, divergence_threshold, max_num_doublings) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py index c82ae2e6d..6e2892e66 100644 --- a/blackjax/mcmc/periodic_orbital.py +++ b/blackjax/mcmc/periodic_orbital.py @@ -20,7 +20,7 @@ import blackjax.mcmc.integrators as integrators import blackjax.mcmc.metrics as metrics from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["PeriodicOrbitalState", "init", "build_kernel", "orbital_hmc"] @@ -48,11 +48,11 @@ class PeriodicOrbitalState(NamedTuple): function for each point in the orbit. """ - positions: PyTree + positions: ArrayTree weights: Array directions: Array logdensities: Array - logdensities_grad: PyTree + logdensities_grad: ArrayTree class PeriodicOrbitalInfo(NamedTuple): @@ -70,13 +70,13 @@ class PeriodicOrbitalInfo(NamedTuple): variance of the unnormalized weights of the orbit, ideally close to 0. """ - momentums: PyTree + momentums: ArrayTree weights_mean: float weights_variance: float def init( - position: PyTree, logdensity_fn: Callable, period: int + position: ArrayLikeTree, logdensity_fn: Callable, period: int ) -> PeriodicOrbitalState: """Create a periodic orbital state from a position. @@ -276,7 +276,7 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel(bijection) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn, period) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/random_walk.py b/blackjax/mcmc/random_walk.py index ef2c54d95..37e6bcc1e 100644 --- a/blackjax/mcmc/random_walk.py +++ b/blackjax/mcmc/random_walk.py @@ -61,7 +61,7 @@ from blackjax.base import MCMCSamplingAlgorithm from blackjax.mcmc import proposal -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise __all__ = [ @@ -95,7 +95,7 @@ def normal(sigma: Array) -> Callable: if jnp.ndim(sigma) > 2: raise ValueError("sigma must be a vector or a matrix.") - def propose(rng_key: PRNGKey, position: PyTree) -> PyTree: + def propose(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree: return generate_gaussian_noise(rng_key, position, sigma=sigma) return propose @@ -111,7 +111,7 @@ class RWState(NamedTuple): """ - position: PyTree + position: ArrayTree logdensity: float @@ -137,7 +137,7 @@ class RWInfo(NamedTuple): proposal: RWState -def init(position: PyTree, logdensity_fn: Callable) -> RWState: +def init(position: ArrayLikeTree, logdensity_fn: Callable) -> RWState: """Create a chain state from a position. Parameters @@ -236,7 +236,7 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): @@ -272,7 +272,8 @@ def kernel( domain of the target distribution. """ - def proposal_generator(rng_key: PRNGKey, position: PyTree): + def proposal_generator(rng_key: PRNGKey, position: ArrayTree): + del position return proposal_distribution(rng_key) inner_kernel = build_rmh() @@ -326,7 +327,7 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): @@ -436,12 +437,12 @@ class rmh: def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, - proposal_generator: Callable[[PRNGKey, PyTree], PyTree], - proposal_logdensity_fn: Optional[Callable[[PyTree], PyTree]] = None, + proposal_generator: Callable[[PRNGKey, ArrayLikeTree], ArrayTree], + proposal_logdensity_fn: Optional[Callable[[ArrayLikeTree], ArrayTree]] = None, ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index f7d432bff..265afc351 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -48,13 +48,13 @@ progressive_uniform_sampling, proposal_generator, ) -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayTree, PRNGKey class Trajectory(NamedTuple): leftmost_state: IntegratorState rightmost_state: IntegratorState - momentum_sum: PyTree + momentum_sum: ArrayTree num_states: int diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index 57aab7bdf..ba688ff8e 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -22,7 +22,7 @@ from jaxopt._src.lbfgs import LbfgsState from jaxopt.base import OptStep -from blackjax.types import Array, PyTree +from blackjax.types import Array, ArrayLikeTree __all__ = [ "LBFGSHistory", @@ -64,7 +64,7 @@ class LBFGSHistory(NamedTuple): def minimize_lbfgs( fun: Callable, - x0: PyTree, + x0: ArrayLikeTree, maxiter: int = 30, maxcor: float = 10, gtol: float = 1e-08, diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index 621baf5fe..6b7309119 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -21,7 +21,7 @@ from blackjax.base import MCMCSamplingAlgorithm from blackjax.sgmcmc.diffusions import overdamped_langevin -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["ContourSGLDState", "init", "build_kernel", "csgld"] @@ -41,12 +41,12 @@ class ContourSGLDState(NamedTuple): Index `i` such that the current position belongs to :math:`S_i`. """ - position: PyTree + position: ArrayTree energy_pdf: Array energy_idx: int -def init(position: PyTree, num_partitions=512) -> ContourSGLDState: +def init(position: ArrayLikeTree, num_partitions=512) -> ContourSGLDState: energy_pdf = ( jnp.arange(num_partitions, 0, -1) / jnp.arange(num_partitions, 0, -1).sum() ) @@ -80,7 +80,7 @@ def kernel( state: ContourSGLDState, logdensity_estimator: Callable, gradient_estimator: Callable, - minibatch: PyTree, + minibatch: ArrayLikeTree, step_size_diff: float, # step size for Langevin diffusion step_size_stoch: float = 1e-3, # step size for stochastic approximation zeta: float = 1, @@ -223,13 +223,13 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel(num_partitions, energy_gap, min_energy) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, num_partitions) def step_fn( rng_key: PRNGKey, state: ContourSGLDState, - minibatch: PyTree, + minibatch: ArrayLikeTree, step_size_diff: float, step_size_stoch: float, temperature: float = 1.0, diff --git a/blackjax/sgmcmc/diffusions.py b/blackjax/sgmcmc/diffusions.py index 9cf3127da..3527653cb 100644 --- a/blackjax/sgmcmc/diffusions.py +++ b/blackjax/sgmcmc/diffusions.py @@ -17,7 +17,7 @@ import jax import jax.numpy as jnp -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise, pytree_size __all__ = ["overdamped_langevin", "sghmc", "sgnht"] @@ -32,11 +32,11 @@ def overdamped_langevin(): def one_step( rng_key: PRNGKey, - position: PyTree, - logdensity_grad: PyTree, + position: ArrayLikeTree, + logdensity_grad: ArrayLikeTree, step_size: float, temperature: float = 1.0, - ) -> PyTree: + ) -> ArrayTree: noise = generate_gaussian_noise(rng_key, position) position = jax.tree_util.tree_map( lambda p, g, n: p @@ -62,9 +62,9 @@ def sghmc(alpha: float = 0.01, beta: float = 0): def one_step( rng_key: PRNGKey, - position: PyTree, - momentum: PyTree, - logdensity_grad: PyTree, + position: ArrayLikeTree, + momentum: ArrayLikeTree, + logdensity_grad: ArrayLikeTree, step_size: float, temperature: float = 1.0, ): @@ -98,10 +98,10 @@ def sgnht(alpha: float = 0.01, beta: float = 0): def one_step( rng_key: PRNGKey, - position: PyTree, - momentum: PyTree, + position: ArrayLikeTree, + momentum: ArrayLikeTree, xi: float, - logdensity_grad: PyTree, + logdensity_grad: ArrayLikeTree, step_size: float, temperature: float = 1.0, ): diff --git a/blackjax/sgmcmc/gradients.py b/blackjax/sgmcmc/gradients.py index ed4d2a8ac..a326fefaa 100644 --- a/blackjax/sgmcmc/gradients.py +++ b/blackjax/sgmcmc/gradients.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp -from blackjax.types import PyTree +from blackjax.types import ArrayLikeTree, ArrayTree def logdensity_estimator( @@ -44,7 +44,9 @@ def logdensity_estimator( """ - def logdensity_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: + def logdensity_estimator_fn( + position: ArrayLikeTree, minibatch: ArrayLikeTree + ) -> ArrayTree: """Return an approximation of the log-posterior density. Parameters @@ -82,8 +84,8 @@ def grad_estimator( def control_variates( logdensity_grad_estimator: Callable, - centering_position: PyTree, - data: PyTree, + centering_position: ArrayLikeTree, + data: ArrayLikeTree, ) -> Callable: """Builds a control variate gradient estimator :cite:p:`baker2019control`. @@ -101,7 +103,9 @@ def control_variates( """ cv_grad_value = logdensity_grad_estimator(centering_position, data) - def cv_grad_estimator_fn(position: PyTree, minibatch: PyTree) -> PyTree: + def cv_grad_estimator_fn( + position: ArrayLikeTree, minibatch: ArrayLikeTree + ) -> ArrayTree: """Return an approximation of the log-posterior density. Parameters diff --git a/blackjax/sgmcmc/sghmc.py b/blackjax/sgmcmc/sghmc.py index 36db37c5b..0ca430077 100644 --- a/blackjax/sgmcmc/sghmc.py +++ b/blackjax/sgmcmc/sghmc.py @@ -18,13 +18,13 @@ import blackjax.sgmcmc.diffusions as diffusions from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise __all__ = ["init", "build_kernel", "sghmc"] -def init(position: PyTree) -> PyTree: +def init(position: ArrayLikeTree) -> ArrayLikeTree: return position @@ -34,13 +34,13 @@ def build_kernel(alpha: float = 0.01, beta: float = 0) -> Callable: def kernel( rng_key: PRNGKey, - position: PyTree, + position: ArrayLikeTree, grad_estimator: Callable, - minibatch: PyTree, + minibatch: ArrayLikeTree, step_size: float, num_integration_steps: int, temperature: float = 1.0, - ) -> PyTree: + ) -> ArrayTree: def body_fn(state, rng_key): position, momentum = state logdensity_grad = grad_estimator(position, minibatch) @@ -123,16 +123,16 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel(alpha, beta) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position) def step_fn( rng_key: PRNGKey, - state: PyTree, - minibatch: PyTree, + state: ArrayLikeTree, + minibatch: ArrayLikeTree, step_size: float, temperature: float = 1, - ) -> PyTree: + ) -> ArrayTree: return kernel( rng_key, state, diff --git a/blackjax/sgmcmc/sgld.py b/blackjax/sgmcmc/sgld.py index b3602b321..afd7086b9 100644 --- a/blackjax/sgmcmc/sgld.py +++ b/blackjax/sgmcmc/sgld.py @@ -16,12 +16,12 @@ import blackjax.sgmcmc.diffusions as diffusions from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["init", "build_kernel", "sgld"] -def init(position: PyTree) -> PyTree: +def init(position: ArrayLikeTree) -> ArrayLikeTree: return position @@ -31,12 +31,12 @@ def build_kernel() -> Callable: def kernel( rng_key: PRNGKey, - position: PyTree, + position: ArrayLikeTree, grad_estimator: Callable, - minibatch: PyTree, + minibatch: ArrayLikeTree, step_size: float, temperature: float = 1.0, - ): + ) -> ArrayTree: logdensity_grad = grad_estimator(position, minibatch) new_position = integrator( rng_key, position, logdensity_grad, step_size, temperature @@ -109,16 +109,16 @@ def __new__( # type: ignore[misc] ) -> MCMCSamplingAlgorithm: kernel = cls.build_kernel() - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position) def step_fn( rng_key: PRNGKey, - state: PyTree, - minibatch: PyTree, + state: ArrayLikeTree, + minibatch: ArrayLikeTree, step_size: float, temperature: float = 1, - ) -> PyTree: + ) -> ArrayTree: return kernel( rng_key, state, grad_estimator, minibatch, step_size, temperature ) diff --git a/blackjax/sgmcmc/sgnht.py b/blackjax/sgmcmc/sgnht.py index 616efd04f..5a403080a 100644 --- a/blackjax/sgmcmc/sgnht.py +++ b/blackjax/sgmcmc/sgnht.py @@ -16,7 +16,7 @@ import blackjax.sgmcmc.diffusions as diffusions from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise __all__ = ["SGNHTState", "init", "build_kernel", "sgnht"] @@ -35,12 +35,12 @@ class SGNHTState(NamedTuple): Scalar thermostat controlling kinetic energy. """ - position: PyTree - momentum: PyTree + position: ArrayTree + momentum: ArrayTree xi: float -def init(position: PyTree, rng_key: PRNGKey, xi: float) -> SGNHTState: +def init(position: ArrayLikeTree, rng_key: PRNGKey, xi: float) -> SGNHTState: momentum = generate_gaussian_noise(rng_key, position) return SGNHTState(position, momentum, xi) @@ -53,10 +53,10 @@ def kernel( rng_key: PRNGKey, state: SGNHTState, grad_estimator: Callable, - minibatch: PyTree, + minibatch: ArrayLikeTree, step_size: float, temperature: float = 1.0, - ) -> PyTree: + ) -> ArrayTree: position, momentum, xi = state logdensity_grad = grad_estimator(position, minibatch) position, momentum, xi = integrator( @@ -133,14 +133,16 @@ def __new__( # type: ignore[misc] kernel = cls.build_kernel(alpha, beta) def init_fn( - position: PyTree, rng_key: PRNGKey, init_xi: Union[None, float] = None + position: ArrayLikeTree, + rng_key: PRNGKey, + init_xi: Union[None, float] = None, ): return cls.init(position, rng_key, init_xi or alpha) def step_fn( rng_key: PRNGKey, state: SGNHTState, - minibatch: PyTree, + minibatch: ArrayLikeTree, step_size: float, temperature: float = 1, ) -> SGNHTState: diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index f740ba82d..632dc3f38 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -21,7 +21,7 @@ import blackjax.smc.solver as solver import blackjax.smc.tempered as tempered from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, PRNGKey __all__ = ["build_kernel", "adaptive_tempered_smc"] @@ -159,7 +159,7 @@ def __new__( # type: ignore[misc] root_solver, ) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position) def step_fn(rng_key: PRNGKey, state): diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 7fbb54810..0732e0404 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -16,14 +16,14 @@ import jax import jax.numpy as jnp -from blackjax.types import PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey class SMCState(NamedTuple): """State of the SMC sampler""" - particles: PyTree - weights: jax.Array + particles: ArrayTree + weights: Array class SMCInfo(NamedTuple): @@ -31,7 +31,7 @@ class SMCInfo(NamedTuple): proposals: PyTree The particles that were proposed by the MCMC pass. - ancestors: jnp.ndarray + ancestors: Array The index of the particles proposed by the MCMC pass that were selected by the resampling step. log_likelihood_increment: float @@ -39,12 +39,14 @@ class SMCInfo(NamedTuple): """ - ancestors: jnp.ndarray + ancestors: Array log_likelihood_increment: float update_info: NamedTuple -def init(particles: PyTree): +def init(particles: ArrayLikeTree): + # Infer the number of particles from the size of the leading dimension of + # the first leaf of the inputted PyTree. num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] weights = jnp.ones(num_particles) / num_particles return SMCState(particles, weights) diff --git a/blackjax/smc/ess.py b/blackjax/smc/ess.py index 428e50fb5..67f5fea5f 100644 --- a/blackjax/smc/ess.py +++ b/blackjax/smc/ess.py @@ -18,14 +18,14 @@ import jax.numpy as jnp import jax.scipy as jsp -from blackjax.types import PyTree +from blackjax.types import Array, ArrayLikeTree -def ess(log_weights: jnp.ndarray) -> float: +def ess(log_weights: Array) -> float: return jnp.exp(log_ess(log_weights)) -def log_ess(log_weights: jnp.ndarray) -> float: +def log_ess(log_weights: Array) -> float: """Compute the effective sample size. Parameters @@ -46,7 +46,7 @@ def log_ess(log_weights: jnp.ndarray) -> float: def ess_solver( logdensity_fn: Callable, - particles: PyTree, + particles: ArrayLikeTree, target_ess: float, max_delta: float, root_solver: Callable, diff --git a/blackjax/smc/resampling.py b/blackjax/smc/resampling.py index 7a9b4553e..8e701e1b2 100644 --- a/blackjax/smc/resampling.py +++ b/blackjax/smc/resampling.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp -from blackjax.types import PRNGKey +from blackjax.types import Array, PRNGKey def _resampling_func(func, name, desc="", additional_params="") -> Callable: @@ -29,16 +29,16 @@ def _resampling_func(func, name, desc="", additional_params="") -> Callable: Parameters ---------- - key: jnp.ndarray + key: Array PRNGKey to use in resampling - weights: jnp.ndarray + weights: Array Weights to resample num_samples: int Number of particles to sample Returns ------- - idx: jnp.ndarray + idx: Array Array of size `num_samples` to use for resampling """ @@ -47,12 +47,12 @@ def _resampling_func(func, name, desc="", additional_params="") -> Callable: @partial(_resampling_func, name="Systematic") -def systematic(rng_key: PRNGKey, weights: jax.Array, num_samples: int) -> jax.Array: +def systematic(rng_key: PRNGKey, weights: Array, num_samples: int) -> Array: return _systematic_or_stratified(rng_key, weights, num_samples, True) @partial(_resampling_func, name="Stratified") -def stratified(rng_key: PRNGKey, weights: jax.Array, num_samples: int) -> jax.Array: +def stratified(rng_key: PRNGKey, weights: Array, num_samples: int) -> Array: return _systematic_or_stratified(rng_key, weights, num_samples, False) @@ -64,7 +64,7 @@ def stratified(rng_key: PRNGKey, weights: jax.Array, num_samples: int) -> jax.Ar and should only be used for illustration purposes, or if your algorithm *REALLY* needs independent samples.""", ) -def multinomial(rng_key: PRNGKey, weights: jax.Array, num_samples: int) -> jnp.ndarray: +def multinomial(rng_key: PRNGKey, weights: Array, num_samples: int) -> Array: # In practice we don't have to sort the generated uniforms, but searchsorted # works faster and is more stable if both inputs are sorted, so we use the # _sorted_uniforms from N. Chopin, but still use searchsorted instead of his @@ -89,7 +89,7 @@ def multinomial(rng_key: PRNGKey, weights: jax.Array, num_samples: int) -> jnp.n compatible. The main difference with Nicolas Chopin's code lies in the introduction of N+1 in the array as a 'sink state' for unused indices.""", ) -def residual(rng_key: PRNGKey, weights: jax.Array, num_samples: int) -> jax.Array: +def residual(rng_key: PRNGKey, weights: Array, num_samples: int) -> Array: key1, key2 = jax.random.split(rng_key) N = weights.shape[0] N_sample_weights = num_samples * weights @@ -122,8 +122,8 @@ def residual(rng_key: PRNGKey, weights: jax.Array, num_samples: int) -> jax.Arra def _systematic_or_stratified( - rng_key: PRNGKey, weights: jax.Array, num_samples: int, is_systematic: bool -) -> jax.Array: + rng_key: PRNGKey, weights: Array, num_samples: int, is_systematic: bool +) -> Array: n = weights.shape[0] if is_systematic: u = jax.random.uniform(rng_key, ()) @@ -135,7 +135,7 @@ def _systematic_or_stratified( return jnp.clip(idx, 0, n - 1) -def _sorted_uniforms(rng_key: PRNGKey, n) -> jax.Array: +def _sorted_uniforms(rng_key: PRNGKey, n) -> Array: # Credit goes to Nicolas Chopin us = jax.random.uniform(rng_key, (n + 1,)) z = jnp.cumsum(-jnp.log(us)) diff --git a/blackjax/smc/solver.py b/blackjax/smc/solver.py index 8e4a7b14b..0cd96e77a 100644 --- a/blackjax/smc/solver.py +++ b/blackjax/smc/solver.py @@ -38,7 +38,7 @@ def dichotomy(fun, _delta0, min_delta, max_delta, eps=1e-4, max_iter=100): Returns ------- - delta: jnp.ndarray, shape (,) + delta: Array, shape (,) The root of `fun` """ diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 433c5d311..9b52a86a5 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -19,7 +19,7 @@ import blackjax.smc as smc from blackjax.base import MCMCSamplingAlgorithm from blackjax.smc.base import SMCState -from blackjax.types import PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["TemperedSMCState", "init", "build_kernel"] @@ -34,12 +34,14 @@ class TemperedSMCState(NamedTuple): """ - particles: PyTree - weights: jax.Array + particles: ArrayTree + weights: Array lmbda: float -def init(particles: PyTree): +def init(particles: ArrayLikeTree): + # Infer the number of particles from the size of the leading dimension of + # the first leaf of the inputted PyTree. num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] weights = jnp.ones(num_particles) / num_particles return TemperedSMCState(particles, weights, 0.0) @@ -117,10 +119,10 @@ def kernel( """ delta = lmbda - state.lmbda - def log_weights_fn(position: PyTree) -> float: + def log_weights_fn(position: ArrayLikeTree) -> float: return delta * loglikelihood_fn(position) - def tempered_logposterior_fn(position: PyTree) -> float: + def tempered_logposterior_fn(position: ArrayLikeTree) -> float: logprior = logprior_fn(position) tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood @@ -201,7 +203,7 @@ def __new__( # type: ignore[misc] resampling_fn, ) - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position) def step_fn(rng_key: PRNGKey, state, lmbda): diff --git a/blackjax/types.py b/blackjax/types.py index dc2181a03..db6e7c76f 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -1,10 +1,45 @@ +# Copyright 2020- The Blackjax Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Iterable, Mapping, Union import jax -from chex import Array +from jax import Array +from jax.typing import ArrayLike +""" +Following the current best practice (https://jax.readthedocs.io/en/latest/jax.typing.html) +We use: +- `ArrayLike` and `ArrayLikeTree` to annotate function input, +- `Array` and `ArrayTree` to annotate function output. + +Leaves of a Pytree definition in the library are in principle annotated as +`Array`, as they are mostly internal representation. For example: +``` +class WelfordAlgorithmState(NamedTuple): + mean: Array + ... +``` + +[TODO] Improve scalar-like typing (e.g. `logdensity`, `acceptance_rate`). +While they are `Array` (as in most cases they should be output of a Jax +function), we annotate them as `float` to empathizes they should be scalar +(until we introduce shape annotation). +""" #: JAX PyTrees -PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]] +ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[Any, "ArrayTree"]] +ArrayLikeTree = Union[ + ArrayLike, Iterable["ArrayLikeTree"], Mapping[Any, "ArrayLikeTree"] +] #: JAX PRNGKey PRNGKey = jax.random.PRNGKeyArray diff --git a/blackjax/util.py b/blackjax/util.py index b659b24ab..1a7ebcd09 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -8,7 +8,7 @@ from jax.random import normal from jax.tree_util import tree_leaves -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @partial(jit, static_argnames=("precision",), inline=True) @@ -56,10 +56,10 @@ def linear_map(diag_or_dense_a, b, *, precision="highest"): # Refactor this function to not use ravel_pytree might be more performant. def generate_gaussian_noise( rng_key: PRNGKey, - position: PyTree, + position: ArrayLikeTree, mu: Union[float, Array] = 0.0, sigma: Union[float, Array] = 1.0, -) -> PyTree: +) -> ArrayTree: """Generate N(mu, sigma) noise with output structure that match a given PyTree. Parameters @@ -82,12 +82,12 @@ def generate_gaussian_noise( return unravel_fn(mu + linear_map(sigma, sample)) -def pytree_size(pytree: PyTree) -> int: +def pytree_size(pytree: ArrayLikeTree) -> int: """Return the dimension of the flatten PyTree.""" return sum(jnp.size(value) for value in tree_leaves(pytree)) -def index_pytree(input_pytree: PyTree) -> PyTree: +def index_pytree(input_pytree: ArrayLikeTree) -> ArrayTree: """Builds a PyTree with elements indicating its corresponding index on a flat array. Various algorithms in BlackJAX take as input a 1 or 2 dimensional array which somehow diff --git a/blackjax/vi/meanfield_vi.py b/blackjax/vi/meanfield_vi.py index 662d9c849..e7f5c409d 100644 --- a/blackjax/vi/meanfield_vi.py +++ b/blackjax/vi/meanfield_vi.py @@ -19,7 +19,7 @@ from optax import GradientTransformation, OptState from blackjax.base import VIAlgorithm -from blackjax.types import PRNGKey, PyTree +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ "MFVIState", @@ -32,8 +32,8 @@ class MFVIState(NamedTuple): - mu: PyTree - rho: PyTree + mu: ArrayTree + rho: ArrayTree opt_state: OptState @@ -42,7 +42,7 @@ class MFVIInfo(NamedTuple): def init( - position: PyTree, + position: ArrayLikeTree, optimizer: GradientTransformation, *optimizer_args, **optimizer_kwargs, @@ -138,7 +138,7 @@ def __new__( optimizer: GradientTransformation, num_samples: int = 100, ): # type: ignore[misc] - def init_fn(position: PyTree): + def init_fn(position: ArrayLikeTree): return cls.init(position, optimizer) def step_fn(rng_key: PRNGKey, state: MFVIState) -> Tuple[MFVIState, MFVIInfo]: diff --git a/blackjax/vi/pathfinder.py b/blackjax/vi/pathfinder.py index 44b766e07..7cb40e437 100644 --- a/blackjax/vi/pathfinder.py +++ b/blackjax/vi/pathfinder.py @@ -23,7 +23,7 @@ bfgs_sample, lbfgs_inverse_hessian_factors, ) -from blackjax.types import Array, PRNGKey, PyTree +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["PathfinderState", "approximate", "sample", "pathfinder"] @@ -50,8 +50,8 @@ class PathfinderState(NamedTuple): """ elbo: Array - position: PyTree - grad_position: PyTree + position: ArrayTree + grad_position: ArrayTree alpha: Array beta: Array gamma: Array @@ -71,7 +71,7 @@ class PathFinderAlgorithm(NamedTuple): def approximate( rng_key: PRNGKey, logdensity_fn: Callable, - initial_position: PyTree, + initial_position: ArrayLikeTree, num_samples: int = 200, *, # lgbfs parameters maxiter=30, @@ -201,7 +201,7 @@ def sample( rng_key: PRNGKey, state: PathfinderState, num_samples: Union[int, Tuple[()], Tuple[int]] = (), -) -> PyTree: +) -> ArrayTree: """Draw from the Pathfinder approximation of the target distribution. Parameters @@ -267,7 +267,7 @@ class pathfinder: def __new__(cls, logdensity_fn: Callable) -> PathFinderAlgorithm: # type: ignore[misc] def approximate_fn( rng_key: PRNGKey, - position: PyTree, + position: ArrayLikeTree, num_samples: int = 200, **lbfgs_parameters, ): diff --git a/blackjax/vi/svgd.py b/blackjax/vi/svgd.py index 4f2ffbd36..838921606 100644 --- a/blackjax/vi/svgd.py +++ b/blackjax/vi/svgd.py @@ -7,19 +7,19 @@ from jax.flatten_util import ravel_pytree from blackjax.base import MCMCSamplingAlgorithm -from blackjax.types import PyTree +from blackjax.types import ArrayLikeTree, ArrayTree __all__ = ["svgd", "rbf_kernel", "update_median_heuristic"] class SVGDState(NamedTuple): - particles: PyTree - kernel_parameters: Dict[str, PyTree] + particles: ArrayTree + kernel_parameters: Dict[str, ArrayTree] opt_state: Any def init( - initial_particles: PyTree, + initial_particles: ArrayLikeTree, kernel_parameters: Dict[str, Any], optimizer: optax.GradientTransformation, ) -> SVGDState: @@ -155,7 +155,7 @@ def __new__( kernel_ = cls.build_kernel(optimizer) def init_fn( - initial_position: PyTree, + initial_position: ArrayLikeTree, kernel_parameters: Dict[str, Any] = {"length_scale": 1.0}, ): return cls.init(initial_position, kernel_parameters, optimizer)