Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] batch broadcast requests into a configurable buffer #43

Closed
wants to merge 51 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
86628e1
first take at comms batching when broadcasting the state
blefaudeux Aug 14, 2020
4dfc71c
sorting imports..
blefaudeux Aug 14, 2020
5b00c3b
nit
blefaudeux Aug 14, 2020
df15aa7
new machine means new linting..
blefaudeux Aug 14, 2020
3565391
better unit testing
blefaudeux Aug 15, 2020
f0e6814
hotfix, dimension
blefaudeux Aug 15, 2020
b7f7802
better unit testing still, preemptive bugfix
blefaudeux Aug 17, 2020
8084f98
linting
blefaudeux Aug 17, 2020
510f773
unit test fix, a little prettier, should be gtg
blefaudeux Aug 17, 2020
582134f
annoying, remove coverage check on the type hints
blefaudeux Aug 18, 2020
4ed074b
initial commit, dummy training loop, pure pytorch but not DDP
blefaudeux Aug 20, 2020
a167289
probably slightly broken, but rough DDP benchmark run
blefaudeux Aug 20, 2020
20b981d
adding the torchvision requirement for testing
blefaudeux Aug 20, 2020
8a2377c
brainfart
blefaudeux Aug 20, 2020
41dcf69
reduce the loss, do something slightly distributed
blefaudeux Aug 20, 2020
b212dee
Some cleanup, distributing the training on two GPUs
blefaudeux Aug 20, 2020
b149113
Merge remote-tracking branch 'upstream/master' into oss_benchmark
blefaudeux Aug 20, 2020
b5cacbd
some cleanup + adding a vanilla run, still not good to go
blefaudeux Aug 20, 2020
928791e
less silly defaults, gtg for a start I think
blefaudeux Aug 20, 2020
e6a4756
smaller batch to fit the smaller gpus used in the circleci rigs
blefaudeux Aug 20, 2020
e01a60a
Merge commit 'c2d6f4b68e9c24d05a3eb5da4f60431d9e5c86d8' into oss_batc…
blefaudeux Aug 20, 2020
ab79ddc
better device/buffer alloation
blefaudeux Aug 20, 2020
906d740
Merge commit 'e6a4756c1c2927d35af2148dbfb8d0e1f3bff797' into oss_batc…
blefaudeux Aug 20, 2020
d599f4c
WIP, some type hint cleaning, speed deficit for now
blefaudeux Aug 20, 2020
78fc476
lint + double buffering setting
blefaudeux Aug 20, 2020
4560a0c
fix some lazy programming when running on cpu
blefaudeux Aug 20, 2020
56974ed
tighter OSS input type
blefaudeux Aug 20, 2020
8aa48f2
Merge branch 'master' into oss_batch_broadcast
blefaudeux Sep 2, 2020
24d619d
fixing botched merge
blefaudeux Sep 2, 2020
11811f7
Merge remote-tracking branch 'upstream/master' into oss_batch_broadcast
blefaudeux Sep 3, 2020
2209cce
default the buffer to None, check for device locality
blefaudeux Sep 3, 2020
47da0b5
linting + tweak the broadcast buffer settings
blefaudeux Sep 3, 2020
4d4b8cf
minor tweak to the oss benchmark CLI, smaller param buffer
blefaudeux Sep 3, 2020
5cbe21e
adjust speed for RMSProp
blefaudeux Sep 3, 2020
b3aad66
bugfix
blefaudeux Sep 3, 2020
07d6626
back to 100% code coverage, slightly cleaner unit test
blefaudeux Sep 3, 2020
b09d2a1
WIP
blefaudeux Sep 4, 2020
3739ee0
Merge branch 'master' into oss_batch_broadcast
blefaudeux Sep 8, 2020
5f78ccf
better bucketing, across devices and ranks. credits to oss_ddp. WIP i…
blefaudeux Sep 8, 2020
c711b73
cosmetics
blefaudeux Sep 9, 2020
af9dc13
WIP
blefaudeux Sep 9, 2020
3dd3c27
Merge remote-tracking branch 'upstream/master' into oss_batch_broadcast
blefaudeux Sep 9, 2020
8916c50
merge fixes + tentative perf improvements
blefaudeux Sep 9, 2020
68a67fb
allocate per-device broadcast buffer once and for all, at constructio…
blefaudeux Sep 9, 2020
4cadd58
deduplicate oss_ddp/oss
blefaudeux Sep 10, 2020
07c20a9
merge with upstream master, could still be optimized
blefaudeux Sep 10, 2020
5decde1
Merge remote-tracking branch 'upstream/master' into oss_batch_broadcast
blefaudeux Sep 15, 2020
a8f601c
Merge remote-tracking branch 'upstream/master' into oss_batch_broadcast
blefaudeux Sep 29, 2020
b34bedd
restoring working state, nccl deadlocking unfortunately
blefaudeux Sep 29, 2020
08ce45d
wip
blefaudeux Sep 29, 2020
26308b4
in working order, but unbearably slow
blefaudeux Sep 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import torch.distributed as dist
from torch.optim import SGD, Optimizer

from .utils import broadcast_object, recursive_copy_to_device
from .utils import batch_broadcast, broadcast_object, recursive_copy_to_device

if TYPE_CHECKING: # pragma: no cover
from torch.optim.optimizer import _params_t
from torch.nn import Parameter
else:
_params_t = Any

Expand Down Expand Up @@ -43,12 +44,22 @@ class OSS(Optimizer):
optimizer to shard (default: SGD)
group (group):
torch.distributed group (default: group.WORLD)
buffer_size (int, optional): number of elements to buffer before
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does optional mean in this context? The parameter does not look like an optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to write that people are free to pass it in or not, there's a default provided

