From 89038623b25eb430a74a044419ea929131ed9868 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Sat, 9 May 2020 22:44:35 +0100 Subject: [PATCH 01/22] multicategorical dist and test --- stable_baselines3/common/distributions.py | 53 +++++++++++++++++++++++ tests/test_distributions.py | 17 +++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index fc601efd6..bf4ef68b4 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -292,6 +292,59 @@ def log_prob(self, actions: th.Tensor) -> th.Tensor: return self.distribution.log_prob(actions) +class MultiCategoricalDistribution(Distribution): + """ + MultiCategorical distribution for multi discrete actions. + + :param action_dims: ([int]) List of sizes of discrete action spaces. + """ + def __init__(self, action_dims: [int]): + super(MultiCategoricalDistribution, self).__init__() + self.action_dims = action_dims + self.distributions = None + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits of the Categorical distribution. + You can then get probabilities using a softmax. + + :param latent_dim: (int) Dimension of the last layer + of the policy network (before the action layer) + :return: (nn.Linear) + """ + action_logits = nn.Linear(latent_dim, sum(self.action_dims)) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution': + reshaped_logits = th.split(action_logits, self.action_dims, dim=1) + self.distributions = [Categorical(logits=l) for l in reshaped_logits] + return self + + def mode(self) -> th.Tensor: + return th.stack([th.argmax(d.probs, dim=1) for d in self.distributions]) + + def sample(self) -> th.Tensor: + return th.stack([d.sample() for d in self.distributions]) + + def entropy(self) -> th.Tensor: + return sum([d.entropy() for d in self.distributions]) + + def actions_from_params(self, action_logits: th.Tensor, + deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return sum(d.log_prob(x) for d, x in zip(self.distributions, th.unbind(actions))) + + class StateDependentNoiseDistribution(Distribution): """ Distribution class for using State Dependent Exploration (SDE). diff --git a/tests/test_distributions.py b/tests/test_distributions.py index e9041f8dc..7c26d2f13 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -4,7 +4,9 @@ from stable_baselines3 import A2C, PPO from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector, StateDependentNoiseDistribution, - CategoricalDistribution, SquashedDiagGaussianDistribution) + CategoricalDistribution, + MultiCategoricalDistribution, + SquashedDiagGaussianDistribution) from stable_baselines3.common.utils import set_random_seed @@ -12,6 +14,8 @@ N_FEATURES = 3 N_SAMPLES = int(5e6) +N_ACTIONS_MULTI = [4,3,2] + def test_bijector(): """ @@ -96,3 +100,14 @@ def test_categorical(): entropy = dist.entropy() log_prob = dist.log_prob(actions) assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4) + +def test_multicategorical(): + dist = MultiCategoricalDistribution(N_ACTIONS_MULTI) + set_random_seed(1) + action_logits = th.rand(N_SAMPLES, sum(N_ACTIONS_MULTI)) + dist = dist.proba_distribution(action_logits) + + actions = dist.get_actions() + entropy = dist.entropy() + log_prob = dist.log_prob(actions) + assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4) \ No newline at end of file From edef83bbafa935a2e0fed849840d61d463c37f19 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Sun, 10 May 2020 13:57:23 +0100 Subject: [PATCH 02/22] fixed List annotation --- stable_baselines3/common/distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index bf4ef68b4..f4d5076ba 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Dict, Any +from typing import Optional, Tuple, Dict, Any, List import gym import torch as th @@ -298,7 +298,7 @@ class MultiCategoricalDistribution(Distribution): :param action_dims: ([int]) List of sizes of discrete action spaces. """ - def __init__(self, action_dims: [int]): + def __init__(self, action_dims: List[int]): super(MultiCategoricalDistribution, self).__init__() self.action_dims = action_dims self.distributions = None From f25b9f34ddc340033f7b4c80c5cc955319829df9 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Sun, 10 May 2020 14:48:56 +0100 Subject: [PATCH 03/22] bernoulli dist and test --- stable_baselines3/common/distributions.py | 65 ++++++++++++++++++++--- tests/test_distributions.py | 16 +++++- 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index f4d5076ba..522a190c2 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -3,7 +3,7 @@ import gym import torch as th import torch.nn as nn -from torch.distributions import Normal, Categorical +from torch.distributions import Normal, Categorical, Bernoulli from gym import spaces from stable_baselines3.common.preprocessing import get_action_dim @@ -306,7 +306,7 @@ def __init__(self, action_dims: List[int]): def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: - it will be the logits of the Categorical distribution. + it will be the logits of the MultiCategorical distribution. You can then get probabilities using a softmax. :param latent_dim: (int) Dimension of the last layer @@ -345,6 +345,59 @@ def log_prob(self, actions: th.Tensor) -> th.Tensor: return sum(d.log_prob(x) for d, x in zip(self.distributions, th.unbind(actions))) +class BernoulliDistribution(Distribution): + """ + Bernoulli distribution for MultiBinary action spaces. + + :param action_dim: (int) Number of binary action dimensions. + """ + + def __init__(self, action_dims: int): + super(BernoulliDistribution, self).__init__() + self.distribution = None + self.action_dims = action_dims + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits of the Bernoulli distribution. + You can then get probabilities using a softmax. + + :param latent_dim: (int) Dimension of the last layer + of the policy network (before the action layer) + :return: (nn.Linear) + """ + action_logits = nn.Linear(latent_dim, self.action_dims) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution': + self.distribution = Bernoulli(logits=action_logits) + return self + + def mode(self) -> th.Tensor: + return th.argmax(self.distribution.probs, dim=1) + + def sample(self) -> th.Tensor: + return self.distribution.sample() + + def entropy(self) -> th.Tensor: + return self.distribution.entropy() + + def actions_from_params(self, action_logits: th.Tensor, + deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions) + + class StateDependentNoiseDistribution(Distribution): """ Distribution class for using State Dependent Exploration (SDE). @@ -602,10 +655,10 @@ def make_proba_distribution(action_space: gym.spaces.Space, return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) - # elif isinstance(action_space, spaces.MultiDiscrete): - # return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) - # elif isinstance(action_space, spaces.MultiBinary): - # return BernoulliDistribution(action_space.n, **dist_kwargs) + elif isinstance(action_space, spaces.MultiDiscrete): + return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) + elif isinstance(action_space, spaces.MultiBinary): + return BernoulliDistribution(action_space.n, **dist_kwargs) else: raise NotImplementedError("Error: probability distribution, not implemented for action space" f"of type {type(action_space)}." diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 7c26d2f13..b657980ef 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -3,10 +3,11 @@ from stable_baselines3 import A2C, PPO from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector, - StateDependentNoiseDistribution, + SquashedDiagGaussianDistribution, CategoricalDistribution, MultiCategoricalDistribution, - SquashedDiagGaussianDistribution) + BernoulliDistribution, + StateDependentNoiseDistribution) from stable_baselines3.common.utils import set_random_seed @@ -107,6 +108,17 @@ def test_multicategorical(): action_logits = th.rand(N_SAMPLES, sum(N_ACTIONS_MULTI)) dist = dist.proba_distribution(action_logits) + actions = dist.get_actions() + entropy = dist.entropy() + log_prob = dist.log_prob(actions) + assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4) + +def test_bernoulli(): + dist = BernoulliDistribution(N_ACTIONS) + set_random_seed(1) + action_logits = th.rand(N_SAMPLES, N_ACTIONS) + dist = dist.proba_distribution(action_logits) + actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) From 7f8dc906a97359b1fa25208df3e7001e2150e062 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Sun, 10 May 2020 16:59:27 +0100 Subject: [PATCH 04/22] added distributions to preprocessing (needs testing) --- stable_baselines3/common/preprocessing.py | 25 +++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 434355dc6..460561682 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -44,7 +44,7 @@ def is_image_space(observation_space: spaces.Space, return n_channels in [1, 3, 4] return False - +#TODO: test def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor: """ @@ -65,11 +65,17 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, elif isinstance(observation_space, spaces.Discrete): # One hot encoding and convert to float to avoid errors return F.one_hot(obs.long(), num_classes=observation_space.n).float() + elif isinstance(observation_space, spaces.MultiDiscrete): + return th.cat([ + F.one_hot(split.long(), num_classes=len(split)).float() + for split in enumerate( + th.split(obs, observation_space.nvec, dim=1))], dim=1) + elif isinstance(observation_space, spaces.MultiBinary): + return obs.float() else: - # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict raise NotImplementedError() - +#TODO: test def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: """ Get the shape of the observation (useful for the buffers). @@ -81,9 +87,12 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: return observation_space.shape elif isinstance(observation_space, spaces.Discrete): # Observation is an int - return 1, + return 1 + elif isinstance(observation_space, spaces.MultiDiscrete): + return observation_space.shape + elif isinstance(observation_space, spaces.MultiBinary): + return observation_space.shape else: - # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict raise NotImplementedError() @@ -98,7 +107,7 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: # Use Gym internal method return spaces.utils.flatdim(observation_space) - +#TODO: test def get_action_dim(action_space: spaces.Space) -> int: """ Get the dimension of the action space. @@ -111,5 +120,9 @@ def get_action_dim(action_space: spaces.Space) -> int: elif isinstance(action_space, spaces.Discrete): # Action is an int return 1 + elif isinstance(action_space, spaces.MultiDiscrete): + return int(len(action_space.shape)) + elif isinstance(action_space, spaces.MultiBinary): + return int(len(action_space.shape)) else: raise NotImplementedError() From 06276a7f566be3c6f94383063930cd60c9c547f4 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Tue, 12 May 2020 20:46:07 +0100 Subject: [PATCH 05/22] fixed and tested distributions --- stable_baselines3/common/distributions.py | 49 ++++++---- stable_baselines3/common/policies.py | 110 +++++++++++++--------- stable_baselines3/common/preprocessing.py | 32 ++++--- tests/test_distributions.py | 51 +++++----- 4 files changed, 139 insertions(+), 103 deletions(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 522a190c2..74bc57100 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -5,6 +5,7 @@ import torch.nn as nn from torch.distributions import Normal, Categorical, Bernoulli from gym import spaces +import numpy as np from stable_baselines3.common.preprocessing import get_action_dim @@ -88,7 +89,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: :return: (th.Tensor) shape: (n_batch,) """ if len(tensor.shape) > 1: - tensor = tensor.sum(axis=1) + tensor = tensor.sum(dim=1) else: tensor = tensor.sum() return tensor @@ -122,7 +123,8 @@ def proba_distribution_net(self, latent_dim: int, """ mean_actions = nn.Linear(latent_dim, self.action_dim) # TODO: allow action dependent std - log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True) + log_std = nn.Parameter(th.ones(self.action_dim) + * log_std_init, requires_grad=True) return mean_actions, log_std def proba_distribution(self, mean_actions: th.Tensor, @@ -198,7 +200,8 @@ def __init__(self, action_dim: int, epsilon: float = 1e-6): def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> 'SquashedDiagGaussianDistribution': - super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) + super(SquashedDiagGaussianDistribution, + self).proba_distribution(mean_actions, log_std) return self def mode(self) -> th.Tensor: @@ -232,7 +235,8 @@ def log_prob(self, actions: th.Tensor, gaussian_actions = TanhBijector.inverse(actions) # Log likelihood for a Gaussian distribution - log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions) + log_prob = super(SquashedDiagGaussianDistribution, + self).log_prob(gaussian_actions) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1) @@ -296,8 +300,9 @@ class MultiCategoricalDistribution(Distribution): """ MultiCategorical distribution for multi discrete actions. - :param action_dims: ([int]) List of sizes of discrete action spaces. + :param action_dims: ([int]) List of sizes of discrete action spaces """ + def __init__(self, action_dims: List[int]): super(MultiCategoricalDistribution, self).__init__() self.action_dims = action_dims @@ -306,18 +311,18 @@ def __init__(self, action_dims: List[int]): def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: - it will be the logits of the MultiCategorical distribution. - You can then get probabilities using a softmax. + it will be the logits (flattend) of the MultiCategorical distribution. + You can then get probabilities using a softmax on each sub-space. :param latent_dim: (int) Dimension of the last layer of the policy network (before the action layer) :return: (nn.Linear) """ - action_logits = nn.Linear(latent_dim, sum(self.action_dims)) - return action_logits + action_logits = nn.Linear(latent_dim, np.sum(self.action_dims)) + return action_logits def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution': - reshaped_logits = th.split(action_logits, self.action_dims, dim=1) + reshaped_logits = th.split(action_logits, tuple(self.action_dims), dim=1) self.distributions = [Categorical(logits=l) for l in reshaped_logits] return self @@ -328,28 +333,30 @@ def sample(self) -> th.Tensor: return th.stack([d.sample() for d in self.distributions]) def entropy(self) -> th.Tensor: - return sum([d.entropy() for d in self.distributions]) + return th.stack([d.entropy() for d in self.distributions], dim=1).sum(dim=1) def actions_from_params(self, action_logits: th.Tensor, - deterministic: bool = False) -> th.Tensor: + deterministic: bool = True) -> th.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) return actions, log_prob def log_prob(self, actions: th.Tensor) -> th.Tensor: - return sum(d.log_prob(x) for d, x in zip(self.distributions, th.unbind(actions))) - + return th.stack([d.log_prob(x) for d, x in zip(self.distributions, + th.unbind(actions))], dim=1).sum(dim=1) + class BernoulliDistribution(Distribution): """ Bernoulli distribution for MultiBinary action spaces. - :param action_dim: (int) Number of binary action dimensions. + :param action_dim: (int) Number of binary actions """ def __init__(self, action_dims: int): @@ -361,7 +368,6 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: it will be the logits of the Bernoulli distribution. - You can then get probabilities using a softmax. :param latent_dim: (int) Dimension of the last layer of the policy network (before the action layer) @@ -396,7 +402,7 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th. def log_prob(self, actions: th.Tensor) -> th.Tensor: return self.distribution.log_prob(actions) - + class StateDependentNoiseDistribution(Distribution): """ @@ -502,7 +508,8 @@ def proba_distribution_net(self, latent_dim: int, log_std_init: float = -2.0, # can be different between the policy and the noise network self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim # Reduce the number of parameters if needed - log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1) + log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones( + self.latent_sde_dim, 1) # Transform it to a parameter so it can be optimized log_std = nn.Parameter(log_std * log_std_init, requires_grad=True) # Sample an exploration matrix @@ -523,7 +530,8 @@ def proba_distribution(self, mean_actions: th.Tensor, # Stop gradient if we don't want to influence the features self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() variance = th.mm(latent_sde ** 2, self.get_std(log_std) ** 2) - self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon)) + self.distribution = Normal( + mean_actions, th.sqrt(variance + self.epsilon)) return self def mode(self) -> th.Tensor: @@ -649,7 +657,8 @@ def make_proba_distribution(action_space: gym.spaces.Space, dist_kwargs = {} if isinstance(action_space, spaces.Box): - assert len(action_space.shape) == 1, "Error: the action space must be a vector" + assert len( + action_space.shape) == 1, "Error: the action space must be a vector" if use_sde: return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs) return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index aa375aa61..1394eda87 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -43,7 +43,8 @@ class FlattenExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) + super(FlattenExtractor, self).__init__( + observation_space, get_flattened_obs_dim(observation_space)) self.flatten = nn.Flatten() def forward(self, observations: th.Tensor) -> th.Tensor: @@ -70,17 +71,21 @@ def __init__(self, observation_space: gym.spaces.Box, n_input_channels = observation_space.shape[0] self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), + nn.Conv2d(32, 64, kernel_size=4, + stride=2, padding=0), nn.ReLU(), - nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), + nn.Conv2d(64, 32, kernel_size=3, + stride=1, padding=0), nn.ReLU(), nn.Flatten()) # Compute shape by doing one forward pass with th.no_grad(): - n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1] + n_flatten = self.cnn(th.as_tensor( + observation_space.sample()[None]).float()).shape[1] - self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) + self.linear = nn.Sequential( + nn.Linear(n_flatten, features_dim), nn.ReLU()) def forward(self, observations: th.Tensor) -> th.Tensor: return self.linear(self.cnn(observations)) @@ -148,7 +153,8 @@ def extract_features(self, obs: th.Tensor) -> th.Tensor: :return: (th.Tensor) """ assert self.features_extractor is not None, 'No feature extractor was set' - preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) + preprocessed_obs = preprocess_obs( + obs, self.observation_space, normalize_images=self.normalize_images) return self.features_extractor(preprocessed_obs) @property @@ -216,7 +222,8 @@ def predict(self, observation: np.ndarray, or transpose_obs.shape[1:] == self.observation_space.shape): observation = transpose_obs - vectorized_env = self._is_vectorized_observation(observation, self.observation_space) + vectorized_env = self._is_vectorized_observation( + observation, self.observation_space) observation = observation.reshape((-1,) + self.observation_space.shape) @@ -233,11 +240,13 @@ def predict(self, observation: np.ndarray, clipped_actions = actions # Clip the actions to avoid out of bound error when using gaussian distribution if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output: - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + clipped_actions = np.clip( + actions, self.action_space.low, self.action_space.high) if not vectorized_env: if state is not None: - raise ValueError("Error: The environment must be vectorized when using recurrent policies.") + raise ValueError( + "Error: The environment must be vectorized when using recurrent policies.") clipped_actions = clipped_actions[0] return clipped_actions, state @@ -291,25 +300,25 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s else: raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") - # TODO: add support for MultiDiscrete and MultiBinary observation spaces - # elif isinstance(observation_space, gym.spaces.MultiDiscrete): - # if observation.shape == (len(observation_space.nvec),): - # return False - # elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): - # return True - # else: - # raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) + - # "environment, please use ({},) or ".format(len(observation_space.nvec)) + - # "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) - # elif isinstance(observation_space, gym.spaces.MultiBinary): - # if observation.shape == (observation_space.n,): - # return False - # elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: - # return True - # else: - # raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) + - # "environment, please use ({},) or ".format(observation_space.n) + - # "(n_env, {}) for the observation shape.".format(observation_space.n)) + + elif isinstance(observation_space, gym.spaces.MultiDiscrete): + if observation.shape == (len(observation_space.nvec),): + return False + elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): + return True + else: + raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) + + "environment, please use ({},) or ".format(len(observation_space.nvec)) + + "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) + elif isinstance(observation_space, gym.spaces.MultiBinary): + if observation.shape == (observation_space.n,): + return False + elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: + return True + else: + raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) + + "environment, please use ({},) or ".format(observation_space.n) + + "(n_env, {}) for the observation shape.".format(observation_space.n)) else: raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}." .format(observation_space)) @@ -336,7 +345,8 @@ def save(self, path: str) -> None: :param path: (str) """ - th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) + th.save({'state_dict': self.state_dict(), + 'data': self._get_data()}, path) @classmethod def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy': @@ -362,7 +372,8 @@ def load_from_vector(self, vector: np.ndarray): :param vector: (np.ndarray) """ - th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) + th.nn.utils.vector_to_parameters(th.FloatTensor( + vector).to(self.device), self.parameters()) def parameters_to_vector(self) -> np.ndarray: """ @@ -426,13 +437,16 @@ def create_sde_features_extractor(features_dim: int, # Special case: when using states as features (i.e. sde_net_arch is an empty list) # don't use any activation function sde_activation = activation_fn if len(sde_net_arch) > 0 else None - latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False) - latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim + latent_sde_net = create_mlp( + features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False) + latent_sde_dim = sde_net_arch[-1] if len( + sde_net_arch) > 0 else features_dim sde_features_extractor = nn.Sequential(*latent_sde_net) return sde_features_extractor, latent_sde_dim -_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] +# type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] +_policy_registry = dict() def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]: @@ -444,7 +458,8 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ :return: (Type[BasePolicy]) the policy """ if base_policy_type not in _policy_registry: - raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") + raise ValueError( + f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: raise ValueError(f"Error: unknown policy type {name}," "the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") @@ -469,12 +484,14 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None: except AttributeError: sub_class = str(th.random.randint(100)) if sub_class is None: - raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") + raise ValueError( + f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") if sub_class not in _policy_registry: _policy_registry[sub_class] = {} if name in _policy_registry[sub_class]: - raise ValueError(f"Error: the name {name} is alreay registered for a different policy, will not override.") + raise ValueError( + f"Error: the name {name} is alreay registered for a different policy, will not override.") _policy_registry[sub_class][name] = policy @@ -513,8 +530,10 @@ def __init__(self, feature_dim: int, device = get_device(device) shared_net, policy_net, value_net = [], [], [] - policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network - value_only_layers = [] # Layer sizes of the network that only belongs to the value network + # Layer sizes of the network that only belongs to the policy network + policy_only_layers = [] + # Layer sizes of the network that only belongs to the value network + value_only_layers = [] last_layer_dim_shared = feature_dim # Iterate through the shared layers and build the shared parts of the network @@ -526,13 +545,16 @@ def __init__(self, feature_dim: int, shared_net.append(activation_fn()) last_layer_dim_shared = layer_size else: - assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" + assert isinstance( + layer, dict), "Error: the net_arch list can only contain ints and dicts" if 'pi' in layer: - assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers." + assert isinstance( + layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers." policy_only_layers = layer['pi'] if 'vf' in layer: - assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers." + assert isinstance( + layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers." value_only_layers = layer['vf'] break # From here on the network splits up in policy and value network @@ -542,13 +564,15 @@ def __init__(self, feature_dim: int, # Build the non-shared part of the network for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)): if pi_layer_size is not None: - assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers." + assert isinstance( + pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers." policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size)) policy_net.append(activation_fn()) last_layer_dim_pi = pi_layer_size if vf_layer_size is not None: - assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers." + assert isinstance( + vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers." value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size)) value_net.append(activation_fn()) last_layer_dim_vf = vf_layer_size diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 460561682..3c055787d 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -44,7 +44,7 @@ def is_image_space(observation_space: spaces.Space, return n_channels in [1, 3, 4] return False -#TODO: test + def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor: """ @@ -66,16 +66,16 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, # One hot encoding and convert to float to avoid errors return F.one_hot(obs.long(), num_classes=observation_space.n).float() elif isinstance(observation_space, spaces.MultiDiscrete): - return th.cat([ - F.one_hot(split.long(), num_classes=len(split)).float() - for split in enumerate( - th.split(obs, observation_space.nvec, dim=1))], dim=1) + # Tensor concatination of one hot encodings of each Categorical sub-space + x = th.cat([F.one_hot(o.long(), num_classes=n).float() + for o, n in zip(obs, observation_space.nvec)], dim=1) + return x.t() elif isinstance(observation_space, spaces.MultiBinary): return obs.float() else: raise NotImplementedError() -#TODO: test + def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: """ Get the shape of the observation (useful for the buffers). @@ -86,12 +86,14 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: if isinstance(observation_space, spaces.Box): return observation_space.shape elif isinstance(observation_space, spaces.Discrete): - # Observation is an int - return 1 + # One observation + return 1, elif isinstance(observation_space, spaces.MultiDiscrete): - return observation_space.shape + # Observation is the number of discrete spaces + return int(len(observation_space.nvec)), elif isinstance(observation_space, spaces.MultiBinary): - return observation_space.shape + # Observation is the number of binary spaces + return int(observation_space.n), else: raise NotImplementedError() @@ -107,7 +109,7 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: # Use Gym internal method return spaces.utils.flatdim(observation_space) -#TODO: test + def get_action_dim(action_space: spaces.Space) -> int: """ Get the dimension of the action space. @@ -118,11 +120,13 @@ def get_action_dim(action_space: spaces.Space) -> int: if isinstance(action_space, spaces.Box): return int(np.prod(action_space.shape)) elif isinstance(action_space, spaces.Discrete): - # Action is an int + # One action return 1 elif isinstance(action_space, spaces.MultiDiscrete): - return int(len(action_space.shape)) + # Action is the number of discrete spaces + return int(len(action_space.nvec)) elif isinstance(action_space, spaces.MultiBinary): - return int(len(action_space.shape)) + # Action is the number of binary spaces + return int(action_space.n) else: raise NotImplementedError() diff --git a/tests/test_distributions.py b/tests/test_distributions.py index b657980ef..1dd3d3796 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -3,11 +3,11 @@ from stable_baselines3 import A2C, PPO from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector, - SquashedDiagGaussianDistribution, - CategoricalDistribution, - MultiCategoricalDistribution, - BernoulliDistribution, - StateDependentNoiseDistribution) + SquashedDiagGaussianDistribution, + CategoricalDistribution, + MultiCategoricalDistribution, + BernoulliDistribution, + StateDependentNoiseDistribution) from stable_baselines3.common.utils import set_random_seed @@ -15,8 +15,6 @@ N_FEATURES = 3 N_SAMPLES = int(5e6) -N_ACTIONS_MULTI = [4,3,2] - def test_bijector(): """ @@ -37,7 +35,8 @@ def test_squashed_gaussian(model_class): """ Test run with squashed Gaussian (notably entropy computation) """ - model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True)) + model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, + n_steps=100, policy_kwargs=dict(squash_output=True)) model.learn(500) gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS) @@ -47,11 +46,13 @@ def test_squashed_gaussian(model_class): actions = dist.get_actions() assert th.max(th.abs(actions)) <= 1.0 + def test_sde_distribution(): n_actions = 1 deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1 state = th.ones(N_SAMPLES, N_FEATURES) * 0.3 - dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False) + dist = StateDependentNoiseDistribution( + n_actions, full_std=True, squash_output=False) set_random_seed(1) _, log_std = dist.proba_distribution_net(N_FEATURES) @@ -60,8 +61,10 @@ def test_sde_distribution(): dist = dist.proba_distribution(deterministic_actions, log_std, state) actions = dist.get_actions() - assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=2e-3) - assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=2e-3) + assert th.allclose( + actions.mean(), dist.distribution.mean.mean(), rtol=2e-3) + assert th.allclose( + actions.std(), dist.distribution.scale.mean(), rtol=2e-3) # TODO: analytical form for squashed Gaussian? @@ -75,7 +78,8 @@ def test_entropy(dist): set_random_seed(1) state = th.rand(N_SAMPLES, N_FEATURES) deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS) - _, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2))) + _, log_std = dist.proba_distribution_net( + N_FEATURES, log_std_init=th.log(th.tensor(0.2))) if isinstance(dist, DiagGaussianDistribution): dist = dist.proba_distribution(deterministic_actions, log_std) @@ -89,31 +93,26 @@ def test_entropy(dist): assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) -def test_categorical(): +categorical_param = [ + (CategoricalDistribution(2), 2), + (MultiCategoricalDistribution([4, 3, 2]), sum([4, 3, 2])) +] +@pytest.mark.parametrize("dist, N_ACTIONS", categorical_param) +def test_categorical(dist, N_ACTIONS): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == entropy - dist = CategoricalDistribution(N_ACTIONS) set_random_seed(1) action_logits = th.rand(N_SAMPLES, N_ACTIONS) dist = dist.proba_distribution(action_logits) - actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4) -def test_multicategorical(): - dist = MultiCategoricalDistribution(N_ACTIONS_MULTI) - set_random_seed(1) - action_logits = th.rand(N_SAMPLES, sum(N_ACTIONS_MULTI)) - dist = dist.proba_distribution(action_logits) - - actions = dist.get_actions() - entropy = dist.entropy() - log_prob = dist.log_prob(actions) - assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4) def test_bernoulli(): + # The entropy can be approximated by averaging the negative log likelihood + # mean negative log likelihood == entropy dist = BernoulliDistribution(N_ACTIONS) set_random_seed(1) action_logits = th.rand(N_SAMPLES, N_ACTIONS) @@ -122,4 +121,4 @@ def test_bernoulli(): actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) - assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4) \ No newline at end of file + assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4) From 394a9bef7b7c68d280942c8055144476681eacc5 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Tue, 12 May 2020 21:38:16 +0100 Subject: [PATCH 06/22] added changelog and fixed ppo policy --- CHANGELOG.md | 12 ++++++++++++ stable_baselines3/ppo/policies.py | 13 ++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..99aafef60 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,12 @@ +Changelog +========== + +Support for MultiBinary / MultiDiscrete spaces +--------------------------- + +New Features: +^^^^^^^^^^^^^ +- Implemented MultiCategorical & Bernoulli distributions in `common/distributions.py` +- Added support for MultiCategorial & Bernoulli observation / action spaces in `preprocessing.py`, `ppo/policies.py` +- Merged the Categorical and MultiCategorical tests, added Bernoulli test in `test_distributions.py` + diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 874222060..617bac3cb 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -11,7 +11,8 @@ BaseFeaturesExtractor, FlattenExtractor) from stable_baselines3.common.distributions import (make_proba_distribution, Distribution, DiagGaussianDistribution, CategoricalDistribution, - StateDependentNoiseDistribution) + StateDependentNoiseDistribution, MultiCategoricalDistribution, + BernoulliDistribution) class PPOPolicy(BasePolicy): @@ -178,6 +179,10 @@ def _build(self, lr_schedule: Callable) -> None: log_std_init=self.log_std_init) elif isinstance(self.action_dist, CategoricalDistribution): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + elif isinstance(self.action_dist, BernoulliDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) # Init weights: use orthogonal initialization @@ -249,6 +254,12 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor, elif isinstance(self.action_dist, CategoricalDistribution): # Here mean_actions are the logits before the softmax return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the logits before the softmax + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, BernoulliDistribution): + # Here mean_actions are the logits before the softmax + return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) From 236153ad2be5d3edcc1cb4bc06f25ec73a11a871 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Tue, 12 May 2020 21:43:52 +0100 Subject: [PATCH 07/22] minor fix --- stable_baselines3/common/preprocessing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 3c055787d..b340dcd21 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -67,9 +67,8 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, return F.one_hot(obs.long(), num_classes=observation_space.n).float() elif isinstance(observation_space, spaces.MultiDiscrete): # Tensor concatination of one hot encodings of each Categorical sub-space - x = th.cat([F.one_hot(o.long(), num_classes=n).float() - for o, n in zip(obs, observation_space.nvec)], dim=1) - return x.t() + return th.cat([F.one_hot(o.long(), num_classes=n).float() + for o, n in zip(obs, observation_space.nvec)], dim=1).t() elif isinstance(observation_space, spaces.MultiBinary): return obs.float() else: From f8518f2fb03f56bf7c0403ede778bbbb1b47d2d3 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Wed, 13 May 2020 21:26:17 +0100 Subject: [PATCH 08/22] dist fixes, added test_spaces --- CHANGELOG.md | 12 --- stable_baselines3/common/buffers.py | 2 +- stable_baselines3/common/distributions.py | 38 ++++----- stable_baselines3/common/policies.py | 99 +++++++++-------------- stable_baselines3/common/preprocessing.py | 17 ++-- stable_baselines3/ppo/policies.py | 7 +- stable_baselines3/ppo/ppo.py | 2 - tests/test_distributions.py | 35 ++++---- tests/test_spaces.py | 55 +++++++++++++ 9 files changed, 137 insertions(+), 130 deletions(-) delete mode 100644 CHANGELOG.md create mode 100644 tests/test_spaces.py diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index 99aafef60..000000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,12 +0,0 @@ -Changelog -========== - -Support for MultiBinary / MultiDiscrete spaces ---------------------------- - -New Features: -^^^^^^^^^^^^^ -- Implemented MultiCategorical & Bernoulli distributions in `common/distributions.py` -- Added support for MultiCategorial & Bernoulli observation / action spaces in `preprocessing.py`, `ppo/policies.py` -- Merged the Categorical and MultiCategorical tests, added Bernoulli test in `test_distributions.py` - diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 4fb4422b2..a21d63d1e 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -303,7 +303,7 @@ def add(self, if len(log_prob.shape) == 0: # Reshape 0-d tensor to avoid error log_prob = log_prob.reshape(-1, 1) - + self.observations[self.pos] = np.array(obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 74bc57100..c7efdec97 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,5 +1,6 @@ from typing import Optional, Tuple, Dict, Any, List - +import time +from functools import partial import gym import torch as th import torch.nn as nn @@ -89,7 +90,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: :return: (th.Tensor) shape: (n_batch,) """ if len(tensor.shape) > 1: - tensor = tensor.sum(dim=1) + tensor = tensor.sum(axis=1) else: tensor = tensor.sum() return tensor @@ -123,8 +124,7 @@ def proba_distribution_net(self, latent_dim: int, """ mean_actions = nn.Linear(latent_dim, self.action_dim) # TODO: allow action dependent std - log_std = nn.Parameter(th.ones(self.action_dim) - * log_std_init, requires_grad=True) + log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True) return mean_actions, log_std def proba_distribution(self, mean_actions: th.Tensor, @@ -200,8 +200,7 @@ def __init__(self, action_dim: int, epsilon: float = 1e-6): def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> 'SquashedDiagGaussianDistribution': - super(SquashedDiagGaussianDistribution, - self).proba_distribution(mean_actions, log_std) + super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) return self def mode(self) -> th.Tensor: @@ -235,8 +234,7 @@ def log_prob(self, actions: th.Tensor, gaussian_actions = TanhBijector.inverse(actions) # Log likelihood for a Gaussian distribution - log_prob = super(SquashedDiagGaussianDistribution, - self).log_prob(gaussian_actions) + log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1) @@ -318,19 +316,18 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: of the policy network (before the action layer) :return: (nn.Linear) """ - action_logits = nn.Linear(latent_dim, np.sum(self.action_dims)) + action_logits = nn.Linear(latent_dim, sum(self.action_dims)) return action_logits def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution': - reshaped_logits = th.split(action_logits, tuple(self.action_dims), dim=1) - self.distributions = [Categorical(logits=l) for l in reshaped_logits] + self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] return self def mode(self) -> th.Tensor: - return th.stack([th.argmax(d.probs, dim=1) for d in self.distributions]) + return th.stack([th.argmax(d.probs, dim=1) for d in self.distributions], dim=1) def sample(self) -> th.Tensor: - return th.stack([d.sample() for d in self.distributions]) + return th.stack([d.sample() for d in self.distributions], dim=1) def entropy(self) -> th.Tensor: return th.stack([d.entropy() for d in self.distributions], dim=1).sum(dim=1) @@ -349,7 +346,7 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th. def log_prob(self, actions: th.Tensor) -> th.Tensor: return th.stack([d.log_prob(x) for d, x in zip(self.distributions, - th.unbind(actions))], dim=1).sum(dim=1) + th.unbind(actions, dim=1))], dim=1).sum(dim=1) class BernoulliDistribution(Distribution): @@ -387,7 +384,7 @@ def sample(self) -> th.Tensor: return self.distribution.sample() def entropy(self) -> th.Tensor: - return self.distribution.entropy() + return self.distribution.entropy().sum(dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -401,7 +398,7 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th. return actions, log_prob def log_prob(self, actions: th.Tensor) -> th.Tensor: - return self.distribution.log_prob(actions) + return self.distribution.log_prob(actions).sum(dim=1) class StateDependentNoiseDistribution(Distribution): @@ -508,8 +505,7 @@ def proba_distribution_net(self, latent_dim: int, log_std_init: float = -2.0, # can be different between the policy and the noise network self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim # Reduce the number of parameters if needed - log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones( - self.latent_sde_dim, 1) + log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1) # Transform it to a parameter so it can be optimized log_std = nn.Parameter(log_std * log_std_init, requires_grad=True) # Sample an exploration matrix @@ -530,8 +526,7 @@ def proba_distribution(self, mean_actions: th.Tensor, # Stop gradient if we don't want to influence the features self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() variance = th.mm(latent_sde ** 2, self.get_std(log_std) ** 2) - self.distribution = Normal( - mean_actions, th.sqrt(variance + self.epsilon)) + self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon)) return self def mode(self) -> th.Tensor: @@ -657,8 +652,7 @@ def make_proba_distribution(action_space: gym.spaces.Space, dist_kwargs = {} if isinstance(action_space, spaces.Box): - assert len( - action_space.shape) == 1, "Error: the action space must be a vector" + assert len(action_space.shape) == 1, "Error: the action space must be a vector" if use_sde: return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs) return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 1394eda87..28e182e00 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -43,8 +43,7 @@ class FlattenExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenExtractor, self).__init__( - observation_space, get_flattened_obs_dim(observation_space)) + super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) self.flatten = nn.Flatten() def forward(self, observations: th.Tensor) -> th.Tensor: @@ -71,21 +70,17 @@ def __init__(self, observation_space: gym.spaces.Box, n_input_channels = observation_space.shape[0] self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=4, - stride=2, padding=0), + nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), nn.ReLU(), - nn.Conv2d(64, 32, kernel_size=3, - stride=1, padding=0), + nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), nn.ReLU(), nn.Flatten()) # Compute shape by doing one forward pass with th.no_grad(): - n_flatten = self.cnn(th.as_tensor( - observation_space.sample()[None]).float()).shape[1] + n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1] - self.linear = nn.Sequential( - nn.Linear(n_flatten, features_dim), nn.ReLU()) + self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) def forward(self, observations: th.Tensor) -> th.Tensor: return self.linear(self.cnn(observations)) @@ -153,8 +148,7 @@ def extract_features(self, obs: th.Tensor) -> th.Tensor: :return: (th.Tensor) """ assert self.features_extractor is not None, 'No feature extractor was set' - preprocessed_obs = preprocess_obs( - obs, self.observation_space, normalize_images=self.normalize_images) + preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return self.features_extractor(preprocessed_obs) @property @@ -212,8 +206,8 @@ def predict(self, observation: np.ndarray, # Handle the different cases for images # as PyTorch use channel first format if is_image_space(self.observation_space): - if (observation.shape == self.observation_space.shape or - observation.shape[1:] == self.observation_space.shape): + if (observation.shape == self.observation_space.shape + or observation.shape[1:] == self.observation_space.shape): pass else: # Try to re-order the channels @@ -222,8 +216,7 @@ def predict(self, observation: np.ndarray, or transpose_obs.shape[1:] == self.observation_space.shape): observation = transpose_obs - vectorized_env = self._is_vectorized_observation( - observation, self.observation_space) + vectorized_env = self._is_vectorized_observation(observation, self.observation_space) observation = observation.reshape((-1,) + self.observation_space.shape) @@ -240,13 +233,11 @@ def predict(self, observation: np.ndarray, clipped_actions = actions # Clip the actions to avoid out of bound error when using gaussian distribution if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output: - clipped_actions = np.clip( - actions, self.action_space.low, self.action_space.high) + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) if not vectorized_env: if state is not None: - raise ValueError( - "Error: The environment must be vectorized when using recurrent policies.") + raise ValueError("Error: The environment must be vectorized when using recurrent policies.") clipped_actions = clipped_actions[0] return clipped_actions, state @@ -288,9 +279,9 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif observation.shape[1:] == observation_space.shape: return True else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + - "Box environment, please use {} ".format(observation_space.shape) + - "or (n_env, {}) for the observation shape." + raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + + "Box environment, please use {} ".format(observation_space.shape) + + "or (n_env, {}) for the observation shape." .format(", ".join(map(str, observation_space.shape)))) elif isinstance(observation_space, gym.spaces.Discrete): if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' @@ -298,8 +289,8 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif len(observation.shape) == 1: return True else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + - "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") + raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") elif isinstance(observation_space, gym.spaces.MultiDiscrete): if observation.shape == (len(observation_space.nvec),): @@ -307,18 +298,18 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): return True else: - raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) + - "environment, please use ({},) or ".format(len(observation_space.nvec)) + - "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) + raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) + + "environment, please use ({},) or ".format(len(observation_space.nvec)) + + "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) elif isinstance(observation_space, gym.spaces.MultiBinary): if observation.shape == (observation_space.n,): return False elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: return True else: - raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) + - "environment, please use ({},) or ".format(observation_space.n) + - "(n_env, {}) for the observation shape.".format(observation_space.n)) + raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) + + "environment, please use ({},) or ".format(observation_space.n) + + "(n_env, {}) for the observation shape.".format(observation_space.n)) else: raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}." .format(observation_space)) @@ -345,8 +336,7 @@ def save(self, path: str) -> None: :param path: (str) """ - th.save({'state_dict': self.state_dict(), - 'data': self._get_data()}, path) + th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) @classmethod def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy': @@ -372,8 +362,7 @@ def load_from_vector(self, vector: np.ndarray): :param vector: (np.ndarray) """ - th.nn.utils.vector_to_parameters(th.FloatTensor( - vector).to(self.device), self.parameters()) + th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) def parameters_to_vector(self) -> np.ndarray: """ @@ -437,16 +426,13 @@ def create_sde_features_extractor(features_dim: int, # Special case: when using states as features (i.e. sde_net_arch is an empty list) # don't use any activation function sde_activation = activation_fn if len(sde_net_arch) > 0 else None - latent_sde_net = create_mlp( - features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False) - latent_sde_dim = sde_net_arch[-1] if len( - sde_net_arch) > 0 else features_dim + latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False) + latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim sde_features_extractor = nn.Sequential(*latent_sde_net) return sde_features_extractor, latent_sde_dim -# type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] -_policy_registry = dict() +_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]: @@ -458,8 +444,7 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ :return: (Type[BasePolicy]) the policy """ if base_policy_type not in _policy_registry: - raise ValueError( - f"Error: the policy type {base_policy_type} is not registered!") + raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: raise ValueError(f"Error: unknown policy type {name}," "the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") @@ -484,14 +469,12 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None: except AttributeError: sub_class = str(th.random.randint(100)) if sub_class is None: - raise ValueError( - f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") + raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") if sub_class not in _policy_registry: _policy_registry[sub_class] = {} if name in _policy_registry[sub_class]: - raise ValueError( - f"Error: the name {name} is alreay registered for a different policy, will not override.") + raise ValueError(f"Error: the name {name} is alreay registered for a different policy, will not override.") _policy_registry[sub_class][name] = policy @@ -528,12 +511,9 @@ def __init__(self, feature_dim: int, device: Union[th.device, str] = 'auto'): super(MlpExtractor, self).__init__() device = get_device(device) - shared_net, policy_net, value_net = [], [], [] - # Layer sizes of the network that only belongs to the policy network - policy_only_layers = [] - # Layer sizes of the network that only belongs to the value network - value_only_layers = [] + policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network + value_only_layers = [] # Layer sizes of the network that only belongs to the value network last_layer_dim_shared = feature_dim # Iterate through the shared layers and build the shared parts of the network @@ -545,16 +525,13 @@ def __init__(self, feature_dim: int, shared_net.append(activation_fn()) last_layer_dim_shared = layer_size else: - assert isinstance( - layer, dict), "Error: the net_arch list can only contain ints and dicts" + assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" if 'pi' in layer: - assert isinstance( - layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers." + assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers." policy_only_layers = layer['pi'] if 'vf' in layer: - assert isinstance( - layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers." + assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers." value_only_layers = layer['vf'] break # From here on the network splits up in policy and value network @@ -564,15 +541,13 @@ def __init__(self, feature_dim: int, # Build the non-shared part of the network for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)): if pi_layer_size is not None: - assert isinstance( - pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers." + assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers." policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size)) policy_net.append(activation_fn()) last_layer_dim_pi = pi_layer_size if vf_layer_size is not None: - assert isinstance( - vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers." + assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers." value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size)) value_net.append(activation_fn()) last_layer_dim_vf = vf_layer_size diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index b340dcd21..7f71c2cc4 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -67,8 +67,9 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, return F.one_hot(obs.long(), num_classes=observation_space.n).float() elif isinstance(observation_space, spaces.MultiDiscrete): # Tensor concatination of one hot encodings of each Categorical sub-space - return th.cat([F.one_hot(o.long(), num_classes=n).float() - for o, n in zip(obs, observation_space.nvec)], dim=1).t() + x = th.cat([F.one_hot(o.long(), num_classes=n).float() + for o, n in zip(obs[0], observation_space.nvec)]) + return x.view(1, -1) elif isinstance(observation_space, spaces.MultiBinary): return obs.float() else: @@ -85,7 +86,7 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: if isinstance(observation_space, spaces.Box): return observation_space.shape elif isinstance(observation_space, spaces.Discrete): - # One observation + # Observation is an int return 1, elif isinstance(observation_space, spaces.MultiDiscrete): # Observation is the number of discrete spaces @@ -94,6 +95,7 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: # Observation is the number of binary spaces return int(observation_space.n), else: + # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict raise NotImplementedError() @@ -105,8 +107,11 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: :param observation_space: (spaces.Space) :return: (int) """ - # Use Gym internal method - return spaces.utils.flatdim(observation_space) + if isinstance(observation_space, spaces.MultiDiscrete): + return (sum(observation_space.nvec)) + else: + # Use Gym internal method + return spaces.utils.flatdim(observation_space) def get_action_dim(action_space: spaces.Space) -> int: @@ -119,7 +124,7 @@ def get_action_dim(action_space: spaces.Space) -> int: if isinstance(action_space, spaces.Box): return int(np.prod(action_space.shape)) elif isinstance(action_space, spaces.Discrete): - # One action + # Action is an int return 1 elif isinstance(action_space, spaces.MultiDiscrete): # Action is the number of discrete spaces diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 617bac3cb..203ecaa75 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -11,8 +11,8 @@ BaseFeaturesExtractor, FlattenExtractor) from stable_baselines3.common.distributions import (make_proba_distribution, Distribution, DiagGaussianDistribution, CategoricalDistribution, - StateDependentNoiseDistribution, MultiCategoricalDistribution, - BernoulliDistribution) + MultiCategoricalDistribution, BernoulliDistribution, + StateDependentNoiseDistribution) class PPOPolicy(BasePolicy): @@ -231,6 +231,7 @@ def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: # Preprocess the observation if needed features = self.extract_features(obs) latent_pi, latent_vf = self.mlp_extractor(features) + # Features for sde latent_sde = latent_pi if self.sde_features_extractor is not None: @@ -250,7 +251,6 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor, if isinstance(self.action_dist, DiagGaussianDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std) - elif isinstance(self.action_dist, CategoricalDistribution): # Here mean_actions are the logits before the softmax return self.action_dist.proba_distribution(action_logits=mean_actions) @@ -260,7 +260,6 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor, elif isinstance(self.action_dist, BernoulliDistribution): # Here mean_actions are the logits before the softmax return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) else: diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 3e81fdb77..b5c6a4fd5 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -157,7 +157,6 @@ def collect_rollouts(self, callback.on_rollout_start() while n_steps < n_rollout_steps: - if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix self.policy.reset_noise(env.num_envs) @@ -213,7 +212,6 @@ def train(self, n_epochs: int, batch_size: int = 64) -> None: approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(batch_size): - actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long diff --git a/tests/test_distributions.py b/tests/test_distributions.py index f3ea056a9..4acd65377 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -3,11 +3,9 @@ from stable_baselines3 import A2C, PPO from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector, - SquashedDiagGaussianDistribution, - CategoricalDistribution, - MultiCategoricalDistribution, - BernoulliDistribution, - StateDependentNoiseDistribution) + StateDependentNoiseDistribution, + CategoricalDistribution, SquashedDiagGaussianDistribution, + MultiCategoricalDistribution, BernoulliDistribution) from stable_baselines3.common.utils import set_random_seed @@ -35,8 +33,7 @@ def test_squashed_gaussian(model_class): """ Test run with squashed Gaussian (notably entropy computation) """ - model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, - n_steps=100, policy_kwargs=dict(squash_output=True)) + model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True)) model.learn(500) gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS) @@ -51,8 +48,7 @@ def test_sde_distribution(): n_actions = 1 deterministic_actions = th.ones(N_SAMPLES, n_actions) * 0.1 state = th.ones(N_SAMPLES, N_FEATURES) * 0.3 - dist = StateDependentNoiseDistribution( - n_actions, full_std=True, squash_output=False) + dist = StateDependentNoiseDistribution(n_actions, full_std=True, squash_output=False) set_random_seed(1) _, log_std = dist.proba_distribution_net(N_FEATURES) @@ -61,10 +57,8 @@ def test_sde_distribution(): dist = dist.proba_distribution(deterministic_actions, log_std, state) actions = dist.get_actions() - assert th.allclose( - actions.mean(), dist.distribution.mean.mean(), rtol=2e-3) - assert th.allclose( - actions.std(), dist.distribution.scale.mean(), rtol=2e-3) + assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=2e-3) + assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=2e-3) # TODO: analytical form for squashed Gaussian? @@ -78,8 +72,7 @@ def test_entropy(dist): set_random_seed(1) state = th.rand(N_SAMPLES, N_FEATURES) deterministic_actions = th.rand(N_SAMPLES, N_ACTIONS) - _, log_std = dist.proba_distribution_net( - N_FEATURES, log_std_init=th.log(th.tensor(0.2))) + _, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2))) if isinstance(dist, DiagGaussianDistribution): dist = dist.proba_distribution(deterministic_actions, log_std) @@ -94,15 +87,15 @@ def test_entropy(dist): categorical_param = [ - (CategoricalDistribution(2), 2), - (MultiCategoricalDistribution([4, 3, 2]), sum([4, 3, 2])) + (CategoricalDistribution(N_ACTIONS), N_ACTIONS), + (MultiCategoricalDistribution([2, 3]), sum([2, 3])) ] -@pytest.mark.parametrize("dist, N_ACTIONS", categorical_param) -def test_categorical(dist, N_ACTIONS): +@pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_param) +def test_categorical(dist, CAT_ACTIONS): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == entropy set_random_seed(1) - action_logits = th.rand(N_SAMPLES, N_ACTIONS) + action_logits = th.rand(N_SAMPLES, CAT_ACTIONS) dist = dist.proba_distribution(action_logits) actions = dist.get_actions() entropy = dist.entropy() @@ -121,4 +114,4 @@ def test_bernoulli(): actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) - assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=2e-4) + assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) diff --git a/tests/test_spaces.py b/tests/test_spaces.py new file mode 100644 index 000000000..f64e30f82 --- /dev/null +++ b/tests/test_spaces.py @@ -0,0 +1,55 @@ +import numpy as np +import pytest + +from stable_baselines3 import A2C, PPO +from stable_baselines3.common.identity_env import IdentityEnvMultiBinary, IdentityEnvMultiDiscrete +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.noise import NormalActionNoise +from stable_baselines3.common.vec_env import DummyVecEnv + + +MODEL_LIST = [A2C, PPO] + + +@pytest.mark.slow +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_identity_multidiscrete(model_class): + """ + Test if the algorithm (with a given policy) + can learn an identity transformation (i.e. return observation as an action) + with a multidiscrete action space + :param model_class: (BaseRLModel) A RL Model + """ + env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(2)]) + + model = model_class("MlpPolicy", env, gamma=0.5, seed=0) + model.learn(total_timesteps=1000) + evaluate_policy(model, env, n_eval_episodes=5) + obs = env.reset() + + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=70) + + assert np.array(model.predict(obs)).shape == (2,), \ + "Error: predict not returning correct shape" + + +@pytest.mark.slow +@pytest.mark.parametrize("model_class", MODEL_LIST) +def test_identity_multibinary(model_class): + """ + Test if the algorithm (with a given policy) + can learn an identity transformation (i.e. return observation as an action) + with a multibinary action space + :param model_class: (BaseRLModel) A RL Model + """ + env = DummyVecEnv([lambda: IdentityEnvMultiBinary(2)]) + + model = model_class("MlpPolicy", env, gamma=0.7, seed=0) + model.learn(total_timesteps=1000) + evaluate_policy(model, env, n_eval_episodes=5) + obs = env.reset() + + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=49) + + assert np.array(model.predict(obs)).shape == (2,), \ + "Error: predict not returning correct shape" From 6c8dfa5581b78679c07b1ef65f41d3c4018767b8 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Wed, 13 May 2020 21:37:11 +0100 Subject: [PATCH 09/22] clean up --- stable_baselines3/common/distributions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index c7efdec97..185a2a3c6 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,5 +1,4 @@ from typing import Optional, Tuple, Dict, Any, List -import time from functools import partial import gym import torch as th @@ -90,7 +89,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: :return: (th.Tensor) shape: (n_batch,) """ if len(tensor.shape) > 1: - tensor = tensor.sum(axis=1) + tensor = tensor.sum(dim=1) else: tensor = tensor.sum() return tensor From 4f9b455b3467e1038c8beed366ec95f4e1e188d5 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Wed, 13 May 2020 21:45:19 +0100 Subject: [PATCH 10/22] modified changelog --- docs/misc/changelog.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 32fcaa0ef..9e2ec30a1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,10 +3,10 @@ Changelog ========== + Pre-Release 0.6.0a7 (WIP) ------------------------------ - Breaking Changes: ^^^^^^^^^^^^^^^^^ - Remove State-Dependent Exploration (SDE) support for ``TD3`` @@ -16,7 +16,10 @@ New Features: - Added env checker (Sync with Stable Baselines) - Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines) - Added determinism tests -- Added ``cmd_utils`` and ``atari_wrappers`` +- Added ``cmd_utils`` and ``atari_wrappers`` +- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation / action spaces +- Added ``MultiCategorical`` and ``Bernoulli`` distributions +- Added ``test_bernoulli``, modified ``test_categorical`` and created ``test_spaces.py`` Bug Fixes: ^^^^^^^^^^ From a20e9a758861d27f2230ecef4df43624781b0e46 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Wed, 13 May 2020 21:57:45 +0100 Subject: [PATCH 11/22] additional fixes --- stable_baselines3/common/preprocessing.py | 2 +- tests/test_spaces.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 7f71c2cc4..71342e632 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -67,7 +67,7 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, return F.one_hot(obs.long(), num_classes=observation_space.n).float() elif isinstance(observation_space, spaces.MultiDiscrete): # Tensor concatination of one hot encodings of each Categorical sub-space - x = th.cat([F.one_hot(o.long(), num_classes=n).float() + x = th.cat([F.one_hot(o.long(), num_classes=int(n)).float() for o, n in zip(obs[0], observation_space.nvec)]) return x.view(1, -1) elif isinstance(observation_space, spaces.MultiBinary): diff --git a/tests/test_spaces.py b/tests/test_spaces.py index f64e30f82..dbde20901 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -11,7 +11,6 @@ MODEL_LIST = [A2C, PPO] -@pytest.mark.slow @pytest.mark.parametrize("model_class", MODEL_LIST) def test_identity_multidiscrete(model_class): """ @@ -33,7 +32,6 @@ def test_identity_multidiscrete(model_class): "Error: predict not returning correct shape" -@pytest.mark.slow @pytest.mark.parametrize("model_class", MODEL_LIST) def test_identity_multibinary(model_class): """ From 65dd1477accc886d98fd2045f79d7bb2df9f4992 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Wed, 13 May 2020 23:34:07 +0100 Subject: [PATCH 12/22] minor changelog mod --- docs/misc/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9e2ec30a1..7a21f9fe8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,7 +17,7 @@ New Features: - Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines) - Added determinism tests - Added ``cmd_utils`` and ``atari_wrappers`` -- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation / action spaces +- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation / action spaces for PPO and A2C - Added ``MultiCategorical`` and ``Bernoulli`` distributions - Added ``test_bernoulli``, modified ``test_categorical`` and created ``test_spaces.py`` From ca6824d197ac55c9a6b304acfe6bb2794cfbf541 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Thu, 14 May 2020 12:57:15 +0100 Subject: [PATCH 13/22] hot encoding fix, flake8 clean up --- stable_baselines3/common/buffers.py | 3 ++- stable_baselines3/common/distributions.py | 5 ++--- stable_baselines3/common/preprocessing.py | 14 +++++++++----- stable_baselines3/ppo/ppo.py | 4 ++++ tests/test_spaces.py | 12 ++++++------ 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index a21d63d1e..e88110f3d 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -20,6 +20,7 @@ class BaseBuffer(object): to which the values will be converted :param n_envs: (int) Number of parallel environments """ + def __init__(self, buffer_size: int, observation_space: spaces.Space, @@ -303,7 +304,7 @@ def add(self, if len(log_prob.shape) == 0: # Reshape 0-d tensor to avoid error log_prob = log_prob.reshape(-1, 1) - + self.observations[self.pos] = np.array(obs).copy() self.actions[self.pos] = np.array(action).copy() self.rewards[self.pos] = np.array(reward).copy() diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 185a2a3c6..63737898d 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,11 +1,9 @@ from typing import Optional, Tuple, Dict, Any, List -from functools import partial import gym import torch as th import torch.nn as nn from torch.distributions import Normal, Categorical, Bernoulli from gym import spaces -import numpy as np from stable_baselines3.common.preprocessing import get_action_dim @@ -315,6 +313,7 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: of the policy network (before the action layer) :return: (nn.Linear) """ + action_logits = nn.Linear(latent_dim, sum(self.action_dims)) return action_logits @@ -332,7 +331,7 @@ def entropy(self) -> th.Tensor: return th.stack([d.entropy() for d in self.distributions], dim=1).sum(dim=1) def actions_from_params(self, action_logits: th.Tensor, - deterministic: bool = True) -> th.Tensor: + deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 71342e632..5fac60c35 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -1,9 +1,9 @@ from typing import Tuple -import numpy as np import torch as th import torch.nn.functional as F from gym import spaces +import numpy as np def is_image_space(observation_space: spaces.Space, @@ -62,16 +62,20 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, if is_image_space(observation_space) and normalize_images: return obs.float() / 255.0 return obs.float() + elif isinstance(observation_space, spaces.Discrete): # One hot encoding and convert to float to avoid errors return F.one_hot(obs.long(), num_classes=observation_space.n).float() + elif isinstance(observation_space, spaces.MultiDiscrete): # Tensor concatination of one hot encodings of each Categorical sub-space - x = th.cat([F.one_hot(o.long(), num_classes=int(n)).float() - for o, n in zip(obs[0], observation_space.nvec)]) - return x.view(1, -1) + return th.cat([F.one_hot(o.long(), num_classes=observation_space.nvec[i]).float() + for i, o in enumerate(th.split(obs.to(th.int64), 1, dim=1))], + dim=-1).view(obs.shape[0], sum(observation_space.nvec)) + elif isinstance(observation_space, spaces.MultiBinary): return obs.float() + else: raise NotImplementedError() @@ -108,7 +112,7 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: :return: (int) """ if isinstance(observation_space, spaces.MultiDiscrete): - return (sum(observation_space.nvec)) + return sum(observation_space.nvec) else: # Use Gym internal method return spaces.utils.flatdim(observation_space) diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index b5c6a4fd5..976f0fd48 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -250,6 +250,10 @@ def train(self, n_epochs: int, batch_size: int = 64) -> None: values_pred = rollout_data.old_values + th.clamp(values - rollout_data.old_values, -clip_range_vf, clip_range_vf) # Value loss using the TD(gae_lambda) target + # print('old', rollout_data.old_values) + # print('roll', rollout_data.returns) + # print('values', values) + value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index dbde20901..3ee42e301 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -19,17 +19,17 @@ def test_identity_multidiscrete(model_class): with a multidiscrete action space :param model_class: (BaseRLModel) A RL Model """ - env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(2)]) + env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(3)]) model = model_class("MlpPolicy", env, gamma=0.5, seed=0) model.learn(total_timesteps=1000) evaluate_policy(model, env, n_eval_episodes=5) obs = env.reset() - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=70) + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=80) - assert np.array(model.predict(obs)).shape == (2,), \ - "Error: predict not returning correct shape" + assert np.shape(model.predict(obs)[0]) == np.shape(obs) + "Error: predict not returning the same shape as observations" @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -49,5 +49,5 @@ def test_identity_multibinary(model_class): evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=49) - assert np.array(model.predict(obs)).shape == (2,), \ - "Error: predict not returning correct shape" + assert np.shape(model.predict(obs)[0]) == np.shape(obs) + "Error: predict not returning the same shape as observations" From 43ae65a470049240e914e490dfed51769e49eaff Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Thu, 14 May 2020 13:06:38 +0100 Subject: [PATCH 14/22] lint tests --- tests/test_distributions.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 4acd65377..f88ef761a 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -3,9 +3,9 @@ from stable_baselines3 import A2C, PPO from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector, - StateDependentNoiseDistribution, - CategoricalDistribution, SquashedDiagGaussianDistribution, - MultiCategoricalDistribution, BernoulliDistribution) + StateDependentNoiseDistribution, + CategoricalDistribution, SquashedDiagGaussianDistribution, + MultiCategoricalDistribution, BernoulliDistribution) from stable_baselines3.common.utils import set_random_seed @@ -90,6 +90,8 @@ def test_entropy(dist): (CategoricalDistribution(N_ACTIONS), N_ACTIONS), (MultiCategoricalDistribution([2, 3]), sum([2, 3])) ] + + @pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_param) def test_categorical(dist, CAT_ACTIONS): # The entropy can be approximated by averaging the negative log likelihood From 3c39b72fb4d03f057605b5f460942d6059832623 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Thu, 14 May 2020 13:24:30 +0100 Subject: [PATCH 15/22] preprocessing fix --- stable_baselines3/common/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 5fac60c35..03281475c 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -69,7 +69,7 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, elif isinstance(observation_space, spaces.MultiDiscrete): # Tensor concatination of one hot encodings of each Categorical sub-space - return th.cat([F.one_hot(o.long(), num_classes=observation_space.nvec[i]).float() + return th.cat([F.one_hot(o.long(), num_classes=int(observation_space.nvec[i])).float() for i, o in enumerate(th.split(obs.to(th.int64), 1, dim=1))], dim=-1).view(obs.shape[0], sum(observation_space.nvec)) From 2e0b2b1e44c3aaeafa80218ec6c6c6a19ffb6a8a Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Thu, 14 May 2020 14:58:13 +0100 Subject: [PATCH 16/22] fixed bernoulli bug --- stable_baselines3/common/distributions.py | 2 +- tests/test_spaces.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 63737898d..81dd26f0e 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -376,7 +376,7 @@ def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution return self def mode(self) -> th.Tensor: - return th.argmax(self.distribution.probs, dim=1) + return th.round(self.distribution.probs) def sample(self) -> th.Tensor: return self.distribution.sample() diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 3ee42e301..1455d2d8f 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -4,11 +4,11 @@ from stable_baselines3 import A2C, PPO from stable_baselines3.common.identity_env import IdentityEnvMultiBinary, IdentityEnvMultiDiscrete from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import DummyVecEnv MODEL_LIST = [A2C, PPO] +DIM = 3 @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -19,7 +19,7 @@ def test_identity_multidiscrete(model_class): with a multidiscrete action space :param model_class: (BaseRLModel) A RL Model """ - env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(3)]) + env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(DIM)]) model = model_class("MlpPolicy", env, gamma=0.5, seed=0) model.learn(total_timesteps=1000) @@ -29,7 +29,6 @@ def test_identity_multidiscrete(model_class): evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=80) assert np.shape(model.predict(obs)[0]) == np.shape(obs) - "Error: predict not returning the same shape as observations" @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -40,14 +39,13 @@ def test_identity_multibinary(model_class): with a multibinary action space :param model_class: (BaseRLModel) A RL Model """ - env = DummyVecEnv([lambda: IdentityEnvMultiBinary(2)]) + env = DummyVecEnv([lambda: IdentityEnvMultiBinary(DIM)]) - model = model_class("MlpPolicy", env, gamma=0.7, seed=0) + model = model_class("MlpPolicy", env, gamma=0.5, seed=0) model.learn(total_timesteps=1000) evaluate_policy(model, env, n_eval_episodes=5) obs = env.reset() - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=49) + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=80) assert np.shape(model.predict(obs)[0]) == np.shape(obs) - "Error: predict not returning the same shape as observations" From ef4874bb54a8524ec9abe23db0a9b0142e1dcbe9 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Thu, 14 May 2020 16:41:22 +0100 Subject: [PATCH 17/22] removed commented prints --- stable_baselines3/ppo/ppo.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 976f0fd48..b5c6a4fd5 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -250,10 +250,6 @@ def train(self, n_epochs: int, batch_size: int = 64) -> None: values_pred = rollout_data.old_values + th.clamp(values - rollout_data.old_values, -clip_range_vf, clip_range_vf) # Value loss using the TD(gae_lambda) target - # print('old', rollout_data.old_values) - # print('roll', rollout_data.returns) - # print('values', values) - value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) From 3c506bb07bce41033cb65b7c7ef6da38998a6355 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 15 May 2020 15:36:26 +0200 Subject: [PATCH 18/22] Update changelog.rst --- docs/misc/changelog.rst | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7d6787487..8d7290da4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,6 @@ Changelog ========== - Pre-Release 0.6.0a8 (WIP) ------------------------------ @@ -17,9 +16,9 @@ New Features: - Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines) - Added determinism tests - Added ``cmd_utils`` and ``atari_wrappers`` -- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation / action spaces for PPO and A2C -- Added ``MultiCategorical`` and ``Bernoulli`` distributions -- Added ``test_bernoulli``, modified ``test_categorical`` and created ``test_spaces.py`` +- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation / action spaces for PPO and A2C (@rolandgvc) +- Added ``MultiCategorical`` and ``Bernoulli`` distributions (@rolandgvc) +- Added ``test_bernoulli``, modified ``test_categorical`` and created ``test_spaces.py`` (@rolandgvc) Bug Fixes: ^^^^^^^^^^ @@ -230,4 +229,4 @@ And all the contributors: @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching -@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta +@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc From 72cd10e8a0aab8a43dc769b9a08b6622dc331808 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Fri, 15 May 2020 15:50:57 +0100 Subject: [PATCH 19/22] included suggested modifications --- stable_baselines3/common/distributions.py | 14 +++++++------- stable_baselines3/common/preprocessing.py | 15 +++++++-------- stable_baselines3/ppo/policies.py | 4 ++-- tests/test_distributions.py | 23 +++++------------------ tests/test_spaces.py | 14 ++++++-------- 5 files changed, 27 insertions(+), 43 deletions(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 81dd26f0e..bb42a729f 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -295,7 +295,7 @@ class MultiCategoricalDistribution(Distribution): """ MultiCategorical distribution for multi discrete actions. - :param action_dims: ([int]) List of sizes of discrete action spaces + :param action_dims: (List[int]) List of sizes of discrete action spaces """ def __init__(self, action_dims: List[int]): @@ -306,7 +306,7 @@ def __init__(self, action_dims: List[int]): def proba_distribution_net(self, latent_dim: int) -> nn.Module: """ Create the layer that represents the distribution: - it will be the logits (flattend) of the MultiCategorical distribution. + it will be the logits (flattened) of the MultiCategorical distribution. You can then get probabilities using a softmax on each sub-space. :param latent_dim: (int) Dimension of the last layer @@ -322,13 +322,13 @@ def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistr return self def mode(self) -> th.Tensor: - return th.stack([th.argmax(d.probs, dim=1) for d in self.distributions], dim=1) + return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) def sample(self) -> th.Tensor: - return th.stack([d.sample() for d in self.distributions], dim=1) + return th.stack([dist.sample() for dist in self.distributions], dim=1) def entropy(self) -> th.Tensor: - return th.stack([d.entropy() for d in self.distributions], dim=1).sum(dim=1) + return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -338,12 +338,12 @@ def actions_from_params(self, action_logits: th.Tensor, def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) - log_prob = self.log_prob(actions) return actions, log_prob def log_prob(self, actions: th.Tensor) -> th.Tensor: - return th.stack([d.log_prob(x) for d, x in zip(self.distributions, + # Extract each discrete action and compute log prob for their respective distributions + return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1).sum(dim=1) diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 03281475c..f3caa94e3 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -68,9 +68,9 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, return F.one_hot(obs.long(), num_classes=observation_space.n).float() elif isinstance(observation_space, spaces.MultiDiscrete): - # Tensor concatination of one hot encodings of each Categorical sub-space - return th.cat([F.one_hot(o.long(), num_classes=int(observation_space.nvec[i])).float() - for i, o in enumerate(th.split(obs.to(th.int64), 1, dim=1))], + # Tensor concatenation of one hot encodings of each Categorical sub-space + return th.cat([F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() + for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))], dim=-1).view(obs.shape[0], sum(observation_space.nvec)) elif isinstance(observation_space, spaces.MultiBinary): @@ -93,13 +93,12 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: # Observation is an int return 1, elif isinstance(observation_space, spaces.MultiDiscrete): - # Observation is the number of discrete spaces + # Number of discrete features return int(len(observation_space.nvec)), elif isinstance(observation_space, spaces.MultiBinary): - # Observation is the number of binary spaces + # Number of binary features return int(observation_space.n), else: - # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict raise NotImplementedError() @@ -131,10 +130,10 @@ def get_action_dim(action_space: spaces.Space) -> int: # Action is an int return 1 elif isinstance(action_space, spaces.MultiDiscrete): - # Action is the number of discrete spaces + # Number of discrete actions return int(len(action_space.nvec)) elif isinstance(action_space, spaces.MultiBinary): - # Action is the number of binary spaces + # Number of binary actions return int(action_space.n) else: raise NotImplementedError() diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 203ecaa75..8d994b820 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -255,10 +255,10 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor, # Here mean_actions are the logits before the softmax return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, MultiCategoricalDistribution): - # Here mean_actions are the logits before the softmax + # Here mean_actions are the flattened logits return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, BernoulliDistribution): - # Here mean_actions are the logits before the softmax + # Here mean_actions are the logits (before rounding to get the binary actions) return self.action_dist.proba_distribution(action_logits=mean_actions) elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index f88ef761a..18d584506 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -86,13 +86,14 @@ def test_entropy(dist): assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) -categorical_param = [ +categorical_params = [ (CategoricalDistribution(N_ACTIONS), N_ACTIONS), - (MultiCategoricalDistribution([2, 3]), sum([2, 3])) + (MultiCategoricalDistribution([2, 3]), sum([2, 3])), + (BernoulliDistribution(N_ACTIONS), N_ACTIONS) ] -@pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_param) +@pytest.mark.parametrize("dist, CAT_ACTIONS", categorical_params) def test_categorical(dist, CAT_ACTIONS): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == entropy @@ -102,18 +103,4 @@ def test_categorical(dist, CAT_ACTIONS): actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) - assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=2e-4) - - -def test_bernoulli(): - # The entropy can be approximated by averaging the negative log likelihood - # mean negative log likelihood == entropy - dist = BernoulliDistribution(N_ACTIONS) - set_random_seed(1) - action_logits = th.rand(N_SAMPLES, N_ACTIONS) - dist = dist.proba_distribution(action_logits) - - actions = dist.get_actions() - entropy = dist.entropy() - log_prob = dist.log_prob(actions) - assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) + assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) \ No newline at end of file diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 1455d2d8f..0b052744c 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -21,12 +21,11 @@ def test_identity_multidiscrete(model_class): """ env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(DIM)]) - model = model_class("MlpPolicy", env, gamma=0.5, seed=0) - model.learn(total_timesteps=1000) - evaluate_policy(model, env, n_eval_episodes=5) + model = model_class("MlpPolicy", env, gamma=0.5, seed=1) + model.learn(total_timesteps=3000) obs = env.reset() - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=80) + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) assert np.shape(model.predict(obs)[0]) == np.shape(obs) @@ -41,11 +40,10 @@ def test_identity_multibinary(model_class): """ env = DummyVecEnv([lambda: IdentityEnvMultiBinary(DIM)]) - model = model_class("MlpPolicy", env, gamma=0.5, seed=0) - model.learn(total_timesteps=1000) - evaluate_policy(model, env, n_eval_episodes=5) + model = model_class("MlpPolicy", env, gamma=0.5, seed=1) + model.learn(total_timesteps=3000) obs = env.reset() - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=80) + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) assert np.shape(model.predict(obs)[0]) == np.shape(obs) From 1669c8c196cc4e59f8e2d816293cede863122278 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Sun, 17 May 2020 13:43:07 +0100 Subject: [PATCH 20/22] linting fix --- stable_baselines3/common/distributions.py | 2 +- tests/test_distributions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index b7f25cb4f..f9bb16c9c 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -344,7 +344,7 @@ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th. def log_prob(self, actions: th.Tensor) -> th.Tensor: # Extract each discrete action and compute log prob for their respective distributions return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, - th.unbind(actions, dim=1))], dim=1).sum(dim=1) + th.unbind(actions, dim=1))], dim=1).sum(dim=1) class BernoulliDistribution(Distribution): diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 18d584506..0461e17f4 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -103,4 +103,4 @@ def test_categorical(dist, CAT_ACTIONS): actions = dist.get_actions() entropy = dist.entropy() log_prob = dist.log_prob(actions) - assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) \ No newline at end of file + assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3) From d347ae53dc785552a608a931c044c0c068f9f655 Mon Sep 17 00:00:00 2001 From: Roland Gavrilescu Date: Sun, 17 May 2020 13:44:39 +0100 Subject: [PATCH 21/22] increased space dim --- tests/test_spaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 0b052744c..668ac4a3e 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -8,7 +8,7 @@ MODEL_LIST = [A2C, PPO] -DIM = 3 +DIM = 4 @pytest.mark.parametrize("model_class", MODEL_LIST) From 648e00cc3f99ff134d319ea35655dc7444e87bb7 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 18 May 2020 12:55:07 +0200 Subject: [PATCH 22/22] Update doc and tests --- README.md | 4 +- docs/guide/algos.rst | 16 +++--- docs/misc/changelog.rst | 9 ++-- docs/modules/a2c.rst | 6 +-- docs/modules/ppo.rst | 6 +-- docs/modules/sac.rst | 6 +-- docs/modules/td3.rst | 6 +-- stable_baselines3/common/policies.py | 36 ++++++------- stable_baselines3/common/preprocessing.py | 2 + stable_baselines3/version.txt | 2 +- tests/test_identity.py | 18 +++++-- tests/test_spaces.py | 62 +++++++++++------------ 12 files changed, 89 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index 309d524a2..7031d9fe8 100644 --- a/README.md +++ b/README.md @@ -162,8 +162,8 @@ All the following examples can be executed online using Google colab notebooks: | **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** | | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | -| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | -| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | +| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | | TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 34b97ec1a..94dc43b23 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -5,14 +5,14 @@ This table displays the rl algorithms that are implemented in the Stable Baselin along with some useful characteristics: support for discrete/continuous actions, multiprocessing. -============ =========== ============ ================ -Name ``Box`` ``Discrete`` Multi Processing -============ =========== ============ ================ -A2C ✔️ ✔️ ✔️ -PPO ✔️ ✔️ ✔️ -SAC ✔️ ❌ ❌ -TD3 ✔️ ❌ ❌ -============ =========== ============ ================ +============ =========== ============ ================= =============== ================ +Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing +============ =========== ============ ================= =============== ================ +A2C ✔️ ✔️ ✔️ ✔️ ✔️ +PPO ✔️ ✔️ ✔️ ✔️ ✔️ +SAC ✔️ ❌ ❌ ❌ ❌ +TD3 ✔️ ❌ ❌ ❌ ❌ +============ =========== ============ ================= =============== ================ .. note:: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8d7290da4..cbdf36f9d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.6.0a8 (WIP) +Pre-Release 0.6.0a9 (WIP) ------------------------------ Breaking Changes: @@ -15,10 +15,9 @@ New Features: - Added env checker (Sync with Stable Baselines) - Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines) - Added determinism tests -- Added ``cmd_utils`` and ``atari_wrappers`` -- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation / action spaces for PPO and A2C (@rolandgvc) -- Added ``MultiCategorical`` and ``Bernoulli`` distributions (@rolandgvc) -- Added ``test_bernoulli``, modified ``test_categorical`` and created ``test_spaces.py`` (@rolandgvc) +- Added ``cmd_utils`` and ``atari_wrappers`` +- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc) +- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc) Bug Fixes: ^^^^^^^^^^ diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 38374f7e5..096778ba1 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -28,10 +28,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ✔️ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ ============= ====== =========== diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index fb83c8985..22fdf150b 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -38,10 +38,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ✔️ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ ============= ====== =========== Example diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index 359df4bde..4e777886f 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -58,10 +58,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ❌ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ ============= ====== =========== diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index 02ae39184..86a939dee 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -50,10 +50,10 @@ Can I use? ============= ====== =========== Space Action Observation ============= ====== =========== -Discrete ❌ ❌ +Discrete ❌ ✔️ Box ✔️ ✔️ -MultiDiscrete ❌ ❌ -MultiBinary ❌ ❌ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ ============= ====== =========== diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 437d1cc1c..187ac8b12 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -279,8 +279,8 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif observation.shape[1:] == observation_space.shape: return True else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) - + "Box environment, please use {} ".format(observation_space.shape) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for " + + f"Box environment, please use {observation_space.shape} " + "or (n_env, {}) for the observation shape." .format(", ".join(map(str, observation_space.shape)))) elif isinstance(observation_space, gym.spaces.Discrete): @@ -289,7 +289,7 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif len(observation.shape) == 1: return True else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for " + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") elif isinstance(observation_space, gym.spaces.MultiDiscrete): @@ -298,21 +298,21 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): return True else: - raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) - + "environment, please use ({},) or ".format(len(observation_space.nvec)) - + "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " + + f"environment, please use ({len(observation_space.nvec)},) or " + + f"(n_env, {len(observation_space.nvec)}) for the observation shape.") elif isinstance(observation_space, gym.spaces.MultiBinary): if observation.shape == (observation_space.n,): return False elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: return True else: - raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) - + "environment, please use ({},) or ".format(observation_space.n) - + "(n_env, {}) for the observation shape.".format(observation_space.n)) + raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary " + + f"environment, please use ({observation_space.n},) or " + + f"(n_env, {observation_space.n}) for the observation shape.") else: - raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}." - .format(observation_space)) + raise ValueError("Error: Cannot determine if the observation is vectorized " + + f" with the space type {observation_space}.") def _get_data(self) -> Dict[str, Any]: """ @@ -447,7 +447,7 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: raise ValueError(f"Error: unknown policy type {name}," - "the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") + f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") return _policy_registry[base_policy_type][name] @@ -460,14 +460,10 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None: :param policy: (Type[BasePolicy]) the policy class """ sub_class = None - # For building the doc - try: - for cls in BasePolicy.__subclasses__(): - if issubclass(policy, cls): - sub_class = cls - break - except AttributeError: - sub_class = str(th.random.randint(100)) + for cls in BasePolicy.__subclasses__(): + if issubclass(policy, cls): + sub_class = cls + break if sub_class is None: raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index f3caa94e3..849756f17 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -110,6 +110,8 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: :param observation_space: (spaces.Space) :return: (int) """ + # See issue https://github.com/openai/gym/issues/1915 + # it may be a problem for Dict/Tuple spaces too... if isinstance(observation_space, spaces.MultiDiscrete): return sum(observation_space.nvec) else: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index df3ddb5f4..21c95036e 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.6.0a8 +0.6.0a9 diff --git a/tests/test_identity.py b/tests/test_identity.py index d937c7e96..b41b70c2f 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -2,17 +2,27 @@ import pytest from stable_baselines3 import A2C, PPO, SAC, TD3 -from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv +from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv, + IdentityEnvMultiBinary, IdentityEnvMultiDiscrete) + +from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.noise import NormalActionNoise +DIM = 4 + + @pytest.mark.parametrize("model_class", [A2C, PPO]) -def test_discrete(model_class): - env = IdentityEnv(10) - model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000) +@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)]) +def test_discrete(model_class, env): + env = DummyVecEnv([lambda: env]) + model = model_class('MlpPolicy', env, gamma=0.5, seed=1).learn(3000) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) + obs = env.reset() + + assert np.shape(model.predict(obs)[0]) == np.shape(obs) @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3]) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 668ac4a3e..dfd4a60e3 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,49 +1,47 @@ import numpy as np import pytest +import gym -from stable_baselines3 import A2C, PPO -from stable_baselines3.common.identity_env import IdentityEnvMultiBinary, IdentityEnvMultiDiscrete +from stable_baselines3 import SAC, TD3 from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.vec_env import DummyVecEnv -MODEL_LIST = [A2C, PPO] -DIM = 4 +class DummyMultiDiscreteSpace(gym.Env): + def __init__(self, nvec): + super(DummyMultiDiscreteSpace, self).__init__() + self.observation_space = gym.spaces.MultiDiscrete(nvec) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + def reset(self): + return self.observation_space.sample() -@pytest.mark.parametrize("model_class", MODEL_LIST) -def test_identity_multidiscrete(model_class): - """ - Test if the algorithm (with a given policy) - can learn an identity transformation (i.e. return observation as an action) - with a multidiscrete action space - :param model_class: (BaseRLModel) A RL Model - """ - env = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(DIM)]) + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} - model = model_class("MlpPolicy", env, gamma=0.5, seed=1) - model.learn(total_timesteps=3000) - obs = env.reset() - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) +class DummyMultiBinary(gym.Env): + def __init__(self, n): + super(DummyMultiBinary, self).__init__() + self.observation_space = gym.spaces.MultiBinary(n) + self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - assert np.shape(model.predict(obs)[0]) == np.shape(obs) + def reset(self): + return self.observation_space.sample() + def step(self, action): + return self.observation_space.sample(), 0.0, False, {} -@pytest.mark.parametrize("model_class", MODEL_LIST) -def test_identity_multibinary(model_class): + +@pytest.mark.parametrize("model_class", [SAC, TD3]) +@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)]) +def test_identity_spaces(model_class, env): """ - Test if the algorithm (with a given policy) - can learn an identity transformation (i.e. return observation as an action) - with a multibinary action space - :param model_class: (BaseRLModel) A RL Model + Additional tests for SAC/TD3 to check observation space support + for MultiDiscrete and MultiBinary. """ - env = DummyVecEnv([lambda: IdentityEnvMultiBinary(DIM)]) - - model = model_class("MlpPolicy", env, gamma=0.5, seed=1) - model.learn(total_timesteps=3000) - obs = env.reset() + env = gym.wrappers.TimeLimit(env, max_episode_steps=100) - evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) + model = model_class("MlpPolicy", env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64])) + model.learn(total_timesteps=500) - assert np.shape(model.predict(obs)[0]) == np.shape(obs) + evaluate_policy(model, env, n_eval_episodes=5)