From 3e0ff29865ef79df2cb21a304844775372037eca Mon Sep 17 00:00:00 2001 From: belerico_t Date: Wed, 29 Nov 2023 15:43:20 +0100 Subject: [PATCH 1/2] torch>=2.0 --- pyproject.toml | 2 +- sheeprl/algos/ppo/ppo.py | 7 ++++--- sheeprl/algos/ppo/ppo_decoupled.py | 7 ++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1ddeb825..b0d92da0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "torchmetrics", "rich==13.5.*", "opencv-python==4.8.0.*", - "torch==2.0.*" + "torch>=2.0" ] dynamic = ["version"] diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 67ea2d8f..26f2ec14 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -305,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(dones, truncated) dones = torch.as_tensor(dones, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) - rewards = torch.as_tensor(rewards, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) + rewards = torch.as_tensor(rewards, dtype=torch.float64, device=device).view(cfg.env.num_envs, -1) # Update the step data step_data["dones"] = dones @@ -314,8 +314,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["logprobs"] = logprobs step_data["rewards"] = rewards if cfg.buffer.memmap: - step_data["returns"] = torch.zeros_like(rewards) - step_data["advantages"] = torch.zeros_like(rewards) + step_data["returns"] = torch.zeros_like(rewards, dtype=torch.float32) + step_data["advantages"] = torch.zeros_like(rewards, dtype=torch.float32) # Append data to buffer rb.add(step_data.unsqueeze(0)) @@ -359,6 +359,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Add returns and advantages to the buffer rb["returns"] = returns.float() rb["advantages"] = advantages.float() + rb["rewards"] = rb["rewards"].float() # Flatten the batch local_data = rb.buffer.view(-1) diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index c0331515..d67d0092 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -227,7 +227,7 @@ def player( rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(dones, truncated) dones = torch.as_tensor(dones, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) - rewards = torch.as_tensor(rewards, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) + rewards = torch.as_tensor(rewards, dtype=torch.float64, device=device).view(cfg.env.num_envs, -1) # Update the step data step_data["dones"] = dones @@ -236,8 +236,8 @@ def player( step_data["logprobs"] = logprobs step_data["rewards"] = rewards if cfg.buffer.memmap: - step_data["returns"] = torch.zeros_like(rewards) - step_data["advantages"] = torch.zeros_like(rewards) + step_data["returns"] = torch.zeros_like(rewards, dtype=torch.float32) + step_data["advantages"] = torch.zeros_like(rewards, dtype=torch.float32) # Append data to buffer rb.add(step_data.unsqueeze(0)) @@ -279,6 +279,7 @@ def player( # Add returns and advantages to the buffer rb["returns"] = returns.float() rb["advantages"] = advantages.float() + rb["rewards"] = rb["rewards"].float() # Flatten the batch local_data = rb.buffer.view(-1) From 16711f3641b33408d270f45529765dc63ba21b8e Mon Sep 17 00:00:00 2001 From: belerico_t Date: Wed, 29 Nov 2023 17:12:27 +0100 Subject: [PATCH 2/2] Fix memmapped rewards --- sheeprl/algos/ppo/ppo.py | 4 ++-- sheeprl/algos/ppo/ppo_decoupled.py | 4 ++-- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 26f2ec14..c22f6d34 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -305,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(dones, truncated) dones = torch.as_tensor(dones, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) - rewards = torch.as_tensor(rewards, dtype=torch.float64, device=device).view(cfg.env.num_envs, -1) + rewards = torch.as_tensor(rewards, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) # Update the step data step_data["dones"] = dones @@ -347,7 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) next_values = agent.module.get_value(normalized_obs) returns, advantages = gae( - rb["rewards"], + rb["rewards"].to(torch.float64), rb["values"], rb["dones"], next_values, diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index d67d0092..4742bdb0 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -227,7 +227,7 @@ def player( rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(dones, truncated) dones = torch.as_tensor(dones, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) - rewards = torch.as_tensor(rewards, dtype=torch.float64, device=device).view(cfg.env.num_envs, -1) + rewards = torch.as_tensor(rewards, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) # Update the step data step_data["dones"] = dones @@ -267,7 +267,7 @@ def player( normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) next_values = agent.get_value(normalized_obs) returns, advantages = gae( - rb["rewards"], + rb["rewards"].to(torch.float64), rb["values"], rb["dones"], next_values, diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 6adede10..00e116c4 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -376,7 +376,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rnn_out, _ = agent.module.rnn(torch.cat((feat, actions), dim=-1), states) next_values = agent.module.get_values(rnn_out) returns, advantages = gae( - rb["rewards"], + rb["rewards"].to(torch.float64), rb["values"], rb["dones"], next_values,