Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Optimizer state sharding - Fairscale (#584)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #584

Bringing in fairscale to provide an optional state sharded optimizer in Classy, which should help in situations bounded by memory pressure.
No new communication backend, this is using vanilla torch.distributed.

See ZeRO for more context https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/

KNOWN TODOs:
[x] huge memory discrepancy in between the two runs (FIXED)
[~x] huge speed discrepancy (broadcast related) -> (FIXED for one node, needs to be investigated for multi node)

[x] final accuracy in the same ballpark but very different behaviours, could be some settings not properly passed down, an issue with LARC, or the parameter scheduling
-> this was due to the LR not properly adjusted (FIXED)

[x] sync with min-xu-ai to use a proper gradient dispatch in the end, not landing anything before that
-> done by min-xu-ai on the fairscale side, needs benchmarking, but should not be related to this diff (no interface consequence hopefully)

Reviewed By: mannatsingh

Differential Revision: D22518768

fbshipit-source-id: 8103a15c164a9f39443b574d34282f6ff70ba3b1
  • Loading branch information
blefaudeux authored and facebook-github-bot committed Sep 12, 2020
1 parent 6150e78 commit ac2993d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
4 changes: 4 additions & 0 deletions classy_vision/generic/distributed_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def get_rank() -> int:
)


def get_primary_rank() -> int:
return _PRIMARY_RANK


def set_cuda_device_index(idx: int) -> None:
global _cuda_device_index
_cuda_device_index = idx
Expand Down
13 changes: 13 additions & 0 deletions test/generic/optim_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def _get_set_state(self, grad_values):

self._set_gradient(self._parameters(), grad_values)
opt1.step(where=0)

if config["name"] == "zero":
opt1.consolidate_state_dict()

state = opt1.get_classy_state()

opt2 = build_optimizer(config)
Expand All @@ -83,6 +87,10 @@ def _get_set_state(self, grad_values):
opt2.optimizer.param_groups[0]["params"][i],
)
)

if config["name"] == "zero":
opt2.consolidate_state_dict()

self._compare_momentum_values(
opt1.get_classy_state()["optim"], opt2.get_classy_state()["optim"]
)
Expand All @@ -106,6 +114,11 @@ def _get_set_state(self, grad_values):
opt2.optimizer.param_groups[0]["params"][i],
)
)

if config["name"] == "zero":
opt1.consolidate_state_dict()
opt2.consolidate_state_dict()

self._compare_momentum_values(
opt1.get_classy_state()["optim"], opt2.get_classy_state()["optim"]
)
Expand Down

0 comments on commit ac2993d

Please sign in to comment.