Skip to content

Commit

Permalink
Sharded DDP: test cpu_offload arg (#40)
Browse files Browse the repository at this point in the history
* Test CPU offload

* remove dead code
  • Loading branch information
sshleifer authored Feb 1, 2021
1 parent b35a28a commit 92c550b
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 119 deletions.
60 changes: 33 additions & 27 deletions fairscale/nn/data_parallel/shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,36 +250,42 @@ def load_local_state_dict(
def _pre_forward_init(self) -> None:
did_init = False
for p in self.params:
if not hasattr(p, "_full_param"):
did_init = True
assert p._is_sharded

p._fp32_shard = p.data

if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
if self.cpu_offload:
p._fp32_shard = p._fp32_shard.pin_memory()
p._fp16_shard = torch.zeros_like(
p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype,
)
free_storage_(p._fp16_shard)
p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype)
else:
p._fp16_shard = None # use _fp32_shard
p._full_param = p._fp32_shard.new_empty(p._orig_size)
if hasattr(p, "_full_param"):
continue
did_init = True
assert p._is_sharded

p._fp32_shard = p.data

if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32

if self.cpu_offload:
assert p._fp32_shard.device == torch.device("cpu")
p._fp32_shard = p._fp32_shard.pin_memory()

p._full_param = p._full_param.to(dtype=self.compute_dtype, device=self.compute_device)
free_storage_(p._full_param)
p._fp16_shard = torch.zeros_like(
p._fp32_shard,
device=self.compute_device,
dtype=self.compute_dtype,
)
free_storage_(p._fp16_shard)
p._full_param = torch.zeros(p._orig_size, device=self.compute_device, dtype=self.compute_dtype)
else:
p._fp16_shard = None # use _fp32_shard
p._full_param = p._fp32_shard.new_empty(p._orig_size)

p.data = p._fp32_shard
p._full_param = p._full_param.to(dtype=self.compute_dtype, device=self.compute_device)
free_storage_(p._full_param)

if self.move_grads_to_cpu:
if self.mixed_precision and not self.fp32_reduce_scatter:
grad_dtype = torch.float16
else:
grad_dtype = torch.float32
p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory()
p.data = p._fp32_shard

if self.move_grads_to_cpu:
if self.mixed_precision and not self.fp32_reduce_scatter:
grad_dtype = torch.float16
else:
grad_dtype = torch.float32
p._cpu_grad = torch.zeros_like(p.data, dtype=grad_dtype, device="cpu").pin_memory()

if did_init:
self._fp32_to_fp16_stream = torch.cuda.Stream()
Expand Down
6 changes: 5 additions & 1 deletion fairscale/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ def __init__(self, embed_dim: int, num_heads: int) -> None:
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads) # type: ignore
self.mlp = nn.Sequential(nn.Linear(embed_dim, embed_dim * 4), nn.GELU(), nn.Linear(embed_dim * 4, embed_dim),)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)

def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
x = inputs[0]
Expand Down
195 changes: 104 additions & 91 deletions tests/nn/data_parallel/test_shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from fairscale.nn.data_parallel import ShardParamsDataParallel
from fairscale.utils.testing import DeviceAndTypeCheckModule, objects_are_equal
from typing import Dict


class DistributedTest(unittest.TestCase):
Expand All @@ -30,99 +31,100 @@ def setUp(self):
if torch.cuda.device_count() < 2:
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")

@staticmethod
def _train_for_several_steps(model, num_steps, autocast):
model_device = next(model.parameters()).device
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
print(f'loss device: {loss.device}')
assert loss.dtype == torch.float32
loss.backward()
optim.step()
return loss.detach()


class TestMixedPrecision(DistributedTest):
def test_all_fp32(self):
spawn_and_init(
functools.partial(
self.__class__._test_dtypes,
{"mixed_precision": False},
False, # autocast enabled
torch.float32, # expected_input_dtype
torch.float32, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
),
world_size=2,
self._spawn_test_case(
{"mixed_precision": False},
False, # autocast enabled
torch.float32, # expected_input_dtype
torch.float32, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
)

def test_mixed_precision(self):
spawn_and_init(
functools.partial(
self.__class__._test_dtypes,
{"mixed_precision": True},
False, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float16, # expected_loss_dtype
torch.float16, # expected_reduce_dtype
),
world_size=2,
self._spawn_test_case(
{"mixed_precision": True},
False, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float16, # expected_loss_dtype
torch.float16, # expected_reduce_dtype
)

def test_mixed_precision_autocast(self):
spawn_and_init(
functools.partial(
self.__class__._test_dtypes,
{"mixed_precision": True},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float16, # expected_reduce_dtype
),
world_size=2,
"""If autocast enabled, loss should be fp32."""
self._spawn_test_case(
{"mixed_precision": True},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float16, # expected_reduce_dtype
)

def test_mixed_precision_autocast_fp32_compute(self):
spawn_and_init(
functools.partial(
self.__class__._test_dtypes,
{"mixed_precision": True, "compute_dtype": torch.float32},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float32, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
),
world_size=2,
self._spawn_test_case(
{"mixed_precision": True, "compute_dtype": torch.float32},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float32, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
)

def test_fp32_reduce_scatter(self):
spawn_and_init(
functools.partial(
self.__class__._test_dtypes,
{"mixed_precision": True, "fp32_reduce_scatter": True},
False, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float16, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
),
world_size=2,
self._spawn_test_case(
{"mixed_precision": True, "fp32_reduce_scatter": True},
False, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float16, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
)

def test_fp32_reduce_scatter_autocast(self):
spawn_and_init(
functools.partial(
self.__class__._test_dtypes,
{"mixed_precision": True, "fp32_reduce_scatter": True},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
),
world_size=2,
self._spawn_test_case(
{"mixed_precision": True, "fp32_reduce_scatter": True},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
)

def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype, world_size=2):
"""Call test_dtypes inside of torch.multiprocessing.spawn"""
fn = functools.partial(self._test_dtypes, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype)
spawn_and_init(fn, world_size=world_size)

@staticmethod
def _test_dtypes(cfg, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group):
def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group):
# Patch _reduce_scatter op to check the dtype of the reduction
orig_reduce_scatter = ShardParamsDataParallel._reduce_scatter

model = DeviceAndTypeCheckModule(
expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype,
expected_input_dtype=in_dtype,
expected_param_dtype=p_dtype,
expected_loss_dtype=loss_dtype,
)

def _reduce_scatter(self, tensor):
Expand Down Expand Up @@ -150,11 +152,24 @@ def test_transformer(self):
for config in itertools.product([True, False], repeat=len(keys)):
config = dict(zip(keys, config))
spawn_and_init(
functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config), world_size=2,
functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config),
world_size=2,
)

def test_cpu_offload_and_cpu_grads(self):
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True}
test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False)
spawn_and_init(test_fn)

def test_cpu_offload_and_cuda_grads(self):
# If grads are on gpu, but model and optimizer are on cpu, backward breaks.
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False}
with self.assertRaises(Exception): # RuntimeError inside spawn
test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False)
spawn_and_init(test_fn)

@classmethod
def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3):
def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3, use_cuda=True):
if config["mixed_precision"]:
autocast = True
# Force the compute dtype to be torch.float32 so that we get
Expand All @@ -173,7 +188,11 @@ def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3):
ref_state_dict = model.module.state_dict()

# Confirm we get the same behavior using ShardParamsDataParallel.
model = ShardParamsDataParallel(model_cls(), group, **config).cuda()
model = ShardParamsDataParallel(model_cls(), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast)
shard_state_dict = model.state_dict()

Expand All @@ -183,31 +202,22 @@ def _test_identical_outputs(cls, model_cls, config, rank, group, num_steps=3):
except (AssertionError, RuntimeError) as e:
raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}")

@classmethod
def _train_for_several_steps(cls, model, num_steps, autocast):
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
device = next(model.parameters()).device
input = model.module.get_input(device)
output = model(*input)
loss = model.module.get_loss(input, output)
assert loss.dtype == torch.float32
loss.backward()
optim.step()
return loss.detach()


class TransformerWithSharedParams(nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0) # keep everything deterministic
self.embed_tokens = nn.Embedding(50, 16)
d_model = 16
d_vocab = 32
self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer(
d_model=16, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=32, dropout=0.1,
d_model=d_model,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=8,
dropout=0.1,
)
self.output_proj = nn.Linear(16, 50)
self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight

Expand All @@ -228,7 +238,7 @@ def get_loss(self, input, output):
return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum")


def spawn_and_init(fn, world_size, args=None):
def spawn_and_init(fn, world_size=2, args=None):
if args is None:
args = ()
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
Expand All @@ -242,7 +252,10 @@ def spawn_and_init(fn, world_size, args=None):

def distributed_init(rank, world_size, tmp_file):
torch.distributed.init_process_group(
backend="nccl", init_method="file://{}".format(tmp_file), world_size=world_size, rank=rank,
backend="nccl",
init_method="file://{}".format(tmp_file),
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)

Expand Down

0 comments on commit 92c550b

Please sign in to comment.