-
Notifications
You must be signed in to change notification settings - Fork 279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Can we access parameter views when using flatten_parameters=True? #430
Comments
Hmm, this is a bit tricky. For the first part, this works (the params will get re-flattened after the context manager exits):
The optimizer use case is harder though. The optimizer won't be happy with the re-flattening that happens when the context manager exits: class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(5, 5)
def forward(self, x):
return self.layer(x)
(...)
print("param norm (before)", torch.norm(torch.stack([p.norm() for p in model.parameters()])))
with model.unflatten_params():
optimizer = torch.optim.SGD(model.layer.parameters(), lr=0.1)
optimizer.zero_grad()
loss = model(torch.rand(8, 5)).sum()
loss.backward()
optimizer.step()
print("param norm (after, broken)", torch.norm(torch.stack([p.norm() for p in model.parameters()])))
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
loss = model(torch.rand(8, 5)).sum()
loss.backward()
optimizer.step()
print("param norm (after, works)", torch.norm(torch.stack([p.norm() for p in model.parameters()]))) Setting |
Thanks @myleott was helpful! I think from a Lightning perspective we'll clear in the docs to only use the FSDP model in setting up optimizers for now, whilst we figure out a solution longer term in the FSDP code. Does that sound reasonable? Regarding your comment about using communication hooks in the future #413 (comment) this would technically fix the issue right? Obviously quite a bit of work to move to using comms hooks/redo bucketing! |
Yes, right now flattening is essential for performance, but once we speed up our bucketing solution the gap should be smaller, making it practical to set To give you a sense of speed difference, here's a benchmark on 8xV100s for WMT'16 En-De translation in fairseq:
|
BTW, the new |
Sorry, I think I misunderstood issue first and ignore the comment about summon_full_params above. However, what's left to do in this issue? I am a bit unclear. |
Hey @min-xu-ai! Unfortunately this still won't work, because the original weights have been bucketed in place, removing pointers to the original weights without replacement I think.
will not work unless we set |
Let me understand a bit more. Will bucketing really help here? Bucketing will help performance since FSDP will loop over fewer params internally. However, for your use case, you need module.layer.parameters() to return the original full unsharded params or sharded ones? Is the optimizer assumed to be point-wise, like that for FSDP? When will the optimizer call on module.layer.parameters() be made? Before or after the model is wrapped? Sorry if I missed the context from reading this thread. |
@SeanNaren @min-xu-ai Is there an action item here to follow up on? From my understanding, flatten_parameters=False + enhanced bucketing strategy will allow for this feature. In the meantime is there something we can do and what is the priority for the suggested change? |
I think new FSDP code will likely address this by adding an API for getting views for all original params to the flatten param. The views can be partial or even empty. |
❓ Questions and Help
This should explain the case:
When
flatten_parameters=True
we remove the parameter as we have migrated it to a contiguous buffer, but this means when we call.parameters()
on specific modules (in the case we only want to wrap certain parts of the model with optimizers) this can not be done.Any remedy to this problem? We were experimenting with the possibility of using views to replace this functionality however this doesn't return a parameter I think. Alternatively, we could tell the users if they run into issues like above, to turn off
flatten_parameters
.The text was updated successfully, but these errors were encountered: