Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset PPO policy.log_std when loading previously saved model #155

Closed
kyrwilliams opened this issue Sep 3, 2020 · 7 comments
Closed

Reset PPO policy.log_std when loading previously saved model #155

kyrwilliams opened this issue Sep 3, 2020 · 7 comments
Labels
question Further information is requested

Comments

@kyrwilliams
Copy link

When performing curriculum learning, being able to reset the ppo policy.log_std between training cycles would be nice. The following code will produce an error:

# Define RL model: randomized init
policy_kwargs = dict(log_std_init=0)
model = PPO(MlpPolicy, env=env1, policy_kwargs=policy_kwargs) 

# Learn and save
model.learn(total_timesteps=500000, tb_log_name='ppo')
model.save("ppo_model")

# Define RL model: preload network parameters
policy_kwargs = dict(log_std_init=-0.5)    
model = PPO.load(load_path="log_dir\ppo_1\ppo_model", env=env2, policy_kwargs=policy_kwargs)

# Learn and save again
model.learn(total_timesteps=500000, tb_log_name='ppo')
model.save("ppo_model")

Describe the bug
ValueError: The specified policy kwargs do not equal the stored policy kwargs.Stored kwargs

This is error is thrown because log_std_init differs between the two training cycles.

@Miffyli Miffyli added the enhancement New feature or request label Sep 3, 2020
@Miffyli
Copy link
Collaborator

Miffyli commented Sep 3, 2020

@araffin
Probably just me being dumdum again, but why exactly does the code enforce that provided policy_kwargs should match the saved when provided? I do not understand the argument's purpose if it can only be None or whatever is stored in file. Allowing it to change parameters would allow this kind of tinkering, but should still come with a big warning "using different policy_kwargs as stored! This may crash the code or result in undefined behaviour. Tread carefully!".

@kyrwilliams
#138 is working on simplifying getting/setting network parameters, but currently bit on hold as I am busy with deadlines.
Meanwhile you could try storing the network parameters with model.policy.state_dict() and then loading them as with any PyTorch model.

@kyrwilliams
Copy link
Author

kyrwilliams commented Sep 8, 2020

Thanks @Miffyli! So, a couple things:

(1) I attempted the pytorch save/load methods, manually reseting the log_std values with:

model.policy.log_std = th.nn.Parameter(th.tensor([-0.5, -0.5, -0.5], device='cuda:0', requires_grad=True))
model.learn(total_timesteps=500000, tb_log_name='ppo')

But unfortunately this just locked model.policy.log_std at -0.5 throughout the entire training.

(2) I found the following crude method DID work, since it uses the PPO class's .load method which apparently updates the model in a specific way:

model = PPO.load(load_path="log_dir\ppo_1\ppo_model", env=env2) # load saved model
model.policy.log_std=th.nn.Parameter(th.tensor([-0.5, -0.5, -0.5], device='cuda:0', requires_grad=True)) # reset log_std
model.save("ppo_model_temp") # save this adjusted model
model = PPO.load(load_path="ppo_model_temp", env=env2) # load the adjusted model
model.learn(total_timesteps=500000, tb_log_name='ppo') # learn

This approach successfully reset the log_std to -0.5 and allowed the optimizer to adjust it during training.

@araffin
Copy link
Member

araffin commented Sep 10, 2020

Hello,

Probably just me being dumdum again, but why exactly does the code enforce that provided policy_kwargs should match the saved when provided?

We do that because you really need to know what is happening when you change those arguments between saving and loading.
This prevent most users from unexpected behavior.
For user that wants to change those anyway, they can do so after loading as @kyrwilliams mentioned, but it requires a good understanding of each RL algorithm.
You can find an example with SAC (when using gSDE) here: https://github.com/DLR-RM/rl-baselines3-zoo/blob/e12a3019b57e11c876b6f875c5ff8c79a168c187/train.py#L569
Also changing log_std in the policy kwargs won't work as the value will be overwritten when loading the saved state dict.

(1) I attempted the pytorch save/load methods, manually reseting the log_std values with:

You may need to register that parameter too and also check if it is present in the optimizer (which I assume is not the case given the result).

@araffin araffin added question Further information is requested and removed enhancement New feature or request labels Sep 10, 2020
@Miffyli
Copy link
Collaborator

Miffyli commented Sep 12, 2020

We do that because you really need to know what is happening when you change those arguments between saving and loading.
This prevent most users from unexpected behavior.

Hmm ok. I think we could remove the parameter all-together in that case. I do not see why you would want to provide same parameters again which are already stored, and the only other option is to provide None. Another option would be to raise warnings when you change parameters, but then again doing modifications the way done here is not that difficult either.

@araffin
Copy link
Member

araffin commented Sep 15, 2020

I think we could remove the parameter all-together in that case.

?
not sure to get what you mean...
how do you do then when you have a custom policy architecture?

@Miffyli
Copy link
Collaborator

Miffyli commented Sep 15, 2020

Wouldn't that information (the custom policy pickled) be stored in the saved model as well? Or does it skip saving policy_kwargs to the file when using a custom policy?

@araffin
Copy link
Member

araffin commented Sep 20, 2020

Wouldn't that information (the custom policy pickled) be stored in the saved model as well?

Yes, it is. However, if you want to continue training (with the zoo for instance), and you tried multiple configurations, checking the kwargs allow you to know if the saved model has the network architecture that you expect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants