Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MInor docstring fix #612

Merged
merged 1 commit into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L.

"""
"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L."""

from typing import NamedTuple

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size # type: ignore
from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size


class MCLMCAdaptationState(NamedTuple):
"""Represents the tunable parameters for MCLMC adaptation.

Attributes:
L (float): The momentum decoherent rate for the MCLMC algorithm.
step_size (float): The step size used for the MCLMC algorithm.
L
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
"""

L: float
Expand All @@ -52,25 +51,39 @@ def mclmc_find_L_and_step_size(
"""
Finds the optimal value of the parameters for the MCLMC algorithm.

Args:
mclmc_kernel (callable): The kernel function used for the MCMC algorithm.
num_steps (int): The number of MCMC steps that will subsequently be run, after tuning.
state (MCMCState): The initial state of the MCMC algorithm.
rng_key (jax.random.PRNGKey): The random number generator key.
frac_tune1 (float): The fraction of tuning for the first step of the adaptation.
frac_tune2 (float): The fraction of tuning for the second step of the adaptation.
frac_tune3 (float): The fraction of tuning for the third step of the adaptation.
desired_energy_var (float): The desired energy variance for the MCMC algorithm.
trust_in_estimate (float): The trust in the estimate of optimal stepsize.
num_effective_samples (int): The number of effective samples for the MCMC algorithm.

Returns:
tuple: A tuple containing the final state of the MCMC algorithm and the final hyperparameters.

Raises:
None

Examples:
Parameters
----------
mclmc_kernel
The kernel function used for the MCMC algorithm.
num_steps
The number of MCMC steps that will subsequently be run, after tuning.
state
The initial state of the MCMC algorithm.
rng_key
The random number generator key.
frac_tune1
The fraction of tuning for the first step of the adaptation.
frac_tune2
The fraction of tuning for the second step of the adaptation.
frac_tune3
The fraction of tuning for the third step of the adaptation.
desired_energy_va
The desired energy variance for the MCMC algorithm.
trust_in_estimate
The trust in the estimate of optimal stepsize.
num_effective_samples
The number of effective samples for the MCMC algorithm.

Returns
-------
A tuple containing the final state of the MCMC algorithm and the final hyperparameters.


Examples
-------

.. code::

# Define the kernel function
def kernel(x):
return x ** 2
Expand Down Expand Up @@ -265,7 +278,8 @@ def adaptation_L(state, params, num_steps, key):


def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
"""if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case."""
"""if there are nans, let's reduce the stepsize, and not update the state. The
function returns the old state in this case."""

reduced_step_size = 0.8
p, unravel_fn = ravel_pytree(next_state.position)
Expand Down
32 changes: 16 additions & 16 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class MCLMCInfo(NamedTuple):
"""
Additional information on the MCLMC transition.

Attributes
----------
transformed_position :
transformed_position
The value of the samples after a transformation. This is typically a projection onto a lower dimensional subspace.
logdensity :
logdensity
The log-density of the distribution at the current step of the MCLMC chain.
energy_change :
kinetic_change
The difference in kinetic energy between the current and previous step.
energy_change
The difference in energy between the current and previous step.
"""

Expand Down Expand Up @@ -68,9 +68,9 @@ def build_kernel(logdensity_fn, integrator, transform):
transform
Value of the difference in energy above which we consider that the transition is divergent.
L
the momentum decoherence rate
the momentum decoherence rate.
step_size
step size of the integrator
step size of the integrator.

Returns
-------
Expand Down Expand Up @@ -136,8 +136,8 @@ class mclmc:

.. code::

step = jax.jit(mclmc.step)
new_state, info = step(rng_key, state)
step = jax.jit(mclmc.step)
new_state, info = step(rng_key, state)

Parameters
----------
Expand All @@ -146,11 +146,11 @@ class mclmc:
transform
A function to perform on the samples drawn from the target distribution
L
the momentum decoherence rate
the momentum decoherence rate
step_size
step size of the integrator
step size of the integrator
integrator
an integrator. We recommend using the default here.
an integrator. We recommend using the default here.

Returns
-------
Expand Down Expand Up @@ -185,13 +185,13 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L):

Parameters
----------
rng_key:
rng_key
The pseudo-random number generator key used to generate random numbers.
momentum:
momentum
PyTree that the structure the output should to match.
step_size:
step_size
Step size
L:
L
controls rate of momentum change

Returns
Expand Down
21 changes: 11 additions & 10 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@ class SMCState(NamedTuple):

Particles must be a ArrayTree, each leave represents a variable from the posterior,
being an array of size `(n_particles, ...)`.

Examples (three particles):
- Single univariate posterior:
[ Array([[1.], [1.2], [3.4]]) ]
- Single bivariate posterior:
[Array([[1,2], [3,4], [5,6]])]
- Two variables, each univariate:
[ Array([[1.], [1.2], [3.4]]),
Array([[50.], [51], [55]]) ]
- Two variables, first one bivariate, second one 4-variate:
[ Array([[1., 2.], [1.2, 0.5], [3.4, 50]]),
Array([[50., 51., 52., 51], [51., 52., 52. ,54.], [55., 60, 60, 70]])]
- Single univariate posterior:
[ Array([[1.], [1.2], [3.4]]) ]
- Single bivariate posterior:
[ Array([[1,2], [3,4], [5,6]]) ]
- Two variables, each univariate:
[ Array([[1.], [1.2], [3.4]]),
Array([[50.], [51], [55]]) ]
- Two variables, first one bivariate, second one 4-variate:
[ Array([[1., 2.], [1.2, 0.5], [3.4, 50]]),
Array([[50., 51., 52., 51], [51., 52., 52. ,54.], [55., 60, 60, 70]]) ]
"""

particles: ArrayTree
Expand Down
3 changes: 2 additions & 1 deletion blackjax/smc/tuning/from_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def mass_matrix_from_particles(particles) -> Array:
Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf
Computing a mass matrix to be used in HMC from particles.
Given the particles covariance matrix, set all non-diagonal elements as zero,
take the inverse, and keep the diagonal.
take the inverse, and keep the diagonal.

Returns
-------
A mass Matrix
Expand Down
2 changes: 1 addition & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jax>=0.4.16
jaxlib>=0.4.16
jaxopt
jupytext
myst_nb>=1.0.0rc0
myst_nb>=1.0.0
numba
numpyro
optax
Expand Down