Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get/set parameters and review of saving and loading #138

Merged
merged 29 commits into from
Sep 24, 2020

Conversation

Miffyli
Copy link
Collaborator

@Miffyli Miffyli commented Aug 10, 2020

Closes #116

closes #70

Review over saving and loading of models, as well as (possibly) implementing get_parameters and set_parameters akin to stable-baselines2.

Changelog

  • Rename BaseClass get_torch_variables -> _get_torch_save_params, and include docstring only in original implementation.
  • Rename BaseClass excluded_save_params -> _excluded_save_params, and include docstring only in original implementation.
  • Reorganize functions for clarity (save/load functions are closer to each other, private functions first in classes)
  • Clarify documentation on different objects stored (data, params and tensors)
  • Rename saved items tensors to pytorch_variables for clarity.
  • Simplify save_to_zip_file by combining duplicate code.
  • Add get/set_parameters, which use _get_torch_save_params to gather/set parameters of different objects.

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Suggestions and TODOs

  • Document the use of recursive_set/getattr in e.g. get_torch_variables
  • Clarify documentation of "excluded_save_params". Main purpose is to avoid storing pytorch stuff with standard pickle.
  • Merge (if possible) "tensors" with "data" or "params", or at least document well the differences between the three. Merging would simplify the code all-around.
  • Consider renaming "tensors" to e.g. "torch_variables" or "extra_data". Now it sounds like these could be the PyTorch parameters generally.
  • Remove "cf base class" docstrings and replace with the full docstring (or remove completely).
  • Rename "get_torch_variables" function to something more private. It sounds like it returns the agent parameters, but in reality it returns library-specific stuff.
  • Add get/set_parameters.

Things to think about

  • Since PyTorch stores variables with pickle, and since v1.6 they store them as zip-file with pickles, the benefits of zip-files here disappears. It gets even bit silly, since we have a zip-file of zip-files of pickles for parameters. Maybe consider storing the parameters with np.savez as with TF?
    • Update: Leave this for later.

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

@araffin
Copy link
Member

araffin commented Aug 11, 2020

Document the use of recursive_set/getattr in e.g. get_torch_variables

Clarify documentation of "excluded_save_params". Main purpose is to avoid storing pytorch stuff with standard pickle.

Main purpose is to avoid saving things that cannot or should not be pickled (e.g. the environment, the replay buffer, pytorch variables, ...).

Merge (if possible) "tensors" with "data" or "params", or at least document well the differences between the three. Merging would simplify the code all-around.

would be nice to reduce the number of variables but not sure if possible. For info, that's mainly @Artemis-Skade who worked on the saving/loading part. We also had some trouble with saving on GPU, loading on cpu...

Consider renaming "tensors" to e.g. "torch_variables" or "extra_data". Now it sounds like these could be the PyTorch parameters generally.

Yes, anything that can be saved with th.save which is not in the policy. I think for now, this is only the entropy temperature in SAC.

Remove "cf base class" docstrings and replace with the full docstring (or remove completely).

Yep, not sure about that, you choose ;)

Rename "get_torch_variables" function to something more private. It sounds like it returns the agent parameters, but in reality it returns library-specific stuff.

agree ;)

Since PyTorch stores variables with pickle, and since v1.6 they store them as zip-file with pickles, the benefits of zip-files here disappears. It gets even bit silly, since we have a zip-file of zip-files of pickles for parameters. Maybe consider storing the parameters with np.savez as with TF?

I did not know that... but that would mean breaking changes for pytorch users of versions 1.4.x and 1.5.x
Also, I don't mind to have zip of zip.

@araffin araffin self-requested a review August 22, 2020 10:10
@araffin
Copy link
Member

araffin commented Aug 24, 2020

For get/set parameters, I think we maybe don't need them as we can access model.policy easily, and therefore have access to model.policy.named_parameters(), model.policy.state_dict() and model.policy.load_state_dict().

if tensors is not None:
for name in tensors:
recursive_setattr(model, name, tensors[name])
# py other pytorch variables back in place
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

@Miffyli
Copy link
Collaborator Author

Miffyli commented Aug 29, 2020

For get/set parameters, I think we maybe don't need them as we can access model.policy easily, and therefore have access to model.policy.named_parameters(), model.policy.state_dict() and model.policy.load_state_dict().

Same thoughts occurred to me: state_dict is already very convenient to use. Two things worry me though: Should we allow loading parameters from numpy arrays directly (load_state_dict does not support this, as far as I am aware), and how to work with agents with multiple networks (e.g. value functions)? I think we should store and load policy and associated q/v-functions together, which could be wrapped under get/set-parameters.

@araffin
Copy link
Member

araffin commented Aug 31, 2020

: Should we allow loading parameters from numpy arrays directly

Doing params = th.as_tensor(params) for each parameter does not seem to be a big deal. So I would keep it as is.

