From 0ad743c85d3c2dbdfab8f18b6c519a45751b4a08 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 25 Oct 2019 10:59:15 +0200 Subject: [PATCH 1/7] Add A2C --- README.md | 8 +-- tests/test_run.py | 7 +- torchy_baselines/__init__.py | 3 +- torchy_baselines/a2c/__init__.py | 2 + torchy_baselines/a2c/a2c.py | 109 +++++++++++++++++++++++++++++++ 5 files changed, 120 insertions(+), 9 deletions(-) create mode 100644 torchy_baselines/a2c/__init__.py create mode 100644 torchy_baselines/a2c/a2c.py diff --git a/README.md b/README.md index a73f69091..b5624ec71 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ PyTorch version of [Stable Baselines](https://github.com/hill-a/stable-baselines ## Implemented Algorithms +- A2C - CEM-RL (with TD3) - PPO - SAC @@ -18,11 +19,8 @@ PyTorch version of [Stable Baselines](https://github.com/hill-a/stable-baselines TODO: - save/load -- predict -- flexible mlp -- logger -- better monitor wrapper? -- A2C +- better predict +- complete logger Later: - get_parameters / set_parameters diff --git a/tests/test_run.py b/tests/test_run.py index 974092157..32a4b30f5 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -3,7 +3,7 @@ import pytest import numpy as np -from torchy_baselines import TD3, CEMRL, PPO, SAC +from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3 from torchy_baselines.common.noise import NormalActionNoise @@ -28,9 +28,10 @@ def test_cemrl(): os.remove("test_save.pth") +@pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0']) -def test_ppo(env_id): - model = PPO('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) +def test_onpolicy(model_class, env_id): + model = model_class('MlpPolicy', env_id, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=1000, eval_freq=500) # model.save("test_save") # model.load("test_save") diff --git a/torchy_baselines/__init__.py b/torchy_baselines/__init__.py index b9dabaabb..a5896e6f9 100644 --- a/torchy_baselines/__init__.py +++ b/torchy_baselines/__init__.py @@ -1,6 +1,7 @@ +from torchy_baselines.a2c import A2C from torchy_baselines.cem_rl import CEMRL from torchy_baselines.ppo import PPO from torchy_baselines.sac import SAC from torchy_baselines.td3 import TD3 -__version__ = "0.0.4" +__version__ = "0.0.5a" diff --git a/torchy_baselines/a2c/__init__.py b/torchy_baselines/a2c/__init__.py new file mode 100644 index 000000000..0cc4be01e --- /dev/null +++ b/torchy_baselines/a2c/__init__.py @@ -0,0 +1,2 @@ +from torchy_baselines.a2c.a2c import A2C +from torchy_baselines.ppo.policies import MlpPolicy diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py new file mode 100644 index 000000000..4de140b98 --- /dev/null +++ b/torchy_baselines/a2c/a2c.py @@ -0,0 +1,109 @@ +from gym import spaces +import torch as th +import torch.nn.functional as F + +from torchy_baselines.common.utils import explained_variance +from torchy_baselines.ppo.ppo import PPO +from torchy_baselines.ppo.policies import PPOPolicy + + +class A2C(PPO): + """ + Advantage Actor Critic (A2C) + + Paper: https://arxiv.org/abs/1602.01783 + Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and + and Stable Baselines (https://github.com/hill-a/stable-baselines) + + Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752 + + :param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) + :param learning_rate: (float or callable) The learning rate, it can be a function + :param n_steps: (int) The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param batch_size: (int) Minibatch size + :param n_epochs: (int) Number of epoch when optimizing the surrogate loss + :param gamma: (float) Discount factor + :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param ent_coef: (float) Entropy coefficient for the loss calculation + :param vf_coef: (float) Value function coefficient for the loss calculation + :param max_grad_norm: (float) The maximum value for the gradient clipping + :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) + :param create_eval_env: (bool) Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug + :param seed: (int) Seed for the pseudo random generators + :param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + """ + + def __init__(self, policy, env, learning_rate=3e-4, + n_steps=2048, batch_size=64, n_epochs=1, + gamma=0.99, gae_lambda=0.95, + ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, + tensorboard_log=None, create_eval_env=False, + policy_kwargs=None, verbose=0, seed=0, device='auto', + _init_setup_model=True): + + super(A2C, self).__init__(policy, env, learning_rate=learning_rate, + n_steps=n_steps, batch_size=batch_size, n_epochs=n_epochs, + gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, + vf_coef=vf_coef, max_grad_norm=max_grad_norm, + tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, + verbose=verbose, device=device, create_eval_env=create_eval_env, + seed=seed, _init_setup_model=False) + + self.batch_size = n_steps + + if _init_setup_model: + self._setup_model() + + def train(self, gradient_steps, batch_size=64): + + for gradient_step in range(gradient_steps): + # approx_kl_divs = [] + # Sample replay buffer + for replay_data in self.rollout_buffer.get(batch_size): + # Unpack + obs, action, _, _, advantage, return_batch = replay_data + + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action for float to long + action = action.long().flatten() + + values, log_prob, entropy = self.policy.get_policy_stats(obs, action) + values = values.flatten() + # Normalize advantage + # TODO: check without + advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) + + policy_loss = -(advantage * log_prob).mean() + + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(return_batch, values) + + # Entropy loss favor exploration + entropy_loss = th.mean(entropy) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + # approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy()) + + # print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(), + # self.rollout_buffer.values.flatten().cpu().numpy())) + + def learn(self, total_timesteps, callback=None, log_interval=100, + eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="A2C", reset_num_timesteps=True): + + return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, + eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps) From f8bcb8ee16817dc62cbc052b8d9d810f726b80d4 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 25 Oct 2019 11:31:20 +0200 Subject: [PATCH 2/7] Update A2C params --- torchy_baselines/a2c/a2c.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 4de140b98..709f67b20 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -22,13 +22,12 @@ class A2C(PPO): :param learning_rate: (float or callable) The learning rate, it can be a function :param n_steps: (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) - :param batch_size: (int) Minibatch size - :param n_epochs: (int) Number of epoch when optimizing the surrogate loss :param gamma: (float) Discount factor :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param ent_coef: (float) Entropy coefficient for the loss calculation :param vf_coef: (float) Value function coefficient for the loss calculation :param max_grad_norm: (float) The maximum value for the gradient clipping + :param normalize_advantage: (bool) Whether to normalize or not the advantage :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param create_eval_env: (bool) Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) @@ -41,23 +40,22 @@ class A2C(PPO): """ def __init__(self, policy, env, learning_rate=3e-4, - n_steps=2048, batch_size=64, n_epochs=1, - gamma=0.99, gae_lambda=0.95, + n_steps=5, gamma=0.99, gae_lambda=0.95, ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, - tensorboard_log=None, create_eval_env=False, + normalize_advantage=True, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): super(A2C, self).__init__(policy, env, learning_rate=learning_rate, - n_steps=n_steps, batch_size=batch_size, n_epochs=n_epochs, + n_steps=n_steps, batch_size=n_steps, n_epochs=1, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, verbose=verbose, device=device, create_eval_env=create_eval_env, seed=seed, _init_setup_model=False) - self.batch_size = n_steps - + # Note: in the original implementation, this is RMSProp that is used + self.normalize_advantage = normalize_advantage if _init_setup_model: self._setup_model() @@ -76,9 +74,9 @@ def train(self, gradient_steps, batch_size=64): values, log_prob, entropy = self.policy.get_policy_stats(obs, action) values = values.flatten() - # Normalize advantage - # TODO: check without - advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) + # Normalize advantage (not present in the original implementation) + if self.normalize_advantage: + advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) policy_loss = -(advantage * log_prob).mean() From 584f549fa15e7c2713e8ab0a0937e6e844e241db Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 25 Oct 2019 12:00:37 +0200 Subject: [PATCH 3/7] Bug fix for discrete actions --- torchy_baselines/common/buffers.py | 1 + torchy_baselines/ppo/ppo.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index 565117c3a..b169dd334 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -134,6 +134,7 @@ def add(self, obs, action, reward, done, value, log_prob): if len(log_prob.shape) == 0: # Reshape 0-d tensor to avoid error log_prob = log_prob.reshape(-1, 1) + self.observations[self.pos] = th.FloatTensor(np.array(obs).copy()) self.actions[self.pos] = th.FloatTensor(np.array(action).copy()) self.rewards[self.pos] = th.FloatTensor(np.array(reward).copy()) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 3d748fb63..b6584dd40 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -158,6 +158,9 @@ def collect_rollouts(self, env, rollout_buffer, n_rollout_steps=256, callback=No self._update_info_buffer(infos) n_steps += 1 + if isinstance(self.action_space, gym.spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) rollout_buffer.add(obs, actions, rewards, dones, values, log_probs) obs = new_obs From b150167bdd22c53dfac184a4ace266a2d90e9919 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 25 Oct 2019 13:01:00 +0200 Subject: [PATCH 4/7] Update default hyperparams --- torchy_baselines/a2c/a2c.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 709f67b20..f08a352b8 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -27,6 +27,9 @@ class A2C(PPO): :param ent_coef: (float) Entropy coefficient for the loss calculation :param vf_coef: (float) Value function coefficient for the loss calculation :param max_grad_norm: (float) The maximum value for the gradient clipping + :param rms_prop_eps: (float) RMSProp epsilon. It stabilizes square root computation in denominator + of RMSProp update + :param use_rms_prop: (bool) Whether to use RMSprop (default) or Adam as optimizer :param normalize_advantage: (bool) Whether to normalize or not the advantage :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param create_eval_env: (bool) Whether to create a second environment that will be @@ -39,10 +42,11 @@ class A2C(PPO): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, policy, env, learning_rate=3e-4, - n_steps=5, gamma=0.99, gae_lambda=0.95, + def __init__(self, policy, env, learning_rate=7e-4, + n_steps=5, gamma=0.99, gae_lambda=1.0, ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, - normalize_advantage=True, tensorboard_log=None, create_eval_env=False, + rms_prop_eps=1e-5, use_rms_prop=True, + normalize_advantage=False, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): @@ -54,11 +58,20 @@ def __init__(self, policy, env, learning_rate=3e-4, verbose=verbose, device=device, create_eval_env=create_eval_env, seed=seed, _init_setup_model=False) - # Note: in the original implementation, this is RMSProp that is used self.normalize_advantage = normalize_advantage + self.rms_prop_eps = rms_prop_eps + self.use_rms_prop = use_rms_prop + if _init_setup_model: self._setup_model() + def _setup_model(self): + super(A2C, self)._setup_model() + if self.use_rms_prop: + self.policy.optimizer = th.optim.RMSprop(self.policy.parameters(), + lr=self.learning_rate, alpha=0.99, + eps=self.rms_prop_eps, weight_decay=0) + def train(self, gradient_steps, batch_size=64): for gradient_step in range(gradient_steps): From 799e30ff3d575015866a1a48da7b4641aa56eed2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Oct 2019 14:27:32 +0100 Subject: [PATCH 5/7] Bug fixes for A2C and PPO --- torchy_baselines/a2c/a2c.py | 62 ++++++++++++++++-------------- torchy_baselines/common/buffers.py | 6 ++- torchy_baselines/ppo/ppo.py | 2 +- 3 files changed, 39 insertions(+), 31 deletions(-) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index f08a352b8..c426552c9 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -51,7 +51,7 @@ def __init__(self, policy, env, learning_rate=7e-4, _init_setup_model=True): super(A2C, self).__init__(policy, env, learning_rate=learning_rate, - n_steps=n_steps, batch_size=n_steps, n_epochs=1, + n_steps=n_steps, batch_size=None, n_epochs=1, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, @@ -72,42 +72,46 @@ def _setup_model(self): lr=self.learning_rate, alpha=0.99, eps=self.rms_prop_eps, weight_decay=0) - def train(self, gradient_steps, batch_size=64): + def train(self, gradient_steps, batch_size=None): - for gradient_step in range(gradient_steps): - # approx_kl_divs = [] - # Sample replay buffer - for replay_data in self.rollout_buffer.get(batch_size): - # Unpack - obs, action, _, _, advantage, return_batch = replay_data + # A2C with gradient_steps > 1 does not make sense + assert gradient_steps == 1 + # We do not use minibatches for A2C + assert batch_size is None - if isinstance(self.action_space, spaces.Discrete): - # Convert discrete action for float to long - action = action.long().flatten() + for rollout_data in self.rollout_buffer.get(batch_size=None): + # Unpack + obs, action, _, _, advantage, return_batch = rollout_data - values, log_prob, entropy = self.policy.get_policy_stats(obs, action) - values = values.flatten() - # Normalize advantage (not present in the original implementation) - if self.normalize_advantage: - advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action for float to long + action = action.long().flatten() - policy_loss = -(advantage * log_prob).mean() + # TODO: avoid second computation of everything because of the gradient + values, log_prob, entropy = self.policy.get_policy_stats(obs, action) + values = values.flatten() - # Value loss using the TD(gae_lambda) target - value_loss = F.mse_loss(return_batch, values) + # Normalize advantage (not present in the original implementation) + if self.normalize_advantage: + advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) - # Entropy loss favor exploration - entropy_loss = th.mean(entropy) + policy_loss = -(advantage * log_prob).mean() - loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(return_batch, values) - # Optimization step - self.policy.optimizer.zero_grad() - loss.backward() - # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) - self.policy.optimizer.step() - # approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy()) + # Entropy loss favor exploration + entropy_loss = -th.mean(entropy) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + # approx_kl_divs.append(th.mean(old_log_prob - log_prob).detach().cpu().numpy()) # print(explained_variance(self.rollout_buffer.returns.flatten().cpu().numpy(), # self.rollout_buffer.values.flatten().cpu().numpy())) diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index b169dd334..34a709841 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -145,7 +145,7 @@ def add(self, obs, action, reward, done, value, log_prob): if self.pos == self.buffer_size: self.full = True - def get(self, batch_size): + def get(self, batch_size=None): assert self.full indices = th.randperm(self.buffer_size * self.n_envs) # Prepare the data @@ -155,6 +155,10 @@ def get(self, batch_size): self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + start_idx = 0 while start_idx < self.buffer_size * self.n_envs: yield self._get_samples(indices[start_idx:start_idx + batch_size]) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index b6584dd40..cf367701d 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -205,7 +205,7 @@ def train(self, gradient_steps, batch_size=64): # Entropy loss favor exploration - entropy_loss = th.mean(entropy) + entropy_loss = -th.mean(entropy) loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss From d67822718c00f685709c6b25ae7333cbe7524c40 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Oct 2019 16:47:13 +0100 Subject: [PATCH 6/7] Add learning rate schedule --- torchy_baselines/a2c/a2c.py | 4 ++- torchy_baselines/cem_rl/cem_rl.py | 3 +- torchy_baselines/common/base_class.py | 26 +++++++++++++++- torchy_baselines/common/utils.py | 45 +++++++++++++++++++++++++++ torchy_baselines/ppo/policies.py | 4 +-- torchy_baselines/ppo/ppo.py | 29 +++++++++++++---- torchy_baselines/sac/policies.py | 6 ++-- torchy_baselines/sac/sac.py | 11 ++++++- torchy_baselines/td3/policies.py | 6 ++-- torchy_baselines/td3/td3.py | 6 ++++ 10 files changed, 122 insertions(+), 18 deletions(-) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index c426552c9..6ee6f4a77 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -69,11 +69,13 @@ def _setup_model(self): super(A2C, self)._setup_model() if self.use_rms_prop: self.policy.optimizer = th.optim.RMSprop(self.policy.parameters(), - lr=self.learning_rate, alpha=0.99, + lr=self.learning_rate(1), alpha=0.99, eps=self.rms_prop_eps, weight_decay=0) def train(self, gradient_steps, batch_size=None): + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) # A2C with gradient_steps > 1 does not make sense assert gradient_steps == 1 # We do not use minibatches for A2C diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py index 2116d3f48..ad798ba41 100644 --- a/torchy_baselines/cem_rl/cem_rl.py +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -78,7 +78,7 @@ def learn(self, total_timesteps, callback=None, log_interval=4, # set params self.actor.load_from_vector(self.es_params[i]) self.actor_target.load_from_vector(self.es_params[i]) - self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=self.learning_rate) + self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=self.learning_rate(self._current_progress)) # In the paper: 2 * actor_steps // self.n_grad # In the original implementation: actor_steps // self.n_grad @@ -153,6 +153,7 @@ def learn(self, total_timesteps, callback=None, log_interval=4, print("Total T: {} Episode Num: {} Episode T: {} Reward: {}".format( self.num_timesteps, episode_num, episode_timesteps, episode_reward)) + self._update_current_progress(self.num_timesteps, total_timesteps) self.es.tell(self.es_params, self.fitnesses) timesteps_since_eval += actor_steps return self diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 8feec0186..0904932f6 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -7,7 +7,7 @@ import numpy as np from torchy_baselines.common.policies import get_policy_from_name -from torchy_baselines.common.utils import set_random_seed +from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv from torchy_baselines.common.monitor import Monitor from torchy_baselines.common import logger @@ -57,6 +57,9 @@ def __init__(self, policy, env, policy_base, policy_kwargs=None, self.replay_buffer = None self.seed = seed self.action_noise = None + # Track the training progress (from 1 to 0) + # this is used to update the learning rate + self._current_progress = 1 if env is not None: if isinstance(env, str): @@ -112,6 +115,27 @@ def unscale_action(self, scaled_action): low, high = self.action_space.low, self.action_space.high return low + (0.5 * (scaled_action + 1.0) * (high - low)) + def _setup_learning_rate(self): + """Transform to callable if needed.""" + self.learning_rate = get_schedule_fn(self.learning_rate) + + def _update_current_progress(self, num_timesteps, total_timesteps): + """ + Compute current progress (from 1 to 0) + + :param num_timesteps: (int) + :param total_timesteps: (int) + """ + self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps) + + def _update_learning_rate(self, optimizers): + logger.logkv("learning_rate", self.learning_rate(self._current_progress)) + + if not isinstance(optimizers, list): + optimizers = [optimizers] + for optimizer in optimizers: + update_learning_rate(optimizer, self.learning_rate(self._current_progress)) + @staticmethod def safe_mean(arr): """ diff --git a/torchy_baselines/common/utils.py b/torchy_baselines/common/utils.py index b4887a711..151333bb7 100644 --- a/torchy_baselines/common/utils.py +++ b/torchy_baselines/common/utils.py @@ -38,3 +38,48 @@ def explained_variance(y_pred, y_true): assert y_true.ndim == 1 and y_pred.ndim == 1 var_y = np.var(y_true) return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + +def update_learning_rate(optimizer, learning_rate): + """ + Update the learning rate for a given optimizer. + Useful when doing linear schedule. + + :param optimizer: (th.optim.Optimizer) + :param learning_rate: (float) + """ + for param_group in optimizer.param_groups: + param_group['lr'] = learning_rate + + +def get_schedule_fn(value_schedule): + """ + Transform (if needed) learning rate and clip range (for PPO) + to callable. + + :param value_schedule: (callable or float) + :return: (function) + """ + # If the passed schedule is a float + # create a constant function + if isinstance(value_schedule, (float, int)): + # Cast to float to avoid errors + value_schedule = constant_fn(float(value_schedule)) + else: + assert callable(value_schedule) + return value_schedule + + +def constant_fn(val): + """ + Create a function that returns a constant + It is useful for learning rate schedule (to avoid code duplication) + + :param val: (float) + :return: (function) + """ + + def func(_): + return val + + return func diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 0a8f38ec7..c967962de 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -100,7 +100,7 @@ def forward(self, features): class PPOPolicy(BasePolicy): def __init__(self, observation_space, action_space, - learning_rate=1e-3, net_arch=None, device='cpu', + learning_rate, net_arch=None, device='cpu', activation_fn=nn.Tanh, adam_epsilon=1e-5, ortho_init=True): super(PPOPolicy, self).__init__(observation_space, action_space, device) self.obs_dim = self.observation_space.shape[0] @@ -149,7 +149,7 @@ def _build(self, learning_rate): }[module] module.apply(partial(self.init_weights, gain=gain)) # TODO: support linear decay of the learning rate - self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate, eps=self.adam_epsilon) + self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon) def forward(self, obs, deterministic=False): if not isinstance(obs, th.Tensor): diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index cf367701d..13a163434 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -16,7 +16,7 @@ from torchy_baselines.common.base_class import BaseRLModel from torchy_baselines.common.evaluation import evaluate_policy from torchy_baselines.common.buffers import RolloutBuffer -from torchy_baselines.common.utils import explained_variance +from torchy_baselines.common.utils import explained_variance, get_schedule_fn from torchy_baselines.common.vec_env import VecNormalize from torchy_baselines.common import logger from torchy_baselines.ppo.policies import PPOPolicy @@ -36,14 +36,16 @@ class PPO(BaseRLModel): :param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) :param learning_rate: (float or callable) The learning rate, it can be a function + of the current progress (from 1 to 0) :param n_steps: (int) The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param batch_size: (int) Minibatch size :param n_epochs: (int) Number of epoch when optimizing the surrogate loss :param gamma: (float) Discount factor :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator - :param clip_range: (float or callable) Clipping parameter, it can be a function - :param clip_range_vf: (float or callable) Clipping parameter for the value function, it can be a function. + :param clip_range: (float or callable) Clipping parameter, it can be a function of the current progress (from 1 to 0). + :param clip_range_vf: (float or callable) Clipping parameter for the value function, + it can be a function of the current progress (from 1 to 0). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. @@ -84,12 +86,12 @@ def __init__(self, policy, env, learning_rate=3e-4, self.gamma = gamma self.gae_lambda = gae_lambda self.clip_range = clip_range + self.clip_range_vf = clip_range_vf self.ent_coef = ent_coef self.vf_coef = vf_coef self.max_grad_norm = max_grad_norm self.rollout_buffer = None self.target_kl = target_kl - self.clip_range_vf = clip_range_vf self.tensorboard_log = tensorboard_log self.tb_writer = None @@ -97,6 +99,7 @@ def __init__(self, policy, env, learning_rate=3e-4, self._setup_model() def _setup_model(self): + self._setup_learning_rate() # TODO: preprocessing: one hot vector for obs discrete state_dim = self.observation_space.shape[0] if isinstance(self.action_space, spaces.Box): @@ -116,6 +119,10 @@ def _setup_model(self): self.learning_rate, device=self.device, **self.policy_kwargs) self.policy = self.policy.to(self.device) + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + def select_action(self, observation): # Normally not needed observation = np.array(observation) @@ -169,6 +176,15 @@ def collect_rollouts(self, env, rollout_buffer, n_rollout_steps=256, callback=No return obs def train(self, gradient_steps, batch_size=64): + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress) + logger.logkv("clip_range", clip_range) + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress) + logger.logkv("clip_range_vf", clip_range_vf) + for gradient_step in range(gradient_steps): approx_kl_divs = [] @@ -190,7 +206,7 @@ def train(self, gradient_steps, batch_size=64): ratio = th.exp(log_prob - old_log_prob) # clipped surrogate loss policy_loss_1 = advantage * ratio - policy_loss_2 = advantage * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range) + policy_loss_2 = advantage * th.clamp(ratio, 1 - clip_range, 1 + clip_range) policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() if self.clip_range_vf is None: @@ -199,7 +215,7 @@ def train(self, gradient_steps, batch_size=64): else: # Clip the different between old and new value # NOTE: this depends on the reward scaling - values_pred = old_values + th.clamp(values - old_values, -self.clip_range_vf, self.clip_range_vf) + values_pred = old_values + th.clamp(values - old_values, -clip_range_vf, clip_range_vf) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(return_batch, values_pred) @@ -244,6 +260,7 @@ def learn(self, total_timesteps, callback=None, log_interval=1, iteration += 1 self.num_timesteps += self.n_steps * self.n_envs timesteps_since_eval += self.n_steps * self.n_envs + self._update_current_progress(self.num_timesteps, total_timesteps) # Display training infos if self.verbose >= 1 and log_interval is not None and iteration % log_interval == 0: diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index c1f89a4b2..b61ae502e 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -63,7 +63,7 @@ def q1_forward(self, obs, action): class SACPolicy(BasePolicy): def __init__(self, observation_space, action_space, - learning_rate=3e-4, net_arch=None, device='cpu', + learning_rate, net_arch=None, device='cpu', activation_fn=nn.ReLU): super(SACPolicy, self).__init__(observation_space, action_space, device) @@ -87,12 +87,12 @@ def __init__(self, observation_space, action_space, def _build(self, learning_rate): self.actor = self.make_actor() - self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=learning_rate) + self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=learning_rate(1)) self.critic = self.make_critic() self.critic_target = self.make_critic() self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate) + self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1)) def make_actor(self): return Actor(**self.net_args).to(self.device) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index dca15f858..bad470a91 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -88,6 +88,7 @@ def __init__(self, policy, env, learning_rate=3e-4, buffer_size=int(1e6), self._setup_model() def _setup_model(self): + self._setup_learning_rate() obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] if self.seed is not None: self.set_random_seed(self.seed) @@ -114,7 +115,7 @@ def _setup_model(self): # Note: we optimize the log of the entropy coeff which is slightly different from the paper # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True) - self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.learning_rate) + self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.learning_rate(1)) else: # Force conversion to float # this will throw an error if a malformed string (different from 'auto') @@ -152,6 +153,13 @@ def predict(self, observation, state=None, mask=None, deterministic=True): return self.unscale_action(self.select_action(observation)) def train(self, gradient_steps, batch_size=64): + # Update optimizers learning rate + optimizers = [self.actor.optimizer, self.critic.optimizer] + if self.ent_coef_optimizer is not None: + optimizers += [self.ent_coef_optimizer] + + self._update_learning_rate(optimizers) + for gradient_step in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size) @@ -245,6 +253,7 @@ def learn(self, total_timesteps, callback=None, log_interval=4, self.num_timesteps += episode_timesteps episode_num += n_episodes timesteps_since_eval += episode_timesteps + self._update_current_progress(self.num_timesteps, total_timesteps) if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: if self.verbose > 1: diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index b85da81cd..4cf32f319 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -39,7 +39,7 @@ def q1_forward(self, obs, action): class TD3Policy(BasePolicy): def __init__(self, observation_space, action_space, - learning_rate=1e-3, net_arch=None, device='cpu', + learning_rate, net_arch=None, device='cpu', activation_fn=nn.ReLU): super(TD3Policy, self).__init__(observation_space, action_space, device) @@ -64,12 +64,12 @@ def _build(self, learning_rate): self.actor = self.make_actor() self.actor_target = self.make_actor() self.actor_target.load_state_dict(self.actor.state_dict()) - self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=learning_rate) + self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=learning_rate(1)) self.critic = self.make_critic() self.critic_target = self.make_critic() self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate) + self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1)) def make_actor(self): return Actor(**self.net_args).to(self.device) diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 7a13c0ac1..49cf16ed7 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -75,6 +75,7 @@ def __init__(self, policy, env, buffer_size=int(1e6), learning_rate=1e-3, self._setup_model() def _setup_model(self): + self._setup_learning_rate() obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] self.set_random_seed(self.seed) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) @@ -109,6 +110,8 @@ def predict(self, observation, state=None, mask=None, deterministic=True): return self.unscale_action(self.select_action(observation)) def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0.0): + # Update optimizer learning rate + self._update_learning_rate(self.critic.optimizer) for gradient_step in range(gradient_steps): # Sample replay buffer @@ -146,6 +149,8 @@ def train_critic(self, gradient_steps=1, batch_size=100, replay_data=None, tau=0 target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005, tau_critic=0.005, replay_data=None): + # Update optimizer learning rate + self._update_learning_rate(self.actor.optimizer) for gradient_step in range(gradient_steps): # Sample replay buffer @@ -208,6 +213,7 @@ def learn(self, total_timesteps, callback=None, log_interval=4, episode_num += n_episodes self.num_timesteps += episode_timesteps timesteps_since_eval += episode_timesteps + self._update_current_progress(self.num_timesteps, total_timesteps) if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: if self.verbose > 1: From df1e7aa0002b0c380d71d8978b63569a93061c53 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 28 Oct 2019 17:42:39 +0100 Subject: [PATCH 7/7] Add docstring --- torchy_baselines/common/base_class.py | 8 ++++++++ torchy_baselines/ppo/policies.py | 1 - 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 0904932f6..6adc45ce8 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -129,6 +129,14 @@ def _update_current_progress(self, num_timesteps, total_timesteps): self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps) def _update_learning_rate(self, optimizers): + """ + Update the optimizers learning rate using the current learning rate schedule + and the current progress (from 1 to 0). + + :param optimizers: ([th.optim.Optimizer] or Optimizer) An optimizer + or a list of optimizer. + """ + # Log the current learning rate logger.logkv("learning_rate", self.learning_rate(self._current_progress)) if not isinstance(optimizers, list): diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index c967962de..e9738584d 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -148,7 +148,6 @@ def _build(self, learning_rate): self.value_net: 1 }[module] module.apply(partial(self.init_weights, gain=gain)) - # TODO: support linear decay of the learning rate self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon) def forward(self, obs, deterministic=False):