-
Notifications
You must be signed in to change notification settings - Fork 279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180
Draft
chrisxcai
wants to merge
27
commits into
ngoyal_changes_for_pp_fp8
Choose a base branch
from
chriscai_ngoyal_changes_for_pp_fp8_v1
base: ngoyal_changes_for_pp_fp8
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
d1102ce
use torch.no_grad() to avoid calling cat() during FSDP backward excep…
chrisxcai 9a22628
remove logging
chrisxcai f787532
logging
chrisxcai 3429f33
logging
chrisxcai 4b5abe2
use new field to accumulate per-parameter grads in fp32 and copy into…
chrisxcai c97bfd9
clean up accumulated fp32 grads between data batches
chrisxcai d2a88b7
logging
chrisxcai 901fb86
logging
chrisxcai ad40f24
return grad in post_backward_hook()
chrisxcai 14499fe
correct param_index
chrisxcai ad7aa1f
logging
chrisxcai b835770
add torch.testing.assert_allclose() to compare baseline and new grads
chrisxcai d689f38
logging
chrisxcai e8df583
logging
chrisxcai 5926a79
honor optimize_backward_concat flag
chrisxcai 5d08aa3
documentation
chrisxcai c91cb72
update documentation
chrisxcai fd3f3fc
update documentation
chrisxcai 7678503
use grad instead of grad.data
chrisxcai c55a0d1
clean up
chrisxcai 688b902
Added reshard hook for frozen params in backward
awgu a3ff5c4
Avoid calling _free_fp16_param_shard() too early with PR 1159
jiecaoyu 9d0e41e
Added requires_grad check for params_with_grad method (#1171)
whbldhwj e43a22f
Changed to only run reshard hook if all gradients computed (#1166)
awgu f039a3a
Add cast input argument (#1175)
whbldhwj 5299982
honor optimize_backward_concat flag
chrisxcai b5e138f
use grad instead of grad.data
chrisxcai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
# Licensed under the MIT License. | ||
|
||
from contextlib import contextmanager | ||
import functools | ||
from itertools import chain | ||
import tempfile | ||
import typing | ||
|
@@ -37,7 +38,6 @@ | |
if TYPE_CHECKING: | ||
from collections import OrderedDict # noqa: F401 | ||
|
||
|
||
class FlatParameter(nn.Parameter): | ||
"""A parameter that is initialized from a list of parameters and can be | ||
turned into a list of views as needed. | ||
|
@@ -79,7 +79,7 @@ def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True): | |
self._param_infos: List[Tuple[str, nn.Module, str]] = [] | ||
self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = [] | ||
|
||
def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]: | ||
def get_param_views(self, require_backward_grad_sync, external_data: Optional[Tensor] = None) -> Iterator[Tensor]: | ||
"""Return a generator of views that map to the original parameters.""" | ||
# Note, self.data could be sharded, so its numel is <= to the sum. | ||
assert self.data.numel() <= sum( | ||
|
@@ -90,7 +90,9 @@ def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Te | |
raise ValueError( | ||
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}" | ||
) | ||
return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes)) | ||
|
||
split_outputs = data.split(self._param_numels) | ||
return (t.view(s) for (t, s) in zip(split_outputs, self._param_shapes)) | ||
|
||
def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]: | ||
"""Return tuple of (names, shapes, numels) metadata for this flat parameter.""" | ||
|
@@ -148,6 +150,11 @@ class FlattenParamsWrapper(nn.Module): | |
flat_param_names (Optional[List[str]]): | ||
originally, give each flat_param a unique name. Note a "flat_param_" | ||
prefix will be added to those names. | ||
optimize_backward_concat (bool): | ||
If True, only let backward pass propagate to the corresponding FSDP.params, which will | ||
invoke the FSDP._post_backward_hook() and concat() op, when _require_backward_grad_sync | ||
is True (e.g. last microbatch) | ||
NOTE: this likely will incur more GPU memory usage | ||
""" | ||
|
||
def __init__( | ||
|
@@ -157,10 +164,18 @@ def __init__( | |
flat_param_names: Optional[List[str]] = None, | ||
ssd_offload: bool = False, | ||
ssd_directory: str = "", | ||
optimize_backward_concat: bool = False, | ||
): | ||
super().__init__() | ||
self._fpw_module = module | ||
self.is_flattened = False | ||
self.optimize_backward_concat = optimize_backward_concat | ||
# If optimize_backward_concat == True, used to propagate the | ||
# corresponding FSDP modules's _require_backward_grad_sync flag | ||
self._require_backward_grad_sync = True | ||
# If optimize_backward_concat == True, used to accumulate the | ||
# fp32 gradients for the flattened parameters | ||
self.fp32_grads = [] | ||
|
||
# Handle param_list being None. | ||
if param_list is None: | ||
|
@@ -364,18 +379,60 @@ def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = No | |
delattr(self, n) | ||
self.flat_params = [] | ||
|
||
# The post backward hook used to accumulate fp32 gradients | ||
def _grad_accumulation_hook( | ||
self, | ||
grad, | ||
param_index, | ||
): | ||
if self.fp32_grads[param_index] is None: | ||
self.fp32_grads[param_index] = grad.to(torch.float32) | ||
else: | ||
self.fp32_grads[param_index].add_(grad.data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think |
||
return grad | ||
|
||
def _unflatten_params_as_views(self) -> None: | ||
"""Unlike ``_unflatten_params``, this function unflatten into views and keep | ||
self.flat_param unchanged. | ||
""" | ||
assert self.is_flattened | ||
ps = self.get_param_views() | ||
if self.optimize_backward_concat: | ||
# If self._require_backward_grad_sync == True (e.g. last microbatch), | ||
# we use the original flat_params as autograd leaf nodes and backward | ||
# pass should propagate all the way back to FSDP module and thus invoke | ||
# FSDP post_backward() hook and concat() op | ||
# Otherwise we stop the backward propagation before FSDP module to avoid | ||
# invoking concat() and store the accumulated fp32 grads | ||
if self._require_backward_grad_sync: | ||
ps = self.get_param_views() | ||
else: | ||
with torch.no_grad(): | ||
ps = self.get_param_views() | ||
else: | ||
ps = self.get_param_views() | ||
|
||
param_views = [] | ||
for (_, m, n), p in zip(self._param_infos, ps): | ||
setattr(p, '_fsdp_weight', True) | ||
setattr(m, n, p) # This will set as plain attr | ||
# The param_index of p used to accumulate the correspnding | ||
# gradients in self.fp32_grads | ||
param_index = len(param_views) | ||
if self.optimize_backward_concat: | ||
# Register post backward hook to accumulate the gradients | ||
# in self.fp32_grads | ||
p.register_hook( | ||
functools.partial( | ||
self._grad_accumulation_hook, | ||
param_index=param_index | ||
) | ||
) | ||
param_views.append(p) | ||
|
||
if self.optimize_backward_concat and len(self.fp32_grads) == 0: | ||
# Allocate self.fp32_grads at the beginning of each data batch's forward() | ||
self.fp32_grads = [None] * len(param_views) | ||
|
||
# Save param views for easy access if anyone still wants to access | ||
# parameters of the module. | ||
setattr(self._fpw_module, "_unflattened_param_views", param_views) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why there will be more GPU memory usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @awgu, currently by testing results it shows the GPU memory overhead could be non-trivial (20% of 80G), we will follow up on reducing the memory usage