diff --git a/benchmarks/oss.py b/benchmarks/oss.py index 281239a3d..63905e725 100755 --- a/benchmarks/oss.py +++ b/benchmarks/oss.py @@ -5,7 +5,7 @@ import math import os import time -from typing import Any, List +from typing import Any, List, Union, cast import torch import torch.distributed as dist @@ -19,6 +19,7 @@ from fairscale.optim.oss import OSS BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore +OPTIM = torch.optim.RMSprop def dist_init(rank, world_size): @@ -36,7 +37,9 @@ def train( use_oss: bool = True, check_regression: bool = True, reference_speed: float = -1.0, + reference_memory: float = -1.0, ): + # DDP dist_init(rank, world_size) @@ -50,21 +53,18 @@ def collate(inputs: List[Any]): "label": torch.stack([i[1] for i in inputs]).to(rank), } - def print_(msg): - if dist.get_rank() == 0: - print(msg) - dataloader = DataLoader( dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate ) loss_fn = nn.CrossEntropyLoss() + # Reset the memory use counter + torch.cuda.reset_peak_memory_stats(rank) + # Shard the optimizer - optimizer = ( - OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.9) - if use_oss - else torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) - ) + optimizer: Union[OSS, OPTIM] = OSS( + params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9 + ) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9) # Dummy training loop torch.cuda.synchronize(rank) @@ -90,8 +90,19 @@ def closure(): optimizer.step(closure) epoch_end = time.monotonic() + + if use_oss: + # Check the checkpointing in the case of the OSS optimizer + # Memory usage could spill over from there + optimizer = cast(OSS, optimizer) + # optimizer.consolidate_state_dict() + if dist.get_rank() == 0: + # _ = optimizer.state_dict() + print("... State dict collected") + measurements.append(data_size / (epoch_end - epoch_start)) - print_(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec") + if dist.get_rank() == 0: + print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec") torch.cuda.synchronize(rank) training_stop = time.monotonic() @@ -101,13 +112,15 @@ def closure(): print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall") print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB") + # Compute the mean and average img per second + mean = sum(measurements) / len(measurements) + diff = map(lambda x: pow(x - mean, 2.0), measurements) + std = math.sqrt(sum(diff) / (len(measurements) - 1)) + print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") + if use_oss and check_regression and dist.get_rank() == 0: - # Compute the mean and average img per second - mean = sum(measurements) / len(measurements) - diff = map(lambda x: pow(x - mean, 2.0), measurements) - std = math.sqrt(sum(diff) / (len(measurements) - 1)) - print(f"[Regression Test] Mean: {mean:.2f} +/- {std:.2f}") - assert (mean - 3.0 * std) < reference_speed, "Regression detected" + assert (mean - 3.0 * std) < reference_speed, "Speed regression detected" + assert max_memory < 1.05 * reference_memory, "Memory use regression detected" print("[Regression Test] VALID") @@ -122,10 +135,11 @@ def closure(): parser.add_argument("--data_size", action="store", default=512, type=int) parser.add_argument("--check_regression", action="store", default=True, type=bool) parser.add_argument("--reference_speed", action="store", default=39.82, type=float) + parser.add_argument("--reference_memory", action="store", default=4475, type=float) args = parser.parse_args() - print("\nBenchmark vanilla SGD") + print("\nBenchmark vanilla optimizer") mp.spawn( train, args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False), @@ -144,6 +158,7 @@ def closure(): True, args.check_regression, args.reference_speed, + args.reference_memory, ), nprocs=args.world_size, join=True,