Skip to content

Commit

Permalink
[Fix] fix realesrgan ema (#1341)
Browse files Browse the repository at this point in the history
* [Fix] fix realesrgan ema

* fix

* fix ut

* fix config

* fix ut
  • Loading branch information
Z-Fran authored Oct 21, 2022
1 parent 20e5850 commit b90ac53
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@
is_use_sharpened_gt_in_pixel=True,
is_use_sharpened_gt_in_percep=True,
is_use_sharpened_gt_in_gan=False,
is_use_ema=True,
train_cfg=dict(start_iter=1000000),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0, 0, 0],
std=[1, 1, 1],
mean=[0., 0., 0.],
std=[255., 255., 255.],
))

train_cfg = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
upscale_factor=scale),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
is_use_sharpened_gt_in_pixel=True,
is_use_ema=True,
train_cfg=dict(),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0, 0, 0],
std=[1, 1, 1],
mean=[0., 0., 0.],
std=[255., 255., 255.],
))

train_pipeline = [
Expand Down Expand Up @@ -177,7 +178,6 @@
val_pipeline = [
dict(type='LoadImageFromFile', key='img', channel_order='rgb'),
dict(type='LoadImageFromFile', key='gt', channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['img', 'gt']),
dict(type='PackEditInputs')
]

Expand Down Expand Up @@ -205,14 +205,13 @@
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='set5', task_name='real_sr'),
data_root='data/set5',
data_prefix=dict(gt='HR', img='bicLRx4'),
data_root='data/Set5',
data_prefix=dict(gt='GTmod12', img='LRbicx4'),
pipeline=val_pipeline))

test_dataloader = val_dataloader

val_evaluator = [
dict(type='MAE'),
dict(type='PSNR'),
dict(type='SSIM'),
]
Expand Down Expand Up @@ -253,7 +252,7 @@
vis_backends=vis_backends,
fn_key='gt_path',
img_keys=['gt_img', 'input', 'pred_img'],
bgr2rgb=True)
bgr2rgb=False)
custom_hooks = [
dict(type='BasicVisualizationHook', interval=1),
dict(
Expand Down
2 changes: 2 additions & 0 deletions mmedit/models/editors/real_basicvsr/real_basicvsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self,
is_use_sharpened_gt_in_pixel=False,
is_use_sharpened_gt_in_percep=False,
is_use_sharpened_gt_in_gan=False,
is_use_ema=False,
train_cfg=None,
test_cfg=None,
init_cfg=None,
Expand All @@ -72,6 +73,7 @@ def __init__(self,
is_use_sharpened_gt_in_pixel=is_use_sharpened_gt_in_pixel,
is_use_sharpened_gt_in_percep=is_use_sharpened_gt_in_percep,
is_use_sharpened_gt_in_gan=is_use_sharpened_gt_in_gan,
is_use_ema=is_use_ema,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
Expand Down
25 changes: 25 additions & 0 deletions mmedit/models/editors/real_esrgan/real_esrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class RealESRGAN(SRGAN):
is_use_sharpened_gt_in_gan (bool, optional): Whether to use the
image sharpened by unsharp masking as the GT for adversarial loss.
Default: False.
is_use_ema (bool, optional): When to apply exponential moving average
on the network weights. Default: True.
train_cfg (dict): Config for training. Default: None.
You may change the training of gan by setting:
`disc_steps`: how many discriminator updates after one generate
Expand All @@ -56,6 +58,7 @@ def __init__(self,
is_use_sharpened_gt_in_pixel=False,
is_use_sharpened_gt_in_percep=False,
is_use_sharpened_gt_in_gan=False,
is_use_ema=True,
train_cfg=None,
test_cfg=None,
init_cfg=None,
Expand All @@ -75,12 +78,34 @@ def __init__(self,
self.is_use_sharpened_gt_in_pixel = is_use_sharpened_gt_in_pixel
self.is_use_sharpened_gt_in_percep = is_use_sharpened_gt_in_percep
self.is_use_sharpened_gt_in_gan = is_use_sharpened_gt_in_gan
self.is_use_ema = is_use_ema

if train_cfg is not None: # used for initializing from ema model
self.start_iter = train_cfg.get('start_iter', -1)
else:
self.start_iter = -1

def forward_tensor(self, inputs, data_samples=None, training=False):
"""Forward tensor. Returns result of simple forward.
Args:
inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (List[BaseDataElement], optional):
data samples collated by :attr:`data_preprocessor`.
training (bool): Whether is training. Default: False.
Returns:
Tensor: result of simple forward.
"""

if training or not self.is_use_ema:
feats = self.generator(inputs)
else:
feats = self.generator_ema(inputs)

return feats

def g_step(self, batch_outputs, batch_gt_data):
"""G step of GAN: Calculate losses of generator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_real_esrgan(init_weights):
is_use_sharpened_gt_in_pixel=False,
is_use_sharpened_gt_in_percep=False,
is_use_sharpened_gt_in_gan=False,
is_use_ema=False,
train_cfg=None,
test_cfg=None,
data_preprocessor=EditDataPreprocessor())
Expand Down Expand Up @@ -88,6 +89,12 @@ def test_real_esrgan(init_weights):
output = model.val_step(data)
assert output[0].output.pred_img.data.shape == (3, 128, 128)

# val_ema
model.generator_ema = model.generator
model.is_use_ema = True
output = model.val_step(data)
assert output[0].output.pred_img.data.shape == (3, 128, 128)

# feat
output = model(torch.rand(1, 3, 32, 32), mode='tensor')
assert output.shape == (1, 3, 128, 128)
Expand Down

0 comments on commit b90ac53

Please sign in to comment.