Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improve] FLAVR demo #954

Merged
merged 3 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# model settings
model = dict(
type='BasicInterpolator',
type='FLAVR',
generator=dict(
type='FLAVRNet',
num_input_frames=4,
Expand Down
20 changes: 12 additions & 8 deletions mmedit/models/video_interpolators/basic_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class BasicInterpolator(BaseModel):
pixel_loss (dict): Config for pixel-wise loss.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
required_frames (int): Required frames in each process. Default: 2
step_frames (int): Step size of video frame interpolation. Default: 1
pretrained (str): Path for pretrained model. Default: None.
"""
allowed_metrics = {'PSNR': psnr, 'SSIM': ssim}
Expand All @@ -37,6 +39,8 @@ def __init__(self,
pixel_loss,
train_cfg=None,
test_cfg=None,
required_frames=2,
step_frames=1,
pretrained=None):
super().__init__()

Expand All @@ -54,9 +58,9 @@ def __init__(self,
self.pixel_loss = build_loss(pixel_loss)

# Required frames in each process
self.required_frames = 2
self.required_frames = required_frames
# Step size of video frame interpolation
self.step_frames = 1
self.step_frames = step_frames

def init_weights(self, pretrained=None):
"""Init weights for models.
Expand Down Expand Up @@ -266,13 +270,9 @@ def val_step(self, data_batch, **kwargs):
output = self.forward_test(**data_batch, **kwargs)
return output

@staticmethod
def split_frames(input_tensors):
def split_frames(self, input_tensors):
"""split input tensors for inference.

This is a basic function, interpolate a frame between the given two
frames.

Args:
input_tensors (Tensor): Tensor of input frames with shape
[1, t, c, h, w]
Expand All @@ -283,7 +283,11 @@ def split_frames(input_tensors):

num_frames = input_tensors.shape[1]

result = [input_tensors[:, i:i + 2] for i in range(0, num_frames - 1)]
result = [
input_tensors[:, i:i + self.required_frames]
for i in range(0, num_frames - self.required_frames +
1, self.step_frames)
]
result = torch.cat(result, dim=0)

return result
Expand Down
69 changes: 69 additions & 0 deletions mmedit/models/video_interpolators/flavr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmedit.core import tensor2img
from ..registry import MODELS
from .basic_interpolator import BasicInterpolator


@MODELS.register_module()
class FLAVR(BasicInterpolator):
"""Basic model for video interpolation.

It must contain a generator that takes frames as inputs and outputs an
interpolated frame. It also has a pixel-wise loss for training.

The subclasses should overwrite the function `forward_train`,
`forward_test` and `train_step`.

Args:
generator (dict): Config for the generator structure.
pixel_loss (dict): Config for pixel-wise loss.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
"""

def __init__(self,
generator,
pixel_loss,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__(
generator=generator,
pixel_loss=pixel_loss,
train_cfg=train_cfg,
test_cfg=test_cfg,
required_frames=4,
step_frames=1,
pretrained=pretrained)

@staticmethod
def merge_frames(input_tensors, output_tensors):
"""merge input frames and output frames.

Interpolate a frame between the given two frames.

Merged from
[[in1, in2, in3, in4], [in2, in3, in4, in5], ...]
[[out1], [out2], [out3], ...]
to
[in1, in2, out1, in3, out2, ..., in(-3), out(-1), in(-2), in(-1)]

Args:
input_tensors (Tensor): The input frames with shape [n, 4, c, h, w]
output_tensors (Tensor): The output frames with shape
[n, 1, c, h, w].

Returns:
list[np.array]: The final frames.
"""

num_frames = input_tensors.shape[0]
result = [tensor2img(input_tensors[0, 0])]
for i in range(num_frames):
result.append(tensor2img(input_tensors[i, 1]))
result.append(tensor2img(output_tensors[i, 0]))
result.append(tensor2img(input_tensors[-1, 2]))
result.append(tensor2img(input_tensors[-1, 3]))

return result
18 changes: 18 additions & 0 deletions tests/test_models/test_video_interpolator/test_flavr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmedit.models.video_interpolators.flavr import FLAVR


def test_flavr():

model = FLAVR(
generator=dict(
type='FLAVRNet', num_input_frames=4, num_output_frames=1),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'))

input_tensors = torch.rand(3, 4, 3, 16, 16)
output_tensors = torch.rand(3, 1, 3, 16, 16)
result = model.merge_frames(input_tensors, output_tensors)
assert len(result) == 9
assert result[0].shape == (16, 16, 3)