performing reduce (default: 32M). Used to reduce multiple small
params to avoid communication overhead.
"""

optim: Optimizer
in_super_constructor: bool

def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any):
def __init__(
self,
params: _params_t,
optim: Type[Optimizer] = SGD,
group: Any = dist.group.WORLD,
buffer_size: int = 2 ** 25,
**defaults: Any
):
# Hold all the nmodel params in the root .param_groups
self.in_super_constructor = True
super().__init__(params, defaults)
Expand All @@ -65,6 +76,7 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any =

# Current device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
self._buffer = torch.zeros(buffer_size).to(self._device)
blefaudeux marked this conversation as resolved.
Show resolved Hide resolved

def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks.
Expand Down Expand Up @@ -97,9 +109,32 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]

# Sync all the states
for rank, param_groups in enumerate(self.partition_parameters()):
# Batch smaller params in a broadcast buffer to speed up the communication
buffered_params: List[Parameter] = []
buffered_elements = 0

for param_group in param_groups:
for param in param_group["params"]:
dist.broadcast(tensor=param, src=rank, group=self.group)
if param.numel() >= self._buffer.numel():
# Big param block, broadcast directly
dist.broadcast(tensor=param, src=rank, group=self.group)
blefaudeux marked this conversation as resolved.
Show resolved Hide resolved
else:
if (buffered_elements + param.numel()) >= self._buffer.numel():
# Batch buffer is full, sync
batch_broadcast(
buffered_params, source_rank=rank, buffer=self._buffer, process_group=self.group
)
buffered_params.clear()
buffered_elements = 0

# Keep async and batch sync later
buffered_params.append(param)
buffered_elements += param.numel()

# Sync whatever is left in the batch buffer before moving to the next rank
if buffered_elements > 0:
batch_broadcast(buffered_params, source_rank=rank, buffer=self._buffer, process_group=self.group)

return loss

def local_state_dict(self) -> dict:
Expand Down
32 changes: 31 additions & 1 deletion fairscale/optim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
# LICENSE file in the root directory of this source tree.

import io
from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict, List

import torch
from torch._six import container_abcs
import torch.distributed as dist

if TYPE_CHECKING: # pragma: no cover
blefaudeux marked this conversation as resolved.
Show resolved Hide resolved
from torch import Tensor
from torch.nn import Parameter
else:
Tensor = Any
Parameter = Any

# Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
Expand Down Expand Up @@ -68,3 +74,27 @@ def broadcast_object(
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=dist_device)
return obj


def batch_broadcast(
buffered_params: List[Parameter], source_rank: int, buffer: Tensor, process_group: Any = None
) -> None:
""" Helper to broadcast a list of params batched into a bigger buffer.
NOTE: This skips the grads on purpose, only broadcasts the tensor parameters.
NOTE: This also asserts that the parameters will fit in the buffer """

offset = 0
for p in buffered_params:
sz = p.numel()
buffer[offset : offset + p.numel()].copy_(p.data.view(-1)) # type: ignore
offset += sz
assert offset < buffer.numel()

dist.broadcast(tensor=buffer, src=source_rank, group=process_group)

# copy brodcasted grads back into their original place
offset = 0
for p in buffered_params:
sz = p.numel()
p.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore
offset += sz
63 changes: 63 additions & 0 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

import math
import os

import pytest
Expand Down Expand Up @@ -140,6 +141,68 @@ def test_step():
mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True)


def run_test_batch_broadcast(rank, world_size):
dist_init(rank, world_size)
width_multiplier = 3
batch_size = 10

x = torch.ones([batch_size, world_size], device=rank)
target = torch.zeros([batch_size, width_multiplier * world_size], device=rank)
error = math.factorial(width_multiplier * world_size - 1)

def get_model():
layers = [torch.nn.Linear(i, i + 1) for i in range(world_size, width_multiplier * world_size)]
for l in layers:
l.weight.data.fill_(1.0)
l.bias.data.fill_(0.0)

m = torch.nn.Sequential(*layers)
m.to(rank)
return m

# Set a very small buffer size to force the full param block to be broadcasted
m_small = get_model()
o = optim.OSS(m_small.parameters(), lr=0.1, buffer_size=8)
loss_fn = torch.nn.L1Loss().to(device=rank)

def closure():
o.zero_grad()
output = m_small(x)
loss = loss_fn(output, target)
loss.backward()
return loss

loss = o.step(closure=closure)
assert round(loss.item()) == error, f"{loss} vs. expected: {error}"

loss_update = o.step(closure=closure)
assert loss_update.item() < loss.item(), f"{loss.item()} vs {loss_update.item()} loss should decrease"

# Set a very big buffer size to force all the params to be packed
m_large = get_model()
o = optim.OSS(m_large.parameters(), lr=0.1, buffer_size=2 ** 26)
loss_fn = torch.nn.L1Loss().to(device=rank)

def closure():
o.zero_grad()
output = m_large(x)
loss = loss_fn(output, target)
loss.backward()
return loss

loss = o.step(closure=closure)
assert round(loss.item()) == error, f"{loss} vs. expected: {error}"

loss_update = o.step(closure=closure)
assert loss_update.item() < loss.item(), f"{loss.item()} vs {loss_update.item()} loss should decrease"


@skip_if_no_cuda
def test_batch_broadcast():
world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_batch_broadcast, args=(world_size,), nprocs=world_size, join=True)


def run_test_step_with_closure(rank, world_size, optimizer=None):
dist_init(rank, world_size)

Expand Down