and how to work with agents with multiple networks (e.g. value functions)? I think we should store and load policy and associated q/v-functions together,

Hmm, model.policy.state_dict() already contain all networks. Or you meant something else?

@Miffyli
Copy link
Collaborator Author

Miffyli commented Sep 5, 2020

Doing params = th.as_tensor(params) for each parameter does not seem to be a big deal. So I would keep it as is.

I will be that guy and argue it would be equally easy to include this in get_parameters :). This way the behaviour matches older stable_baselines and users do not have to play with torch to use sb3.

Hmm, model.policy.state_dict() already contain all networks. Or you meant something else?

I think the get/set_params should follow the state dicts of objects specified by get_torch_variables, because e.g. SAC has more stuff than just model.policy there:

state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]

@araffin
Copy link
Member

araffin commented Sep 15, 2020

I will be that guy and argue it would be equally easy to include this in get_parameters :). This way the behaviour matches older stable_baselines and users do not have to play with torch to use sb3.

Fair enough ;)

I think the get/set_params should follow the state dicts of objects specified by get_torch_variables, because e.g. SAC has more stuff than just model.policy there:

Oh, true. But as for now, this is only valid for SAC and it corresponds to a very special variable (entropy temperature).

@Miffyli
Copy link
Collaborator Author

Miffyli commented Sep 22, 2020

Oh, true. But as for now, this is only valid for SAC and it corresponds to a very special variable (entropy temperature).

Hmm how about the other nn.Modules included in this list? In above SAC example it has policy, actor and critic, all with potentially different parameters. OnPolicyAlgorithms have policy and its optimizer (granted, including all parameters of policy would likely include optimizer too):

def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
cf base class
"""
state_dicts = ["policy", "policy.optimizer"]

@Miffyli
Copy link
Collaborator Author

Miffyli commented Sep 23, 2020

Added the (long delayed...) functions. If it looks ok, I will check over docs (something seems to be failing there) and update where necessary.

for name in params:
attr = recursive_getattr(model, name)
attr.load_state_dict(params[name])
model.set_parameters(params, exact_match=True, device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why exact_match is hardcoded?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using exact_match=False would mean some parameters were missing in the saved model file which should not happen unless someone modifies the file. Laying out the hardcoded parameter like this is to signal that we want to make sure every parameter is updated as it was saved, and that nothing is missing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fair enough.
What can happen is also that parameters are renamed between versions (it happened to me after refactoring the continuous critic, the name of the parameters were not the same)

@araffin
Copy link
Member

araffin commented Sep 23, 2020

minor changes required otherwise LGTM ;)

Maybe one thing to add to the saved model: the SB3 version (could be checked later if needed, but it is good to have it anyway)

@araffin
Copy link
Member

araffin commented Sep 23, 2020

Is it ready to review now?

@Miffyli Miffyli marked this pull request as ready for review September 23, 2020 20:44
"""
if seed is None:
return
set_random_seed(seed, using_cuda=self.device == th.device("cuda"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since #154 I realized that we should check self.device.type, should we fix it here or in a separate PR?

Copy link
Collaborator Author

@Miffyli Miffyli Sep 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like something that could be included here, although would you be able to add it if it is a quick thing? I am reading through the PR but only slowly digesting how it works.

Also any ideas what could be causing the flake8 linting error exactly? My flake8 is not catching it :/ Fixed.

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM =) (apart from minor remark)

@araffin
Copy link
Member

araffin commented Sep 23, 2020

I think the get/set_params should follow the state dicts of objects specified by get_torch_variables, because e.g. SAC has more stuff than just model.policy there:

So at the end, get/set_parameters take care of everything but the custom pytorch variables (here only the log_ent_coef) which is set/saved separately?

@Miffyli
Copy link
Collaborator Author

Miffyli commented Sep 23, 2020

So at the end, get/set_parameters take care of everything but the custom pytorch variables (here only the log_ent_coef) which is set/saved separately?

Yeap, with the reasoning that these are rarely used, although if we keep using torch_variables it should also be included. This is because tensors lack any state_dict operations, but of course could be added with a check for standard tensors. I could do it now, too.

Also flake8 thing pointed out a error I still need to fix, so hold the brakes on merging :)

@araffin araffin mentioned this pull request Sep 24, 2020
@araffin araffin merged commit 9855486 into master Sep 24, 2020
@araffin araffin deleted the review/save_load_params branch September 24, 2020 12:28
@araffin araffin mentioned this pull request Sep 24, 2020
42 tasks
@araffin
Copy link
Member

araffin commented Sep 24, 2020

@Miffyli a bit late... this introduces breaking change for previously saved policy, unless or file_path == "tensors.pth" is added line 407 of save utils.
I think I will do a quick PR for that.

@araffin araffin mentioned this pull request Oct 7, 2020
19 tasks
@araffin araffin mentioned this pull request Oct 23, 2020
16 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Allow to pass strict=False when loading model transferrable models
2 participants