Skip to content

Commit

Permalink
Fix/dv3 cont dist (#224)
Browse files Browse the repository at this point in the history
* Scaled normal as in the paper

* Fix shift std before sigmoid

* init_std=2.0 as in Hafner for continuous actions

* fix: action continuous clip

* Compute action_clip once

* Removed unused parameter

---------

Co-authored-by: Michele Milesi <michele.milesi@studio.unibo.it>
  • Loading branch information
belerico and michele-milesi authored Mar 30, 2024
1 parent 5bc7d78 commit 5e83d51
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 22 deletions.
37 changes: 25 additions & 12 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state
from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.utils.distribution import TruncatedNormal
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward
from sheeprl.utils.utils import symlog
Expand Down Expand Up @@ -680,10 +679,12 @@ class Actor(nn.Module):
The number of actions if continuous, the dimension of the action if discrete.
is_continuous (bool): whether or not the actions are continuous.
distribution_cfg (Dict[str, Any]): The configs of the distributions.
init_std (float): the amount to sum to the input of the softplus function for the standard deviation.
init_std (float): the amount to sum to the standard deviation.
Default to 0.0.
min_std (float): the minimum standard deviation for the actions.
Default to 0.1.
Default to 1.0.
max_std (float): the maximum standard deviation for the actions.
Default to 1.0.
dense_units (int): the dimension of the hidden dense layers.
Default to 1024.
activation (int): the activation function to apply after the dense layers.
Expand All @@ -697,6 +698,8 @@ class Actor(nn.Module):
then `p = (1 - self.unimix) * p + self.unimix * unif`,
where `unif = `1 / self.discrete`.
Defaults to 0.01.
action_clip (float): the action clip parameter.
Default to 1.0.
"""

def __init__(
Expand All @@ -706,26 +709,28 @@ def __init__(
is_continuous: bool,
distribution_cfg: Dict[str, Any],
init_std: float = 0.0,
min_std: float = 0.1,
min_std: float = 1.0,
max_std: float = 1.0,
dense_units: int = 1024,
activation: nn.Module = nn.SiLU,
mlp_layers: int = 5,
layer_norm: bool = True,
unimix: float = 0.01,
action_clip: float = 1.0,
) -> None:
super().__init__()
self.distribution_cfg = distribution_cfg
self.distribution = distribution_cfg.get("type", "auto").lower()
if self.distribution not in ("auto", "normal", "tanh_normal", "discrete", "trunc_normal"):
if self.distribution not in ("auto", "normal", "tanh_normal", "discrete", "scaled_normal"):
raise ValueError(
"The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `trunc_normal`. "
"The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `scaled_normal`. "
f"Found: {self.distribution}"
)
if self.distribution == "discrete" and is_continuous:
raise ValueError("You have choose a discrete distribution but `is_continuous` is true")
if self.distribution == "auto":
if is_continuous:
self.distribution = "trunc_normal"
self.distribution = "scaled_normal"
else:
self.distribution = "discrete"
self.model = MLP(
Expand All @@ -746,9 +751,11 @@ def __init__(
self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, action_dim) for action_dim in actions_dim])
self.actions_dim = actions_dim
self.is_continuous = is_continuous
self.init_std = torch.tensor(init_std)
self.init_std = init_std
self.min_std = min_std
self.max_std = max_std
self._unimix = unimix
self._action_clip = action_clip

def forward(
self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None
Expand Down Expand Up @@ -780,16 +787,19 @@ def forward(
elif self.distribution == "normal":
actions_dist = Normal(mean, std)
actions_dist = Independent(actions_dist, 1)
elif self.distribution == "trunc_normal":
std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std
dist = TruncatedNormal(torch.tanh(mean), std, -1, 1)
elif self.distribution == "scaled_normal":
std = (self.max_std - self.min_std) * torch.sigmoid(std + self.init_std) + self.min_std
dist = Normal(torch.tanh(mean), std)
actions_dist = Independent(dist, 1)
if sample_actions:
actions = actions_dist.rsample()
else:
sample = actions_dist.sample((100,))
log_prob = actions_dist.log_prob(sample)
actions = sample[log_prob.argmax(0)].view(1, 1, -1)
if self._action_clip > 0.0:
action_clip = torch.full_like(actions, self._action_clip)
actions = actions * (action_clip / torch.maximum(action_clip, torch.abs(actions))).detach()
actions = [actions]
actions_dist = [actions_dist]
else:
Expand Down Expand Up @@ -826,6 +836,7 @@ def __init__(
mlp_layers: int = 5,
layer_norm: bool = True,
unimix: float = 0.01,
action_clip: float = 1.0,
) -> None:
super().__init__(
latent_state_size=latent_state_size,
Expand All @@ -839,6 +850,7 @@ def __init__(
mlp_layers=mlp_layers,
layer_norm=layer_norm,
unimix=unimix,
action_clip=action_clip,
)

def forward(
Expand Down Expand Up @@ -1093,7 +1105,7 @@ def build_agent(
continue_model.apply(init_weights),
)
actor_cls = hydra.utils.get_class(cfg.algo.actor.cls)
actor: nn.Module = actor_cls(
actor: Actor | MinedojoActor = actor_cls(
latent_state_size=latent_state_size,
actions_dim=actions_dim,
is_continuous=is_continuous,
Expand All @@ -1105,6 +1117,7 @@ def build_agent(
distribution_cfg=cfg.distribution,
layer_norm=actor_cfg.layer_norm,
unimix=cfg.algo.unimix,
action_clip=actor_cfg.action_clip,
)
critic = MLP(
input_dims=latent_state_size,
Expand Down
3 changes: 0 additions & 3 deletions sheeprl/algos/dreamer_v3/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def reconstruction_loss(
pc: Optional[Distribution] = None,
continue_targets: Optional[Tensor] = None,
continue_scale_factor: float = 1.0,
validate_args: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Compute the reconstruction loss as described in Eq. 5 in
Expand Down Expand Up @@ -49,8 +48,6 @@ def reconstruction_loss(
Default to None.
continue_scale_factor (float): the scale factor for the continue loss.
Default to 10.
validate_args (bool): Whether or not to validate distributions arguments.
Default to False.
Returns:
observation_loss (Tensor): the value of the observation loss.
Expand Down
6 changes: 4 additions & 2 deletions sheeprl/configs/algo/dreamer_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ actor:
cls: sheeprl.algos.dreamer_v3.agent.Actor
ent_coef: 3e-4
min_std: 0.1
init_std: 0.0
objective_mix: 1.0
max_std: 1.0
init_std: 2.0
dense_act: ${algo.dense_act}
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
dense_units: ${algo.dense_units}
clip_gradients: 100.0
unimix: ${algo.unimix}
action_clip: 1.0

# Disttributed percentile model (used to scale the values)
moments:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,20 @@ env:
task_name: swingup_sparse
from_vectors: False
from_pixels: True
seed: ${seed}

# Checkpoint
checkpoint:
every: 10000

# Buffer
buffer:
size: 100000
size: 1_000_000
checkpoint: True
memmap: True

# Algorithm
algo:
total_steps: 1000000
total_steps: 500_000
cnn_keys:
encoder:
- rgb
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ seed: 5
# Environment
env:
num_envs: 4
action_repeat: 2
max_episode_steps: -1
wrapper:
domain_name: walker
Expand All @@ -25,13 +26,13 @@ checkpoint:

# Buffer
buffer:
size: 100000
size: 1_000_000
checkpoint: True
memmap: True

# Algorithm
algo:
total_steps: 1000000
total_steps: 500_000
cnn_keys:
encoder:
- rgb
Expand Down

0 comments on commit 5e83d51

Please sign in to comment.