Skip to content

Commit

Permalink
Test save/load state_dict V2 (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer authored Feb 4, 2021
1 parent e139857 commit 6f153b0
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 6 deletions.
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ pytest-cov == 2.10.0
pytest-mpi == 0.4
pytest-timeout == 1.4.2
mpi4py == 3.0.3
remote-pdb >= 2.1.0
parameterized >= 0.8.1
114 changes: 108 additions & 6 deletions tests/nn/data_parallel/test_shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from fairscale.nn.data_parallel import ShardParamsDataParallel
from fairscale.utils.testing import DeviceAndTypeCheckModule, get_cycles_per_ms, objects_are_equal

# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4


class DistributedTest(unittest.TestCase):
def setUp(self):
Expand All @@ -36,6 +38,7 @@ def setUp(self):
def _train_for_several_steps(model, num_steps, autocast):
model_device = next(model.parameters()).device
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
# If you set this higher implem differs from ddp in the 5th decimal place
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
Expand Down Expand Up @@ -157,9 +160,7 @@ def test_transformer(self):
keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters"]
for config in itertools.product([True, False], repeat=len(keys)):
config = dict(zip(keys, config))
spawn_and_init(
functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), world_size=2,
)
spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config),)

def test_cpu_offload_and_cpu_grads(self):
for move_grads_choice in (True, None):
Expand Down Expand Up @@ -233,6 +234,109 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3
raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}")


class TestSaveLoadLocalStateDict(DistributedTest):
def test_load_local_state_dict(self):
test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False})
spawn_and_init(test_fn)

def test_local_state_dict_flatten_params_breaks(self):
test_fn_broken = functools.partial(self._load_local_and_train, {"flatten_parameters": True})
with self.assertRaises(Exception):
spawn_and_init(test_fn_broken)
# RuntimeError: Traceback [1]
# [1] https://gist.github.com/sshleifer/612d8eb02dbbf357d6133b2700e02f5e

def test_local_state_dict_odd_vocab_shape_breaks(self):
test_fn = functools.partial(self._load_local_and_train, {"flatten_parameters": False}, d_model=16, d_vocab=37)
with self.assertRaises(Exception):
spawn_and_init(test_fn)

@classmethod
def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = ShardParamsDataParallel(
TransformerWithSharedParams(d_model=d_model, d_vocab=d_vocab), group, **config
).cuda()
state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
model.load_local_state_dict(state_1)
state_1_weight = state_1["embed_tokens.weight"]

# This weight will be sharded since we access module.state_dict directly
state_1_module_weight = model.module.state_dict()["embed_tokens.weight"]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 4, False)

state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
model.load_local_state_dict(state_2)

assert state_1.keys() == state_2.keys()

# Assert that parameters were updated since before training
unchanged = []
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all():
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")


class TestSaveLoadStateDict(DistributedTest):
def test_calling_state_dict_twice_breaks(self):
test_fn = functools.partial(self._test_calling_state_dict_twice_breaks, {"flatten_parameters": False})
spawn_and_init(test_fn)

@classmethod
def _test_calling_state_dict_twice_breaks(self, config, rank, group):
ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config)
self._train_for_several_steps(ddp_model, 1, False)
ddp_model.state_dict() # Succeeds
try:
ddp_model.state_dict()
assert False, "Second state_dict call succeeded"
except Exception:
pass

def test_state_dict_after_forward(self):
test_fn = functools.partial(self._test_module_state_dict, {"flatten_parameters": False})
spawn_and_init(test_fn)

@classmethod
def _test_module_state_dict(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
try:
ddp_model.state_dict()
assert False, "Calling state_dict before forward succeeded"
except Exception:
pass
cls._train_for_several_steps(ddp_model, 2, False)
state_1 = ddp_model.state_dict()
# You must make a new ShardParamsDataParallel instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams()
unwrapped_model.load_state_dict(state_1)
new_ddp_model = ShardParamsDataParallel(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, False)
try:
ddp_model.load_state_dict(new_ddp_model.state_dict())
assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except Exception:
pass


def get_sharded_model():
sharded_model = ShardParamsDataParallel(
nn.Sequential(
nn.Linear(8, 100),
ShardParamsDataParallel(nn.Linear(100, 100)),
ShardParamsDataParallel(nn.Linear(100, 100)),
nn.Linear(100, 8),
)
)
return sharded_model


class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used
Expand Down Expand Up @@ -279,11 +383,9 @@ def _test_register_functions_called(self, rank, group, cuda_first=False):


class TransformerWithSharedParams(nn.Module):
def __init__(self, *args, **kwargs):
def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs):
super().__init__()
torch.manual_seed(0) # keep everything deterministic
d_model = 16
d_vocab = 32
self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer(
d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1,
Expand Down

0 comments on commit 6f153b0

Please sign in to comment.