Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to set a device when loading a model #154

Merged
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ New Features:
^^^^^^^^^^^^^
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
- Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio)
- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -399,4 +400,4 @@ And all the contributors:
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273
@diditforlulz273 @liorcohen5
9 changes: 6 additions & 3 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,16 +316,19 @@ def predict(
return self.policy.predict(observation, state, mask, deterministic)

@classmethod
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm":
def load(
cls, load_path: str, env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs
) -> "BaseAlgorithm":
"""
Load the model from a zip-file

:param load_path: the location of the saved data
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: (Union[th.device, str]) Device on which the code should run.
:param kwargs: extra arguments to change the model when loading
"""
data, params, tensors = load_from_zip_file(load_path)
data, params, tensors = load_from_zip_file(load_path, device=device)

if "policy_kwargs" in data:
for arg_to_remove in ["device"]:
Expand All @@ -352,7 +355,7 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAl
model = cls(
policy=data["policy_class"],
env=env,
device="auto",
device=device,
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
)

Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0)
def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
device: Union[th.device, str] = "auto",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the order of the arguments here makes sense? I'm not sure if I should have added the new argument last, for cases where users didn't use explicit keyword arguments.
On the other hand, I think it makes more sense to be in front of 'verbose'...

verbose=0,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
"""
Expand All @@ -360,13 +361,14 @@ def load_from_zip_file(
:param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:param device: (Union[th.device, str]) Device on which the code should run.
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
and dict of extra tensors
"""
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")

# set device to cpu if cuda is not available
device = get_device()
device = get_device(device=device)

# Open the zip archive and load data
try:
Expand Down
35 changes: 24 additions & 11 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv

MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
Expand Down Expand Up @@ -70,21 +71,33 @@ def test_save_load(tmp_path, model_class):
# Check
model.save(tmp_path / "test_save.zip")
del model
model = model_class.load(str(tmp_path / "test_save.zip"), env=env)

# check if params are still the same after load
new_params = model.policy.state_dict()
# Check if the model loads as expected for every possible choice of device:
for device in ["auto", "cpu", "cuda"]:
Copy link
Contributor Author

@leor-c leor-c Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the git code comparison looks quite messy. I'm elaborating about the changes I've made here to ease the review process for you:
The actual change that I made here is the added 'for' loop that goes over all possible devices, and at each iteration the device parameter is passed to the call of 'load' (line 76). At the end of each iteration I delete the model (line 92) so it can be loaded cleanly at the next iteration.
Everything else is the same as before, i.e., I've used the exact same test (inside the new 'for' loop) to ensure proper loading and tested with all possible values of the new argument 'device'.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that you are actually not testing that the device parameter was successfully used.
Also, you should skip the cuda device if no GPU is available

Copy link
Contributor Author

@leor-c leor-c Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. You're right. I will work on improving the test.
  2. What should be the expected behavior when a user uses "device='cuda'" on a machine with no GPU?
    I noticed that the c'tor defaults to using the CPU in that case without notifying the user.
    Anyway, I think the test should include all possible inputs while verifying that the outcome matches your expectations. Do you agree?

Copy link
Contributor Author

@leor-c leor-c Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've used in my test the utils.get_device() function (which is used inside the constructor as well) to determine the device. This way, if for example, the behavior of get_device will change, the test won't break.

model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)

# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
# check if the model was loaded to the correct device
assert model.device.type == get_device(device).type
leor-c marked this conversation as resolved.
Show resolved Hide resolved
assert model.policy.device.type == get_device(device).type

# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if params are still the same after load
new_params = model.policy.state_dict()

# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)
# Check that all params are the same as before save load procedure now
for key in params:
assert new_params[key].device.type == get_device(device).type
assert th.allclose(
params[key].to("cpu"), new_params[key].to("cpu")
), "Model parameters not the same after save and load."

# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)

# check if learn still works
model.learn(total_timesteps=1000, eval_freq=500)

del model

# clear file from os
os.remove(tmp_path / "test_save.zip")
Expand Down