Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180

Draft
wants to merge 27 commits into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d1102ce
use torch.no_grad() to avoid calling cat() during FSDP backward excep…
chrisxcai Apr 29, 2024
9a22628
remove logging
chrisxcai Apr 29, 2024
f787532
logging
chrisxcai Apr 30, 2024
3429f33
logging
chrisxcai May 1, 2024
4b5abe2
use new field to accumulate per-parameter grads in fp32 and copy into…
chrisxcai May 2, 2024
c97bfd9
clean up accumulated fp32 grads between data batches
chrisxcai May 2, 2024
d2a88b7
logging
chrisxcai May 2, 2024
901fb86
logging
chrisxcai May 6, 2024
ad40f24
return grad in post_backward_hook()
chrisxcai May 8, 2024
14499fe
correct param_index
chrisxcai May 9, 2024
ad7aa1f
logging
chrisxcai May 9, 2024
b835770
add torch.testing.assert_allclose() to compare baseline and new grads
chrisxcai May 9, 2024
d689f38
logging
chrisxcai May 9, 2024
e8df583
logging
chrisxcai May 13, 2024
5926a79
honor optimize_backward_concat flag
chrisxcai May 15, 2024
5d08aa3
documentation
chrisxcai May 15, 2024
c91cb72
update documentation
chrisxcai May 15, 2024
fd3f3fc
update documentation
chrisxcai May 15, 2024
7678503
use grad instead of grad.data
chrisxcai May 15, 2024
c55a0d1
clean up
chrisxcai May 15, 2024
688b902
Added reshard hook for frozen params in backward
awgu Jan 12, 2024
a3ff5c4
Avoid calling _free_fp16_param_shard() too early with PR 1159
jiecaoyu Feb 21, 2024
9d0e41e
Added requires_grad check for params_with_grad method (#1171)
whbldhwj Mar 25, 2024
e43a22f
Changed to only run reshard hook if all gradients computed (#1166)
awgu Apr 1, 2024
f039a3a
Add cast input argument (#1175)
whbldhwj Apr 5, 2024
5299982
honor optimize_backward_concat flag
chrisxcai May 15, 2024
b5e138f
use grad instead of grad.data
chrisxcai May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ class FullyShardedDataParallel(nn.Module):
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
optimize_backward_concat (bool):
If True, only let backward pass propagate to self.params, which will
invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync
is True (e.g. last microbatch)
NOTE: this likely will incur more GPU memory usage
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why there will be more GPU memory usage?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @awgu, currently by testing results it shows the GPU memory overhead could be non-trivial (20% of 80G), we will follow up on reducing the memory usage
Screenshot 2024-05-15 at 10 39 19 AM
Screenshot 2024-05-15 at 10 40 18 AM

"""

def __init__(
Expand Down Expand Up @@ -369,6 +374,7 @@ def __init__(
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
should_validate_process_group: bool = True,
optimize_backward_concat: bool = False,
):
try:
import torch._C
Expand Down Expand Up @@ -493,8 +499,12 @@ def __init__(
param_name_groups = [param_names]
del param_names

self.optimize_backward_concat = optimize_backward_concat
if self.optimize_backward_concat:
assert self.fp32_reduce_scatter, f"{optimize_backward_concat=} requires self.fp32_reduce_scatter=True"

self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory, optimize_backward_concat=self.optimize_backward_concat,
)
del module # free original module in case it helps garbage collection

Expand Down Expand Up @@ -851,6 +861,7 @@ def extra_repr(self) -> str:
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
f"optimize_backward_concat={self.optimize_backward_concat}"
)
return repr

Expand Down Expand Up @@ -1099,12 +1110,20 @@ def no_sync(self) -> Generator:
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
if self.optimize_backward_concat:
# Set the flag on the wrapped FlattenParamsWrapper module as well,
# so that FlattenParamsWrapper could accumulate grads at corresponding
# leaf nodes without triggering concat operations when gradient
# synchronization is not needed.
m._fsdp_wrapped_module._require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag
if self.optimize_backward_concat:
m._fsdp_wrapped_module._require_backward_grad_sync = old_flag

@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
Expand Down Expand Up @@ -1458,7 +1477,6 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()

outputs = self.module(*args, **kwargs)

if self.reshard_after_forward:
Expand Down Expand Up @@ -1716,10 +1734,17 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
self._use_fp32_param_shard([param])

if self.fp32_reduce_scatter:
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
if self.optimize_backward_concat:
# Flatten and concat the accumulated fp32 grads
# and assign them to param.unsharded_main_grad
param.unsharded_main_grad = torch.cat([grad.flatten() for grad in self._fsdp_wrapped_module.fp32_grads])
# Clean up accumulated grads between data batches
self._fsdp_wrapped_module.fp32_grads = []
else:
param.unsharded_main_grad.add_(param.grad.data)
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
else:
param.unsharded_main_grad.add_(param.grad.data)

param.grad = None

Expand Down Expand Up @@ -1852,7 +1877,16 @@ def _wait_for_post_backward(self) -> None:
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in self.params]):
self.assert_state(TrainingState.BACKWARD_POST)
if self.optimize_backward_concat:
# If self.optimize_backward_concat==True, FSDP backward should
# only be triggered (which will invoke concat())
# when self._fsdp_wrapped_module._require_backward_grad_sync = True
if self._fsdp_wrapped_module._require_backward_grad_sync:
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)
else:
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)

Expand Down Expand Up @@ -1929,7 +1963,16 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.params]):
m.assert_state(TrainingState.BACKWARD_POST)
if self.optimize_backward_concat:
# If self.optimize_backward_concat==True, FSDP backward should
# only be triggered (which will invoke concat())
# when self._fsdp_wrapped_module._require_backward_grad_sync = True
if self._fsdp_wrapped_module._require_backward_grad_sync:
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
else:
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
else:
Expand Down
65 changes: 61 additions & 4 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Licensed under the MIT License.

from contextlib import contextmanager
import functools
from itertools import chain
import tempfile
import typing
Expand Down Expand Up @@ -37,7 +38,6 @@
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401


class FlatParameter(nn.Parameter):
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
self._param_infos: List[Tuple[str, nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = []

def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
def get_param_views(self, require_backward_grad_sync, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum(
Expand All @@ -90,7 +90,9 @@ def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Te
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
)
return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes))

split_outputs = data.split(self._param_numels)
return (t.view(s) for (t, s) in zip(split_outputs, self._param_shapes))

def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
Expand Down Expand Up @@ -148,6 +150,11 @@ class FlattenParamsWrapper(nn.Module):
flat_param_names (Optional[List[str]]):
originally, give each flat_param a unique name. Note a "flat_param_"
prefix will be added to those names.
optimize_backward_concat (bool):
If True, only let backward pass propagate to the corresponding FSDP.params, which will
invoke the FSDP._post_backward_hook() and concat() op, when _require_backward_grad_sync
is True (e.g. last microbatch)
NOTE: this likely will incur more GPU memory usage
"""

def __init__(
Expand All @@ -157,10 +164,18 @@ def __init__(
flat_param_names: Optional[List[str]] = None,
ssd_offload: bool = False,
ssd_directory: str = "",
optimize_backward_concat: bool = False,
):
super().__init__()
self._fpw_module = module
self.is_flattened = False
self.optimize_backward_concat = optimize_backward_concat
# If optimize_backward_concat == True, used to propagate the
# corresponding FSDP modules's _require_backward_grad_sync flag
self._require_backward_grad_sync = True
# If optimize_backward_concat == True, used to accumulate the
# fp32 gradients for the flattened parameters
self.fp32_grads = []

# Handle param_list being None.
if param_list is None:
Expand Down Expand Up @@ -364,18 +379,60 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No
delattr(self, n)
self.flat_params = []

# The post backward hook used to accumulate fp32 gradients
def _grad_accumulation_hook(
self,
grad,
param_index,
):
if self.fp32_grads[param_index] is None:
self.fp32_grads[param_index] = grad.to(torch.float32)
else:
self.fp32_grads[param_index].add_(grad.data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think grad.data can just be grad (save one aten.detach call)

return grad

def _unflatten_params_as_views(self) -> None:
"""Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""
assert self.is_flattened
ps = self.get_param_views()
if self.optimize_backward_concat:
# If self._require_backward_grad_sync == True (e.g. last microbatch),
# we use the original flat_params as autograd leaf nodes and backward
# pass should propagate all the way back to FSDP module and thus invoke
# FSDP post_backward() hook and concat() op
# Otherwise we stop the backward propagation before FSDP module to avoid
# invoking concat() and store the accumulated fp32 grads
if self._require_backward_grad_sync:
ps = self.get_param_views()
else:
with torch.no_grad():
ps = self.get_param_views()
else:
ps = self.get_param_views()

param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
setattr(m, n, p) # This will set as plain attr
# The param_index of p used to accumulate the correspnding
# gradients in self.fp32_grads
param_index = len(param_views)
if self.optimize_backward_concat:
# Register post backward hook to accumulate the gradients
# in self.fp32_grads
p.register_hook(
functools.partial(
self._grad_accumulation_hook,
param_index=param_index
)
)
param_views.append(p)

if self.optimize_backward_concat and len(self.fp32_grads) == 0:
# Allocate self.fp32_grads at the beginning of each data batch's forward()
self.fp32_grads = [None] * len(param_views)

# Save param views for easy access if anyone still wants to access
# parameters of the module.
setattr(self._fpw_module, "_unflattened_param_views", param_views)
Expand Down
Loading