diff --git a/blackjax/types.py b/blackjax/types.py index 9fc697267..5f02bc661 100644 --- a/blackjax/types.py +++ b/blackjax/types.py @@ -1,6 +1,6 @@ from typing import Any, Iterable, Mapping, Union -import jax._src.prng as prng +import jax import jax.numpy as jnp import numpy as np @@ -8,10 +8,7 @@ Array = Union[np.ndarray, jnp.ndarray] #: JAX PyTrees -PyTree = Union[Array, Iterable[Array], Mapping[Any, Array]] -# It is not currently tested but we also support recursive PyTrees. -# Once recursive typing is fully supported (https://github.com/python/mypy/issues/731), we can uncomment the line below. -# PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]] +PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]] #: JAX PRNGKey -PRNGKey = prng.PRNGKeyArray +PRNGKey = jax.random.PRNGKeyArray