-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow to set a device when loading a model #154
Changes from 6 commits
ae63d1c
196f220
33adda4
1799c6a
d85d88d
000f917
d7b3329
004307a
a4ad856
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
from stable_baselines3.common.base_class import BaseAlgorithm | ||
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox | ||
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl | ||
from stable_baselines3.common.utils import get_device | ||
from stable_baselines3.common.vec_env import DummyVecEnv | ||
|
||
MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG] | ||
|
@@ -70,21 +71,33 @@ def test_save_load(tmp_path, model_class): | |
# Check | ||
model.save(tmp_path / "test_save.zip") | ||
del model | ||
model = model_class.load(str(tmp_path / "test_save.zip"), env=env) | ||
|
||
# check if params are still the same after load | ||
new_params = model.policy.state_dict() | ||
# Check if the model loads as expected for every possible choice of device: | ||
for device in ["auto", "cpu", "cuda"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed that the git code comparison looks quite messy. I'm elaborating about the changes I've made here to ease the review process for you: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems that you are actually not testing that the device parameter was successfully used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've used in my test the utils.get_device() function (which is used inside the constructor as well) to determine the device. This way, if for example, the behavior of get_device will change, the test won't break. |
||
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device) | ||
|
||
# Check that all params are the same as before save load procedure now | ||
for key in params: | ||
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load." | ||
# check if the model was loaded to the correct device | ||
assert model.device.type == get_device(device).type | ||
leor-c marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert model.policy.device.type == get_device(device).type | ||
|
||
# check if model still selects the same actions | ||
new_selected_actions, _ = model.predict(observations, deterministic=True) | ||
assert np.allclose(selected_actions, new_selected_actions, 1e-4) | ||
# check if params are still the same after load | ||
new_params = model.policy.state_dict() | ||
|
||
# check if learn still works | ||
model.learn(total_timesteps=1000, eval_freq=500) | ||
# Check that all params are the same as before save load procedure now | ||
for key in params: | ||
assert new_params[key].device.type == get_device(device).type | ||
assert th.allclose( | ||
params[key].to("cpu"), new_params[key].to("cpu") | ||
), "Model parameters not the same after save and load." | ||
|
||
# check if model still selects the same actions | ||
new_selected_actions, _ = model.predict(observations, deterministic=True) | ||
assert np.allclose(selected_actions, new_selected_actions, 1e-4) | ||
|
||
# check if learn still works | ||
model.learn(total_timesteps=1000, eval_freq=500) | ||
|
||
del model | ||
|
||
# clear file from os | ||
os.remove(tmp_path / "test_save.zip") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the order of the arguments here makes sense? I'm not sure if I should have added the new argument last, for cases where users didn't use explicit keyword arguments.
On the other hand, I think it makes more sense to be in front of 'verbose'...