Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Typing to follow Jax best practice #543

Merged
merged 4 commits into from
Jun 12, 2023
Merged

Change Typing to follow Jax best practice #543

merged 4 commits into from
Jun 12, 2023

Conversation

junpenglao
Copy link
Member

Update typing in the library following current best practice.

General guideline:

  • ArrayLike and ArrayLikeTree to annotate function input,
  • Array and ArrayTree to annotate function output.

Leaves of a Pytree definition in the library are in principle annotated as Array, as they are mostly internal representation. For example:

class WelfordAlgorithmState(NamedTuple):
    mean: Array
    ...

Something I am not happy about is the way logdensity is annotated. While they are Array (as they are in most cases should be output of a Jax function), we annotate them as float to empathizes they should be scalar. I am leaving them as float for now until we introduce shape annotation.

…/jax.readthedocs.io/en/latest/jax.typing.html).

General guideline:
- `ArrayLike` and `ArrayLikeTree` to annotate function input,
- `Array` and `ArrayTree` to annotate function output.

Leaves of a Pytree definition in the library are in principle annotated as `Array`, as they are mostly internal representation. For example:
```
class WelfordAlgorithmState(NamedTuple):
    mean: Array
    ...
```

Something I am not happy about is the way `logdensity` is annotated. While they are `Array` (as they are in most cases should be output of a Jax function), we annotate them as `float` to empathizes they should be scalar. I am leaving them as `float` for now until we introduce shape annotation.
This reverts commit 3e4923a.
@codecov
Copy link

codecov bot commented Jun 12, 2023

Codecov Report

Merging #543 (5a2eba4) into main (9c1b740) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main     #543   +/-   ##
=======================================
  Coverage   99.14%   99.14%           
=======================================
  Files          49       49           
  Lines        2094     2099    +5     
=======================================
+ Hits         2076     2081    +5     
  Misses         18       18           
Impacted Files Coverage Δ
blackjax/mcmc/marginal_latent_gaussian.py 100.00% <ø> (ø)
blackjax/smc/solver.py 100.00% <ø> (ø)
blackjax/adaptation/base.py 100.00% <100.00%> (ø)
blackjax/adaptation/mass_matrix.py 100.00% <100.00%> (ø)
blackjax/adaptation/meads_adaptation.py 100.00% <100.00%> (ø)
blackjax/adaptation/pathfinder_adaptation.py 100.00% <100.00%> (ø)
blackjax/adaptation/step_size.py 100.00% <100.00%> (ø)
blackjax/adaptation/window_adaptation.py 100.00% <100.00%> (ø)
blackjax/base.py 100.00% <100.00%> (ø)
blackjax/diagnostics.py 100.00% <100.00%> (ø)
... and 28 more

@albcab albcab merged commit a90ca8e into main Jun 12, 2023
@albcab albcab deleted the typing branch June 12, 2023 17:46
junpenglao added a commit that referenced this pull request Mar 12, 2024
Update typing in the library following current best practice (https://jax.readthedocs.io/en/latest/jax.typing.html).

General guideline:
- `ArrayLike` and `ArrayLikeTree` to annotate function input,
- `Array` and `ArrayTree` to annotate function output.

Leaves of a Pytree definition in the library are in principle annotated as `Array`, as they are mostly internal representation. For example:
```
class WelfordAlgorithmState(NamedTuple):
    mean: Array
    ...
```

Something I am not happy about is the way `logdensity` is annotated. While they are `Array` (as they are in most cases should be output of a Jax function), we annotate them as `float` to empathizes they should be scalar. I am leaving them as `float` for now until we introduce shape annotation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants