Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MultiBinary / MultiDiscrete spaces #13

Merged
merged 29 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8903862
multicategorical dist and test
rolandgvc May 9, 2020
edef83b
fixed List annotation
rolandgvc May 10, 2020
f25b9f3
bernoulli dist and test
rolandgvc May 10, 2020
7f8dc90
added distributions to preprocessing (needs testing)
rolandgvc May 10, 2020
0a42a69
Merge branch 'master' into master
araffin May 11, 2020
b73b67f
Merge branch 'master' into master
araffin May 11, 2020
97db699
Merge branch 'master' into master
araffin May 12, 2020
06276a7
fixed and tested distributions
rolandgvc May 12, 2020
b1886f2
distributions implemented and tested on ppo
rolandgvc May 12, 2020
394a9be
added changelog and fixed ppo policy
rolandgvc May 12, 2020
236153a
minor fix
rolandgvc May 12, 2020
f8518f2
dist fixes, added test_spaces
rolandgvc May 13, 2020
6c8dfa5
clean up
rolandgvc May 13, 2020
4f9b455
modified changelog
rolandgvc May 13, 2020
a20e9a7
additional fixes
rolandgvc May 13, 2020
65dd147
minor changelog mod
rolandgvc May 13, 2020
ca6824d
hot encoding fix, flake8 clean up
rolandgvc May 14, 2020
43ae65a
lint tests
rolandgvc May 14, 2020
3c39b72
preprocessing fix
rolandgvc May 14, 2020
2e0b2b1
fixed bernoulli bug
rolandgvc May 14, 2020
ef4874b
removed commented prints
rolandgvc May 14, 2020
a021434
Merge branch 'master' into master
araffin May 15, 2020
3c506bb
Update changelog.rst
araffin May 15, 2020
72cd10e
included suggested modifications
rolandgvc May 15, 2020
2f372ee
pulled
rolandgvc May 15, 2020
1669c8c
linting fix
rolandgvc May 17, 2020
d347ae5
increased space dim
rolandgvc May 17, 2020
648e00c
Update doc and tests
araffin May 18, 2020
805a87e
Merge pull request #1 from DLR-RM/pull_13
rolandgvc May 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ All the following examples can be executed online using Google colab notebooks:

| **Name** | **Recurrent** | `Box` | `Discrete` | `MultiDiscrete` | `MultiBinary` | **Multi Processing** |
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| PPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| SAC | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| TD3 | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |

Expand Down
16 changes: 8 additions & 8 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ This table displays the rl algorithms that are implemented in the Stable Baselin
along with some useful characteristics: support for discrete/continuous actions, multiprocessing.


============ =========== ============ ================
Name ``Box`` ``Discrete`` Multi Processing
============ =========== ============ ================
A2C ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌
TD3 ✔️ ❌ ❌
============ =========== ============ ================
============ =========== ============ ================= =============== ================
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
A2C ✔️ ✔️ ✔️ ✔️ ✔️
PPO ✔️ ✔️ ✔️ ✔️ ✔️
SAC ✔️ ❌ ❌ ❌ ❌
TD3 ✔️ ❌ ❌ ❌ ❌
============ =========== ============ ================= =============== ================


.. note::
Expand Down
7 changes: 4 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
Changelog
==========

Pre-Release 0.6.0a8 (WIP)
Pre-Release 0.6.0a9 (WIP)
------------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Remove State-Dependent Exploration (SDE) support for ``TD3``
Expand All @@ -17,6 +16,8 @@ New Features:
- Added ``VecCheckNan`` and ``VecVideoRecorder`` (Sync with Stable Baselines)
- Added determinism tests
- Added ``cmd_utils`` and ``atari_wrappers``
- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc)
- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -227,4 +228,4 @@ And all the contributors:
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc
6 changes: 3 additions & 3 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete
MultiBinary
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
============= ====== ===========


Expand Down
6 changes: 3 additions & 3 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete
Discrete ✔️ ✔️
Box ✔️ ✔️
MultiDiscrete
MultiBinary
MultiDiscrete ✔️ ✔️
MultiBinary ✔️ ✔️
============= ====== ===========

Example
Expand Down
6 changes: 3 additions & 3 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌
MultiBinary ❌
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========


Expand Down
6 changes: 3 additions & 3 deletions docs/modules/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ Can I use?
============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌
MultiBinary ❌
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========


Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class BaseBuffer(object):
to which the values will be converted
:param n_envs: (int) Number of parallel environments
"""

def __init__(self,
buffer_size: int,
observation_space: spaces.Space,
Expand Down
123 changes: 115 additions & 8 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Tuple, Dict, Any

from typing import Optional, Tuple, Dict, Any, List
import gym
import torch as th
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions import Normal, Categorical, Bernoulli
from gym import spaces

from stable_baselines3.common.preprocessing import get_action_dim
Expand Down Expand Up @@ -88,7 +87,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
:return: (th.Tensor) shape: (n_batch,)
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(axis=1)
tensor = tensor.sum(dim=1)
rolandgvc marked this conversation as resolved.
Show resolved Hide resolved
else:
tensor = tensor.sum()
return tensor
Expand Down Expand Up @@ -292,6 +291,114 @@ def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)


class MultiCategoricalDistribution(Distribution):
"""
MultiCategorical distribution for multi discrete actions.

:param action_dims: (List[int]) List of sizes of discrete action spaces
"""

def __init__(self, action_dims: List[int]):
super(MultiCategoricalDistribution, self).__init__()
self.action_dims = action_dims
self.distributions = None

def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits (flattened) of the MultiCategorical distribution.
You can then get probabilities using a softmax on each sub-space.

:param latent_dim: (int) Dimension of the last layer
of the policy network (before the action layer)
:return: (nn.Linear)
"""

action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits

def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution':
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self

def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)

def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distributions], dim=1)

def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)

def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)

def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob

def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
th.unbind(actions, dim=1))], dim=1).sum(dim=1)


class BernoulliDistribution(Distribution):
"""
Bernoulli distribution for MultiBinary action spaces.

:param action_dim: (int) Number of binary actions
"""

def __init__(self, action_dims: int):
super(BernoulliDistribution, self).__init__()
self.distribution = None
self.action_dims = action_dims

def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Bernoulli distribution.

:param latent_dim: (int) Dimension of the last layer
of the policy network (before the action layer)
:return: (nn.Linear)
"""
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits

def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution':
self.distribution = Bernoulli(logits=action_logits)
return self

def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)

def sample(self) -> th.Tensor:
return self.distribution.sample()

def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)

def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)

def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob

def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)


class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using generalized State Dependent Exploration (gSDE).
Expand Down Expand Up @@ -551,10 +658,10 @@ def make_proba_distribution(action_space: gym.spaces.Space,
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiDiscrete):
# return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiBinary):
# return BernoulliDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError("Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
Expand Down
71 changes: 33 additions & 38 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def predict(self, observation: np.ndarray,
# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space):
if (observation.shape == self.observation_space.shape or
observation.shape[1:] == self.observation_space.shape):
if (observation.shape == self.observation_space.shape
or observation.shape[1:] == self.observation_space.shape):
pass
else:
# Try to re-order the channels
Expand Down Expand Up @@ -279,40 +279,40 @@ def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.s
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Box environment, please use {} ".format(observation_space.shape) +
"or (n_env, {}) for the observation shape."
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape."
.format(", ".join(map(str, observation_space.shape))))
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
# if observation.shape == (len(observation_space.nvec),):
# return False
# elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
# "environment, please use ({},) or ".format(len(observation_space.nvec)) +
# "(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
# elif isinstance(observation_space, gym.spaces.MultiBinary):
# if observation.shape == (observation_space.n,):
# return False
# elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
# "environment, please use ({},) or ".format(observation_space.n) +
# "(n_env, {}) for the observation shape.".format(observation_space.n))
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")

elif isinstance(observation_space, gym.spaces.MultiDiscrete):
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape.")
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use ({observation_space.n},) or "
+ f"(n_env, {observation_space.n}) for the observation shape.")
else:
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
.format(observation_space))
raise ValueError("Error: Cannot determine if the observation is vectorized "
+ f" with the space type {observation_space}.")

def _get_data(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -447,7 +447,7 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
raise ValueError(f"Error: unknown policy type {name},"
"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
return _policy_registry[base_policy_type][name]


Expand All @@ -460,14 +460,10 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None:
:param policy: (Type[BasePolicy]) the policy class
"""
sub_class = None
# For building the doc
try:
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
except AttributeError:
sub_class = str(th.random.randint(100))
for cls in BasePolicy.__subclasses__():
if issubclass(policy, cls):
sub_class = cls
break
if sub_class is None:
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")

Expand Down Expand Up @@ -511,7 +507,6 @@ def __init__(self, feature_dim: int,
device: Union[th.device, str] = 'auto'):
super(MlpExtractor, self).__init__()
device = get_device(device)

shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
Expand Down
Loading