Skip to content

Commit

Permalink
[feat] Add a memory usage regression test to the OSS benchmark (#62)
Browse files Browse the repository at this point in the history
* Aligning the optimizer state dict with what PyTorch expects

* Adding a check on the dict keys, ensure that `state` and `param_groups` are there

* after installing the specific isort, black and all, one liner to please the linter..

* Adding some measurement of the memory consumption while training + checkpointing

* mandatory lintfix commit

* brainfart, reset the memory use counter at the beginning of the training in case two of them are run in a row

* move reset stats call, hotfix

* move the optimizer to rmsprop, more stateful and still used in CV

* trying to figure out a sigsev in circleci
  • Loading branch information
blefaudeux authored Sep 3, 2020
1 parent b6a5e63 commit ee38e1e
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import os
import time
from typing import Any, List
from typing import Any, List, Union, cast

import torch
import torch.distributed as dist
Expand All @@ -19,6 +19,7 @@
from fairscale.optim.oss import OSS

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
OPTIM = torch.optim.RMSprop


def dist_init(rank, world_size):
Expand All @@ -36,7 +37,9 @@ def train(
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):

# DDP
dist_init(rank, world_size)

Expand All @@ -50,21 +53,18 @@ def collate(inputs: List[Any]):
"label": torch.stack([i[1] for i in inputs]).to(rank),
}

def print_(msg):
if dist.get_rank() == 0:
print(msg)

dataloader = DataLoader(
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
)
loss_fn = nn.CrossEntropyLoss()

# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)

# Shard the optimizer
optimizer = (
OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.9)
if use_oss
else torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
)
optimizer: Union[OSS, OPTIM] = OSS(
params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)

# Dummy training loop
torch.cuda.synchronize(rank)
Expand All @@ -90,8 +90,19 @@ def closure():
optimizer.step(closure)

epoch_end = time.monotonic()

if use_oss:
# Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there
optimizer = cast(OSS, optimizer)
# optimizer.consolidate_state_dict()
if dist.get_rank() == 0:
# _ = optimizer.state_dict()
print("... State dict collected")

measurements.append(data_size / (epoch_end - epoch_start))
print_(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")

torch.cuda.synchronize(rank)
training_stop = time.monotonic()
Expand All @@ -101,13 +112,15 @@ def closure():
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

if use_oss and check_regression and dist.get_rank() == 0:
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[Regression Test] Mean: {mean:.2f} +/- {std:.2f}")
assert (mean - 3.0 * std) < reference_speed, "Regression detected"
assert (mean - 3.0 * std) < reference_speed, "Speed regression detected"
assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
print("[Regression Test] VALID")


Expand All @@ -122,10 +135,11 @@ def closure():
parser.add_argument("--data_size", action="store", default=512, type=int)
parser.add_argument("--check_regression", action="store", default=True, type=bool)
parser.add_argument("--reference_speed", action="store", default=39.82, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)

args = parser.parse_args()

print("\nBenchmark vanilla SGD")
print("\nBenchmark vanilla optimizer")
mp.spawn(
train,
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
Expand All @@ -144,6 +158,7 @@ def closure():
True,
args.check_regression,
args.reference_speed,
args.reference_memory,
),
nprocs=args.world_size,
join=True,
Expand Down

0 comments on commit ee38e1e

Please sign in to comment.