This repository has been archived by the owner on Jul 1, 2024. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimizer state sharding - Fairscale (#584)
Summary: Pull Request resolved: #584 Bringing in fairscale to provide an optional state sharded optimizer in Classy, which should help in situations bounded by memory pressure. No new communication backend, this is using vanilla torch.distributed. See ZeRO for more context https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/ KNOWN TODOs: [x] huge memory discrepancy in between the two runs (FIXED) [~x] huge speed discrepancy (broadcast related) -> (FIXED for one node, needs to be investigated for multi node) [x] final accuracy in the same ballpark but very different behaviours, could be some settings not properly passed down, an issue with LARC, or the parameter scheduling -> this was due to the LR not properly adjusted (FIXED) [x] sync with min-xu-ai to use a proper gradient dispatch in the end, not landing anything before that -> done by min-xu-ai on the fairscale side, needs benchmarking, but should not be related to this diff (no interface consequence hopefully) Reviewed By: mannatsingh Differential Revision: D22518768 fbshipit-source-id: 8103a15c164a9f39443b574d34282f6ff70ba3b1
- Loading branch information