Skip to content

Commit

Permalink
Misc comments from @anj-s (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott authored Feb 1, 2021
1 parent 92c550b commit 5bb212f
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions fairscale/nn/data_parallel/shard_params_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ShardParamsDataParallel(nn.Module):
Usage::
sharded_module = ShardParamsDistributedWrapper(my_module)
sharded_module = ShardParamsDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
loss = x.sum()
Expand All @@ -43,11 +43,11 @@ class ShardParamsDataParallel(nn.Module):
reduce memory usage and to improve training speed by distributing the
unsharding (all-gather) across the forward pass. For example::
sharded_model = ShardParamsDistributedWrapper(
sharded_model = ShardParamsDataParallel(
nn.Sequential(
nn.Linear(5, 100),
ShardParamsDistributedWrapper(nn.Linear(100, 100)),
ShardParamsDistributedWrapper(nn.Linear(100, 100)),
ShardParamsDataParallel(nn.Linear(100, 100)),
ShardParamsDataParallel(nn.Linear(100, 100)),
nn.Linear(100, 5),
)
)
Expand Down Expand Up @@ -186,7 +186,7 @@ def __getstate__(self) -> Dict[str, str]:
state = copy.copy(self.__dict__)
state["orig_sizes"] = [p._orig_size for p in self.params]
if state["process_group"] is not None:
state["process_group"] = "MISSING" # raise error if used
state["process_group"] = "MISSING" # process_group isn't pickleable
if "_fp32_to_fp16_stream" in state:
del state["_fp32_to_fp16_stream"]
return state
Expand Down Expand Up @@ -216,7 +216,7 @@ def state_dict(self, *args, **kwargs): # type: ignore
torch.cuda.synchronize()
self._rebuild_full_params()
# We don't free the params after generating the state dict, since
# freeing is done in-place (via the Storagee) and would corrupt the
# freeing is done in-place (via the Storage) and would corrupt the
# returned state dict.
return self.module.state_dict(*args, **kwargs)

Expand All @@ -225,7 +225,7 @@ def local_state_dict(self, *args, **kwargs): # type: ignore
"""
Returns the local (sharded) state of the module. Parameters are sharded,
so the resulting state_dict can only be loaded after the Module has been
wrapped with ShardParamsDistributedWrapper.
wrapped with ShardParamsDataParallel.
"""
if self.flatten_parameters:
kwargs["unflatten_params"] = False
Expand Down Expand Up @@ -313,26 +313,28 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# initialized with the correct dtype and size.
self._use_fp32_param_shard()

if not torch.is_grad_enabled():
return outputs

# Register pre-backward hook to run before the wrapped module's backward.
if torch.is_grad_enabled():
pre_backward_hook_has_run = [False]
pre_backward_hook_has_run = [False]

def _pre_backward_hook(*unused: Any) -> None:
if pre_backward_hook_has_run[0]:
return # only run once
pre_backward_hook_has_run[0] = True
def _pre_backward_hook(*unused: Any) -> None:
if pre_backward_hook_has_run[0]:
return # only run once
pre_backward_hook_has_run[0] = True

if self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()
if self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()

def _register_hook(t: torch.Tensor) -> torch.Tensor:
t.register_hook(_pre_backward_hook)
return t
def _register_hook(t: torch.Tensor) -> torch.Tensor:
t.register_hook(_pre_backward_hook)
return t

# Attach hooks to Tensor outputs.
outputs = apply_to_tensors(_register_hook, outputs)
# Attach hooks to Tensor outputs.
outputs = apply_to_tensors(_register_hook, outputs)

return outputs

Expand All @@ -354,7 +356,7 @@ def _post_backward_hook(self, param: torch.nn.Parameter, *unused: Any) -> None:
if param.grad is None:
return
if param.grad.requires_grad:
raise RuntimeError("ShardParamsDistributedWrapper only works with gradients that don't require grad")
raise RuntimeError("ShardParamsDataParallel only works with gradients that don't require grad")

# Free full params and switch to FP32 shard after backward.
self._free_full_params([param])
Expand Down Expand Up @@ -501,4 +503,3 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
"""Allocate storage for a tensor."""
assert data.storage().size() == 0
data.storage().resize_(size.numel())
# data.set_(size=size)

0 comments on commit 5bb212f

Please sign in to comment.