Skip to content

Commit

Permalink
Fix bug in gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonLarsen committed Jun 26, 2024
1 parent 37a45b5 commit d58dd6d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
2 changes: 1 addition & 1 deletion frogbox/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _backward_with_scaler(
):
if scaler:
scaler.scale(loss).backward()
if iteration & gradient_accumulation_steps == 0:
if iteration % gradient_accumulation_steps == 0:
scaler.unscale_(optimizer)
if clip_grad_norm:
torch.nn.utils.clip_grad_norm_(
Expand Down
52 changes: 26 additions & 26 deletions frogbox/engines/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,21 @@ def _update(
x, y = input_transform(x, y)

# Update discriminator
disc_optimizer.zero_grad()
disc_loss = torch.tensor(0.0)
if (engine.state.iteration - 1) % disc_update_interval == 0:
disc_optimizer.zero_grad()

with torch.autocast(device_type=device.type, enabled=amp):
y_pred = model_transform(model(x)).detach()
disc_pred_real = disc_model_transform(disc_model(y))
disc_pred_fake = disc_model_transform(disc_model(y_pred))
disc_loss = disc_loss_fn(
y_pred,
y,
disc_real=disc_pred_real,
disc_fake=disc_pred_fake,
)

with torch.autocast(device_type=device.type, enabled=amp):
y_pred = model_transform(model(x))
disc_pred_real = disc_model_transform(disc_model(y))
disc_pred_fake = disc_model_transform(
disc_model(y_pred.detach())
)
disc_loss = disc_loss_fn(
y_pred,
y,
disc_real=disc_pred_real,
disc_fake=disc_pred_fake,
)

if engine.state.iteration % disc_update_interval == 0:
_backward_with_scaler(
model=disc_model,
optimizer=disc_optimizer,
Expand All @@ -136,18 +135,19 @@ def _update(
)

# Update generator
optimizer.zero_grad()

with torch.autocast(device_type=device.type, enabled=amp):
disc_pred_fake = disc_model_transform(disc_model(y_pred))
loss = loss_fn(
y_pred,
y,
disc_real=disc_pred_real,
disc_fake=disc_pred_fake,
)
loss = torch.tensor(0.0)
if (engine.state.iteration - 1) % update_interval == 0:
optimizer.zero_grad()

with torch.autocast(device_type=device.type, enabled=amp):
y_pred = model_transform(model(x))
disc_pred_fake = disc_model_transform(disc_model(y_pred))
loss = loss_fn(
y_pred,
y,
disc_fake=disc_pred_fake,
)

if engine.state.iteration % update_interval == 0:
_backward_with_scaler(
model=model,
optimizer=optimizer,
Expand Down

0 comments on commit d58dd6d

Please sign in to comment.