From 8d47540bbfcca1b826644b8194a8c76b78534c2f Mon Sep 17 00:00:00 2001 From: ruoning Date: Thu, 29 Sep 2022 16:37:22 +0800 Subject: [PATCH 01/32] [feat]: add Instance-aware Image Colorization --- configs/insta/README.md | 65 + configs/insta/insta_full_cocostuff_256x256.py | 141 ++ .../insta/insta_fusion_cocostuff_256x256.py | 0 .../insta/insta_instance_cocostuff_256x256.py | 94 ++ demo/colorization_demo.py | 45 + mmedit/apis/__init__.py | 2 + mmedit/apis/colorization_inference.py | 26 + mmedit/datasets/coco.py | 84 + mmedit/datasets/transforms/__init__.py | 5 +- .../datasets/transforms/get_gray_color_pil.py | 29 + .../datasets/transforms/get_maskrcnn_bbox.py | 210 +++ mmedit/models/base_models/__init__.py | 4 +- .../models/base_models/base_colorization.py | 57 + mmedit/models/editors/__init__.py | 1 + mmedit/models/editors/insta/__init__.py | 6 + mmedit/models/editors/insta/insta.py | 454 +++++ mmedit/models/editors/insta/insta_net.py | 1497 +++++++++++++++++ mmedit/models/editors/insta/util.py | 261 +++ mmedit/models/losses/__init__.py | 5 +- mmedit/models/losses/pixelwise_loss.py | 17 + 20 files changed, 2999 insertions(+), 4 deletions(-) create mode 100644 configs/insta/README.md create mode 100644 configs/insta/insta_full_cocostuff_256x256.py create mode 100644 configs/insta/insta_fusion_cocostuff_256x256.py create mode 100644 configs/insta/insta_instance_cocostuff_256x256.py create mode 100644 demo/colorization_demo.py create mode 100644 mmedit/apis/colorization_inference.py create mode 100644 mmedit/datasets/coco.py create mode 100644 mmedit/datasets/transforms/get_gray_color_pil.py create mode 100644 mmedit/datasets/transforms/get_maskrcnn_bbox.py create mode 100644 mmedit/models/base_models/base_colorization.py create mode 100644 mmedit/models/editors/insta/__init__.py create mode 100644 mmedit/models/editors/insta/insta.py create mode 100644 mmedit/models/editors/insta/insta_net.py create mode 100644 mmedit/models/editors/insta/util.py diff --git a/configs/insta/README.md b/configs/insta/README.md new file mode 100644 index 0000000000..860d8a0e5f --- /dev/null +++ b/configs/insta/README.md @@ -0,0 +1,65 @@ +# Instance-aware Image Colorization (CVPR'2020) + +> **任务**: 图像上色 + +## 快速开始 + +**训练** + +
+训练说明 + +您可以使用以下命令来训练模型。 + +```shell +# CPU上训练 +CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/insta/insta_full_cocostuff_256x256.py + +# 单个GPU上训练 +python tools/train.py configs/insta/insta_full_cocostuff_256x256.py + +# 多个GPU上训练 +./tools/dist_train.sh configs/insta/insta_full_cocostuff_256x256.py 8 +``` + +更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Train a model** 部分。 + +
+ +**测试** + +
+测试说明 + +您可以使用以下命令来测试模型。 + +```shell +# CPU上测试 +CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth + +# 单个GPU上测试 +python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth + +# 多个GPU上测试 +./tools/dist_test.sh configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 +``` + +更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Test a pre-trained model** 部分。 + +
+ + +
+Instance-aware Image Colorization (CVPR'2020) + +```bibtex +@inproceedings{Su-CVPR-2020, + author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin}, + title = {Instance-aware Image Colorization}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2020} +} +``` + +
+ diff --git a/configs/insta/insta_full_cocostuff_256x256.py b/configs/insta/insta_full_cocostuff_256x256.py new file mode 100644 index 0000000000..e25f2836c9 --- /dev/null +++ b/configs/insta/insta_full_cocostuff_256x256.py @@ -0,0 +1,141 @@ +_base_ = [ + '../_base_/default_runtime.py' +] + +exp_name = 'Instance-aware_full' +save_dir = './' +work_dir = '..' + +model = dict( + type='FusionModel', + data_preprocessor=dict( + type='EditDataPreprocessor', + mean=[127.5], + std=[127.5], + ), + instance_model=dict( + type='SIGGRAPHGenerator', + input_nc=4, + output_nc=2, + norm_type='batch' + ), + stage='full', + ngf=64, + output_nc=2, + avg_loss_alpha=.986, + ab_norm=110., + l_norm=100., + l_cent=50., + sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], + mask_cent=.5, + init_type='normal', + which_direction='AtoB', + loss=dict(type='HuberLoss', delta=.01), + pretrained='./checkpoints/pytorch_trained.pth' +) + +input_shape = (256, 256) + +train_pipeline = [ + dict(type='LoadImageFromFile', key='gt_img'), + dict(type='GenGrayColorPil', stage='full', keys=['rgb_img', 'gray_img']), + dict( + type='Resize', + keys=['rgb_img', 'gray_img'], + scale=input_shape, + keep_ratio=False, + interpolation='nearest'), + dict(type='RescaleToZeroOne', keys=['rgb_img', 'gray_img']), + dict(type='PackEditInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile', key='gt'), + dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), + dict(type='Resize', + keys=['gt'], + scale=(256, 256), + keep_ratio=False + ), + dict(type='PackEditInputs'), +] + +dataset_type = 'CocoDataset' +data_root = '/mnt/j/DataSet/cocostuff/train2017' +ann_file_path = '/mnt/j/DataSet/cocostuff/' + +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=False, + sampler=dict(shuffle=False), + workers_per_gpu=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(gt='data_large'), + ann_file=f'{ann_file_path}/img_list.txt', + pipeline=train_pipeline, + test_mode=False)) + +test_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=False, + sampler=dict(type='DefaultSampler', shuffle=False), + workers_per_gpu=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(gt='data_large'), + ann_file=f'{ann_file_path}/img_list.txt', + pipeline=test_pipeline, + test_mode=False)) + + +test_evaluator = [dict(type='PSNR'), dict(type='SSIM')] + +train_cfg = dict( + type='IterBasedTrainLoop', + max_iters=500002, + val_interval=50000, +) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + + +# optimizer +optim_wrapper = dict( + constructor='DefaultOptimWrapperConstructor', + generator=dict( + type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9))), + disc=dict( + type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9)))) + +param_scheduler = dict( + # todo engine中暂时还没有这个 + type='LambdaLR', + by_epoch=False, +) + +vis_backends = [dict(type='LocalVisBackend')] + +visualizer = dict( + type='ConcatImageVisualizer', + vis_backends=vis_backends, + fn_key='gt_path', + img_keys=[ + 'gray', 'real', 'fake_reg', 'hint', 'real_ab', 'fake_ab_reg' + ], + bgr2rgb=False) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + + diff --git a/configs/insta/insta_fusion_cocostuff_256x256.py b/configs/insta/insta_fusion_cocostuff_256x256.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/configs/insta/insta_instance_cocostuff_256x256.py b/configs/insta/insta_instance_cocostuff_256x256.py new file mode 100644 index 0000000000..f7712efdf9 --- /dev/null +++ b/configs/insta/insta_instance_cocostuff_256x256.py @@ -0,0 +1,94 @@ +ab_norm = 110. +model = dict( + type='FusionModel', + stage='instance', + ngf=64, + output_nc=2, + # avg_loss_alpha=.986, + ab_norm=ab_norm, + l_norm=100., + l_cent=50., + sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], + mask_cent=.5, + init_type='normal', + fusion_weight_path='../checkpoints/coco_finetuned_mask_256_ffs', + which_direction='AtoB', + loss=dict(type='HuberLoss', delta=1. / ab_norm), + instance_model=dict( + type='SIGGRAPHGenerator', + input_nc=4, + output_nc=2, + )) + +train_cfg = dict(disc_step=1) +test_cfg = dict(metrics=['psnr', 'ssim']) +input_shape = (256, 256) + +train_pipeline = [ + dict(type='LoadImageFromFile', key='gt_img'), + dict(type='LoadBboxFromFile', key='instance', stage='instance'), + dict( + type='GenGrayColorPil', stage='instance', keys=['rgb_img', + 'gray_img']), + dict( + type='Resize', + keys=['rgb_img', 'gray_img'], + scale=input_shape, + keep_ratio=False, + interpolation='nearest'), + dict( + type='Collect', + keys=['instance', 'rgb_img', 'gray_img'], + meta_keys=['gt_img_path']), + dict(type='ImageToTensor', keys=['instance', 'rgb_img', 'gray_img']) +] + +dataset_type = 'COCOStuff_Instance_Dataset' +data_root = '/mnt/cache/share_data/zhangwenwei/data/coco/train2017' + +npz_root = '/mnt/cache/yuruoning.vendor/data' + +data = dict( + workers_per_gpu=2, + train_dataloader=dict(samples_per_gpu=1, drop_last=True), + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), + train=dict( + type=dataset_type, + ann_file=f'{npz_root}/img_list.txt', + data_prefix=data_root, + npz_prefix=f'{npz_root}/train_bbox/train2017_bbox', + pipeline=train_pipeline, + test_mode=False)) + +optimizers = dict(generator=dict(type='Adam', lr=0.0001, betas=(0.9, 0.999)), ) +lr_config = dict(policy='Fixed', by_epoch=False) + +checkpoint_config = dict(by_epoch=False, interval=10000) + +log_config = dict( + interval=100, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + dict(type='TensorboardLoggerHook'), + ]) + +visual_config = dict( + type='VisualizationHook', + output_dir='visual', + interval=100, + bgr2rgb=False, + res_name_list=[ + 'gray', 'real', 'fake_reg', 'hint', 'real_ab', 'fake_ab_reg' + ], +) + +total_iters = 500002 +dist_params = dict(backend='nccl') +load_from = None +resume_from = None +work_dir = '..' +log_level = 'INFO' +workflow = [('train', 10000)] +exp_name = 'Instance-aware' +find_unused_parameters = True diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py new file mode 100644 index 0000000000..1654895be2 --- /dev/null +++ b/demo/colorization_demo.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import mmcv +import torch + +from mmedit.apis import colorization_inference, init_model +from mmedit.utils import modify_args + + +def parse_args(): + modify_args() + parser = argparse.ArgumentParser(description='Colorzation demo') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoints', help='checkpoints file path') + parser.add_argument('img_path', help='path to input image file') + # parser.add_argument('bbox_path', help='path to input image bbox file') + parser.add_argument('save_path', help='path to save generation result') + parser.add_argument( + '--unpaired-path', default=None, help='path to unpaired image file') + parser.add_argument( + '--imshow', action='store_true', help='whether show image with opencv') + parser.add_argument('--device', type=int, default=0, help='CUDA device id') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.device < 0 or not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda', args.device) + + # + model = init_model(args.config, args.checkpoints, device=device) + output = colorization_inference(model, args.img_path, args.bbox_path) + + if args.imshow: + mmcv.imshow(output, 'predicted generation result') + + +if __name__ == '__main__': + main() diff --git a/mmedit/apis/__init__.py b/mmedit/apis/__init__.py index 63989da131..2f23b65cb1 100644 --- a/mmedit/apis/__init__.py +++ b/mmedit/apis/__init__.py @@ -8,6 +8,7 @@ from .restoration_video_inference import restoration_video_inference from .translation_inference import sample_img2img_model from .video_interpolation_inference import video_interpolation_inference +from .colorization_inference import colorization_inference __all__ = [ 'init_model', @@ -22,4 +23,5 @@ 'sample_conditional_model', 'sample_unconditional_model', 'sample_img2img_model', + 'colorization_inference' ] diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py new file mode 100644 index 0000000000..8f1b05da67 --- /dev/null +++ b/mmedit/apis/colorization_inference.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmengine.dataset import Compose +from mmengine.dataset.utils import default_collate as collate +from torch.nn.parallel import scatter + + +def colorization_inference(model, img, bbox): + + device = next(model.parameters()).device + + # build the data pipeline + test_pipeline = Compose(model.cfg.test_pipeline) + # prepare data + data = dict(gt_path=img, bbox_path=bbox) + data = test_pipeline(data) + data = collate([data]) + + if 'cuda' in str(device): + data = scatter(data, [device])[0] + # forward the model + with torch.no_grad(): + result = model(mode='predict', **data) + + return result['fake_img'] diff --git a/mmedit/datasets/coco.py b/mmedit/datasets/coco.py new file mode 100644 index 0000000000..4390c43875 --- /dev/null +++ b/mmedit/datasets/coco.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Callable, Optional, Union +from pathlib import Path + +from mmengine.dataset import BaseDataset +from mmengine.fileio import load + +from mmedit.registry import DATASETS + + +@DATASETS.register_module() +class CocoDataset: + """Dataset for COCO.""" + + METAINFO = { + 'CLASSES': + ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), + # PALETTE is a list of color tuples, which is used for visualization. + 'PALETTE': + [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), + (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), + (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), + (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), + (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), + (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), + (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), + (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), + (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), + (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), + (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), + (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), + (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), + (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), + (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), + (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), + (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), + (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), + (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), + (246, 0, 122), (191, 162, 208)] + } + + METAINFO = dict(dataset_type='colorization_dataset', task_name='colorization') + + def __init__( + self, + ann_file: str, + data_prefix + ): + + self.ann_file = str(ann_file) + self.data_prefix = data_prefix + self.data_infos = self.load_annotations() + + def load_annotations(self): + """Load annotations for dataset. + + Returns: + list[dict]: Contain dataset annotations. + """ + with open(self.ann_file, 'r') as f: + img_infos = [] + for idx, line in enumerate(f): + line = line.strip() + _info = dict() + img_path = line.split(' ')[0].split('/')[1] + _info = dict( + gt_img_path=Path( + self.data_prefix).joinpath(img_path).as_posix(), + gt_img_idx=idx) + img_infos.append(_info) + + return img_infos diff --git a/mmedit/datasets/transforms/__init__.py b/mmedit/datasets/transforms/__init__.py index f5eb2a02c6..2bb927de7a 100644 --- a/mmedit/datasets/transforms/__init__.py +++ b/mmedit/datasets/transforms/__init__.py @@ -28,6 +28,8 @@ from .trimap import (FormatTrimap, GenerateTrimap, GenerateTrimapWithDistTransform, TransformTrimap) from .values import CopyValues, SetValues +from .get_maskrcnn_bbox import GenMaskRCNNBbox +from .get_gray_color_pil import GenGrayColorPil __all__ = [ 'BinarizeImage', 'Clip', 'ColorJitter', 'CopyValues', 'Crop', 'CropLike', @@ -45,5 +47,6 @@ 'GenerateSoftSeg', 'FormatTrimap', 'TransformTrimap', 'GenerateTrimap', 'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg', 'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile', - 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad' + 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'GenMaskRCNNBbox', + 'GenGrayColorPil' ] diff --git a/mmedit/datasets/transforms/get_gray_color_pil.py b/mmedit/datasets/transforms/get_gray_color_pil.py new file mode 100644 index 0000000000..91dd763c35 --- /dev/null +++ b/mmedit/datasets/transforms/get_gray_color_pil.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import cv2 +from mmcv.transforms.base import BaseTransform + +from mmedit.registry import TRANSFORMS + +@TRANSFORMS.register_module() +class GenGrayColorPil(BaseTransform): + + def __init__(self, stage, keys): + self.stage = stage + self.keys = keys + + def transform(self, results): + + if self.stage == 'instance': + rgb_img = results['instance'] + else: + rgb_img = results['gt_img'] + if len(rgb_img.shape) == 2: + rgb_img = np.stack([rgb_img, rgb_img, rgb_img], 2) + gray_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY) + gray_img = np.stack([gray_img, gray_img, gray_img], -1) + + results[self.keys[0]] = rgb_img + results[self.keys[1]] = gray_img + + return results \ No newline at end of file diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py new file mode 100644 index 0000000000..493caf51ab --- /dev/null +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from random import sample + +import cv2 as cv +import numpy as np +import torch +import torchvision.transforms as transforms +from detectron2 import model_zoo +from detectron2.config import get_cfg +from detectron2.engine import DefaultPredictor +from PIL import Image +from skimage import color + +from mmedit.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class GenMaskRCNNBbox: + + def __init__(self, key='gt', stage='test_fusion', finesize=256): + self.key = key + self.predictor = self.detectron() + self.stage = stage + self.final_size = finesize + self.transforms = transforms.Compose([ + transforms.Resize((self.final_size, self.final_size), + interpolation=2), + transforms.ToTensor() + ]) + + def gen_maskrcnn_bbox_fromPred(self, + img, + bbox_path=None, + box_num_upbound=8): + ''' + ## Arguments: + - pred_data_path: Detectron2 predict results + - box_num_upbound: object bounding boxes number. + Default: -1 means use all the instances. + ''' + if bbox_path: + pred_data = np.load(bbox_path) + pred_bbox = pred_data['bbox'].astype(np.int32) + pred_scores = pred_data['scores'] + else: + lab_image = cv.cvtColor(img, cv.COLOR_BGR2LAB) + l_channel, a_channel, b_channel = cv.split(lab_image) + l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) + outputs = self.predictor(l_stack) + pred_bbox = outputs['instances'].pred_boxes.to( + torch.device('cpu')).tensor.numpy() + pred_scores = outputs['instances'].scores.cpu().data.numpy() + + pred_bbox = pred_bbox.astype(np.int32) + if 0 < box_num_upbound < pred_bbox.shape[0]: + index_mask = np.argsort( + pred_scores, axis=0)[pred_scores.shape[0] - + box_num_upbound:pred_scores.shape[0]] + pred_bbox = pred_bbox[index_mask] + + return pred_bbox + + @staticmethod + def gen_gray_color_pil(rgb_img): + ''' + return: RGB and GRAY pillow image object + ''' + if len(np.asarray(rgb_img).shape) == 2: + rgb_img = np.stack([ + np.asarray(rgb_img), + np.asarray(rgb_img), + np.asarray(rgb_img) + ], 2) + rgb_img = Image.fromarray(rgb_img) + gray_img = np.round(color.rgb2gray(np.asarray(rgb_img)) * + 255.0).astype(np.uint8) + gray_img = np.stack([gray_img, gray_img, gray_img], -1) + gray_img = Image.fromarray(gray_img) + return rgb_img, gray_img + + @staticmethod + def read_to_pil(out_img): + ''' + return: pillow image object HxWx3 + ''' + out_img = Image.fromarray(out_img) + if len(np.asarray(out_img).shape) == 2: + out_img = np.stack([ + np.asarray(out_img), + np.asarray(out_img), + np.asarray(out_img) + ], 2) + out_img = Image.fromarray(out_img) + return out_img + + @staticmethod + def get_box_info(pred_bbox, original_shape, final_size): + assert len(pred_bbox) == 4 + resize_startx = int(pred_bbox[0] / original_shape[0] * final_size) + resize_starty = int(pred_bbox[1] / original_shape[1] * final_size) + resize_endx = int(pred_bbox[2] / original_shape[0] * final_size) + resize_endy = int(pred_bbox[3] / original_shape[1] * final_size) + rh = resize_endx - resize_startx + rw = resize_endy - resize_starty + if rh < 1: + if final_size - resize_endx > 1: + resize_endx += 1 + else: + resize_startx -= 1 + rh = 1 + if rw < 1: + if final_size - resize_endy > 1: + resize_endy += 1 + else: + resize_starty -= 1 + rw = 1 + L_pad = resize_startx + R_pad = final_size - resize_endx + T_pad = resize_starty + B_pad = final_size - resize_endy + return [L_pad, R_pad, T_pad, B_pad, rh, rw] + + def test_fusion(self, results): + img = results['gt'] + pil_img = self.read_to_pil(img) + if results['bbox_path']: + pred_bbox = self.gen_maskrcnn_bbox_fromPred( + img, results['bbox_path'], box_num_upbound=8) + else: + pred_bbox = self.gen_maskrcnn_bbox_fromPred(img, box_num_upbound=8) + + img_list = [self.transforms(pil_img)] # 这里删除了一个transform + + cropped_img_list = [] + index_list = range(len(pred_bbox)) + box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros( + (4, len(index_list), 6)) + for i in index_list: + startx, starty, endx, endy = pred_bbox[i] + box_info[i] = np.array( + self.get_box_info(pred_bbox[i], pil_img.size, self.final_size)) + box_info_2x[i] = np.array( + self.get_box_info(pred_bbox[i], pil_img.size, + self.final_size // 2)) + box_info_4x[i] = np.array( + self.get_box_info(pred_bbox[i], pil_img.size, + self.final_size // 4)) + box_info_8x[i] = np.array( + self.get_box_info(pred_bbox[i], pil_img.size, + self.final_size // 8)) + cropped_img = self.transforms( + pil_img.crop((startx, starty, endx, endy))) + cropped_img_list.append(cropped_img) + + results['full_img'] = torch.stack(img_list) + # output['file_id'] = self.IMAGE_ID_LIST[index].split('.')[0] + if len(pred_bbox) > 0: + results['cropped_img'] = torch.stack(cropped_img_list) + results['box_info'] = torch.from_numpy(box_info).type(torch.long) + results['box_info_2x'] = torch.from_numpy(box_info_2x).type( + torch.long) + results['box_info_4x'] = torch.from_numpy(box_info_4x).type( + torch.long) + results['box_info_8x'] = torch.from_numpy(box_info_8x).type( + torch.long) + results['empty_box'] = False + else: + results['empty_box'] = True + print('full_img:', results['full_img'].size) + # print("cropped_img:", results['cropped_img'].size) + return results + + def train(self, results): + img = results[self.key] + if results['bbox_path']: + pred_bbox = self.gen_maskrcnn_bbox_fromPred( + img, results['bbox_path']) + else: + pred_bbox = self.gen_maskrcnn_bbox_fromPred(img) + rgb_img, gray_img = self.gen_gray_color_pil(img) + index_list = range(len(pred_bbox)) + index_list = sample(index_list, 1) + startx, starty, endx, endy = pred_bbox[index_list[0]] + + results['rgb_img'] = self.transforms( + rgb_img.crop((startx, starty, endx, endy))) + results['gray_img'] = self.transforms( + gray_img.crop((startx, starty, endx, endy))) + + return results + + def __call__(self, results): + + if self.stage == 'test_fusion': + results = self.test_fusion(results) + + if self.stage == 'train': + results = self.train(results) + + return results + + def detectron(self): + cfg = get_cfg() + cfg.merge_from_file( + model_zoo.get_config_file( + 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml')) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 + cfg.MODEL.WEIGHTS = '/mnt/d/code/MMEditing/model_final_2d9806.pkl' + predictor = DefaultPredictor(cfg) + return predictor diff --git a/mmedit/models/base_models/__init__.py b/mmedit/models/base_models/__init__.py index ff1107fb67..60bcdafad3 100644 --- a/mmedit/models/base_models/__init__.py +++ b/mmedit/models/base_models/__init__.py @@ -8,9 +8,11 @@ from .basic_interpolator import BasicInterpolator from .one_stage import OneStageInpaintor from .two_stage import TwoStageInpaintor +from .base_colorization import BaseColorization __all__ = [ 'BaseEditModel', 'BaseGAN', 'BaseConditionalGAN', 'BaseMattor', 'BasicInterpolator', 'BaseTranslationModel', 'OneStageInpaintor', - 'TwoStageInpaintor', 'ExponentialMovingAverage', 'RampUpEMA' + 'TwoStageInpaintor', 'ExponentialMovingAverage', 'RampUpEMA', + 'BaseColorization' ] diff --git a/mmedit/models/base_models/base_colorization.py b/mmedit/models/base_models/base_colorization.py new file mode 100644 index 0000000000..37183bc18f --- /dev/null +++ b/mmedit/models/base_models/base_colorization.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torchvision.utils import save_image +from mmengine.model import BaseModel +from mmengine.config import Config, ConfigDict + +from mmedit.registry import MODELS + + +class BaseColorization(BaseModel, metaclass=ABCMeta): + + def __init__(self, + data_preprocessor: Union[dict, Config], + loss, + init_cfg: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None): + + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + self.loss = MODELS.build(loss) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[Union[list, torch.Tensor]] = None, + mode: str = 'tensor', + **kwargs): + + if mode == 'tensor': + return self.forward_tensor(inputs, data_samples, **kwargs) + + elif mode == 'predict': + predictions = self.forward_test(inputs, data_samples, **kwargs) + predictions = self.convert_to_datasample(data_samples, predictions) + return predictions + + elif mode == 'loss': + return self.forward_train(inputs, data_samples, **kwargs) + + def forward_train(self, *args, **kwargs): + pass + + def forward_test(self, input, data_samples, **kwargs): + pass + + def train_step(self, data_batch, optimizer): + pass + + def init_weights(self): + pass + + def save_visualization(self, img, filename): + save_image(img, filename) diff --git a/mmedit/models/editors/__init__.py b/mmedit/models/editors/__init__.py index a5dd2252b2..442538599b 100644 --- a/mmedit/models/editors/__init__.py +++ b/mmedit/models/editors/__init__.py @@ -50,6 +50,7 @@ from .tof import TOFlowVFINet, TOFlowVSRNet, ToFResBlock from .ttsr import LTE, TTSR, SearchTransformer, TTSRDiscriminator, TTSRNet from .wgan_gp import WGANGP +from .insta import INSTA __all__ = [ 'AOTEncoderDecoder', 'AOTBlockNeck', 'AOTInpaintor', diff --git a/mmedit/models/editors/insta/__init__.py b/mmedit/models/editors/insta/__init__.py new file mode 100644 index 0000000000..10286302b3 --- /dev/null +++ b/mmedit/models/editors/insta/__init__.py @@ -0,0 +1,6 @@ +from .insta import INSTA +from .insta_net import (SIGGRAPHGenerator, InstanceGenerator, FusionGenerator) + +__all__ = [ + 'INSTA', 'SIGGRAPHGenerator', 'InstanceGenerator', 'FusionGenerator' +] \ No newline at end of file diff --git a/mmedit/models/editors/insta/insta.py b/mmedit/models/editors/insta/insta.py new file mode 100644 index 0000000000..180dcd5e4e --- /dev/null +++ b/mmedit/models/editors/insta/insta.py @@ -0,0 +1,454 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from collections import OrderedDict +from typing import Union + +import torch +from mmengine.config import Config + +from mmedit.models.utils import generation_init_weights +from mmedit.models.base_models import BaseColorization +from mmedit.registry import BACKBONES, COMPONENTS + +from .util import encode_ab_ind, get_colorization_data, lab2rgb + + +@BACKBONES.register_module() +class INSTA(BaseColorization): + + def __init__(self, + data_preprocessor: Union[dict, Config], + ngf, + output_nc, + avg_loss_alpha, + ab_norm, + l_norm, + l_cent, + sample_Ps, + mask_cent, + stage=None, + which_direction='AtoB', + instance_model=None, + full_model=None, + fusion_model=None, + loss=None, + init_cfg=None, + train_cfg=None, + test_cfg=None): + super(INSTA, self).__init__( + data_preprocessor=data_preprocessor, + loss=loss, + init_cfg=init_cfg, + train_cfg=train_cfg, + test_cfg=test_cfg + ) + + self.ngf = ngf + self.output_nc = output_nc + self.avg_loss_alpha = avg_loss_alpha + self.ab_norm = ab_norm + self.l_norm = l_norm + self.l_cent = l_cent + self.sample_Ps = sample_Ps + self.mask_cent = mask_cent + self.which_direction = which_direction + + self.device = torch.device('cuda:{}'.format( + self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + + self.instance_model = instance_model + self.full_model = full_model + self.fusion_model = fusion_model + + self.stage = stage + + if self.stage is not None: + self.training = False + self.setup_to_train() + else: + self.setup_to_test() + + def set_input(self, input): + AtoB = self.which_direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) + self.hint_B = input['hint_B'].to(self.device) + + self.mask_B = input['mask_B'].to(self.device) + self.mask_B_nc = self.mask_B + self.mask_cent + + self.real_B_enc = encode_ab_ind( + self.real_B[:, :, ::4, ::4], + ab_norm=self.ab_norm, + ab_max=self.ab_max, + ab_quant=self.ab_quant) + + def set_fusion_input(self, input, box_info): + AtoB = self.which_direction == 'AtoB' + self.full_real_A = input['A' if AtoB else 'B'].to(self.device) + self.full_real_B = input['B' if AtoB else 'A'].to(self.device) + + self.full_hint_B = input['hint_B'].to(self.device) + self.full_mask_B = input['mask_B'].to(self.device) + + self.full_mask_B_nc = self.full_mask_B + self.mask_cent + self.full_real_B_enc = encode_ab_ind( + self.full_real_B[:, :, ::4, ::4], + ab_norm=self.ab_norm, + ab_max=self.ab_max, + ab_quant=self.ab_quant) + self.box_info_list = box_info + + def set_forward_without_box(self, input): + AtoB = self.which_direction == 'AtoB' + self.full_real_A = input['A' if AtoB else 'B'].to(self.device) + self.full_real_B = input['B' if AtoB else 'A'].to(self.device) + # self.image_paths = input['A_paths' if AtoB else 'B_paths'] + self.full_hint_B = input['hint_B'].to(self.device) + self.full_mask_B = input['mask_B'].to(self.device) + self.full_mask_B_nc = self.full_mask_B + self.mask_cent + self.full_real_B_enc = encode_ab_ind(self.full_real_B[:, :, ::4, ::4], + self) + + (_, self.comp_B_reg) = self.netGComp(self.full_real_A, + self.full_hint_B, + self.full_mask_B) + self.fake_B_reg = self.comp_B_reg + + def generator_loss(self): + if self.stage == 'full' or self.stage == 'instance': + self.loss_L1 = torch.mean( + self.criterionL1( + self.fake_B_reg.type(torch.cuda.FloatTensor), + self.real_B.type(torch.cuda.FloatTensor))) + self.loss_G = 10 * torch.mean( + self.criterionL1( + self.fake_B_reg.type(torch.cuda.FloatTensor), + self.real_B.type(torch.cuda.FloatTensor))) + + elif self.stage == 'fusion': + self.loss_L1 = torch.mean( + self.criterionL1( + self.fake_B_reg.type(torch.cuda.FloatTensor), + self.full_real_B.type(torch.cuda.FloatTensor))) + self.loss_G = 10 * torch.mean( + self.criterionL1( + self.fake_B_reg.type(torch.cuda.FloatTensor), + self.full_real_B.type(torch.cuda.FloatTensor))) + else: + print('Error! Wrong stage selection!') + exit() + + self.error_cnt += 1 + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + # float(...) works for both scalar tensor and float number + self.avg_losses[name] = float(getattr( + self, 'loss_' + + name)) + self.avg_loss_alpha * self.avg_losses[name] + errors_ret[name] = (1 - self.avg_loss_alpha) / ( + 1 - self.avg_loss_alpha** + self.error_cnt) * self.avg_losses[name] + + return errors_ret + + def train_step(self, data_batch, optimizer): + + log_vars = {} + + colorization_data_opt = dict( + ab_thresh=0, + ab_norm=self.ab_norm, + l_norm=self.l_norm, + l_cent=self.l_cent, + sample_PS=self.sample_Ps, + mask_cent=self.mask_cent, + ) + + if self.stage == 'full' or self.stage == 'instance': + data_batch['rgb_img'] = [data_batch['rgb_img']] + data_batch['gray_img'] = [data_batch['gray_img']] + + input_data = get_colorization_data(data_batch['gray_img'], + **colorization_data_opt) + + gt_data = get_colorization_data(data_batch['rgb_img'], + **colorization_data_opt) + + input_data['B'] = gt_data['B'] + input_data['hint_B'] = gt_data['hint_B'] + input_data['mask_B'] = gt_data['mask_B'] + self.set_input(input_data) + (_, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, + self.mask_B) + + elif self.stage == 'fusion': + + data_batch['cropped_rgb'] = torch.stack( + data_batch['cropped_rgb_list']) + data_batch['cropped_gray'] = torch.stack( + data_batch['cropped_gray_list']) + data_batch['full_rgb'] = torch.stack(data_batch['full_rgb_list']) + data_batch['full_gray'] = torch.stack(data_batch['full_gray_list']) + data_batch['box_info'] = torch.from_numpy( + data_batch['box_info']).type(torch.long) + data_batch['box_info_2x'] = torch.from_numpy( + data_batch['box_info_2x']).type(torch.long) + data_batch['box_info_4x'] = torch.from_numpy( + data_batch['box_info_4x']).type(torch.long) + data_batch['box_info_8x'] = torch.from_numpy( + data_batch['box_info_8x']).type(torch.long) + + box_info = data_batch['box_info'][0] + box_info_2x = data_batch['box_info_2x'][0] + box_info_4x = data_batch['box_info_4x'][0] + box_info_8x = data_batch['box_info_8x'][0] + + cropped_input_data = get_colorization_data( + data_batch['cropped_gray'], **colorization_data_opt) + cropped_gt_data = get_colorization_data(data_batch['cropped_rgb'], + **colorization_data_opt) + full_input_data = get_colorization_data(data_batch['full_gray'], + **colorization_data_opt) + full_gt_data = get_colorization_data(data_batch['full_rgb'], + **colorization_data_opt) + + cropped_input_data['B'] = cropped_gt_data['B'] + full_input_data['B'] = full_gt_data['B'] + + self.set_input(cropped_input_data) + self.set_fusion_input( + full_input_data, + [box_info, box_info_2x, box_info_4x, box_info_8x]) + + (_, self.comp_B_reg) = self.netGComp(self.full_real_A, + self.full_hint_B, + self.full_mask_B) + (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B) + self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, + self.full_mask_B, feature_map, + self.box_info_list) + + optimizer['generator'].zero_grad() + + loss = self.generator_loss() + + loss_d, log_vars_d = self.parse_losses(loss) + log_vars.update(log_vars_d) + + loss_d.backward() + + optimizer['generator'].step() + + results = self.get_current_visuals() + + output = dict( + log_vars=log_vars, + num_samples=len(data_batch['rgb_img']), + results=results) + + return output + + def setup_to_train(self): + + self.loss_names = ['G', 'L1'] + + if self.stage == 'full' or self.stage == 'instance': + self.model_names = ['G'] + self.netG = COMPONENTS.build(self.instance_model) + generation_init_weights(self.netG) + self.generator = self.netG + + elif self.stage == 'fusion': + self.model_names = ['G', 'GF', 'GComp'] + self.netG = COMPONENTS.build(self.instance_model) + generation_init_weights(self.netG) + self.netG.eval() + + self.netGF = COMPONENTS.build(self.fusion_model) + generation_init_weights(self.netGF) + self.netGF.eval() + + self.netGComp = COMPONENTS.build(self.full_model) + generation_init_weights(self.netGComp) + self.netGComp.eval() + + self.generator = \ + list(self.netGF.module.weight_layer.parameters()) + \ + list(self.netGF.module.weight_layer2.parameters()) + \ + list(self.netGF.module.weight_layer3.parameters()) + \ + list(self.netGF.module.weight_layer4.parameters()) + \ + list(self.netGF.module.weight_layer5.parameters()) + \ + list(self.netGF.module.weight_layer6.parameters()) + \ + list(self.netGF.module.weight_layer7.parameters()) + \ + list(self.netGF.module.weight_layer8_1.parameters()) + \ + list(self.netGF.module.weight_layer8_2.parameters()) + \ + list(self.netGF.module.weight_layer9_1.parameters()) + \ + list(self.netGF.module.weight_layer9_2.parameters()) + \ + list(self.netGF.module.weight_layer10_1.parameters()) + \ + list(self.netGF.module.weight_layer10_2.parameters()) + \ + list(self.netGF.module.model10.parameters()) + \ + list(self.netGF.module.model_out.parameters()) + + else: + print('Error Stage!') + exit() + + self.criterionL1 = self.loss + + # initialize average loss values + self.avg_losses = OrderedDict() + # self.avg_loss_alpha = self.avg_loss_alpha + self.error_cnt = 0 + for loss_name in self.loss_names: + self.avg_losses[loss_name] = 0 + + def get_current_visuals(self): + from collections import OrderedDict + visual_ret = OrderedDict() + opt = dict( + ab_norm=self.ab_norm, l_norm=self.l_norm, l_cent=self.l_cent) + if self.stage == 'full' or self.stage == 'instance': + + visual_ret['gray'] = lab2rgb( + torch.cat((self.real_A.type( + torch.cuda.FloatTensor), torch.zeros_like( + self.real_B).type(torch.cuda.FloatTensor)), + dim=1), **opt) + visual_ret['real'] = lab2rgb( + torch.cat((self.real_A.type(torch.cuda.FloatTensor), + self.real_B.type(torch.cuda.FloatTensor)), + dim=1), **opt) + visual_ret['fake_reg'] = lab2rgb( + torch.cat((self.real_A.type(torch.cuda.FloatTensor), + self.fake_B_reg.type(torch.cuda.FloatTensor)), + dim=1), **opt) + + visual_ret['hint'] = lab2rgb( + torch.cat((self.real_A.type(torch.cuda.FloatTensor), + self.hint_B.type(torch.cuda.FloatTensor)), + dim=1), **opt) + visual_ret['real_ab'] = lab2rgb( + torch.cat((torch.zeros_like( + self.real_A.type(torch.cuda.FloatTensor)), + self.real_B.type(torch.cuda.FloatTensor)), + dim=1), **opt) + visual_ret['fake_ab_reg'] = lab2rgb( + torch.cat((torch.zeros_like( + self.real_A.type(torch.cuda.FloatTensor)), + self.fake_B_reg.type(torch.cuda.FloatTensor)), + dim=1), **opt) + + elif self.stage == 'fusion': + visual_ret['gray'] = lab2rgb( + torch.cat((self.full_real_A.type( + torch.cuda.FloatTensor), torch.zeros_like( + self.full_real_B).type(torch.cuda.FloatTensor)), + dim=1), **opt) + visual_ret['real'] = lab2rgb( + torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), + self.full_real_B.type(torch.cuda.FloatTensor)), + dim=1), **opt) + visual_ret['comp_reg'] = lab2rgb( + torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), + self.comp_B_reg.type(torch.cuda.FloatTensor)), + dim=1), + ab_norm=self.ab_norm, + l_norm=self.l_norm, + l_cent=self.l_cent) + visual_ret['fake_reg'] = lab2rgb( + torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), + self.fake_B_reg.type(torch.cuda.FloatTensor)), + dim=1), **opt) + + self.instance_mask = torch.nn.functional.interpolate( + torch.zeros([1, 1, 176, 176]), + size=visual_ret['gray'].shape[2:], + mode='bilinear').type(torch.cuda.FloatTensor) + visual_ret['box_mask'] = torch.cat( + (self.instance_mask, self.instance_mask, self.instance_mask), + 1) + visual_ret['real_ab'] = lab2rgb( + torch.cat((torch.zeros_like( + self.full_real_A.type(torch.cuda.FloatTensor)), + self.full_real_B.type(torch.cuda.FloatTensor)), + dim=1), **opt) + visual_ret['comp_ab_reg'] = lab2rgb( + torch.cat((torch.zeros_like( + self.full_real_A.type(torch.cuda.FloatTensor)), + self.comp_B_reg.type(torch.cuda.FloatTensor)), + dim=1), **opt) + visual_ret['fake_ab_reg'] = lab2rgb( + torch.cat((torch.zeros_like( + self.full_real_A.type(torch.cuda.FloatTensor)), + self.fake_B_reg.type(torch.cuda.FloatTensor)), + dim=1), **opt) + else: + print('Error! Wrong stage selection!') + exit() + return visual_ret + + def forward_test(self, **kwargs): + output = dict() + kwargs['full_img'][0] = kwargs['full_img'][0].cuda() + if not kwargs['empty_box']: + kwargs['cropped_img'][0] = kwargs['cropped_img'][0].cuda() + box_info = kwargs['box_info'][0] + box_info_2x = kwargs['box_info_2x'][0] + box_info_4x = kwargs['box_info_4x'][0] + box_info_8x = kwargs['box_info_8x'][0] + cropped_data = get_colorization_data( + kwargs['cropped_img'], + ab_thresh=0, + ab_norm=self.ab_norm, + l_norm=self.l_norm, + l_cent=self.l_cent, + sample_PS=self.sample_Ps, + mask_cent=self.mask_cent, + ) + full_img_data = get_colorization_data( + kwargs['full_img'], + ab_thresh=0, + ab_norm=self.ab_norm, + l_norm=self.l_norm, + l_cent=self.l_cent, + sample_PS=self.sample_Ps, + mask_cent=self.mask_cent, + ) + self.set_input(cropped_data) + self.set_fusion_input( + full_img_data, + [box_info, box_info_2x, box_info_4x, box_info_8x]) + else: + full_img_data = get_colorization_data( + kwargs['full_img'], ab_thresh=0) + self.set_forward_without_box(full_img_data) + + (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B) + self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, + self.full_mask_B, feature_map, + self.box_info_list) + + out_img = torch.clamp( + lab2rgb( + torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), + self.fake_B_reg.type(torch.cuda.FloatTensor)), + dim=1), + ab_norm=self.ab_norm, + l_norm=self.l_norm, + l_cent=self.l_cent), 0.0, 1.0) + + output['fake_img'] = out_img + output['meta'] = None if 'meta' not in kwargs else kwargs['meta'][0] + + self.save_visualization(out_img, + '/mnt/ruoning/results/output_mmedit11.png') + return output + + def setup_to_test(self): + self.netG = COMPONENTS.build(self.instance_model) + generation_init_weights(self.netG, self.init_type) + + self.netGF = COMPONENTS.build(self.fusion_model) + generation_init_weights(self.netGF, self.init_type) diff --git a/mmedit/models/editors/insta/insta_net.py b/mmedit/models/editors/insta/insta_net.py new file mode 100644 index 0000000000..2dc207165a --- /dev/null +++ b/mmedit/models/editors/insta/insta_net.py @@ -0,0 +1,1497 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch +import torch.nn as nn + +from mmedit.registry import BACKBONES + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % + norm_type) + return norm_layer + + +@BACKBONES.register_module() +class SIGGRAPHGenerator(nn.Module): + + def __init__(self, + input_nc, + output_nc, + norm_type, + use_tanh=True, + classification=True): + super(SIGGRAPHGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.classification = classification + + norm_layer = get_norm_layer(norm_type) + + use_bias = True + + # Conv1 + # model1=[nn.ReflectionPad2d(1),] + model1 = [ + nn.Conv2d( + input_nc, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias), + ] + # model1+=[norm_layer(64),] + model1 += [ + nn.ReLU(True), + ] + # model1+=[nn.ReflectionPad2d(1),] + model1 += [ + nn.Conv2d( + 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + norm_layer(64), + ] + # add a subsampling operation + + # Conv2 + # model2=[nn.ReflectionPad2d(1),] + model2 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model2+=[norm_layer(128),] + model2 += [ + nn.ReLU(True), + ] + # model2+=[nn.ReflectionPad2d(1),] + model2 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + norm_layer(128), + ] + # add a subsampling layer operation + + # Conv3 + # model3=[nn.ReflectionPad2d(1),] + model3 = [ + nn.Conv2d( + 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + norm_layer(256), + ] + # add a subsampling layer operation + + # Conv4 + # model47=[nn.ReflectionPad2d(1),] + model4 = [ + nn.Conv2d( + 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + norm_layer(512), + ] + + # Conv5 + # model47+=[nn.ReflectionPad2d(2),] + model5 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + norm_layer(512), + ] + + # Conv6 + # model6+=[nn.ReflectionPad2d(2),] + model6 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + norm_layer(512), + ] + + # Conv7 + # model47+=[nn.ReflectionPad2d(1),] + model7 = [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + norm_layer(512), + ] + + # Conv7 + model8up = [ + nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) + ] + + # model3short8=[nn.ReflectionPad2d(1),] + model3short8 = [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + + # model47+=[norm_layer(256),] + model8 = [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model8+=[norm_layer(256),] + model8 += [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + norm_layer(256), + ] + + # Conv9 + model9up = [ + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model2short9=[nn.ReflectionPad2d(1),] + model2short9 = [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + # model9=[norm_layer(128),] + model9 = [ + nn.ReLU(True), + ] + # model9+=[nn.ReflectionPad2d(1),] + model9 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model9 += [ + nn.ReLU(True), + ] + model9 += [ + norm_layer(128), + ] + + # Conv10 + model10up = [ + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model1short10=[nn.ReflectionPad2d(1),] + model1short10 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + # model10=[norm_layer(128),] + model10 = [ + nn.ReLU(True), + ] + # model10+=[nn.ReflectionPad2d(1),] + model10 += [ + nn.Conv2d( + 128, + 128, + kernel_size=3, + dilation=1, + stride=1, + padding=1, + bias=use_bias), + ] + model10 += [ + nn.LeakyReLU(negative_slope=.2), + ] + + # classification output + model_class = [ + nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + + # regression output + model_out = [ + nn.Conv2d( + 128, + 2, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + if (use_tanh): + model_out += [nn.Tanh()] + + self.model1 = nn.Sequential(*model1) + self.model2 = nn.Sequential(*model2) + self.model3 = nn.Sequential(*model3) + self.model4 = nn.Sequential(*model4) + self.model5 = nn.Sequential(*model5) + self.model6 = nn.Sequential(*model6) + self.model7 = nn.Sequential(*model7) + self.model8up = nn.Sequential(*model8up) + self.model8 = nn.Sequential(*model8) + self.model9up = nn.Sequential(*model9up) + self.model9 = nn.Sequential(*model9) + self.model10up = nn.Sequential(*model10up) + self.model10 = nn.Sequential(*model10) + self.model3short8 = nn.Sequential(*model3short8) + self.model2short9 = nn.Sequential(*model2short9) + self.model1short10 = nn.Sequential(*model1short10) + + self.model_class = nn.Sequential(*model_class) + self.model_out = nn.Sequential(*model_out) + + self.upsample4 = nn.Sequential(*[ + nn.Upsample(scale_factor=4, mode='nearest'), + ]) + self.softmax = nn.Sequential(*[ + nn.Softmax(dim=1), + ]) + + def forward(self, input_A, input_B, mask_B): + conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) + conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) + conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) + conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) + conv5_3 = self.model5(conv4_3) + conv6_3 = self.model6(conv5_3) + conv7_3 = self.model7(conv6_3) + conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) + conv8_3 = self.model8(conv8_up) + + if (self.classification): + out_class = self.model_class(conv8_3) + conv9_up = self.model9up(conv8_3.detach()) + self.model2short9( + conv2_2.detach()) + conv9_3 = self.model9(conv9_up) + conv10_up = self.model10up(conv9_3) + self.model1short10( + conv1_2.detach()) + conv10_2 = self.model10(conv10_up) + out_reg = self.model_out(conv10_2) + else: + out_class = self.model_class(conv8_3.detach()) + + conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) + conv9_3 = self.model9(conv9_up) + conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) + conv10_2 = self.model10(conv10_up) + out_reg = self.model_out(conv10_2) + + return (out_class, out_reg) + + +@BACKBONES.register_module() +class FusionGenerator(nn.Module): + + def __init__(self, + input_nc, + output_nc, + norm_type, + use_tanh=True, + classification=True): + super(FusionGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.classification = classification + use_bias = True + + norm_layer = get_norm_layer(norm_type) + + # Conv1 + # model1=[nn.ReflectionPad2d(1),] + model1 = [ + nn.Conv2d( + input_nc, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias), + ] + # model1+=[norm_layer(64),] + model1 += [ + nn.ReLU(True), + ] + # model1+=[nn.ReflectionPad2d(1),] + model1 += [ + nn.Conv2d( + 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + norm_layer(64), + ] + # add a subsampling operation + + self.weight_layer = WeightGenerator(64) + + # Conv2 + # model2=[nn.ReflectionPad2d(1),] + model2 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model2+=[norm_layer(128),] + model2 += [ + nn.ReLU(True), + ] + # model2+=[nn.ReflectionPad2d(1),] + model2 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + norm_layer(128), + ] + # add a subsampling layer operation + + self.weight_layer2 = WeightGenerator(128) + + # Conv3 + # model3=[nn.ReflectionPad2d(1),] + model3 = [ + nn.Conv2d( + 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + norm_layer(256), + ] + # add a subsampling layer operation + + self.weight_layer3 = WeightGenerator(256) + + # Conv4 + # model47=[nn.ReflectionPad2d(1),] + model4 = [ + nn.Conv2d( + 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + norm_layer(512), + ] + + self.weight_layer4 = WeightGenerator(512) + + # Conv5 + # model47+=[nn.ReflectionPad2d(2),] + model5 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + norm_layer(512), + ] + + self.weight_layer5 = WeightGenerator(512) + + # Conv6 + # model6+=[nn.ReflectionPad2d(2),] + model6 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + norm_layer(512), + ] + + self.weight_layer6 = WeightGenerator(512) + + # Conv7 + # model47+=[nn.ReflectionPad2d(1),] + model7 = [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + norm_layer(512), + ] + + self.weight_layer7 = WeightGenerator(512) + + # Conv7 + model8up = [ + nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) + ] + + # model3short8=[nn.ReflectionPad2d(1),] + model3short8 = [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + + self.weight_layer8_1 = WeightGenerator(256) + + # model47+=[norm_layer(256),] + model8 = [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model8+=[norm_layer(256),] + model8 += [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + norm_layer(256), + ] + + self.weight_layer8_2 = WeightGenerator(256) + + # Conv9 + model9up = [ + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model2short9=[nn.ReflectionPad2d(1),] + model2short9 = [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + self.weight_layer9_1 = WeightGenerator(128) + + # model9=[norm_layer(128),] + model9 = [ + nn.ReLU(True), + ] + # model9+=[nn.ReflectionPad2d(1),] + model9 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model9 += [ + nn.ReLU(True), + ] + model9 += [ + norm_layer(128), + ] + + self.weight_layer9_2 = WeightGenerator(128) + + # Conv10 + model10up = [ + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model1short10=[nn.ReflectionPad2d(1),] + model1short10 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + self.weight_layer10_1 = WeightGenerator(128) + + # model10=[norm_layer(128),] + model10 = [ + nn.ReLU(True), + ] + # model10+=[nn.ReflectionPad2d(1),] + model10 += [ + nn.Conv2d( + 128, + 128, + kernel_size=3, + dilation=1, + stride=1, + padding=1, + bias=use_bias), + ] + model10 += [ + nn.LeakyReLU(negative_slope=.2), + ] + + self.weight_layer10_2 = WeightGenerator(128) + + # classification output + model_class = [ + nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + + # regression output + model_out = [ + nn.Conv2d( + 128, + 2, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + if (use_tanh): + model_out += [nn.Tanh()] + + self.weight_layerout = WeightGenerator(2) + + self.model1 = nn.Sequential(*model1) + self.model2 = nn.Sequential(*model2) + self.model3 = nn.Sequential(*model3) + self.model4 = nn.Sequential(*model4) + self.model5 = nn.Sequential(*model5) + self.model6 = nn.Sequential(*model6) + self.model7 = nn.Sequential(*model7) + self.model8up = nn.Sequential(*model8up) + self.model8 = nn.Sequential(*model8) + self.model9up = nn.Sequential(*model9up) + self.model9 = nn.Sequential(*model9) + self.model10up = nn.Sequential(*model10up) + self.model10 = nn.Sequential(*model10) + self.model3short8 = nn.Sequential(*model3short8) + self.model2short9 = nn.Sequential(*model2short9) + self.model1short10 = nn.Sequential(*model1short10) + + self.model_class = nn.Sequential(*model_class) + self.model_out = nn.Sequential(*model_out) + + self.upsample4 = nn.Sequential(*[ + nn.Upsample(scale_factor=4, mode='nearest'), + ]) + self.softmax = nn.Sequential(*[ + nn.Softmax(dim=1), + ]) + + def forward(self, input_A, input_B, mask_B, instance_feature, + box_info_list): + conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) + conv1_2 = self.weight_layer(instance_feature['conv1_2'], conv1_2, + box_info_list[0]) + + conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) + conv2_2 = self.weight_layer2(instance_feature['conv2_2'], conv2_2, + box_info_list[1]) + + conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) + conv3_3 = self.weight_layer3(instance_feature['conv3_3'], conv3_3, + box_info_list[2]) + + conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) + conv4_3 = self.weight_layer4(instance_feature['conv4_3'], conv4_3, + box_info_list[3]) + + conv5_3 = self.model5(conv4_3) + conv5_3 = self.weight_layer5(instance_feature['conv5_3'], conv5_3, + box_info_list[3]) + + conv6_3 = self.model6(conv5_3) + conv6_3 = self.weight_layer6(instance_feature['conv6_3'], conv6_3, + box_info_list[3]) + + conv7_3 = self.model7(conv6_3) + conv7_3 = self.weight_layer7(instance_feature['conv7_3'], conv7_3, + box_info_list[3]) + + conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) + conv8_up = self.weight_layer8_1(instance_feature['conv8_up'], conv8_up, + box_info_list[2]) + + conv8_3 = self.model8(conv8_up) + conv8_3 = self.weight_layer8_2(instance_feature['conv8_3'], conv8_3, + box_info_list[2]) + + conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) + conv9_up = self.weight_layer9_1(instance_feature['conv9_up'], conv9_up, + box_info_list[1]) + + conv9_3 = self.model9(conv9_up) + conv9_3 = self.weight_layer9_2(instance_feature['conv9_3'], conv9_3, + box_info_list[1]) + + conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) + conv10_up = self.weight_layer10_1(instance_feature['conv10_up'], + conv10_up, box_info_list[0]) + + conv10_2 = self.model10(conv10_up) + conv10_2 = self.weight_layer10_2(instance_feature['conv10_2'], + conv10_2, box_info_list[0]) + + out_reg = self.model_out(conv10_2) + return out_reg + + +class WeightGenerator(nn.Module): + + def __init__(self, input_ch, inner_ch=16): + super(WeightGenerator, self).__init__() + self.simple_instance_conv = nn.Sequential( + nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + ) + + self.simple_bg_conv = nn.Sequential( + nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + ) + + self.normalize = nn.Softmax(1) + + def resize_and_pad(self, feauture_maps, info_array): + feauture_maps = torch.nn.functional.interpolate( + feauture_maps, + size=(info_array[5], info_array[4]), + mode='bilinear') + feauture_maps = torch.nn.functional.pad(feauture_maps, + (info_array[0], info_array[1], + info_array[2], info_array[3]), + 'constant', 0) + return feauture_maps + + def forward(self, instance_feature, bg_feature, box_info): + mask_list = [] + featur_map_list = [] + mask_sum_for_pred = torch.zeros_like(bg_feature)[:1, :1] + for i in range(instance_feature.shape[0]): + tmp_crop = torch.unsqueeze(instance_feature[i], 0) + conv_tmp_crop = self.simple_instance_conv(tmp_crop) + pred_mask = self.resize_and_pad(conv_tmp_crop, box_info[i]) + + tmp_crop = self.resize_and_pad(tmp_crop, box_info[i]) + + mask = torch.zeros_like(bg_feature)[:1, :1] + mask[0, 0, box_info[i][2]:box_info[i][2] + box_info[i][5], + box_info[i][0]:box_info[i][0] + box_info[i][4]] = 1.0 + device = mask.device + mask = mask.type(torch.FloatTensor).to(device) + + mask_sum_for_pred = torch.clamp(mask_sum_for_pred + mask, 0.0, 1.0) + + mask_list.append(pred_mask) + featur_map_list.append(tmp_crop) + + pred_bg_mask = self.simple_bg_conv(bg_feature) + mask_list.append(pred_bg_mask + (1 - mask_sum_for_pred) * 100000.0) + mask_list = self.normalize(torch.cat(mask_list, 1)) + + mask_list_maskout = mask_list.clone() + + # instance_mask = torch.clamp( + # torch.sum( + # mask_list_maskout[:, :instance_feature.shape[0]], + # 1, + # keepdim=True), 0.0, 1.0) + + featur_map_list.append(bg_feature) + featur_map_list = torch.cat(featur_map_list, 0) + mask_list_maskout = mask_list_maskout.permute(1, 0, 2, 3).contiguous() + out = featur_map_list * mask_list_maskout + out = torch.sum(out, 0, keepdim=True) + return out # , instance_mask, torch.clamp(mask_list, 0.0, 1.0) + + +@BACKBONES.register_module() +class InstanceGenerator(nn.Module): + + def __init__(self, + input_nc, + output_nc, + norm_type, + use_tanh=True, + classification=True): + super(InstanceGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.classification = classification + use_bias = True + + norm_layer = get_norm_layer(norm_type) + + # Conv1 + # model1=[nn.ReflectionPad2d(1),] + model1 = [ + nn.Conv2d( + input_nc, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias), + ] + # model1+=[norm_layer(64),] + model1 += [ + nn.ReLU(True), + ] + # model1+=[nn.ReflectionPad2d(1),] + model1 += [ + nn.Conv2d( + 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + norm_layer(64), + ] + # add a subsampling operation + + # Conv2 + # model2=[nn.ReflectionPad2d(1),] + model2 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model2+=[norm_layer(128),] + model2 += [ + nn.ReLU(True), + ] + # model2+=[nn.ReflectionPad2d(1),] + model2 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + norm_layer(128), + ] + # add a subsampling layer operation + + # Conv3 + # model3=[nn.ReflectionPad2d(1),] + model3 = [ + nn.Conv2d( + 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + norm_layer(256), + ] + # add a subsampling layer operation + + # Conv4 + # model47=[nn.ReflectionPad2d(1),] + model4 = [ + nn.Conv2d( + 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + norm_layer(512), + ] + + # Conv5 + # model47+=[nn.ReflectionPad2d(2),] + model5 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + norm_layer(512), + ] + + # Conv6 + # model6+=[nn.ReflectionPad2d(2),] + model6 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + norm_layer(512), + ] + + # Conv7 + # model47+=[nn.ReflectionPad2d(1),] + model7 = [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + norm_layer(512), + ] + + # Conv7 + model8up = [ + nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) + ] + + # model3short8=[nn.ReflectionPad2d(1),] + model3short8 = [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + + # model47+=[norm_layer(256),] + model8 = [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model8+=[norm_layer(256),] + model8 += [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + norm_layer(256), + ] + + # Conv9 + model9up = [ + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model2short9=[nn.ReflectionPad2d(1),] + model2short9 = [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + # model9=[norm_layer(128),] + model9 = [ + nn.ReLU(True), + ] + # model9+=[nn.ReflectionPad2d(1),] + model9 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model9 += [ + nn.ReLU(True), + ] + model9 += [ + norm_layer(128), + ] + + # Conv10 + model10up = [ + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model1short10=[nn.ReflectionPad2d(1),] + model1short10 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + # model10=[norm_layer(128),] + model10 = [ + nn.ReLU(True), + ] + # model10+=[nn.ReflectionPad2d(1),] + model10 += [ + nn.Conv2d( + 128, + 128, + kernel_size=3, + dilation=1, + stride=1, + padding=1, + bias=use_bias), + ] + model10 += [ + nn.LeakyReLU(negative_slope=.2), + ] + + # classification output + model_class = [ + nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + + # regression output + model_out = [ + nn.Conv2d( + 128, + 2, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + if (use_tanh): + model_out += [nn.Tanh()] + + self.model1 = nn.Sequential(*model1) + self.model2 = nn.Sequential(*model2) + self.model3 = nn.Sequential(*model3) + self.model4 = nn.Sequential(*model4) + self.model5 = nn.Sequential(*model5) + self.model6 = nn.Sequential(*model6) + self.model7 = nn.Sequential(*model7) + self.model8up = nn.Sequential(*model8up) + self.model8 = nn.Sequential(*model8) + self.model9up = nn.Sequential(*model9up) + self.model9 = nn.Sequential(*model9) + self.model10up = nn.Sequential(*model10up) + self.model10 = nn.Sequential(*model10) + self.model3short8 = nn.Sequential(*model3short8) + self.model2short9 = nn.Sequential(*model2short9) + self.model1short10 = nn.Sequential(*model1short10) + + self.model_class = nn.Sequential(*model_class) + self.model_out = nn.Sequential(*model_out) + + self.upsample4 = nn.Sequential(*[ + nn.Upsample(scale_factor=4, mode='nearest'), + ]) + self.softmax = nn.Sequential(*[ + nn.Softmax(dim=1), + ]) + + def forward(self, input_A, input_B, mask_B): + conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) + conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) + conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) + conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) + conv5_3 = self.model5(conv4_3) + conv6_3 = self.model6(conv5_3) + conv7_3 = self.model7(conv6_3) + conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) + conv8_3 = self.model8(conv8_up) + + if (self.classification): + # out_class = self.model_class(conv8_3) + conv9_up = self.model9up(conv8_3.detach()) + self.model2short9( + conv2_2.detach()) + conv9_3 = self.model9(conv9_up) + conv10_up = self.model10up(conv9_3) + self.model1short10( + conv1_2.detach()) + conv10_2 = self.model10(conv10_up) + out_reg = self.model_out(conv10_2) + else: + # out_class = self.model_class(conv8_3.detach()) + + conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) + conv9_3 = self.model9(conv9_up) + conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) + conv10_2 = self.model10(conv10_up) + out_reg = self.model_out(conv10_2) + + feature_map = {} + feature_map['conv1_2'] = conv1_2 + feature_map['conv2_2'] = conv2_2 + feature_map['conv3_3'] = conv3_3 + feature_map['conv4_3'] = conv4_3 + feature_map['conv5_3'] = conv5_3 + feature_map['conv6_3'] = conv6_3 + feature_map['conv7_3'] = conv7_3 + feature_map['conv8_up'] = conv8_up + feature_map['conv8_3'] = conv8_3 + feature_map['conv9_up'] = conv9_up + feature_map['conv9_3'] = conv9_3 + feature_map['conv10_up'] = conv10_up + feature_map['conv10_2'] = conv10_2 + feature_map['out_reg'] = out_reg + + return (out_reg, feature_map) diff --git a/mmedit/models/editors/insta/util.py b/mmedit/models/editors/insta/util.py new file mode 100644 index 0000000000..9e859f3aa9 --- /dev/null +++ b/mmedit/models/editors/insta/util.py @@ -0,0 +1,261 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import print_function +import os +from collections import OrderedDict + +import numpy as np +import torch +from PIL import Image + + +# Color conversion code +def rgb2xyz(rgb): # rgb from [0,1] + # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], + # [0.212671, 0.715160, 0.072169], + # [0.019334, 0.119193, 0.950227]]) + + mask = (rgb > .04045).type(torch.FloatTensor) + if (rgb.is_cuda): + mask = mask.cuda() + + rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask) + + x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] \ + + .180423 * rgb[:, 2, :, :] + y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] \ + + .072169 * rgb[:, 2, :, :] + z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] \ + + .950227 * rgb[:, 2, :, :] + out = torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), + dim=1) + + # if(torch.sum(torch.isnan(out))>0): + # print('rgb2xyz') + # embed() + return out + + +def xyz2rgb(xyz): + # array([[ 3.24048134, -1.53715152, -0.49853633], + # [-0.96925495, 1.87599 , 0.04155593], + # [ 0.05564664, -0.20404134, 1.05731107]]) + + r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] \ + - 0.49853633 * xyz[:, 2, :, :] + g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] \ + + .04155593 * xyz[:, 2, :, :] + b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] \ + + 1.05731107 * xyz[:, 2, :, :] + + rgb = torch.cat((r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), + dim=1) + rgb = torch.max(rgb, torch.zeros_like(rgb)) + # sometimes reaches a small negative number, which causes NaNs + + mask = (rgb > .0031308).type(torch.FloatTensor) + if rgb.is_cuda: + mask = mask.cuda() + + rgb = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask) + # if(torch.sum(torch.isnan(rgb))>0): + # print('xyz2rgb') + # embed() + return rgb + + +def xyz2lab(xyz): + # 0.95047, 1., 1.08883 # white + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + if (xyz.is_cuda): + sc = sc.cuda() + + xyz_scale = xyz / sc + + mask = (xyz_scale > .008856).type(torch.FloatTensor) + if (xyz_scale.is_cuda): + mask = mask.cuda() + + xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale + + 16. / 116.) * (1 - mask) + + L = 116. * xyz_int[:, 1, :, :] - 16. + a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :]) + b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :]) + out = torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), + dim=1) + + # if(torch.sum(torch.isnan(out))>0): + # print('xyz2lab') + # embed() + + return out + + +def lab2xyz(lab): + y_int = (lab[:, 0, :, :] + 16.) / 116. + x_int = (lab[:, 1, :, :] / 500.) + y_int + z_int = y_int - (lab[:, 2, :, :] / 200.) + if (z_int.is_cuda): + z_int = torch.max(torch.Tensor((0, )).cuda(), z_int) + else: + z_int = torch.max(torch.Tensor((0, )), z_int) + + out = torch.cat( + (x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), + dim=1) + mask = (out > .2068966).type(torch.FloatTensor) + if (out.is_cuda): + mask = mask.cuda() + + out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask) + + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + sc = sc.to(out.device) + + out = out * sc + + # if(torch.sum(torch.isnan(out))>0): + # print('lab2xyz') + # embed() + + return out + + +def rgb2lab(rgb, **kwargs): + lab = xyz2lab(rgb2xyz(rgb)) + # print(lab[0, 0, 0, 0]) + lab_0 = lab[:, [0], :, :] + l_rs = (lab[:, [0], :, :] - kwargs['l_cent']) / kwargs['l_norm'] + # print(l_rs[0, 0, 0, 0]) + ab_rs = lab[:, 1:, :, :] / kwargs['ab_norm'] + out = torch.cat((l_rs, ab_rs), dim=1) + # if(torch.sum(torch.isnan(out))>0): + # print('rgb2lab') + # embed() + return out + + +def lab2rgb(lab_rs, **kwargs): + L = lab_rs[:, [0], :, :] * kwargs['l_norm'] + kwargs['l_cent'] + AB = lab_rs[:, 1:, :, :] * kwargs['ab_norm'] + lab = torch.cat((L, AB), dim=1) + out = xyz2rgb(lab2xyz(lab)) + # if(torch.sum(torch.isnan(out))>0): + # print('lab2rgb') + # embed() + return out + + +def get_colorization_data(data_raw, + ab_thresh=5., + p=.125, + num_points=None, + **kwargs): + data = {} + + data_lab = rgb2lab(data_raw[0], **kwargs) + data['A'] = data_lab[:, [ + 0, + ], :, :] + data['B'] = data_lab[:, 1:, :, :] + + if ab_thresh > 0: # mask out grayscale images + thresh = 1. * ab_thresh / kwargs['ab_norm'] + mask = torch.sum( + torch.abs( + torch.max(torch.max(data['B'], dim=3)[0], dim=2)[0] - + torch.min(torch.min(data['B'], dim=3)[0], dim=2)[0]), + dim=1) >= thresh + data['A'] = data['A'][mask, :, :, :] + data['B'] = data['B'][mask, :, :, :] + # print('Removed %i points'%torch.sum(mask==0).numpy()) + if torch.sum(mask) == 0: + return None + + return add_color_patches_rand_gt( + data, p=p, num_points=num_points, **kwargs) + + +def add_color_patches_rand_gt(data, + p=.125, + num_points=None, + use_avg=True, + samp='normal', + **kwargs): + # Add random color points sampled from ground truth based on: + # Number of points + # - if num_points is 0, then sample from geometric distribution, + # drawn from probability p + # - if num_points > 0, then sample that number of points + # Location of points + # - if samp is 'normal', draw from N(0.5, 0.25) of image + # - otherwise, draw from U[0, 1] of image + N, C, H, W = data['B'].shape + + data['hint_B'] = torch.zeros_like(data['B']) + data['mask_B'] = torch.zeros_like(data['A']) + + for nn in range(N): + pp = 0 + cont_cond = True + while cont_cond: + if num_points is None: # draw from geometric + # embed() + cont_cond = np.random.rand() < (1 - p) + else: # add certain number of points + cont_cond = pp < num_points + if not cont_cond: # skip out of loop if condition not met + continue + + P = np.random.choice(kwargs['sample_PS']) # patch size + + # sample location + if samp == 'normal': # geometric distribution + h = int( + np.clip( + np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), + 0, H - P)) + w = int( + np.clip( + np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), + 0, W - P)) + else: # uniform distribution + h = np.random.randint(H - P + 1) + w = np.random.randint(W - P + 1) + + # add color point + if use_avg: + # embed() + data['hint_B'][nn, :, h:h + P, w:w + P] = torch.mean( + torch.mean( + data['B'][nn, :, h:h + P, w:w + P], + dim=2, + keepdim=True), + dim=1, + keepdim=True).view(1, C, 1, 1) + else: + data['hint_B'][nn, :, h:h + P, w:w + P] = \ + data['B'][nn, :, h:h + P, w:w + P] + + data['mask_B'][nn, :, h:h + P, w:w + P] = 1 + + # increment counter + pp += 1 + + data['mask_B'] -= kwargs['mask_cent'] + + return data + + +def encode_ab_ind(data_ab, **kwargs): + # Encode ab value into an index + # INPUTS + # data_ab Nx2xHxW \in [-1,1] + # OUTPUTS + # data_q Nx1xHxW \in [0,Q) + A = 2 * kwargs['ab_max'] / kwargs['ab_quant'] + 1 + data_ab_rs = torch.round((data_ab * kwargs['ab_norm'] + kwargs['ab_max']) / + kwargs['ab_quant']) # normalized bin number + data_q = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] + return data_q + diff --git a/mmedit/models/losses/__init__.py b/mmedit/models/losses/__init__.py index 4027c72013..96955b0fe7 100644 --- a/mmedit/models/losses/__init__.py +++ b/mmedit/models/losses/__init__.py @@ -15,7 +15,7 @@ from .loss_wrapper import mask_reduce_loss, reduce_loss from .perceptual_loss import (PerceptualLoss, PerceptualVGG, TransferalPerceptualLoss) -from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss +from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss, HuberLoss __all__ = [ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss', @@ -26,5 +26,6 @@ 'r1_gradient_penalty_loss', 'gen_path_regularizer', 'FaceIdLoss', 'CLIPLoss', 'CLIPLossComps', 'DiscShiftLossComps', 'FaceIdLossComps', 'GANLossComps', 'GeneratorPathRegularizerComps', - 'GradientPenaltyLossComps', 'R1GradientPenaltyComps', 'disc_shift_loss' + 'GradientPenaltyLossComps', 'R1GradientPenaltyComps', 'disc_shift_loss', + 'HuberLoss' ] diff --git a/mmedit/models/losses/pixelwise_loss.py b/mmedit/models/losses/pixelwise_loss.py index ad41fcd731..1425f5570a 100644 --- a/mmedit/models/losses/pixelwise_loss.py +++ b/mmedit/models/losses/pixelwise_loss.py @@ -221,3 +221,20 @@ def forward(self, pred, mask=None): loss = x_diff + y_diff return loss + + +@LOSSES.register_module() +class HuberLoss(nn.Module): + + def __init__(self, delta=.01): + super(HuberLoss, self).__init__() + self.delta = delta + + def __call__(self, in0, in1): + mask = torch.zeros_like(in0) + mann = torch.abs(in0 - in1) + eucl = .5 * (mann**2) + mask[...] = mann < self.delta + + loss = eucl * mask / self.delta + (mann - .5 * self.delta) * (1 - mask) + return torch.sum(loss, dim=1, keepdim=True) \ No newline at end of file From ca983ce9e7f6cffbc0b9c6960eb8cbbe6b62d519 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Mon, 10 Oct 2022 16:24:16 +0800 Subject: [PATCH 02/32] refactor folder --- configs/inst_colorization/README.md | 75 ++++++++++++++++++ .../README_zh-CN.md} | 14 +++- .../insta_full_cocostuff_256x256.py | 27 ++----- .../insta_fusion_cocostuff_256x256.py | 0 .../insta_instance_cocostuff_256x256.py | 0 configs/inst_colorization/metafile.yml | 9 +++ mmedit/apis/__init__.py | 19 ++--- mmedit/apis/colorization_inference.py | 1 - mmedit/datasets/coco.py | 79 +++++++++---------- mmedit/datasets/transforms/__init__.py | 4 +- .../datasets/transforms/get_gray_color_pil.py | 5 +- .../datasets/transforms/get_maskrcnn_bbox.py | 6 +- mmedit/models/base_models/__init__.py | 2 +- .../models/base_models/base_colorization.py | 57 ------------- mmedit/models/editors/__init__.py | 4 +- .../editors/inst_colorization/__init__.py | 7 ++ .../{insta => inst_colorization}/insta.py | 8 +- .../{insta => inst_colorization}/insta_net.py | 0 .../{insta => inst_colorization}/util.py | 6 +- mmedit/models/editors/insta/__init__.py | 6 -- mmedit/models/losses/__init__.py | 3 +- mmedit/models/losses/huber_loss.py | 22 ++++++ mmedit/models/losses/pixelwise_loss.py | 2 +- model-index.yml | 1 + .../test_apis/test_colorization_inference.py | 1 + tests/test_datasets/test_coco.py | 1 + .../test_get_gray_color_pil.py | 1 + .../test_transforms/test_get_maskrcnn_bbox.py | 1 + .../test_inst_colorization/test_insta.py | 1 + .../test_inst_colorization/test_insta_net.py | 1 + .../test_inst_colorization/test_util.py | 1 + .../test_losses/test_huber_loss.py | 1 + 32 files changed, 199 insertions(+), 166 deletions(-) create mode 100644 configs/inst_colorization/README.md rename configs/{insta/README.md => inst_colorization/README_zh-CN.md} (53%) rename configs/{insta => inst_colorization}/insta_full_cocostuff_256x256.py (87%) rename configs/{insta => inst_colorization}/insta_fusion_cocostuff_256x256.py (100%) rename configs/{insta => inst_colorization}/insta_instance_cocostuff_256x256.py (100%) create mode 100644 configs/inst_colorization/metafile.yml delete mode 100644 mmedit/models/base_models/base_colorization.py create mode 100644 mmedit/models/editors/inst_colorization/__init__.py rename mmedit/models/editors/{insta => inst_colorization}/insta.py (99%) rename mmedit/models/editors/{insta => inst_colorization}/insta_net.py (100%) rename mmedit/models/editors/{insta => inst_colorization}/util.py (98%) delete mode 100644 mmedit/models/editors/insta/__init__.py create mode 100644 mmedit/models/losses/huber_loss.py create mode 100644 tests/test_apis/test_colorization_inference.py create mode 100644 tests/test_datasets/test_coco.py create mode 100644 tests/test_datasets/test_transforms/test_get_gray_color_pil.py create mode 100644 tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py create mode 100644 tests/test_models/test_editors/test_inst_colorization/test_insta.py create mode 100644 tests/test_models/test_editors/test_inst_colorization/test_insta_net.py create mode 100644 tests/test_models/test_editors/test_inst_colorization/test_util.py create mode 100644 tests/test_models/test_losses/test_huber_loss.py diff --git a/configs/inst_colorization/README.md b/configs/inst_colorization/README.md new file mode 100644 index 0000000000..7d2675134a --- /dev/null +++ b/configs/inst_colorization/README.md @@ -0,0 +1,75 @@ +# Instance-aware Image Colorization (CVPR'2020) + +> [Instance-Aware Image Colorization](https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html) + +> **Task**: Colorization + + + +## Abstract + + + +Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization. + +## Results and models + +## Quick Start + +**Train** + +
+Train Instructions + +You can use the following commands to train a model with cpu or single/multiple GPUs. + +```shell +# CPU train +CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/insta/insta_full_cocostuff_256x256.py + +# single-gpu train +python tools/train.py configs/insta/insta_full_cocostuff_256x256.py + +# multi-gpu train +./tools/dist_train.sh configs/insta/insta_full_cocostuff_256x256.py 8 +``` + +For more details, you can refer to **Train a model** part in [train_test.md](/docs/en/user_guides/train_test.md#Train-a-model-in-MMEditing). + +
+ +**Test** + +
+Test Instructions + +You can use the following commands to test a model with cpu or single/multiple GPUs. + +```shell +# CPU test +CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth + +# single-gpu test +python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth + +# multi-gpu test +./tools/dist_test.sh configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 +``` + +For more details, you can refer to **Test a pre-trained model** part in [train_test.md](/docs/en/user_guides/train_test.md#Test-a-pre-trained-model-in-MMEditing). + +
+ +
+Instance-aware Image Colorization (CVPR'2020) + +```bibtex +@inproceedings{Su-CVPR-2020, + author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin}, + title = {Instance-aware Image Colorization}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2020} +} +``` + +
diff --git a/configs/insta/README.md b/configs/inst_colorization/README_zh-CN.md similarity index 53% rename from configs/insta/README.md rename to configs/inst_colorization/README_zh-CN.md index 860d8a0e5f..157966bb44 100644 --- a/configs/insta/README.md +++ b/configs/inst_colorization/README_zh-CN.md @@ -1,7 +1,19 @@ # Instance-aware Image Colorization (CVPR'2020) +> [Instance-Aware Image Colorization](https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html) + > **任务**: 图像上色 + + +## 摘要 + + + +Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization. + +## 结果和模型 + ## 快速开始 **训练** @@ -48,7 +60,6 @@ python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py . -
Instance-aware Image Colorization (CVPR'2020) @@ -62,4 +73,3 @@ python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py . ```
- diff --git a/configs/insta/insta_full_cocostuff_256x256.py b/configs/inst_colorization/insta_full_cocostuff_256x256.py similarity index 87% rename from configs/insta/insta_full_cocostuff_256x256.py rename to configs/inst_colorization/insta_full_cocostuff_256x256.py index e25f2836c9..fa24f2cdab 100644 --- a/configs/insta/insta_full_cocostuff_256x256.py +++ b/configs/inst_colorization/insta_full_cocostuff_256x256.py @@ -1,6 +1,4 @@ -_base_ = [ - '../_base_/default_runtime.py' -] +_base_ = ['../_base_/default_runtime.py'] exp_name = 'Instance-aware_full' save_dir = './' @@ -14,11 +12,7 @@ std=[127.5], ), instance_model=dict( - type='SIGGRAPHGenerator', - input_nc=4, - output_nc=2, - norm_type='batch' - ), + type='SIGGRAPHGenerator', input_nc=4, output_nc=2, norm_type='batch'), stage='full', ngf=64, output_nc=2, @@ -31,8 +25,7 @@ init_type='normal', which_direction='AtoB', loss=dict(type='HuberLoss', delta=.01), - pretrained='./checkpoints/pytorch_trained.pth' -) + pretrained='./checkpoints/pytorch_trained.pth') input_shape = (256, 256) @@ -52,11 +45,7 @@ test_pipeline = [ dict(type='LoadImageFromFile', key='gt'), dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), - dict(type='Resize', - keys=['gt'], - scale=(256, 256), - keep_ratio=False - ), + dict(type='Resize', keys=['gt'], scale=(256, 256), keep_ratio=False), dict(type='PackEditInputs'), ] @@ -92,7 +81,6 @@ pipeline=test_pipeline, test_mode=False)) - test_evaluator = [dict(type='PSNR'), dict(type='SSIM')] train_cfg = dict( @@ -104,7 +92,6 @@ val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') - # optimizer optim_wrapper = dict( constructor='DefaultOptimWrapperConstructor', @@ -127,9 +114,7 @@ type='ConcatImageVisualizer', vis_backends=vis_backends, fn_key='gt_path', - img_keys=[ - 'gray', 'real', 'fake_reg', 'hint', 'real_ab', 'fake_ab_reg' - ], + img_keys=['gray', 'real', 'fake_reg', 'hint', 'real_ab', 'fake_ab_reg'], bgr2rgb=False) env_cfg = dict( @@ -137,5 +122,3 @@ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), ) - - diff --git a/configs/insta/insta_fusion_cocostuff_256x256.py b/configs/inst_colorization/insta_fusion_cocostuff_256x256.py similarity index 100% rename from configs/insta/insta_fusion_cocostuff_256x256.py rename to configs/inst_colorization/insta_fusion_cocostuff_256x256.py diff --git a/configs/insta/insta_instance_cocostuff_256x256.py b/configs/inst_colorization/insta_instance_cocostuff_256x256.py similarity index 100% rename from configs/insta/insta_instance_cocostuff_256x256.py rename to configs/inst_colorization/insta_instance_cocostuff_256x256.py diff --git a/configs/inst_colorization/metafile.yml b/configs/inst_colorization/metafile.yml new file mode 100644 index 0000000000..54bf9ccebc --- /dev/null +++ b/configs/inst_colorization/metafile.yml @@ -0,0 +1,9 @@ +Collections: +- Metadata: + Architecture: + - Instance-aware Image Colorization + Name: Instance-aware Image Colorization + Paper: + - https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html + README: configs/inst_colorization/README.md +Models: [] diff --git a/mmedit/apis/__init__.py b/mmedit/apis/__init__.py index 2f23b65cb1..75033721b0 100644 --- a/mmedit/apis/__init__.py +++ b/mmedit/apis/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .colorization_inference import colorization_inference from .gan_inference import sample_conditional_model, sample_unconditional_model from .inference import delete_cfg, init_model, set_random_seed from .inpainting_inference import inpainting_inference @@ -8,20 +9,12 @@ from .restoration_video_inference import restoration_video_inference from .translation_inference import sample_img2img_model from .video_interpolation_inference import video_interpolation_inference -from .colorization_inference import colorization_inference __all__ = [ - 'init_model', - 'delete_cfg', - 'set_random_seed', - 'matting_inference', - 'inpainting_inference', - 'restoration_inference', - 'restoration_video_inference', - 'restoration_face_inference', - 'video_interpolation_inference', - 'sample_conditional_model', - 'sample_unconditional_model', - 'sample_img2img_model', + 'init_model', 'delete_cfg', 'set_random_seed', 'matting_inference', + 'inpainting_inference', 'restoration_inference', + 'restoration_video_inference', 'restoration_face_inference', + 'video_interpolation_inference', 'sample_conditional_model', + 'sample_unconditional_model', 'sample_img2img_model', 'colorization_inference' ] diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py index 8f1b05da67..2ddb4df599 100644 --- a/mmedit/apis/colorization_inference.py +++ b/mmedit/apis/colorization_inference.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch - from mmengine.dataset import Compose from mmengine.dataset.utils import default_collate as collate from torch.nn.parallel import scatter diff --git a/mmedit/datasets/coco.py b/mmedit/datasets/coco.py index 4390c43875..82a3b9b3a3 100644 --- a/mmedit/datasets/coco.py +++ b/mmedit/datasets/coco.py @@ -1,10 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Callable, Optional, Union from pathlib import Path -from mmengine.dataset import BaseDataset -from mmengine.fileio import load - from mmedit.registry import DATASETS @@ -14,50 +10,47 @@ class CocoDataset: METAINFO = { 'CLASSES': - ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', - 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', - 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', - 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', - 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', - 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', - 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', - 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', - 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), + ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), # PALETTE is a list of color tuples, which is used for visualization. 'PALETTE': - [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), - (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), - (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), - (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), - (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), - (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), - (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), - (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), - (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), - (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), - (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), - (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), - (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), - (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), - (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), - (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), - (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), - (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), - (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), - (246, 0, 122), (191, 162, 208)] + [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), + (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), + (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), + (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), + (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), + (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), + (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), + (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), + (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), + (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), + (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), + (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), + (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), + (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), + (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), + (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), + (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), + (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), + (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), + (246, 0, 122), (191, 162, 208)] } - METAINFO = dict(dataset_type='colorization_dataset', task_name='colorization') + METAINFO = dict( + dataset_type='colorization_dataset', task_name='colorization') - def __init__( - self, - ann_file: str, - data_prefix - ): + def __init__(self, ann_file: str, data_prefix): self.ann_file = str(ann_file) self.data_prefix = data_prefix diff --git a/mmedit/datasets/transforms/__init__.py b/mmedit/datasets/transforms/__init__.py index 2bb927de7a..b5b5cb1bd8 100644 --- a/mmedit/datasets/transforms/__init__.py +++ b/mmedit/datasets/transforms/__init__.py @@ -16,7 +16,9 @@ from .generate_frame_indices import (GenerateFrameIndices, GenerateFrameIndiceswithPadding, GenerateSegmentIndices) +from .get_gray_color_pil import GenGrayColorPil from .get_masked_image import GetMaskedImage +from .get_maskrcnn_bbox import GenMaskRCNNBbox from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadMask, LoadPairedImageFromFile) from .matlab_like_resize import MATLABLikeResize @@ -28,8 +30,6 @@ from .trimap import (FormatTrimap, GenerateTrimap, GenerateTrimapWithDistTransform, TransformTrimap) from .values import CopyValues, SetValues -from .get_maskrcnn_bbox import GenMaskRCNNBbox -from .get_gray_color_pil import GenGrayColorPil __all__ = [ 'BinarizeImage', 'Clip', 'ColorJitter', 'CopyValues', 'Crop', 'CropLike', diff --git a/mmedit/datasets/transforms/get_gray_color_pil.py b/mmedit/datasets/transforms/get_gray_color_pil.py index 91dd763c35..31107282cd 100644 --- a/mmedit/datasets/transforms/get_gray_color_pil.py +++ b/mmedit/datasets/transforms/get_gray_color_pil.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import numpy as np import cv2 +import numpy as np from mmcv.transforms.base import BaseTransform from mmedit.registry import TRANSFORMS + @TRANSFORMS.register_module() class GenGrayColorPil(BaseTransform): @@ -26,4 +27,4 @@ def transform(self, results): results[self.keys[0]] = rgb_img results[self.keys[1]] = gray_img - return results \ No newline at end of file + return results diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py index 493caf51ab..d768bdc9db 100644 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -32,12 +32,12 @@ def gen_maskrcnn_bbox_fromPred(self, img, bbox_path=None, box_num_upbound=8): - ''' - ## Arguments: + """## Arguments: + - pred_data_path: Detectron2 predict results - box_num_upbound: object bounding boxes number. Default: -1 means use all the instances. - ''' + """ if bbox_path: pred_data = np.load(bbox_path) pred_bbox = pred_data['bbox'].astype(np.int32) diff --git a/mmedit/models/base_models/__init__.py b/mmedit/models/base_models/__init__.py index 60bcdafad3..3e85294c4a 100644 --- a/mmedit/models/base_models/__init__.py +++ b/mmedit/models/base_models/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .average_model import ExponentialMovingAverage, RampUpEMA +from .base_colorization import BaseColorization from .base_conditional_gan import BaseConditionalGAN from .base_edit_model import BaseEditModel from .base_gan import BaseGAN @@ -8,7 +9,6 @@ from .basic_interpolator import BasicInterpolator from .one_stage import OneStageInpaintor from .two_stage import TwoStageInpaintor -from .base_colorization import BaseColorization __all__ = [ 'BaseEditModel', 'BaseGAN', 'BaseConditionalGAN', 'BaseMattor', diff --git a/mmedit/models/base_models/base_colorization.py b/mmedit/models/base_models/base_colorization.py deleted file mode 100644 index 37183bc18f..0000000000 --- a/mmedit/models/base_models/base_colorization.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import ABCMeta -from typing import Dict, List, Optional, Tuple, Union - -import torch -from torchvision.utils import save_image -from mmengine.model import BaseModel -from mmengine.config import Config, ConfigDict - -from mmedit.registry import MODELS - - -class BaseColorization(BaseModel, metaclass=ABCMeta): - - def __init__(self, - data_preprocessor: Union[dict, Config], - loss, - init_cfg: Optional[dict] = None, - train_cfg: Optional[dict] = None, - test_cfg: Optional[dict] = None): - - super().__init__( - data_preprocessor=data_preprocessor, init_cfg=init_cfg) - - self.loss = MODELS.build(loss) - - def forward(self, - inputs: torch.Tensor, - data_samples: Optional[Union[list, torch.Tensor]] = None, - mode: str = 'tensor', - **kwargs): - - if mode == 'tensor': - return self.forward_tensor(inputs, data_samples, **kwargs) - - elif mode == 'predict': - predictions = self.forward_test(inputs, data_samples, **kwargs) - predictions = self.convert_to_datasample(data_samples, predictions) - return predictions - - elif mode == 'loss': - return self.forward_train(inputs, data_samples, **kwargs) - - def forward_train(self, *args, **kwargs): - pass - - def forward_test(self, input, data_samples, **kwargs): - pass - - def train_step(self, data_batch, optimizer): - pass - - def init_weights(self): - pass - - def save_visualization(self, img, filename): - save_image(img, filename) diff --git a/mmedit/models/editors/__init__.py b/mmedit/models/editors/__init__.py index 442538599b..238dba9514 100644 --- a/mmedit/models/editors/__init__.py +++ b/mmedit/models/editors/__init__.py @@ -28,6 +28,7 @@ from .indexnet import (DepthwiseIndexBlock, HolisticIndexBlock, IndexedUpsample, IndexNet, IndexNetDecoder, IndexNetEncoder) +from .inst_colorization import InstColorization from .liif import LIIF, MLPRefiner from .lsgan import LSGAN from .mspie import MSPIEStyleGAN2, PESinGAN @@ -50,7 +51,6 @@ from .tof import TOFlowVFINet, TOFlowVSRNet, ToFResBlock from .ttsr import LTE, TTSR, SearchTransformer, TTSRDiscriminator, TTSRNet from .wgan_gp import WGANGP -from .insta import INSTA __all__ = [ 'AOTEncoderDecoder', 'AOTBlockNeck', 'AOTInpaintor', @@ -74,5 +74,5 @@ 'FBADecoder', 'WGANGP', 'CycleGAN', 'SAGAN', 'LSGAN', 'GGAN', 'Pix2Pix', 'StyleGAN1', 'StyleGAN2', 'StyleGAN3', 'BigGAN', 'DCGAN', 'ProgressiveGrowingGAN', 'SinGAN', 'IDLossModel', 'PESinGAN', - 'MSPIEStyleGAN2' + 'MSPIEStyleGAN2', 'InstColorization' ] diff --git a/mmedit/models/editors/inst_colorization/__init__.py b/mmedit/models/editors/inst_colorization/__init__.py new file mode 100644 index 0000000000..4b02c69ce4 --- /dev/null +++ b/mmedit/models/editors/inst_colorization/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .insta import INSTA +from .insta_net import FusionGenerator, InstanceGenerator, SIGGRAPHGenerator + +__all__ = [ + 'INSTA', 'SIGGRAPHGenerator', 'InstanceGenerator', 'FusionGenerator' +] diff --git a/mmedit/models/editors/insta/insta.py b/mmedit/models/editors/inst_colorization/insta.py similarity index 99% rename from mmedit/models/editors/insta/insta.py rename to mmedit/models/editors/inst_colorization/insta.py index 180dcd5e4e..c86f56cbb0 100644 --- a/mmedit/models/editors/insta/insta.py +++ b/mmedit/models/editors/inst_colorization/insta.py @@ -6,10 +6,9 @@ import torch from mmengine.config import Config -from mmedit.models.utils import generation_init_weights from mmedit.models.base_models import BaseColorization +from mmedit.models.utils import generation_init_weights from mmedit.registry import BACKBONES, COMPONENTS - from .util import encode_ab_ind, get_colorization_data, lab2rgb @@ -40,8 +39,7 @@ def __init__(self, loss=loss, init_cfg=init_cfg, train_cfg=train_cfg, - test_cfg=test_cfg - ) + test_cfg=test_cfg) self.ngf = ngf self.output_nc = output_nc @@ -148,7 +146,7 @@ def generator_loss(self): self, 'loss_' + name)) + self.avg_loss_alpha * self.avg_losses[name] errors_ret[name] = (1 - self.avg_loss_alpha) / ( - 1 - self.avg_loss_alpha** + 1 - self.avg_loss_alpha** # noqa self.error_cnt) * self.avg_losses[name] return errors_ret diff --git a/mmedit/models/editors/insta/insta_net.py b/mmedit/models/editors/inst_colorization/insta_net.py similarity index 100% rename from mmedit/models/editors/insta/insta_net.py rename to mmedit/models/editors/inst_colorization/insta_net.py diff --git a/mmedit/models/editors/insta/util.py b/mmedit/models/editors/inst_colorization/util.py similarity index 98% rename from mmedit/models/editors/insta/util.py rename to mmedit/models/editors/inst_colorization/util.py index 9e859f3aa9..2e2177ab3a 100644 --- a/mmedit/models/editors/insta/util.py +++ b/mmedit/models/editors/inst_colorization/util.py @@ -1,11 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from __future__ import print_function -import os -from collections import OrderedDict import numpy as np import torch -from PIL import Image # Color conversion code @@ -124,7 +121,7 @@ def lab2xyz(lab): def rgb2lab(rgb, **kwargs): lab = xyz2lab(rgb2xyz(rgb)) # print(lab[0, 0, 0, 0]) - lab_0 = lab[:, [0], :, :] + # lab_0 = lab[:, [0], :, :] l_rs = (lab[:, [0], :, :] - kwargs['l_cent']) / kwargs['l_norm'] # print(l_rs[0, 0, 0, 0]) ab_rs = lab[:, 1:, :, :] / kwargs['ab_norm'] @@ -258,4 +255,3 @@ def encode_ab_ind(data_ab, **kwargs): kwargs['ab_quant']) # normalized bin number data_q = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] return data_q - diff --git a/mmedit/models/editors/insta/__init__.py b/mmedit/models/editors/insta/__init__.py deleted file mode 100644 index 10286302b3..0000000000 --- a/mmedit/models/editors/insta/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .insta import INSTA -from .insta_net import (SIGGRAPHGenerator, InstanceGenerator, FusionGenerator) - -__all__ = [ - 'INSTA', 'SIGGRAPHGenerator', 'InstanceGenerator', 'FusionGenerator' -] \ No newline at end of file diff --git a/mmedit/models/losses/__init__.py b/mmedit/models/losses/__init__.py index 96955b0fe7..4e388acc47 100644 --- a/mmedit/models/losses/__init__.py +++ b/mmedit/models/losses/__init__.py @@ -9,13 +9,14 @@ gen_path_regularizer, gradient_penalty_loss, r1_gradient_penalty_loss) from .gradient_loss import GradientLoss +from .huber_loss import HuberLoss from .loss_comps import (CLIPLossComps, DiscShiftLossComps, FaceIdLossComps, GANLossComps, GeneratorPathRegularizerComps, GradientPenaltyLossComps, R1GradientPenaltyComps) from .loss_wrapper import mask_reduce_loss, reduce_loss from .perceptual_loss import (PerceptualLoss, PerceptualVGG, TransferalPerceptualLoss) -from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss, HuberLoss +from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss __all__ = [ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss', diff --git a/mmedit/models/losses/huber_loss.py b/mmedit/models/losses/huber_loss.py new file mode 100644 index 0000000000..7b45f41571 --- /dev/null +++ b/mmedit/models/losses/huber_loss.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmedit.registry import LOSSES + + +@LOSSES.register_module() +class HuberLoss(nn.Module): + + def __init__(self, delta=.01): + super(HuberLoss, self).__init__() + self.delta = delta + + def __call__(self, in0, in1): + mask = torch.zeros_like(in0) + mann = torch.abs(in0 - in1) + eucl = .5 * (mann**2) + mask[...] = mann < self.delta + + loss = eucl * mask / self.delta + (mann - .5 * self.delta) * (1 - mask) + return torch.sum(loss, dim=1, keepdim=True) diff --git a/mmedit/models/losses/pixelwise_loss.py b/mmedit/models/losses/pixelwise_loss.py index 1425f5570a..f3c77675aa 100644 --- a/mmedit/models/losses/pixelwise_loss.py +++ b/mmedit/models/losses/pixelwise_loss.py @@ -237,4 +237,4 @@ def __call__(self, in0, in1): mask[...] = mann < self.delta loss = eucl * mask / self.delta + (mann - .5 * self.delta) * (1 - mask) - return torch.sum(loss, dim=1, keepdim=True) \ No newline at end of file + return torch.sum(loss, dim=1, keepdim=True) diff --git a/model-index.yml b/model-index.yml index 6373a6dc4d..b0ce511cac 100644 --- a/model-index.yml +++ b/model-index.yml @@ -20,6 +20,7 @@ Import: - configs/global_local/metafile.yml - configs/iconvsr/metafile.yml - configs/indexnet/metafile.yml +- configs/inst_colorization/metafile.yml - configs/liif/metafile.yml - configs/lsgan/metafile.yml - configs/partial_conv/metafile.yml diff --git a/tests/test_apis/test_colorization_inference.py b/tests/test_apis/test_colorization_inference.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_apis/test_colorization_inference.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_datasets/test_coco.py b/tests/test_datasets/test_coco.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_datasets/test_coco.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_datasets/test_transforms/test_get_gray_color_pil.py b/tests/test_datasets/test_transforms/test_get_gray_color_pil.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_datasets/test_transforms/test_get_gray_color_pil.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_editors/test_inst_colorization/test_insta.py b/tests/test_models/test_editors/test_inst_colorization/test_insta.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_insta.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_editors/test_inst_colorization/test_insta_net.py b/tests/test_models/test_editors/test_inst_colorization/test_insta_net.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_insta_net.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_editors/test_inst_colorization/test_util.py b/tests/test_models/test_editors/test_inst_colorization/test_util.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_util.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_losses/test_huber_loss.py b/tests/test_models/test_losses/test_huber_loss.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_models/test_losses/test_huber_loss.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. From d7354527454587be3eada1b7c56a1e6d93aa285a Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Mon, 10 Oct 2022 16:46:12 +0800 Subject: [PATCH 03/32] refactor model implementation --- configs/inst_colorization/README.md | 12 ++++++------ demo/colorization_demo.py | 2 +- mmedit/models/base_models/__init__.py | 15 ++++++++++----- .../models/editors/inst_colorization/__init__.py | 5 +++-- .../{insta.py => inst_colorization.py} | 4 ++-- .../{insta_net.py => inst_colorization_net.py} | 0 .../util.py => utils/color_utils.py} | 2 -- 7 files changed, 22 insertions(+), 18 deletions(-) rename mmedit/models/editors/inst_colorization/{insta.py => inst_colorization.py} (99%) rename mmedit/models/editors/inst_colorization/{insta_net.py => inst_colorization_net.py} (100%) rename mmedit/models/{editors/inst_colorization/util.py => utils/color_utils.py} (99%) diff --git a/configs/inst_colorization/README.md b/configs/inst_colorization/README.md index 7d2675134a..a21df6a724 100644 --- a/configs/inst_colorization/README.md +++ b/configs/inst_colorization/README.md @@ -25,13 +25,13 @@ You can use the following commands to train a model with cpu or single/multiple ```shell # CPU train -CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/insta/insta_full_cocostuff_256x256.py +CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/inst_colorization/insta_full_cocostuff_256x256.py # single-gpu train -python tools/train.py configs/insta/insta_full_cocostuff_256x256.py +python tools/train.py configs/inst_colorization/insta_full_cocostuff_256x256.py # multi-gpu train -./tools/dist_train.sh configs/insta/insta_full_cocostuff_256x256.py 8 +./tools/dist_train.sh configs/inst_colorization/insta_full_cocostuff_256x256.py 8 ``` For more details, you can refer to **Train a model** part in [train_test.md](/docs/en/user_guides/train_test.md#Train-a-model-in-MMEditing). @@ -47,13 +47,13 @@ You can use the following commands to test a model with cpu or single/multiple G ```shell # CPU test -CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth +CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization//insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth # single-gpu test -python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth +python demo/colorization_demo.py configs/inst_colorization/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth # multi-gpu test -./tools/dist_test.sh configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 +./tools/dist_test.sh configs/inst_colorization/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 ``` For more details, you can refer to **Test a pre-trained model** part in [train_test.md](/docs/en/user_guides/train_test.md#Test-a-pre-trained-model-in-MMEditing). diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py index 1654895be2..2cc8b8f7b7 100644 --- a/demo/colorization_demo.py +++ b/demo/colorization_demo.py @@ -14,7 +14,6 @@ def parse_args(): parser.add_argument('config', help='test config file path') parser.add_argument('checkpoints', help='checkpoints file path') parser.add_argument('img_path', help='path to input image file') - # parser.add_argument('bbox_path', help='path to input image bbox file') parser.add_argument('save_path', help='path to save generation result') parser.add_argument( '--unpaired-path', default=None, help='path to unpaired image file') @@ -36,6 +35,7 @@ def main(): # model = init_model(args.config, args.checkpoints, device=device) output = colorization_inference(model, args.img_path, args.bbox_path) + mmcv.imwrite(output, args.save_path) if args.imshow: mmcv.imshow(output, 'predicted generation result') diff --git a/mmedit/models/base_models/__init__.py b/mmedit/models/base_models/__init__.py index 3e85294c4a..0ec81d6d5a 100644 --- a/mmedit/models/base_models/__init__.py +++ b/mmedit/models/base_models/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .average_model import ExponentialMovingAverage, RampUpEMA -from .base_colorization import BaseColorization from .base_conditional_gan import BaseConditionalGAN from .base_edit_model import BaseEditModel from .base_gan import BaseGAN @@ -11,8 +10,14 @@ from .two_stage import TwoStageInpaintor __all__ = [ - 'BaseEditModel', 'BaseGAN', 'BaseConditionalGAN', 'BaseMattor', - 'BasicInterpolator', 'BaseTranslationModel', 'OneStageInpaintor', - 'TwoStageInpaintor', 'ExponentialMovingAverage', 'RampUpEMA', - 'BaseColorization' + 'BaseEditModel', + 'BaseGAN', + 'BaseConditionalGAN', + 'BaseMattor', + 'BasicInterpolator', + 'BaseTranslationModel', + 'OneStageInpaintor', + 'TwoStageInpaintor', + 'ExponentialMovingAverage', + 'RampUpEMA', ] diff --git a/mmedit/models/editors/inst_colorization/__init__.py b/mmedit/models/editors/inst_colorization/__init__.py index 4b02c69ce4..4829262ff5 100644 --- a/mmedit/models/editors/inst_colorization/__init__.py +++ b/mmedit/models/editors/inst_colorization/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .insta import INSTA -from .insta_net import FusionGenerator, InstanceGenerator, SIGGRAPHGenerator +from .inst_colorization import INSTA +from .inst_colorizatiuion_net import (FusionGenerator, InstanceGenerator, + SIGGRAPHGenerator) __all__ = [ 'INSTA', 'SIGGRAPHGenerator', 'InstanceGenerator', 'FusionGenerator' diff --git a/mmedit/models/editors/inst_colorization/insta.py b/mmedit/models/editors/inst_colorization/inst_colorization.py similarity index 99% rename from mmedit/models/editors/inst_colorization/insta.py rename to mmedit/models/editors/inst_colorization/inst_colorization.py index c86f56cbb0..3a34785f55 100644 --- a/mmedit/models/editors/inst_colorization/insta.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -6,14 +6,14 @@ import torch from mmengine.config import Config -from mmedit.models.base_models import BaseColorization +from mmedit.models.editors import SRGAN from mmedit.models.utils import generation_init_weights from mmedit.registry import BACKBONES, COMPONENTS from .util import encode_ab_ind, get_colorization_data, lab2rgb @BACKBONES.register_module() -class INSTA(BaseColorization): +class INSTA(SRGAN): def __init__(self, data_preprocessor: Union[dict, Config], diff --git a/mmedit/models/editors/inst_colorization/insta_net.py b/mmedit/models/editors/inst_colorization/inst_colorization_net.py similarity index 100% rename from mmedit/models/editors/inst_colorization/insta_net.py rename to mmedit/models/editors/inst_colorization/inst_colorization_net.py diff --git a/mmedit/models/editors/inst_colorization/util.py b/mmedit/models/utils/color_utils.py similarity index 99% rename from mmedit/models/editors/inst_colorization/util.py rename to mmedit/models/utils/color_utils.py index 2e2177ab3a..f2819f71f5 100644 --- a/mmedit/models/editors/inst_colorization/util.py +++ b/mmedit/models/utils/color_utils.py @@ -1,6 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from __future__ import print_function - import numpy as np import torch From e319851fca26d9ffc1406fed7d0eaa9567e5ff8f Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Mon, 10 Oct 2022 21:09:28 +0800 Subject: [PATCH 04/32] refactor demo --- .../inst-colorizatioon_cocostuff_256x256.py | 124 ++++++++++++++++++ .../editors/inst_colorization/__init__.py | 9 +- .../inst_colorization/inst_colorization.py | 10 +- mmedit/models/losses/pixelwise_loss.py | 17 --- mmedit/models/utils/__init__.py | 20 +-- 5 files changed, 141 insertions(+), 39 deletions(-) create mode 100644 configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py new file mode 100644 index 0000000000..fa24f2cdab --- /dev/null +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -0,0 +1,124 @@ +_base_ = ['../_base_/default_runtime.py'] + +exp_name = 'Instance-aware_full' +save_dir = './' +work_dir = '..' + +model = dict( + type='FusionModel', + data_preprocessor=dict( + type='EditDataPreprocessor', + mean=[127.5], + std=[127.5], + ), + instance_model=dict( + type='SIGGRAPHGenerator', input_nc=4, output_nc=2, norm_type='batch'), + stage='full', + ngf=64, + output_nc=2, + avg_loss_alpha=.986, + ab_norm=110., + l_norm=100., + l_cent=50., + sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], + mask_cent=.5, + init_type='normal', + which_direction='AtoB', + loss=dict(type='HuberLoss', delta=.01), + pretrained='./checkpoints/pytorch_trained.pth') + +input_shape = (256, 256) + +train_pipeline = [ + dict(type='LoadImageFromFile', key='gt_img'), + dict(type='GenGrayColorPil', stage='full', keys=['rgb_img', 'gray_img']), + dict( + type='Resize', + keys=['rgb_img', 'gray_img'], + scale=input_shape, + keep_ratio=False, + interpolation='nearest'), + dict(type='RescaleToZeroOne', keys=['rgb_img', 'gray_img']), + dict(type='PackEditInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile', key='gt'), + dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), + dict(type='Resize', keys=['gt'], scale=(256, 256), keep_ratio=False), + dict(type='PackEditInputs'), +] + +dataset_type = 'CocoDataset' +data_root = '/mnt/j/DataSet/cocostuff/train2017' +ann_file_path = '/mnt/j/DataSet/cocostuff/' + +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=False, + sampler=dict(shuffle=False), + workers_per_gpu=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(gt='data_large'), + ann_file=f'{ann_file_path}/img_list.txt', + pipeline=train_pipeline, + test_mode=False)) + +test_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=False, + sampler=dict(type='DefaultSampler', shuffle=False), + workers_per_gpu=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(gt='data_large'), + ann_file=f'{ann_file_path}/img_list.txt', + pipeline=test_pipeline, + test_mode=False)) + +test_evaluator = [dict(type='PSNR'), dict(type='SSIM')] + +train_cfg = dict( + type='IterBasedTrainLoop', + max_iters=500002, + val_interval=50000, +) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# optimizer +optim_wrapper = dict( + constructor='DefaultOptimWrapperConstructor', + generator=dict( + type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9))), + disc=dict( + type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9)))) + +param_scheduler = dict( + # todo engine中暂时还没有这个 + type='LambdaLR', + by_epoch=False, +) + +vis_backends = [dict(type='LocalVisBackend')] + +visualizer = dict( + type='ConcatImageVisualizer', + vis_backends=vis_backends, + fn_key='gt_path', + img_keys=['gray', 'real', 'fake_reg', 'hint', 'real_ab', 'fake_ab_reg'], + bgr2rgb=False) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) diff --git a/mmedit/models/editors/inst_colorization/__init__.py b/mmedit/models/editors/inst_colorization/__init__.py index 4829262ff5..faa56edca0 100644 --- a/mmedit/models/editors/inst_colorization/__init__.py +++ b/mmedit/models/editors/inst_colorization/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .inst_colorization import INSTA -from .inst_colorizatiuion_net import (FusionGenerator, InstanceGenerator, - SIGGRAPHGenerator) +from .inst_colorization import InstColorization +from .inst_colorization_net import (FusionGenerator, InstanceGenerator, + SIGGRAPHGenerator) __all__ = [ - 'INSTA', 'SIGGRAPHGenerator', 'InstanceGenerator', 'FusionGenerator' + 'InstColorization', 'SIGGRAPHGenerator', 'InstanceGenerator', + 'FusionGenerator' ] diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 3a34785f55..152e880c71 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -6,14 +6,14 @@ import torch from mmengine.config import Config -from mmedit.models.editors import SRGAN -from mmedit.models.utils import generation_init_weights +from mmedit.models.utils import (encode_ab_ind, generation_init_weights, + get_colorization_data, lab2rgb) from mmedit.registry import BACKBONES, COMPONENTS -from .util import encode_ab_ind, get_colorization_data, lab2rgb +from ..srgan import SRGAN @BACKBONES.register_module() -class INSTA(SRGAN): +class InstColorization(SRGAN): def __init__(self, data_preprocessor: Union[dict, Config], @@ -34,7 +34,7 @@ def __init__(self, init_cfg=None, train_cfg=None, test_cfg=None): - super(INSTA, self).__init__( + super(InstColorization, self).__init__( data_preprocessor=data_preprocessor, loss=loss, init_cfg=init_cfg, diff --git a/mmedit/models/losses/pixelwise_loss.py b/mmedit/models/losses/pixelwise_loss.py index f3c77675aa..ad41fcd731 100644 --- a/mmedit/models/losses/pixelwise_loss.py +++ b/mmedit/models/losses/pixelwise_loss.py @@ -221,20 +221,3 @@ def forward(self, pred, mask=None): loss = x_diff + y_diff return loss - - -@LOSSES.register_module() -class HuberLoss(nn.Module): - - def __init__(self, delta=.01): - super(HuberLoss, self).__init__() - self.delta = delta - - def __call__(self, in0, in1): - mask = torch.zeros_like(in0) - mann = torch.abs(in0 - in1) - eucl = .5 * (mann**2) - mask[...] = mann < self.delta - - loss = eucl * mask / self.delta + (mann - .5 * self.delta) * (1 - mask) - return torch.sum(loss, dim=1, keepdim=True) diff --git a/mmedit/models/utils/__init__.py b/mmedit/models/utils/__init__.py index b579869d60..4f6c6e1baa 100644 --- a/mmedit/models/utils/__init__.py +++ b/mmedit/models/utils/__init__.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. + from .bbox_utils import extract_around_bbox, extract_bbox_patch +from .color_utils import encode_ab_ind, get_colorization_data, lab2rgb from .flow_warp import flow_warp from .model_utils import (default_init_weights, generation_init_weights, get_module_device, get_valid_noise_size, @@ -8,17 +10,9 @@ from .tensor_utils import get_unknown_tensor __all__ = [ - 'default_init_weights', - 'make_layer', - 'flow_warp', - 'generation_init_weights', - 'set_requires_grad', - 'extract_bbox_patch', - 'extract_around_bbox', - 'get_unknown_tensor', - 'noise_sample_fn', - 'label_sample_fn', - 'get_valid_num_batches', - 'get_valid_noise_size', - 'get_module_device', + 'default_init_weights', 'make_layer', 'flow_warp', + 'generation_init_weights', 'set_requires_grad', 'extract_bbox_patch', + 'extract_around_bbox', 'get_unknown_tensor', 'noise_sample_fn', + 'label_sample_fn', 'get_valid_num_batches', 'get_valid_noise_size', + 'get_module_device', 'encode_ab_ind', 'get_colorization_data', 'lab2rgb' ] From e22a343d3c0a8c18be34e1eb1cae596604958031 Mon Sep 17 00:00:00 2001 From: ruoning Date: Wed, 12 Oct 2022 18:27:30 +0800 Subject: [PATCH 05/32] [Enhancement]: add inference module for instance-aware Image Colorization --- configs/insta/insta_full_cocostuff_256x256.py | 37 ++++---- configs/insta/insta_inference.py | 54 +++++++++++ demo/colorization_demo.py | 3 +- mmedit/apis/colorization_inference.py | 4 +- mmedit/datasets/__init__.py | 2 + mmedit/datasets/coco.py | 95 +++++++------------ mmedit/datasets/transforms/formatting.py | 12 ++- .../datasets/transforms/get_gray_color_pil.py | 2 +- .../datasets/transforms/get_maskrcnn_bbox.py | 19 ++-- .../models/base_models/base_colorization.py | 1 + mmedit/models/editors/insta/insta.py | 62 ++++++------ mmedit/models/editors/insta/util.py | 2 +- 12 files changed, 169 insertions(+), 124 deletions(-) create mode 100644 configs/insta/insta_inference.py diff --git a/configs/insta/insta_full_cocostuff_256x256.py b/configs/insta/insta_full_cocostuff_256x256.py index e25f2836c9..f30c272005 100644 --- a/configs/insta/insta_full_cocostuff_256x256.py +++ b/configs/insta/insta_full_cocostuff_256x256.py @@ -7,7 +7,7 @@ work_dir = '..' model = dict( - type='FusionModel', + type='INSTA', data_preprocessor=dict( type='EditDataPreprocessor', mean=[127.5], @@ -19,25 +19,25 @@ output_nc=2, norm_type='batch' ), - stage='full', + insta_stage='full', ngf=64, output_nc=2, avg_loss_alpha=.986, ab_norm=110., + ab_max=110., + ab_quant=10., l_norm=100., l_cent=50., sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], mask_cent=.5, - init_type='normal', which_direction='AtoB', loss=dict(type='HuberLoss', delta=.01), - pretrained='./checkpoints/pytorch_trained.pth' ) input_shape = (256, 256) train_pipeline = [ - dict(type='LoadImageFromFile', key='gt_img'), + dict(type='LoadImageFromFile', key='img'), dict(type='GenGrayColorPil', stage='full', keys=['rgb_img', 'gray_img']), dict( type='Resize', @@ -50,10 +50,10 @@ ] test_pipeline = [ - dict(type='LoadImageFromFile', key='gt'), + dict(type='LoadImageFromFile', key='img'), dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), dict(type='Resize', - keys=['gt'], + keys=['img'], scale=(256, 256), keep_ratio=False ), @@ -61,18 +61,17 @@ ] dataset_type = 'CocoDataset' -data_root = '/mnt/j/DataSet/cocostuff/train2017' -ann_file_path = '/mnt/j/DataSet/cocostuff/' +data_root = '/mnt/j/DataSet/cocostuff' +ann_file_path = '/mnt/j/DataSet/cocostuff' train_dataloader = dict( batch_size=4, num_workers=4, persistent_workers=False, sampler=dict(shuffle=False), - workers_per_gpu=1, dataset=dict( type=dataset_type, - data_root=data_root, + data_root=data_root + '/train2017', data_prefix=dict(gt='data_large'), ann_file=f'{ann_file_path}/img_list.txt', pipeline=train_pipeline, @@ -80,18 +79,16 @@ test_dataloader = dict( batch_size=1, - num_workers=4, + num_workers=1, persistent_workers=False, sampler=dict(type='DefaultSampler', shuffle=False), - workers_per_gpu=1, dataset=dict( type=dataset_type, - data_root=data_root, + data_root=data_root + '/train2017', data_prefix=dict(gt='data_large'), - ann_file=f'{ann_file_path}/img_list.txt', + ann_file=f'{ann_file_path}/train_annotation.json', pipeline=test_pipeline, - test_mode=False)) - + test_mode=True)) test_evaluator = [dict(type='PSNR'), dict(type='SSIM')] @@ -101,10 +98,12 @@ val_interval=50000, ) +val_dataloader = test_dataloader +val_evaluator = test_evaluator + val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') - # optimizer optim_wrapper = dict( constructor='DefaultOptimWrapperConstructor', @@ -137,5 +136,3 @@ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), ) - - diff --git a/configs/insta/insta_inference.py b/configs/insta/insta_inference.py new file mode 100644 index 0000000000..4e3972cd33 --- /dev/null +++ b/configs/insta/insta_inference.py @@ -0,0 +1,54 @@ +_base_ = [ + '../_base_/default_runtime.py' +] + +exp_name = 'Instance-aware_full' +save_dir = './' +work_dir = '..' + +model = dict( + type='INSTA', + data_preprocessor=dict( + type='EditDataPreprocessor', + mean=[127.5], + std=[127.5], + ), + instance_model=dict( + type='SIGGRAPHGenerator', + input_nc=4, + output_nc=2, + norm_type='batch' + ), + fusion_model=dict( + type='FusionGenerator', + input_nc=4, + output_nc=2, + norm_type='batch' + ), + insta_stage='test', + ngf=64, + output_nc=2, + avg_loss_alpha=.986, + ab_norm=110., + ab_max=110., + ab_quant=10., + l_norm=100., + l_cent=50., + sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], + mask_cent=.5, + which_direction='AtoB', + loss=dict(type='HuberLoss', delta=.01), +) + +input_shape = (256, 256) + +test_pipeline = [ + dict(type='LoadImageFromFile', key='img'), + dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), + dict(type='Resize', + keys=['img'], + scale=(256, 256), + keep_ratio=False + ), + dict(type='PackEditInputs'), +] \ No newline at end of file diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py index 1654895be2..e9096be308 100644 --- a/demo/colorization_demo.py +++ b/demo/colorization_demo.py @@ -14,7 +14,6 @@ def parse_args(): parser.add_argument('config', help='test config file path') parser.add_argument('checkpoints', help='checkpoints file path') parser.add_argument('img_path', help='path to input image file') - # parser.add_argument('bbox_path', help='path to input image bbox file') parser.add_argument('save_path', help='path to save generation result') parser.add_argument( '--unpaired-path', default=None, help='path to unpaired image file') @@ -35,7 +34,7 @@ def main(): # model = init_model(args.config, args.checkpoints, device=device) - output = colorization_inference(model, args.img_path, args.bbox_path) + output = colorization_inference(model, args.img_path) if args.imshow: mmcv.imshow(output, 'predicted generation result') diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py index 8f1b05da67..4cd5c53440 100644 --- a/mmedit/apis/colorization_inference.py +++ b/mmedit/apis/colorization_inference.py @@ -6,14 +6,14 @@ from torch.nn.parallel import scatter -def colorization_inference(model, img, bbox): +def colorization_inference(model, img): device = next(model.parameters()).device # build the data pipeline test_pipeline = Compose(model.cfg.test_pipeline) # prepare data - data = dict(gt_path=img, bbox_path=bbox) + data = dict(img_path=img) data = test_pipeline(data) data = collate([data]) diff --git a/mmedit/datasets/__init__.py b/mmedit/datasets/__init__.py index 0744816ca6..850f7787b7 100644 --- a/mmedit/datasets/__init__.py +++ b/mmedit/datasets/__init__.py @@ -8,6 +8,7 @@ from .imagenet_dataset import ImageNet from .paired_image_dataset import PairedImageDataset from .unpaired_image_dataset import UnpairedImageDataset +from .coco import CocoDataset __all__ = [ 'AdobeComp1kDataset', @@ -19,4 +20,5 @@ 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', + 'CocoDataset' ] diff --git a/mmedit/datasets/coco.py b/mmedit/datasets/coco.py index 4390c43875..7296a0d506 100644 --- a/mmedit/datasets/coco.py +++ b/mmedit/datasets/coco.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp from typing import List, Callable, Optional, Union from pathlib import Path +from typing import List, Union, Dict from mmengine.dataset import BaseDataset from mmengine.fileio import load @@ -9,76 +11,43 @@ @DATASETS.register_module() -class CocoDataset: +class CocoDataset(BaseDataset): """Dataset for COCO.""" METAINFO = { - 'CLASSES': - ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', - 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', - 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', - 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', - 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', - 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', - 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', - 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', - 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), - # PALETTE is a list of color tuples, which is used for visualization. - 'PALETTE': - [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), - (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), - (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), - (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), - (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), - (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), - (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), - (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), - (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), - (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), - (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), - (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), - (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), - (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), - (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), - (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), - (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), - (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), - (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), - (246, 0, 122), (191, 162, 208)] + 'dataset_type': 'colorization_dataset', + 'task_name': 'colorization', } - METAINFO = dict(dataset_type='colorization_dataset', task_name='colorization') + def load_data_list(self) -> List[dict]: - def __init__( - self, - ann_file: str, - data_prefix - ): + annotations = load(self.ann_file) - self.ann_file = str(ann_file) - self.data_prefix = data_prefix - self.data_infos = self.load_annotations() + assert annotations, f'annotation file "{self.ann_file}" is empty.' - def load_annotations(self): - """Load annotations for dataset. + metainfo = annotations['metainfo'] + raw_data_list = annotations['data_list'] - Returns: - list[dict]: Contain dataset annotations. - """ - with open(self.ann_file, 'r') as f: - img_infos = [] - for idx, line in enumerate(f): - line = line.strip() - _info = dict() - img_path = line.split(' ')[0].split('/')[1] - _info = dict( - gt_img_path=Path( - self.data_prefix).joinpath(img_path).as_posix(), - gt_img_idx=idx) - img_infos.append(_info) + for k, v in metainfo.items(): + self._metainfo.setdefault(k, v) - return img_infos + data_list = [] + for raw_data_info in raw_data_list: + data_info = self.parse_data_info(raw_data_info) + if isinstance(data_info, dict): + data_list.append(data_info) + else: + raise TypeError('data_info should be a dict or list of dict, ' + f'but got {type(data_info)}') + + return data_list + + def parse_data_info(self, raw_data_info: dict) -> dict: + """Join data_root to each path in data_info.""" + + data_info = raw_data_info.copy() + for key in raw_data_info: + if 'path' in key: + data_info['gt_img_path'] = osp.join(self.data_root, data_info[key]) + + return data_info diff --git a/mmedit/datasets/transforms/formatting.py b/mmedit/datasets/transforms/formatting.py index f271c31bca..727ef2fc2f 100644 --- a/mmedit/datasets/transforms/formatting.py +++ b/mmedit/datasets/transforms/formatting.py @@ -200,6 +200,16 @@ def transform(self, results: dict) -> dict: gt_bg = results.pop('bg') gt_bg_tensor = images_to_tensor(gt_bg) data_sample.gt_bg = PixelData(data=gt_bg_tensor) + + if 'rgb_img' in results: + gt_rgb = results.pop('rgb_img') + gt_rgb_tensor = images_to_tensor(gt_rgb) + data_sample.gt_rgb = PixelData(data=gt_rgb_tensor) + + if 'gray_img' in results: + gray = results.pop('gray_img') + gray_tensor = images_to_tensor(gray) + data_sample.gray = PixelData(data=gray_tensor) metainfo = dict() for key in results: @@ -233,7 +243,7 @@ def __init__(self, keys, to_float32=True): self.keys = keys self.to_float32 = to_float32 - + def _data_to_tensor(self, value): """Convert the value to tensor.""" is_image = check_if_image(value) diff --git a/mmedit/datasets/transforms/get_gray_color_pil.py b/mmedit/datasets/transforms/get_gray_color_pil.py index 91dd763c35..3c59f72ebe 100644 --- a/mmedit/datasets/transforms/get_gray_color_pil.py +++ b/mmedit/datasets/transforms/get_gray_color_pil.py @@ -17,7 +17,7 @@ def transform(self, results): if self.stage == 'instance': rgb_img = results['instance'] else: - rgb_img = results['gt_img'] + rgb_img = results['img'] if len(rgb_img.shape) == 2: rgb_img = np.stack([rgb_img, rgb_img, rgb_img], 2) gray_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY) diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py index 493caf51ab..f9f360b299 100644 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -17,7 +17,7 @@ @TRANSFORMS.register_module() class GenMaskRCNNBbox: - def __init__(self, key='gt', stage='test_fusion', finesize=256): + def __init__(self, key='img', stage='test_fusion', finesize=256): self.key = key self.predictor = self.detectron() self.stage = stage @@ -121,13 +121,16 @@ def get_box_info(pred_bbox, original_shape, final_size): return [L_pad, R_pad, T_pad, B_pad, rh, rw] def test_fusion(self, results): - img = results['gt'] + img = results['img'] pil_img = self.read_to_pil(img) - if results['bbox_path']: + + if 'bbox_path' in results.keys(): pred_bbox = self.gen_maskrcnn_bbox_fromPred( - img, results['bbox_path'], box_num_upbound=8) + img, results['bbox_path']) + elif 'instance' in results.keys(): + pred_bbox = results['instance'] else: - pred_bbox = self.gen_maskrcnn_bbox_fromPred(img, box_num_upbound=8) + pred_bbox = self.gen_maskrcnn_bbox_fromPred(img) img_list = [self.transforms(pil_img)] # 这里删除了一个transform @@ -172,11 +175,15 @@ def test_fusion(self, results): def train(self, results): img = results[self.key] - if results['bbox_path']: + + if 'bbox_path' in results.keys(): pred_bbox = self.gen_maskrcnn_bbox_fromPred( img, results['bbox_path']) + elif 'instance' in results.keys(): + pred_bbox = results['instance'][0]['bbox'] else: pred_bbox = self.gen_maskrcnn_bbox_fromPred(img) + rgb_img, gray_img = self.gen_gray_color_pil(img) index_list = range(len(pred_bbox)) index_list = sample(index_list, 1) diff --git a/mmedit/models/base_models/base_colorization.py b/mmedit/models/base_models/base_colorization.py index 37183bc18f..d4cf89a8de 100644 --- a/mmedit/models/base_models/base_colorization.py +++ b/mmedit/models/base_models/base_colorization.py @@ -7,6 +7,7 @@ from mmengine.model import BaseModel from mmengine.config import Config, ConfigDict +from mmedit.structures import EditDataSample, PixelData from mmedit.registry import MODELS diff --git a/mmedit/models/editors/insta/insta.py b/mmedit/models/editors/insta/insta.py index 180dcd5e4e..499ce21d64 100644 --- a/mmedit/models/editors/insta/insta.py +++ b/mmedit/models/editors/insta/insta.py @@ -22,11 +22,13 @@ def __init__(self, output_nc, avg_loss_alpha, ab_norm, + ab_max, + ab_quant, l_norm, l_cent, sample_Ps, mask_cent, - stage=None, + insta_stage=None, which_direction='AtoB', instance_model=None, full_model=None, @@ -47,22 +49,23 @@ def __init__(self, self.output_nc = output_nc self.avg_loss_alpha = avg_loss_alpha self.ab_norm = ab_norm + self.ab_max = ab_max + self.ab_quant = ab_quant self.l_norm = l_norm self.l_cent = l_cent self.sample_Ps = sample_Ps self.mask_cent = mask_cent self.which_direction = which_direction - self.device = torch.device('cuda:{}'.format( - self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + self.device = torch.device('cuda:{}'.format(0)) self.instance_model = instance_model self.full_model = full_model self.fusion_model = fusion_model - self.stage = stage + self.insta_stage = insta_stage - if self.stage is not None: + if self.insta_stage == 'full' or self.insta_stage == 'instance': self.training = False self.setup_to_train() else: @@ -116,7 +119,7 @@ def set_forward_without_box(self, input): self.fake_B_reg = self.comp_B_reg def generator_loss(self): - if self.stage == 'full' or self.stage == 'instance': + if self.insta_stage == 'full' or self.insta_stage == 'instance': self.loss_L1 = torch.mean( self.criterionL1( self.fake_B_reg.type(torch.cuda.FloatTensor), @@ -126,7 +129,7 @@ def generator_loss(self): self.fake_B_reg.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor))) - elif self.stage == 'fusion': + elif self.insta_stage == 'fusion': self.loss_L1 = torch.mean( self.criterionL1( self.fake_B_reg.type(torch.cuda.FloatTensor), @@ -166,7 +169,7 @@ def train_step(self, data_batch, optimizer): mask_cent=self.mask_cent, ) - if self.stage == 'full' or self.stage == 'instance': + if self.insta_stage == 'full' or self.insta_stage == 'instance': data_batch['rgb_img'] = [data_batch['rgb_img']] data_batch['gray_img'] = [data_batch['gray_img']] @@ -183,7 +186,7 @@ def train_step(self, data_batch, optimizer): (_, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, self.mask_B) - elif self.stage == 'fusion': + elif self.insta_stage == 'fusion': data_batch['cropped_rgb'] = torch.stack( data_batch['cropped_rgb_list']) @@ -254,13 +257,13 @@ def setup_to_train(self): self.loss_names = ['G', 'L1'] - if self.stage == 'full' or self.stage == 'instance': + if self.insta_stage == 'full' or self.insta_stage == 'instance': self.model_names = ['G'] self.netG = COMPONENTS.build(self.instance_model) generation_init_weights(self.netG) self.generator = self.netG - elif self.stage == 'fusion': + elif self.insta_stage == 'fusion': self.model_names = ['G', 'GF', 'GComp'] self.netG = COMPONENTS.build(self.instance_model) generation_init_weights(self.netG) @@ -292,8 +295,9 @@ def setup_to_train(self): list(self.netGF.module.model_out.parameters()) else: - print('Error Stage!') - exit() + # print('Error Stage!') + # exit() + pass self.criterionL1 = self.loss @@ -309,7 +313,7 @@ def get_current_visuals(self): visual_ret = OrderedDict() opt = dict( ab_norm=self.ab_norm, l_norm=self.l_norm, l_cent=self.l_cent) - if self.stage == 'full' or self.stage == 'instance': + if self.insta_stage == 'full' or self.insta_stage == 'instance': visual_ret['gray'] = lab2rgb( torch.cat((self.real_A.type( @@ -340,7 +344,7 @@ def get_current_visuals(self): self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), **opt) - elif self.stage == 'fusion': + elif self.insta_stage == 'fusion': visual_ret['gray'] = lab2rgb( torch.cat((self.full_real_A.type( torch.cuda.FloatTensor), torch.zeros_like( @@ -389,17 +393,19 @@ def get_current_visuals(self): exit() return visual_ret - def forward_test(self, **kwargs): + def forward_test(self, inputs, data_samples, **kwargs): + output = dict() - kwargs['full_img'][0] = kwargs['full_img'][0].cuda() - if not kwargs['empty_box']: - kwargs['cropped_img'][0] = kwargs['cropped_img'][0].cuda() - box_info = kwargs['box_info'][0] - box_info_2x = kwargs['box_info_2x'][0] - box_info_4x = kwargs['box_info_4x'][0] - box_info_8x = kwargs['box_info_8x'][0] + data = data_samples[0] + full_img= data.full_img + if not data.empty_box: + cropped_img = data.cropped_img + box_info = data.box_info + box_info_2x = data.box_info_2x + box_info_4x = data.box_info_4x + box_info_8x = data.box_info_8x cropped_data = get_colorization_data( - kwargs['cropped_img'], + cropped_img, ab_thresh=0, ab_norm=self.ab_norm, l_norm=self.l_norm, @@ -408,7 +414,7 @@ def forward_test(self, **kwargs): mask_cent=self.mask_cent, ) full_img_data = get_colorization_data( - kwargs['full_img'], + full_img, ab_thresh=0, ab_norm=self.ab_norm, l_norm=self.l_norm, @@ -422,7 +428,7 @@ def forward_test(self, **kwargs): [box_info, box_info_2x, box_info_4x, box_info_8x]) else: full_img_data = get_colorization_data( - kwargs['full_img'], ab_thresh=0) + full_img, ab_thresh=0) self.set_forward_without_box(full_img_data) (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B) @@ -448,7 +454,7 @@ def forward_test(self, **kwargs): def setup_to_test(self): self.netG = COMPONENTS.build(self.instance_model) - generation_init_weights(self.netG, self.init_type) + generation_init_weights(self.netG) self.netGF = COMPONENTS.build(self.fusion_model) - generation_init_weights(self.netGF, self.init_type) + generation_init_weights(self.netGF) diff --git a/mmedit/models/editors/insta/util.py b/mmedit/models/editors/insta/util.py index 9e859f3aa9..be0de4aa1d 100644 --- a/mmedit/models/editors/insta/util.py +++ b/mmedit/models/editors/insta/util.py @@ -153,7 +153,7 @@ def get_colorization_data(data_raw, **kwargs): data = {} - data_lab = rgb2lab(data_raw[0], **kwargs) + data_lab = rgb2lab(data_raw, **kwargs) data['A'] = data_lab[:, [ 0, ], :, :] From 1aaf2c2e4061d90c11a3570d8dbffd32efb62b20 Mon Sep 17 00:00:00 2001 From: ruoning Date: Sat, 15 Oct 2022 20:56:33 +0800 Subject: [PATCH 06/32] [Fix]: fix inference module of Instance-aware Image Colorization --- configs/inst_colorization/README.md | 12 +- configs/inst_colorization/README_zh-CN.md | 12 +- .../inst-colorizatioon_cocostuff_256x256.py | 114 ++------ ...t-colorizatioon_cocostuff_full_256x256.py} | 37 ++- ...colorizatioon_cocostuff_fusion_256x256.py} | 0 ...lorizatioon_cocostuff_instance_256x256.py} | 0 demo/colorization_demo.py | 7 +- mmedit/apis/colorization_inference.py | 4 +- mmedit/datasets/coco.py | 96 +++---- .../datasets/transforms/get_maskrcnn_bbox.py | 2 +- .../editors/inst_colorization/__init__.py | 3 +- .../inst_colorization/inst_colorization.py | 258 +++++++----------- .../inst_colorization_generator.py | 99 +++++++ 13 files changed, 297 insertions(+), 347 deletions(-) rename configs/inst_colorization/{insta_full_cocostuff_256x256.py => inst-colorizatioon_cocostuff_full_256x256.py} (79%) rename configs/inst_colorization/{insta_fusion_cocostuff_256x256.py => inst-colorizatioon_cocostuff_fusion_256x256.py} (100%) rename configs/inst_colorization/{insta_instance_cocostuff_256x256.py => inst-colorizatioon_cocostuff_instance_256x256.py} (100%) create mode 100644 mmedit/models/editors/inst_colorization/inst_colorization_generator.py diff --git a/configs/inst_colorization/README.md b/configs/inst_colorization/README.md index a21df6a724..ab147570e7 100644 --- a/configs/inst_colorization/README.md +++ b/configs/inst_colorization/README.md @@ -25,13 +25,13 @@ You can use the following commands to train a model with cpu or single/multiple ```shell # CPU train -CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/inst_colorization/insta_full_cocostuff_256x256.py +CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py # single-gpu train -python tools/train.py configs/inst_colorization/insta_full_cocostuff_256x256.py +python tools/train.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py # multi-gpu train -./tools/dist_train.sh configs/inst_colorization/insta_full_cocostuff_256x256.py 8 +./tools/dist_train.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py 8 ``` For more details, you can refer to **Train a model** part in [train_test.md](/docs/en/user_guides/train_test.md#Train-a-model-in-MMEditing). @@ -47,13 +47,13 @@ You can use the following commands to test a model with cpu or single/multiple G ```shell # CPU test -CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization//insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth +CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization//inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth # single-gpu test -python demo/colorization_demo.py configs/inst_colorization/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth +python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth # multi-gpu test -./tools/dist_test.sh configs/inst_colorization/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 +./tools/dist_test.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 ``` For more details, you can refer to **Test a pre-trained model** part in [train_test.md](/docs/en/user_guides/train_test.md#Test-a-pre-trained-model-in-MMEditing). diff --git a/configs/inst_colorization/README_zh-CN.md b/configs/inst_colorization/README_zh-CN.md index 157966bb44..752872abcb 100644 --- a/configs/inst_colorization/README_zh-CN.md +++ b/configs/inst_colorization/README_zh-CN.md @@ -25,13 +25,13 @@ Image colorization is inherently an ill-posed problem with multi-modal uncertain ```shell # CPU上训练 -CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/insta/insta_full_cocostuff_256x256.py +CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py # 单个GPU上训练 -python tools/train.py configs/insta/insta_full_cocostuff_256x256.py +python tools/train.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py # 多个GPU上训练 -./tools/dist_train.sh configs/insta/insta_full_cocostuff_256x256.py 8 +./tools/dist_train.sh configs/insta/inst-colorizatioon_cocostuff_full_256x256.py 8 ``` 更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Train a model** 部分。 @@ -47,13 +47,13 @@ python tools/train.py configs/insta/insta_full_cocostuff_256x256.py ```shell # CPU上测试 -CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth +CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth # 单个GPU上测试 -python demo/colorization_demo.py configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth +python demo/colorization_demo.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth # 多个GPU上测试 -./tools/dist_test.sh configs/insta/insta_full_cocostuff_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 +./tools/dist_test.sh configs/insta/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 ``` 更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Test a pre-trained model** 部分。 diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index fa24f2cdab..a96830c1e3 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -4,121 +4,39 @@ save_dir = './' work_dir = '..' +stage = 'test' model = dict( - type='FusionModel', + type='InstColorization', data_preprocessor=dict( type='EditDataPreprocessor', mean=[127.5], std=[127.5], ), - instance_model=dict( - type='SIGGRAPHGenerator', input_nc=4, output_nc=2, norm_type='batch'), - stage='full', + generator=dict( + type='InstColorizationGenerator', + stage=stage, + instance_model=dict( + type='InstanceGenerator', input_nc=4, output_nc=2, norm_type='batch'), + fusion_model=dict( + type='FusionGenerator', input_nc=4, output_nc=2, norm_type='batch') + ), + insta_stage=stage, ngf=64, output_nc=2, avg_loss_alpha=.986, ab_norm=110., + ab_max=110., + ab_quant=10., l_norm=100., l_cent=50., sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], mask_cent=.5, - init_type='normal', which_direction='AtoB', - loss=dict(type='HuberLoss', delta=.01), - pretrained='./checkpoints/pytorch_trained.pth') - -input_shape = (256, 256) - -train_pipeline = [ - dict(type='LoadImageFromFile', key='gt_img'), - dict(type='GenGrayColorPil', stage='full', keys=['rgb_img', 'gray_img']), - dict( - type='Resize', - keys=['rgb_img', 'gray_img'], - scale=input_shape, - keep_ratio=False, - interpolation='nearest'), - dict(type='RescaleToZeroOne', keys=['rgb_img', 'gray_img']), - dict(type='PackEditInputs') -] + loss=dict(type='HuberLoss', delta=.01)) test_pipeline = [ - dict(type='LoadImageFromFile', key='gt'), + dict(type='LoadImageFromFile', key='img'), dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), - dict(type='Resize', keys=['gt'], scale=(256, 256), keep_ratio=False), + dict(type='Resize', keys=['img'], scale=(256, 256), keep_ratio=False), dict(type='PackEditInputs'), ] - -dataset_type = 'CocoDataset' -data_root = '/mnt/j/DataSet/cocostuff/train2017' -ann_file_path = '/mnt/j/DataSet/cocostuff/' - -train_dataloader = dict( - batch_size=4, - num_workers=4, - persistent_workers=False, - sampler=dict(shuffle=False), - workers_per_gpu=1, - dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict(gt='data_large'), - ann_file=f'{ann_file_path}/img_list.txt', - pipeline=train_pipeline, - test_mode=False)) - -test_dataloader = dict( - batch_size=1, - num_workers=4, - persistent_workers=False, - sampler=dict(type='DefaultSampler', shuffle=False), - workers_per_gpu=1, - dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict(gt='data_large'), - ann_file=f'{ann_file_path}/img_list.txt', - pipeline=test_pipeline, - test_mode=False)) - -test_evaluator = [dict(type='PSNR'), dict(type='SSIM')] - -train_cfg = dict( - type='IterBasedTrainLoop', - max_iters=500002, - val_interval=50000, -) - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') - -# optimizer -optim_wrapper = dict( - constructor='DefaultOptimWrapperConstructor', - generator=dict( - type='OptimWrapper', - optimizer=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9))), - disc=dict( - type='OptimWrapper', - optimizer=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9)))) - -param_scheduler = dict( - # todo engine中暂时还没有这个 - type='LambdaLR', - by_epoch=False, -) - -vis_backends = [dict(type='LocalVisBackend')] - -visualizer = dict( - type='ConcatImageVisualizer', - vis_backends=vis_backends, - fn_key='gt_path', - img_keys=['gray', 'real', 'fake_reg', 'hint', 'real_ab', 'fake_ab_reg'], - bgr2rgb=False) - -env_cfg = dict( - cudnn_benchmark=False, - mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), - dist_cfg=dict(backend='nccl'), -) diff --git a/configs/inst_colorization/insta_full_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py similarity index 79% rename from configs/inst_colorization/insta_full_cocostuff_256x256.py rename to configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py index fa24f2cdab..6e6be6db6b 100644 --- a/configs/inst_colorization/insta_full_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py @@ -4,25 +4,31 @@ save_dir = './' work_dir = '..' +stage = 'full' model = dict( - type='FusionModel', + type='INSTA', data_preprocessor=dict( type='EditDataPreprocessor', mean=[127.5], std=[127.5], ), - instance_model=dict( - type='SIGGRAPHGenerator', input_nc=4, output_nc=2, norm_type='batch'), - stage='full', + generator=dict( + type='InstColorizationGenerator', + stage=stage, + instance_model=dict( + type='SIGGRAPHGenerator', input_nc=4, output_nc=2, norm_type='batch'), + ), + insta_stage=stage, ngf=64, output_nc=2, avg_loss_alpha=.986, ab_norm=110., + ab_max=110., + ab_quant=10., l_norm=100., l_cent=50., sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], mask_cent=.5, - init_type='normal', which_direction='AtoB', loss=dict(type='HuberLoss', delta=.01), pretrained='./checkpoints/pytorch_trained.pth') @@ -30,7 +36,7 @@ input_shape = (256, 256) train_pipeline = [ - dict(type='LoadImageFromFile', key='gt_img'), + dict(type='LoadImageFromFile', key='img'), dict(type='GenGrayColorPil', stage='full', keys=['rgb_img', 'gray_img']), dict( type='Resize', @@ -43,25 +49,24 @@ ] test_pipeline = [ - dict(type='LoadImageFromFile', key='gt'), + dict(type='LoadImageFromFile', key='img'), dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), dict(type='Resize', keys=['gt'], scale=(256, 256), keep_ratio=False), dict(type='PackEditInputs'), ] dataset_type = 'CocoDataset' -data_root = '/mnt/j/DataSet/cocostuff/train2017' -ann_file_path = '/mnt/j/DataSet/cocostuff/' +data_root = '/mnt/j/DataSet/cocostuff' +ann_file_path = '/mnt/j/DataSet/cocostuff' train_dataloader = dict( batch_size=4, num_workers=4, persistent_workers=False, sampler=dict(shuffle=False), - workers_per_gpu=1, dataset=dict( type=dataset_type, - data_root=data_root, + data_root=data_root + '/train2017', data_prefix=dict(gt='data_large'), ann_file=f'{ann_file_path}/img_list.txt', pipeline=train_pipeline, @@ -69,15 +74,14 @@ test_dataloader = dict( batch_size=1, - num_workers=4, + num_workers=1, persistent_workers=False, sampler=dict(type='DefaultSampler', shuffle=False), - workers_per_gpu=1, dataset=dict( type=dataset_type, - data_root=data_root, + data_root=data_root + '/train2017', data_prefix=dict(gt='data_large'), - ann_file=f'{ann_file_path}/img_list.txt', + ann_file=f'{ann_file_path}/train_annotation.json', pipeline=test_pipeline, test_mode=False)) @@ -89,6 +93,9 @@ val_interval=50000, ) +val_dataloader = test_dataloader +val_evaluator = test_evaluator + val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') diff --git a/configs/inst_colorization/insta_fusion_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_fusion_256x256.py similarity index 100% rename from configs/inst_colorization/insta_fusion_cocostuff_256x256.py rename to configs/inst_colorization/inst-colorizatioon_cocostuff_fusion_256x256.py diff --git a/configs/inst_colorization/insta_instance_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_instance_256x256.py similarity index 100% rename from configs/inst_colorization/insta_instance_cocostuff_256x256.py rename to configs/inst_colorization/inst-colorizatioon_cocostuff_instance_256x256.py diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py index 2cc8b8f7b7..cbfa00fe66 100644 --- a/demo/colorization_demo.py +++ b/demo/colorization_demo.py @@ -5,7 +5,7 @@ import torch from mmedit.apis import colorization_inference, init_model -from mmedit.utils import modify_args +from mmedit.utils import modify_args, tensor2img def parse_args(): @@ -34,8 +34,9 @@ def main(): # model = init_model(args.config, args.checkpoints, device=device) - output = colorization_inference(model, args.img_path, args.bbox_path) - mmcv.imwrite(output, args.save_path) + output = colorization_inference(model, args.img_path) + result = tensor2img(output)[..., ::-1] + mmcv.imwrite(result, args.save_path) if args.imshow: mmcv.imshow(output, 'predicted generation result') diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py index 1cf0f9676c..352d32f6ce 100644 --- a/mmedit/apis/colorization_inference.py +++ b/mmedit/apis/colorization_inference.py @@ -20,6 +20,6 @@ def colorization_inference(model, img): data = scatter(data, [device])[0] # forward the model with torch.no_grad(): - result = model(mode='predict', **data) + result = model(mode='tensor', **data) - return result['fake_img'] + return result diff --git a/mmedit/datasets/coco.py b/mmedit/datasets/coco.py index 82a3b9b3a3..7296a0d506 100644 --- a/mmedit/datasets/coco.py +++ b/mmedit/datasets/coco.py @@ -1,77 +1,53 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List, Callable, Optional, Union from pathlib import Path +from typing import List, Union, Dict + +from mmengine.dataset import BaseDataset +from mmengine.fileio import load from mmedit.registry import DATASETS @DATASETS.register_module() -class CocoDataset: +class CocoDataset(BaseDataset): """Dataset for COCO.""" METAINFO = { - 'CLASSES': - ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', - 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', - 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', - 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', - 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', - 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', - 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', - 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', - 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', - 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), - # PALETTE is a list of color tuples, which is used for visualization. - 'PALETTE': - [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), - (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), - (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), - (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), - (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), - (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), - (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), - (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), - (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), - (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), - (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), - (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), - (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), - (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), - (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), - (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), - (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), - (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), - (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), - (246, 0, 122), (191, 162, 208)] + 'dataset_type': 'colorization_dataset', + 'task_name': 'colorization', } - METAINFO = dict( - dataset_type='colorization_dataset', task_name='colorization') + def load_data_list(self) -> List[dict]: + + annotations = load(self.ann_file) + + assert annotations, f'annotation file "{self.ann_file}" is empty.' + + metainfo = annotations['metainfo'] + raw_data_list = annotations['data_list'] + + for k, v in metainfo.items(): + self._metainfo.setdefault(k, v) - def __init__(self, ann_file: str, data_prefix): + data_list = [] + for raw_data_info in raw_data_list: + data_info = self.parse_data_info(raw_data_info) + if isinstance(data_info, dict): + data_list.append(data_info) + else: + raise TypeError('data_info should be a dict or list of dict, ' + f'but got {type(data_info)}') - self.ann_file = str(ann_file) - self.data_prefix = data_prefix - self.data_infos = self.load_annotations() + return data_list - def load_annotations(self): - """Load annotations for dataset. + def parse_data_info(self, raw_data_info: dict) -> dict: + """Join data_root to each path in data_info.""" - Returns: - list[dict]: Contain dataset annotations. - """ - with open(self.ann_file, 'r') as f: - img_infos = [] - for idx, line in enumerate(f): - line = line.strip() - _info = dict() - img_path = line.split(' ')[0].split('/')[1] - _info = dict( - gt_img_path=Path( - self.data_prefix).joinpath(img_path).as_posix(), - gt_img_idx=idx) - img_infos.append(_info) + data_info = raw_data_info.copy() + for key in raw_data_info: + if 'path' in key: + data_info['gt_img_path'] = osp.join(self.data_root, data_info[key]) - return img_infos + return data_info diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py index ca5df05f29..517feed7f4 100644 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -212,6 +212,6 @@ def detectron(self): model_zoo.get_config_file( 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml')) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 - cfg.MODEL.WEIGHTS = '/mnt/d/code/MMEditing/model_final_2d9806.pkl' + cfg.MODEL.WEIGHTS = '/mnt/ruoning/model_final_2d9806.pkl' predictor = DefaultPredictor(cfg) return predictor diff --git a/mmedit/models/editors/inst_colorization/__init__.py b/mmedit/models/editors/inst_colorization/__init__.py index faa56edca0..dcf30b8d9e 100644 --- a/mmedit/models/editors/inst_colorization/__init__.py +++ b/mmedit/models/editors/inst_colorization/__init__.py @@ -2,8 +2,9 @@ from .inst_colorization import InstColorization from .inst_colorization_net import (FusionGenerator, InstanceGenerator, SIGGRAPHGenerator) +from .inst_colorization_generator import InstColorizationGenerator __all__ = [ 'InstColorization', 'SIGGRAPHGenerator', 'InstanceGenerator', - 'FusionGenerator' + 'FusionGenerator', 'InstColorizationGenerator' ] diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 7f113f6361..65aa178ccb 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -1,13 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. - from collections import OrderedDict -from typing import Union +from typing import Union, List, Dict import torch from mmengine.config import Config +from mmengine.optim import OptimWrapperDict from mmedit.models.utils import (encode_ab_ind, generation_init_weights, get_colorization_data, lab2rgb) +from mmedit.structures import EditDataSample, PixelData from mmedit.registry import BACKBONES, COMPONENTS from ..srgan import SRGAN @@ -29,16 +30,16 @@ def __init__(self, mask_cent, insta_stage=None, which_direction='AtoB', - instance_model=None, - full_model=None, - fusion_model=None, + generator=None, loss=None, init_cfg=None, train_cfg=None, test_cfg=None): + super(InstColorization, self).__init__( + generator=generator, data_preprocessor=data_preprocessor, - loss=loss, + pixel_loss=loss, init_cfg=init_cfg, train_cfg=train_cfg, test_cfg=test_cfg) @@ -57,19 +58,19 @@ def __init__(self, self.device = torch.device('cuda:{}'.format(0)) - self.instance_model = instance_model - self.full_model = full_model - self.fusion_model = fusion_model - self.insta_stage = insta_stage if self.insta_stage == 'full' or self.insta_stage == 'instance': self.training = False self.setup_to_train() - else: - self.setup_to_test() def set_input(self, input): + + self.encode_ab_opt = dict( + ab_norm=self.ab_norm, + ab_max=self.ab_max, + ab_quant=self.ab_quant + ) AtoB = self.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) @@ -79,12 +80,10 @@ def set_input(self, input): self.mask_B_nc = self.mask_B + self.mask_cent self.real_B_enc = encode_ab_ind( - self.real_B[:, :, ::4, ::4], - ab_norm=self.ab_norm, - ab_max=self.ab_max, - ab_quant=self.ab_quant) + self.real_B[:, :, ::4, ::4], **self.encode_ab_opt) def set_fusion_input(self, input, box_info): + AtoB = self.which_direction == 'AtoB' self.full_real_A = input['A' if AtoB else 'B'].to(self.device) self.full_real_B = input['B' if AtoB else 'A'].to(self.device) @@ -94,13 +93,11 @@ def set_fusion_input(self, input, box_info): self.full_mask_B_nc = self.full_mask_B + self.mask_cent self.full_real_B_enc = encode_ab_ind( - self.full_real_B[:, :, ::4, ::4], - ab_norm=self.ab_norm, - ab_max=self.ab_max, - ab_quant=self.ab_quant) + self.full_real_B[:, :, ::4, ::4], **self.encode_ab_opt) self.box_info_list = box_info def set_forward_without_box(self, input): + AtoB = self.which_direction == 'AtoB' self.full_real_A = input['A' if AtoB else 'B'].to(self.device) self.full_real_B = input['B' if AtoB else 'A'].to(self.device) @@ -109,7 +106,7 @@ def set_forward_without_box(self, input): self.full_mask_B = input['mask_B'].to(self.device) self.full_mask_B_nc = self.full_mask_B + self.mask_cent self.full_real_B_enc = encode_ab_ind(self.full_real_B[:, :, ::4, ::4], - self) + **self.encode_ab_opt) (_, self.comp_B_reg) = self.netGComp(self.full_real_A, self.full_hint_B, @@ -117,25 +114,21 @@ def set_forward_without_box(self, input): self.fake_B_reg = self.comp_B_reg def generator_loss(self): + if self.insta_stage == 'full' or self.insta_stage == 'instance': self.loss_L1 = torch.mean( self.criterionL1( self.fake_B_reg.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor))) - self.loss_G = 10 * torch.mean( - self.criterionL1( - self.fake_B_reg.type(torch.cuda.FloatTensor), - self.real_B.type(torch.cuda.FloatTensor))) + self.loss_G = 10 * self.loss_L1 elif self.insta_stage == 'fusion': self.loss_L1 = torch.mean( self.criterionL1( self.fake_B_reg.type(torch.cuda.FloatTensor), self.full_real_B.type(torch.cuda.FloatTensor))) - self.loss_G = 10 * torch.mean( - self.criterionL1( - self.fake_B_reg.type(torch.cuda.FloatTensor), - self.full_real_B.type(torch.cuda.FloatTensor))) + self.loss_G = 10 * self.loss_L1 + else: print('Error! Wrong stage selection!') exit() @@ -147,14 +140,19 @@ def generator_loss(self): # float(...) works for both scalar tensor and float number self.avg_losses[name] = float(getattr( self, 'loss_' + - name)) + self.avg_loss_alpha * self.avg_losses[name] + name)) + self.avg_loss_alpha * self.avg_losses[name] errors_ret[name] = (1 - self.avg_loss_alpha) / ( - 1 - self.avg_loss_alpha** # noqa - self.error_cnt) * self.avg_losses[name] + 1 - self.avg_loss_alpha ** # noqa + self.error_cnt) * self.avg_losses[name] return errors_ret - def train_step(self, data_batch, optimizer): + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: + + g_optim_wrapper = optim_wrapper['generator'] + data = self.data_preprocessor(data, True) + batch_inputs, data_samples = data['inputs'], data['data_samples'] log_vars = {} @@ -168,51 +166,50 @@ def train_step(self, data_batch, optimizer): ) if self.insta_stage == 'full' or self.insta_stage == 'instance': - data_batch['rgb_img'] = [data_batch['rgb_img']] - data_batch['gray_img'] = [data_batch['gray_img']] + data_samples['rgb_img'] = [data_samples['rgb_img']] + data_samples['gray_img'] = [data_samples['gray_img']] - input_data = get_colorization_data(data_batch['gray_img'], + input_data = get_colorization_data(data_samples['gray_img'], **colorization_data_opt) - gt_data = get_colorization_data(data_batch['rgb_img'], + gt_data = get_colorization_data(data_samples['rgb_img'], **colorization_data_opt) input_data['B'] = gt_data['B'] input_data['hint_B'] = gt_data['hint_B'] input_data['mask_B'] = gt_data['mask_B'] self.set_input(input_data) - (_, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, - self.mask_B) + self.fake_B_reg = self.generator(self.real_A, self.hint_B, self.mask_B) elif self.insta_stage == 'fusion': - data_batch['cropped_rgb'] = torch.stack( - data_batch['cropped_rgb_list']) - data_batch['cropped_gray'] = torch.stack( - data_batch['cropped_gray_list']) - data_batch['full_rgb'] = torch.stack(data_batch['full_rgb_list']) - data_batch['full_gray'] = torch.stack(data_batch['full_gray_list']) - data_batch['box_info'] = torch.from_numpy( - data_batch['box_info']).type(torch.long) - data_batch['box_info_2x'] = torch.from_numpy( - data_batch['box_info_2x']).type(torch.long) - data_batch['box_info_4x'] = torch.from_numpy( - data_batch['box_info_4x']).type(torch.long) - data_batch['box_info_8x'] = torch.from_numpy( - data_batch['box_info_8x']).type(torch.long) - - box_info = data_batch['box_info'][0] - box_info_2x = data_batch['box_info_2x'][0] - box_info_4x = data_batch['box_info_4x'][0] - box_info_8x = data_batch['box_info_8x'][0] + data_samples['cropped_rgb'] = torch.stack( + data_samples['cropped_rgb_list']) + data_samples['cropped_gray'] = torch.stack( + data_samples['cropped_gray_list']) + data_samples['full_rgb'] = torch.stack(data_samples['full_rgb_list']) + data_samples['full_gray'] = torch.stack(data_samples['full_gray_list']) + data_samples['box_info'] = torch.from_numpy( + data_samples['box_info']).type(torch.long) + data_samples['box_info_2x'] = torch.from_numpy( + data_samples['box_info_2x']).type(torch.long) + data_samples['box_info_4x'] = torch.from_numpy( + data_samples['box_info_4x']).type(torch.long) + data_samples['box_info_8x'] = torch.from_numpy( + data_samples['box_info_8x']).type(torch.long) + + box_info = data_samples['box_info'][0] + box_info_2x = data_samples['box_info_2x'][0] + box_info_4x = data_samples['box_info_4x'][0] + box_info_8x = data_samples['box_info_8x'][0] cropped_input_data = get_colorization_data( - data_batch['cropped_gray'], **colorization_data_opt) - cropped_gt_data = get_colorization_data(data_batch['cropped_rgb'], + data_samples['cropped_gray'], **colorization_data_opt) + cropped_gt_data = get_colorization_data(data_samples['cropped_rgb'], **colorization_data_opt) - full_input_data = get_colorization_data(data_batch['full_gray'], + full_input_data = get_colorization_data(data_samples['full_gray'], **colorization_data_opt) - full_gt_data = get_colorization_data(data_batch['full_rgb'], + full_gt_data = get_colorization_data(data_samples['full_rgb'], **colorization_data_opt) cropped_input_data['B'] = cropped_gt_data['B'] @@ -223,13 +220,10 @@ def train_step(self, data_batch, optimizer): full_input_data, [box_info, box_info_2x, box_info_4x, box_info_8x]) - (_, self.comp_B_reg) = self.netGComp(self.full_real_A, - self.full_hint_B, - self.full_mask_B) - (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B) - self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, - self.full_mask_B, feature_map, - self.box_info_list) + self.fake_B_reg = self.generator( + self.real_A, self.hint_B, self.mask_B, self.full_real_A, self.full_hint_B, + self.full_mask_B, self.box_info_list + ) optimizer['generator'].zero_grad() @@ -255,48 +249,6 @@ def setup_to_train(self): self.loss_names = ['G', 'L1'] - if self.insta_stage == 'full' or self.insta_stage == 'instance': - self.model_names = ['G'] - self.netG = COMPONENTS.build(self.instance_model) - generation_init_weights(self.netG) - self.generator = self.netG - - elif self.insta_stage == 'fusion': - self.model_names = ['G', 'GF', 'GComp'] - self.netG = COMPONENTS.build(self.instance_model) - generation_init_weights(self.netG) - self.netG.eval() - - self.netGF = COMPONENTS.build(self.fusion_model) - generation_init_weights(self.netGF) - self.netGF.eval() - - self.netGComp = COMPONENTS.build(self.full_model) - generation_init_weights(self.netGComp) - self.netGComp.eval() - - self.generator = \ - list(self.netGF.module.weight_layer.parameters()) + \ - list(self.netGF.module.weight_layer2.parameters()) + \ - list(self.netGF.module.weight_layer3.parameters()) + \ - list(self.netGF.module.weight_layer4.parameters()) + \ - list(self.netGF.module.weight_layer5.parameters()) + \ - list(self.netGF.module.weight_layer6.parameters()) + \ - list(self.netGF.module.weight_layer7.parameters()) + \ - list(self.netGF.module.weight_layer8_1.parameters()) + \ - list(self.netGF.module.weight_layer8_2.parameters()) + \ - list(self.netGF.module.weight_layer9_1.parameters()) + \ - list(self.netGF.module.weight_layer9_2.parameters()) + \ - list(self.netGF.module.weight_layer10_1.parameters()) + \ - list(self.netGF.module.weight_layer10_2.parameters()) + \ - list(self.netGF.module.model10.parameters()) + \ - list(self.netGF.module.model_out.parameters()) - - else: - # print('Error Stage!') - # exit() - pass - self.criterionL1 = self.loss # initialize average loss values @@ -307,7 +259,7 @@ def setup_to_train(self): self.avg_losses[loss_name] = 0 def get_current_visuals(self): - from collections import OrderedDict + visual_ret = OrderedDict() opt = dict( ab_norm=self.ab_norm, l_norm=self.l_norm, l_cent=self.l_cent) @@ -316,8 +268,8 @@ def get_current_visuals(self): visual_ret['gray'] = lab2rgb( torch.cat((self.real_A.type( torch.cuda.FloatTensor), torch.zeros_like( - self.real_B).type(torch.cuda.FloatTensor)), - dim=1), **opt) + self.real_B).type(torch.cuda.FloatTensor)), + dim=1), **opt) visual_ret['real'] = lab2rgb( torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor)), @@ -335,19 +287,19 @@ def get_current_visuals(self): torch.cat((torch.zeros_like( self.real_A.type(torch.cuda.FloatTensor)), self.real_B.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **opt) visual_ret['fake_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **opt) elif self.insta_stage == 'fusion': visual_ret['gray'] = lab2rgb( torch.cat((self.full_real_A.type( torch.cuda.FloatTensor), torch.zeros_like( - self.full_real_B).type(torch.cuda.FloatTensor)), - dim=1), **opt) + self.full_real_B).type(torch.cuda.FloatTensor)), + dim=1), **opt) visual_ret['real'] = lab2rgb( torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.full_real_B.type(torch.cuda.FloatTensor)), @@ -355,10 +307,7 @@ def get_current_visuals(self): visual_ret['comp_reg'] = lab2rgb( torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.comp_B_reg.type(torch.cuda.FloatTensor)), - dim=1), - ab_norm=self.ab_norm, - l_norm=self.l_norm, - l_cent=self.l_cent) + dim=1), **opt) visual_ret['fake_reg'] = lab2rgb( torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), @@ -375,27 +324,36 @@ def get_current_visuals(self): torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.full_real_B.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **opt) visual_ret['comp_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.comp_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **opt) visual_ret['fake_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **opt) else: print('Error! Wrong stage selection!') exit() return visual_ret - def forward_test(self, inputs, data_samples, **kwargs): + def forward_tensor(self, inputs, data_samples, **kwargs): - output = dict() data = data_samples[0] - full_img= data.full_img + full_img = data.full_img + + convert_params = dict( + ab_thresh=0, + ab_norm=self.ab_norm, + l_norm=self.l_norm, + l_cent=self.l_cent, + sample_PS=self.sample_Ps, + mask_cent=self.mask_cent, + ) + if not data.empty_box: cropped_img = data.cropped_img box_info = data.box_info @@ -404,21 +362,11 @@ def forward_test(self, inputs, data_samples, **kwargs): box_info_8x = data.box_info_8x cropped_data = get_colorization_data( cropped_img, - ab_thresh=0, - ab_norm=self.ab_norm, - l_norm=self.l_norm, - l_cent=self.l_cent, - sample_PS=self.sample_Ps, - mask_cent=self.mask_cent, + **convert_params ) full_img_data = get_colorization_data( full_img, - ab_thresh=0, - ab_norm=self.ab_norm, - l_norm=self.l_norm, - l_cent=self.l_cent, - sample_PS=self.sample_Ps, - mask_cent=self.mask_cent, + **convert_params ) self.set_input(cropped_data) self.set_fusion_input( @@ -429,10 +377,9 @@ def forward_test(self, inputs, data_samples, **kwargs): full_img, ab_thresh=0) self.set_forward_without_box(full_img_data) - (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B) - self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, - self.full_mask_B, feature_map, - self.box_info_list) + self.fake_B_reg = self.generator( + self.real_A, self.hint_B, self.mask_B, self.full_real_A, + self.full_hint_B, self.full_mask_B, self.box_info_list) out_img = torch.clamp( lab2rgb( @@ -443,16 +390,17 @@ def forward_test(self, inputs, data_samples, **kwargs): l_norm=self.l_norm, l_cent=self.l_cent), 0.0, 1.0) - output['fake_img'] = out_img - output['meta'] = None if 'meta' not in kwargs else kwargs['meta'][0] - - self.save_visualization(out_img, - '/mnt/ruoning/results/output_mmedit11.png') - return output + return out_img - def setup_to_test(self): - self.netG = COMPONENTS.build(self.instance_model) - generation_init_weights(self.netG) + def forward_inference(self, inputs, data_samples=None, **kwargs): + feats = self.forward_tensor(inputs, data_samples, **kwargs) + predictions = [] + for idx in range(feats.shape[0]): + batch_tensor = feats[idx] * 127.5 + 127.5 + pred_img = PixelData(data=batch_tensor.to('cpu')) + predictions.append( + EditDataSample( + pred_img=pred_img, + metainfo=data_samples[idx].metainfo)) - self.netGF = COMPONENTS.build(self.fusion_model) - generation_init_weights(self.netGF) + return predictions diff --git a/mmedit/models/editors/inst_colorization/inst_colorization_generator.py b/mmedit/models/editors/inst_colorization/inst_colorization_generator.py new file mode 100644 index 0000000000..5b9c3850cd --- /dev/null +++ b/mmedit/models/editors/inst_colorization/inst_colorization_generator.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from mmedit.registry import BACKBONES, COMPONENTS +from mmedit.models.utils import generation_init_weights + + +@BACKBONES.register_module() +class InstColorizationGenerator(nn.Module): + + def __init__(self, + stage, + instance_model=None, + full_model=None, + fusion_model=None, + ): + + super(InstColorizationGenerator, self).__init__() + + self.stage = stage + + if self.stage == "test": + self.netG = COMPONENTS.build(instance_model) + generation_init_weights(self.netG) + + self.netGF = COMPONENTS.build(fusion_model) + generation_init_weights(self.netGF) + + elif self.stage == "instance" or stage == 'full': + self.netG = COMPONENTS.build(instance_model) + generation_init_weights(self.netG) + + elif self.stage == "fusion": + self.netG = COMPONENTS.build(instance_model) + generation_init_weights(self.netG) + self.netG.eval() + + self.netGF = COMPONENTS.build(fusion_model) + generation_init_weights(self.netGF) + self.netGF.eval() + + self.netGComp = COMPONENTS.build(full_model) + generation_init_weights(self.netGComp) + self.netGComp.eval() + + self.generator = \ + list(self.netGF.module.weight_layer.parameters()) + \ + list(self.netGF.module.weight_layer2.parameters()) + \ + list(self.netGF.module.weight_layer3.parameters()) + \ + list(self.netGF.module.weight_layer4.parameters()) + \ + list(self.netGF.module.weight_layer5.parameters()) + \ + list(self.netGF.module.weight_layer6.parameters()) + \ + list(self.netGF.module.weight_layer7.parameters()) + \ + list(self.netGF.module.weight_layer8_1.parameters()) + \ + list(self.netGF.module.weight_layer8_2.parameters()) + \ + list(self.netGF.module.weight_layer9_1.parameters()) + \ + list(self.netGF.module.weight_layer9_2.parameters()) + \ + list(self.netGF.module.weight_layer10_1.parameters()) + \ + list(self.netGF.module.weight_layer10_2.parameters()) + \ + list(self.netGF.module.model10.parameters()) + \ + list(self.netGF.module.model_out.parameters()) + else: + print('Error! Wrong stage selection!') + exit() + + def forward(self, + real_A, + hint_B, + mask_B, + full_real_A=None, + full_hint_B=None, + full_mask_B=None, + box_info_list=None + ): + if self.stage == 'test': + (_, feature_map) = self.netG(real_A, hint_B, mask_B) + fake_B_reg = self.netGF( + full_real_A, full_hint_B, full_mask_B, + feature_map, box_info_list + ) + + return fake_B_reg + + elif self.stage == 'full' or self.stage == 'instance': + (_, fake_B_reg) = self.netG(real_A, hint_B, mask_B) + + return fake_B_reg + + elif self.stage == 'fusion': + (_, self.comp_B_reg) = self.netGComp( + full_real_A, full_hint_B, full_mask_B) + + (_, feature_map) = self.netG(real_A, hint_B, mask_B) + + fake_B_reg = self.netGF( + full_real_A, full_hint_B, full_mask_B, + feature_map, box_info_list) + + return fake_B_reg \ No newline at end of file From 5719356485a4a603b658832a32812a2809707689 Mon Sep 17 00:00:00 2001 From: ruoning Date: Sun, 16 Oct 2022 13:59:02 +0800 Subject: [PATCH 07/32] [Refactor]: refactor get_maskrcnn_bbox.py and inst_colorization.py --- .../inst-colorizatioon_cocostuff_256x256.py | 2 +- .../datasets/transforms/get_maskrcnn_bbox.py | 105 +++++---- .../inst_colorization/inst_colorization.py | 205 +++++++++--------- 3 files changed, 149 insertions(+), 163 deletions(-) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index a96830c1e3..c6d2f3b087 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -36,7 +36,7 @@ test_pipeline = [ dict(type='LoadImageFromFile', key='img'), - dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), + dict(type='GenMaskRCNNBbox', stage=stage, finesize=256), dict(type='Resize', keys=['img'], scale=(256, 256), keep_ratio=False), dict(type='PackEditInputs'), ] diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py index 517feed7f4..f49f595973 100644 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -60,24 +60,6 @@ def gen_maskrcnn_bbox_fromPred(self, return pred_bbox - @staticmethod - def gen_gray_color_pil(rgb_img): - ''' - return: RGB and GRAY pillow image object - ''' - if len(np.asarray(rgb_img).shape) == 2: - rgb_img = np.stack([ - np.asarray(rgb_img), - np.asarray(rgb_img), - np.asarray(rgb_img) - ], 2) - rgb_img = Image.fromarray(rgb_img) - gray_img = np.round(color.rgb2gray(np.asarray(rgb_img)) * - 255.0).astype(np.uint8) - gray_img = np.stack([gray_img, gray_img, gray_img], -1) - gray_img = Image.fromarray(gray_img) - return rgb_img, gray_img - @staticmethod def read_to_pil(out_img): ''' @@ -120,45 +102,57 @@ def get_box_info(pred_bbox, original_shape, final_size): B_pad = final_size - resize_endy return [L_pad, R_pad, T_pad, B_pad, rh, rw] - def test_fusion(self, results): - img = results['img'] - pil_img = self.read_to_pil(img) + def fusion(self, results, img, pred_bbox): + if self.stage == 'test': + gray_img = self.read_to_pil(img) + return self.get_instance_info(results, pred_bbox, gray_img) - if 'bbox_path' in results.keys(): - pred_bbox = self.gen_maskrcnn_bbox_fromPred( - img, results['bbox_path']) - elif 'instance' in results.keys(): - pred_bbox = results['instance'] - else: - pred_bbox = self.gen_maskrcnn_bbox_fromPred(img) + if self.stage == 'fusion': + rgb_img, gray_img = results['rgb_img'], results['gray_img'] + return self.get_instance_info(results, pred_bbox, gray_img, rgb_img) - img_list = [self.transforms(pil_img)] # 这里删除了一个transform + def get_instance_info(self, results, pred_bbox, gray_img, rgb_img=None): + + if not rgb_img: + full_gray_list = [self.transforms(gray_img)] + cropped_gray_list = [] + else: + full_rgb_list = [self.transforms(rgb_img)] + full_gray_list = [self.transforms(gray_img)] + cropped_rgb_list = [] + cropped_gray_list = [] - cropped_img_list = [] index_list = range(len(pred_bbox)) box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros( (4, len(index_list), 6)) for i in index_list: startx, starty, endx, endy = pred_bbox[i] box_info[i] = np.array( - self.get_box_info(pred_bbox[i], pil_img.size, self.final_size)) + self.get_box_info(pred_bbox[i], gray_img.size, self.final_size)) box_info_2x[i] = np.array( - self.get_box_info(pred_bbox[i], pil_img.size, + self.get_box_info(pred_bbox[i], gray_img.size, self.final_size // 2)) box_info_4x[i] = np.array( - self.get_box_info(pred_bbox[i], pil_img.size, + self.get_box_info(pred_bbox[i], gray_img.size, self.final_size // 4)) box_info_8x[i] = np.array( - self.get_box_info(pred_bbox[i], pil_img.size, + self.get_box_info(pred_bbox[i], gray_img.size, self.final_size // 8)) cropped_img = self.transforms( - pil_img.crop((startx, starty, endx, endy))) - cropped_img_list.append(cropped_img) + gray_img.crop((startx, starty, endx, endy))) + cropped_gray_list.append(cropped_img) + if rgb_img: + cropped_rgb_list.append(self.transforms(rgb_img.crop((startx, starty, endx, endy)))) + cropped_gray_list.append(self.transforms(gray_img.crop((startx, starty, endx, endy)))) + + results['full_gray'] = torch.stack(full_gray_list) + if rgb_img: + results['full_rgb'] = torch.stack(full_rgb_list) - results['full_img'] = torch.stack(img_list) - # output['file_id'] = self.IMAGE_ID_LIST[index].split('.')[0] if len(pred_bbox) > 0: - results['cropped_img'] = torch.stack(cropped_img_list) + results['cropped_gray'] = torch.stack(cropped_gray_list) + if rgb_img: + results['cropped_rgb'] = torch.stack(cropped_rgb_list) results['box_info'] = torch.from_numpy(box_info).type(torch.long) results['box_info_2x'] = torch.from_numpy(box_info_2x).type( torch.long) @@ -169,22 +163,12 @@ def test_fusion(self, results): results['empty_box'] = False else: results['empty_box'] = True - print('full_img:', results['full_img'].size) - # print("cropped_img:", results['cropped_img'].size) - return results - def train(self, results): - img = results[self.key] + return results - if 'bbox_path' in results.keys(): - pred_bbox = self.gen_maskrcnn_bbox_fromPred( - img, results['bbox_path']) - elif 'instance' in results.keys(): - pred_bbox = results['instance'][0]['bbox'] - else: - pred_bbox = self.gen_maskrcnn_bbox_fromPred(img) + def train(self, results, pred_bbox): - rgb_img, gray_img = self.gen_gray_color_pil(img) + rgb_img, gray_img = results['rgb_img'], results['gray_img'] index_list = range(len(pred_bbox)) index_list = sample(index_list, 1) startx, starty, endx, endy = pred_bbox[index_list[0]] @@ -197,12 +181,21 @@ def train(self, results): return results def __call__(self, results): + img = results['img'] + + if 'bbox_path' in results.keys(): + pred_bbox = self.gen_maskrcnn_bbox_fromPred( + img, results['bbox_path']) + elif 'instance' in results.keys(): + pred_bbox = results['instance'] + else: + pred_bbox = self.gen_maskrcnn_bbox_fromPred(img) - if self.stage == 'test_fusion': - results = self.test_fusion(results) + if self.stage == 'fusion' or self.stage == 'test': + results = self.fusion(results, img, pred_bbox) - if self.stage == 'train': - results = self.train(results) + if self.stage == 'full' or self.stage == 'instance': + results = self.train(results, pred_bbox) return results diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 65aa178ccb..1e888df1da 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -47,15 +47,35 @@ def __init__(self, self.ngf = ngf self.output_nc = output_nc self.avg_loss_alpha = avg_loss_alpha - self.ab_norm = ab_norm - self.ab_max = ab_max - self.ab_quant = ab_quant - self.l_norm = l_norm - self.l_cent = l_cent - self.sample_Ps = sample_Ps self.mask_cent = mask_cent self.which_direction = which_direction + self.encode_ab_opt = dict( + ab_norm=ab_norm, + ab_max=ab_max, + ab_quant=ab_quant) + + self.colorization_data_opt = dict( + ab_thresh=0, + ab_norm=ab_norm, + l_norm=l_norm, + l_cent=l_cent, + sample_PS=sample_Ps, + mask_cent=mask_cent, + ) + + self.lab2rgb_opt = dict( + ab_norm=ab_norm, l_norm=l_norm, l_cent=l_cent) + + self.convert_params = dict( + ab_thresh=0, + ab_norm=ab_norm, + l_norm=l_norm, + l_cent=l_cent, + sample_PS=sample_Ps, + mask_cent=mask_cent, + ) + self.device = torch.device('cuda:{}'.format(0)) self.insta_stage = insta_stage @@ -66,11 +86,6 @@ def __init__(self, def set_input(self, input): - self.encode_ab_opt = dict( - ab_norm=self.ab_norm, - ab_max=self.ab_max, - ab_quant=self.ab_quant - ) AtoB = self.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) @@ -156,24 +171,15 @@ def train_step(self, data: List[dict], log_vars = {} - colorization_data_opt = dict( - ab_thresh=0, - ab_norm=self.ab_norm, - l_norm=self.l_norm, - l_cent=self.l_cent, - sample_PS=self.sample_Ps, - mask_cent=self.mask_cent, - ) - if self.insta_stage == 'full' or self.insta_stage == 'instance': data_samples['rgb_img'] = [data_samples['rgb_img']] data_samples['gray_img'] = [data_samples['gray_img']] input_data = get_colorization_data(data_samples['gray_img'], - **colorization_data_opt) + **self.colorization_data_opt) gt_data = get_colorization_data(data_samples['rgb_img'], - **colorization_data_opt) + **self.colorization_data_opt) input_data['B'] = gt_data['B'] input_data['hint_B'] = gt_data['hint_B'] @@ -204,13 +210,13 @@ def train_step(self, data: List[dict], box_info_8x = data_samples['box_info_8x'][0] cropped_input_data = get_colorization_data( - data_samples['cropped_gray'], **colorization_data_opt) + data_samples['cropped_gray'], **self.colorization_data_opt) cropped_gt_data = get_colorization_data(data_samples['cropped_rgb'], - **colorization_data_opt) + **self.colorization_data_opt) full_input_data = get_colorization_data(data_samples['full_gray'], - **colorization_data_opt) + **self.colorization_data_opt) full_gt_data = get_colorization_data(data_samples['full_rgb'], - **colorization_data_opt) + **self.colorization_data_opt) cropped_input_data['B'] = cropped_gt_data['B'] full_input_data['B'] = full_gt_data['B'] @@ -258,60 +264,112 @@ def setup_to_train(self): for loss_name in self.loss_names: self.avg_losses[loss_name] = 0 + def forward_tensor(self, inputs, data_samples, **kwargs): + + data = data_samples[0] + full_img = data.full_gray + + if not data.empty_box: + cropped_img = data.cropped_gray + box_info = data.box_info + box_info_2x = data.box_info_2x + box_info_4x = data.box_info_4x + box_info_8x = data.box_info_8x + cropped_data = get_colorization_data( + cropped_img, + **self.convert_params + ) + full_img_data = get_colorization_data( + full_img, + **self.convert_params + ) + self.set_input(cropped_data) + self.set_fusion_input( + full_img_data, + [box_info, box_info_2x, box_info_4x, box_info_8x]) + else: + full_img_data = get_colorization_data( + full_img, ab_thresh=0) + self.set_forward_without_box(full_img_data) + + self.fake_B_reg = self.generator( + self.real_A, self.hint_B, self.mask_B, self.full_real_A, + self.full_hint_B, self.full_mask_B, self.box_info_list) + + out_img = torch.clamp( + lab2rgb( + torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), + self.fake_B_reg.type(torch.cuda.FloatTensor)), + dim=1), **self.lab2rgb_opt), 0.0, 1.0) + + return out_img + + def forward_inference(self, inputs, data_samples=None, **kwargs): + feats = self.forward_tensor(inputs, data_samples, **kwargs) + predictions = [] + for idx in range(feats.shape[0]): + batch_tensor = feats[idx] * 127.5 + 127.5 + pred_img = PixelData(data=batch_tensor.to('cpu')) + predictions.append( + EditDataSample( + pred_img=pred_img, + metainfo=data_samples[idx].metainfo)) + + return predictions + def get_current_visuals(self): visual_ret = OrderedDict() - opt = dict( - ab_norm=self.ab_norm, l_norm=self.l_norm, l_cent=self.l_cent) + if self.insta_stage == 'full' or self.insta_stage == 'instance': visual_ret['gray'] = lab2rgb( torch.cat((self.real_A.type( torch.cuda.FloatTensor), torch.zeros_like( self.real_B).type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['real'] = lab2rgb( torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['fake_reg'] = lab2rgb( torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['hint'] = lab2rgb( torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.hint_B.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['real_ab'] = lab2rgb( torch.cat((torch.zeros_like( self.real_A.type(torch.cuda.FloatTensor)), self.real_B.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['fake_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) elif self.insta_stage == 'fusion': visual_ret['gray'] = lab2rgb( torch.cat((self.full_real_A.type( torch.cuda.FloatTensor), torch.zeros_like( self.full_real_B).type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['real'] = lab2rgb( torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.full_real_B.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['comp_reg'] = lab2rgb( torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.comp_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['fake_reg'] = lab2rgb( torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) self.instance_mask = torch.nn.functional.interpolate( torch.zeros([1, 1, 176, 176]), @@ -324,83 +382,18 @@ def get_current_visuals(self): torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.full_real_B.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['comp_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.comp_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) visual_ret['fake_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **opt) + dim=1), **self.lab2rgb_opt) else: print('Error! Wrong stage selection!') exit() return visual_ret - - def forward_tensor(self, inputs, data_samples, **kwargs): - - data = data_samples[0] - full_img = data.full_img - - convert_params = dict( - ab_thresh=0, - ab_norm=self.ab_norm, - l_norm=self.l_norm, - l_cent=self.l_cent, - sample_PS=self.sample_Ps, - mask_cent=self.mask_cent, - ) - - if not data.empty_box: - cropped_img = data.cropped_img - box_info = data.box_info - box_info_2x = data.box_info_2x - box_info_4x = data.box_info_4x - box_info_8x = data.box_info_8x - cropped_data = get_colorization_data( - cropped_img, - **convert_params - ) - full_img_data = get_colorization_data( - full_img, - **convert_params - ) - self.set_input(cropped_data) - self.set_fusion_input( - full_img_data, - [box_info, box_info_2x, box_info_4x, box_info_8x]) - else: - full_img_data = get_colorization_data( - full_img, ab_thresh=0) - self.set_forward_without_box(full_img_data) - - self.fake_B_reg = self.generator( - self.real_A, self.hint_B, self.mask_B, self.full_real_A, - self.full_hint_B, self.full_mask_B, self.box_info_list) - - out_img = torch.clamp( - lab2rgb( - torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), - self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), - ab_norm=self.ab_norm, - l_norm=self.l_norm, - l_cent=self.l_cent), 0.0, 1.0) - - return out_img - - def forward_inference(self, inputs, data_samples=None, **kwargs): - feats = self.forward_tensor(inputs, data_samples, **kwargs) - predictions = [] - for idx in range(feats.shape[0]): - batch_tensor = feats[idx] * 127.5 + 127.5 - pred_img = PixelData(data=batch_tensor.to('cpu')) - predictions.append( - EditDataSample( - pred_img=pred_img, - metainfo=data_samples[idx].metainfo)) - - return predictions From 1d026526ddae2907475f880d32c92be1d4e8f2b4 Mon Sep 17 00:00:00 2001 From: ruoning Date: Sun, 16 Oct 2022 23:42:54 +0800 Subject: [PATCH 08/32] [Enhancement]: add unit test of Instance-aware Image Colorization --- ...st-colorizatioon_cocostuff_full_256x256.py | 4 +- .../datasets/transforms/get_maskrcnn_bbox.py | 21 +++--- .../inst_colorization/inst_colorization.py | 40 ++++------- mmedit/models/losses/huber_loss.py | 2 +- .../test_apis/test_colorization_inference.py | 38 ++++++++++ tests/test_datasets/test_coco.py | 34 +++++++++ .../test_get_gray_color_pil.py | 15 ++++ .../test_transforms/test_get_maskrcnn_bbox.py | 72 +++++++++++++++++++ .../test_inst_colorization/test_util.py | 2 + .../test_losses/test_huber_loss.py | 5 ++ 10 files changed, 191 insertions(+), 42 deletions(-) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py index 6e6be6db6b..3d0dcc58c0 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py @@ -56,8 +56,8 @@ ] dataset_type = 'CocoDataset' -data_root = '/mnt/j/DataSet/cocostuff' -ann_file_path = '/mnt/j/DataSet/cocostuff' +data_root = '/mnt/meng/cocos' +ann_file_path = '/mnt/ruoning/bbox' train_dataloader = dict( batch_size=4, diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py index f49f595973..cfeb00cde6 100644 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -17,7 +17,7 @@ @TRANSFORMS.register_module() class GenMaskRCNNBbox: - def __init__(self, key='img', stage='test_fusion', finesize=256): + def __init__(self, key='img', stage='test', finesize=256): self.key = key self.predictor = self.detectron() self.stage = stage @@ -75,31 +75,30 @@ def read_to_pil(out_img): out_img = Image.fromarray(out_img) return out_img - @staticmethod - def get_box_info(pred_bbox, original_shape, final_size): + def get_box_info(self, pred_bbox, original_shape): assert len(pred_bbox) == 4 - resize_startx = int(pred_bbox[0] / original_shape[0] * final_size) - resize_starty = int(pred_bbox[1] / original_shape[1] * final_size) - resize_endx = int(pred_bbox[2] / original_shape[0] * final_size) - resize_endy = int(pred_bbox[3] / original_shape[1] * final_size) + resize_startx = int(pred_bbox[0] / original_shape[0] * self.final_size) + resize_starty = int(pred_bbox[1] / original_shape[1] * self.final_size) + resize_endx = int(pred_bbox[2] / original_shape[0] * self.final_size) + resize_endy = int(pred_bbox[3] / original_shape[1] * self.final_size) rh = resize_endx - resize_startx rw = resize_endy - resize_starty if rh < 1: - if final_size - resize_endx > 1: + if self.final_size - resize_endx > 1: resize_endx += 1 else: resize_startx -= 1 rh = 1 if rw < 1: - if final_size - resize_endy > 1: + if self.final_size - resize_endy > 1: resize_endy += 1 else: resize_starty -= 1 rw = 1 L_pad = resize_startx - R_pad = final_size - resize_endx + R_pad = self.final_size - resize_endx T_pad = resize_starty - B_pad = final_size - resize_endy + B_pad = self.final_size - resize_endy return [L_pad, R_pad, T_pad, B_pad, rh, rw] def fusion(self, results, img, pred_bbox): diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 1e888df1da..5abb5a5ec7 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -172,13 +172,13 @@ def train_step(self, data: List[dict], log_vars = {} if self.insta_stage == 'full' or self.insta_stage == 'instance': - data_samples['rgb_img'] = [data_samples['rgb_img']] - data_samples['gray_img'] = [data_samples['gray_img']] + rgb_img = [data_samples.rgb_img] + gray_img = [data_samples.gray_img] - input_data = get_colorization_data(data_samples['gray_img'], + input_data = get_colorization_data(gray_img, **self.colorization_data_opt) - gt_data = get_colorization_data(data_samples['rgb_img'], + gt_data = get_colorization_data(rgb_img, **self.colorization_data_opt) input_data['B'] = gt_data['B'] @@ -188,34 +188,18 @@ def train_step(self, data: List[dict], self.fake_B_reg = self.generator(self.real_A, self.hint_B, self.mask_B) elif self.insta_stage == 'fusion': - - data_samples['cropped_rgb'] = torch.stack( - data_samples['cropped_rgb_list']) - data_samples['cropped_gray'] = torch.stack( - data_samples['cropped_gray_list']) - data_samples['full_rgb'] = torch.stack(data_samples['full_rgb_list']) - data_samples['full_gray'] = torch.stack(data_samples['full_gray_list']) - data_samples['box_info'] = torch.from_numpy( - data_samples['box_info']).type(torch.long) - data_samples['box_info_2x'] = torch.from_numpy( - data_samples['box_info_2x']).type(torch.long) - data_samples['box_info_4x'] = torch.from_numpy( - data_samples['box_info_4x']).type(torch.long) - data_samples['box_info_8x'] = torch.from_numpy( - data_samples['box_info_8x']).type(torch.long) - - box_info = data_samples['box_info'][0] - box_info_2x = data_samples['box_info_2x'][0] - box_info_4x = data_samples['box_info_4x'][0] - box_info_8x = data_samples['box_info_8x'][0] + box_info = data_samples.box_info + box_info_2x = data_samples.box_info_2x + box_info_4x = data_samples.box_info_4x + box_info_8x = data_samples.box_info_8x cropped_input_data = get_colorization_data( - data_samples['cropped_gray'], **self.colorization_data_opt) - cropped_gt_data = get_colorization_data(data_samples['cropped_rgb'], + data_samples.cropped_gray, **self.colorization_data_opt) + cropped_gt_data = get_colorization_data(data_samples.cropped_rgb, **self.colorization_data_opt) - full_input_data = get_colorization_data(data_samples['full_gray'], + full_input_data = get_colorization_data(data_samples.full_gray, **self.colorization_data_opt) - full_gt_data = get_colorization_data(data_samples['full_rgb'], + full_gt_data = get_colorization_data(data_samples.full_rgb, **self.colorization_data_opt) cropped_input_data['B'] = cropped_gt_data['B'] diff --git a/mmedit/models/losses/huber_loss.py b/mmedit/models/losses/huber_loss.py index 7b45f41571..187a47bef9 100644 --- a/mmedit/models/losses/huber_loss.py +++ b/mmedit/models/losses/huber_loss.py @@ -12,7 +12,7 @@ def __init__(self, delta=.01): super(HuberLoss, self).__init__() self.delta = delta - def __call__(self, in0, in1): + def forward(self, in0, in1): mask = torch.zeros_like(in0) mann = torch.abs(in0 - in1) eucl = .5 * (mann**2) diff --git a/tests/test_apis/test_colorization_inference.py b/tests/test_apis/test_colorization_inference.py index ef101fec61..5c8545ef45 100644 --- a/tests/test_apis/test_colorization_inference.py +++ b/tests/test_apis/test_colorization_inference.py @@ -1 +1,39 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import torch +from mmengine import Config +from mmengine.runner import load_checkpoint + +from mmedit.apis import colorization_inference +from mmedit.registry import MODELS +from mmedit.utils import register_all_modules, tensor2img + + +def test_colorization_inference(): + register_all_modules() + + if torch.cuda.is_available(): + device = torch.device('cuda', 0) + else: + device = torch.device('cpu') + + data_root = '../../' + config = data_root + 'configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py' + + checkpoint = None + + cfg = Config.fromfile(config) + model = MODELS.build(cfg.model) + + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint) + + model.cfg = cfg + model.to(device) + model.eval() + + img_path = '../data/image/gray/test.jpg' + + result = colorization_inference(model, img_path) + assert tensor2img(result)[..., ::-1].shape == (256, 256, 3) \ No newline at end of file diff --git a/tests/test_datasets/test_coco.py b/tests/test_datasets/test_coco.py index ef101fec61..08a5dc6436 100644 --- a/tests/test_datasets/test_coco.py +++ b/tests/test_datasets/test_coco.py @@ -1 +1,35 @@ # Copyright (c) OpenMMLab. All rights reserved. + +import os.path as osp +from pathlib import Path + +from mmedit.registry import DATASETS +from mmedit.datasets import CocoDataset + + +# todo 完成coco的单元测试编写 +class TestCOCOStuff: + DATASET_TYPE = 'CocoDataset' + + ann_file = 'test.json' + data_root = "../.." + + DEFAULT_ARGS = dict( + data_root=data_root + '/train2017', + data_prefix=dict(gt='data_large'), + ann_file=ann_file, + pipeline=[], + test_mode=False + ) + + def test_load_data_list(self): + dataset_class = DATASETS.get(self.DATASET_TYPE) + dataset = dataset_class(**self.DEFAULT_ARGS) + + assert dataset.mateinfo == { + 'dataset_type': 'colorization_dataset', + 'task_name': 'colorization', + } + + # 对拿到的数据列表和数据进行判断 + \ No newline at end of file diff --git a/tests/test_datasets/test_transforms/test_get_gray_color_pil.py b/tests/test_datasets/test_transforms/test_get_gray_color_pil.py index ef101fec61..53cabdb3a7 100644 --- a/tests/test_datasets/test_transforms/test_get_gray_color_pil.py +++ b/tests/test_datasets/test_transforms/test_get_gray_color_pil.py @@ -1 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. +import cv2 as cv + +from mmedit.datasets.transforms import GenGrayColorPil + + +def test_get_gray_color_pil(): + img = cv.imread("../../data/image/gt/baboon.png") + test_class = GenGrayColorPil( + stage='test', keys=['rgb_img', 'gray_img'] + ) + + results = test_class.transform(dict(img=img)) + + assert 'rgb_img' in results.keys() and 'gray_img' in results.keys() + assert results['gray_img'].shape == img.shape \ No newline at end of file diff --git a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py index ef101fec61..a1ee22a2f1 100644 --- a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py +++ b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py @@ -1 +1,73 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os +import cv2 as cv +from mmedit.datasets.transforms import GenMaskRCNNBbox +from mmedit.utils import tensor2img + + +class TestMaskRCNNBbox: + + DEFAULT_ARGS = dict( + key='img', finesize=256 + ) + + def test_maskrcnn_bbox(self): + detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') + data_root = ".." + img_path = "data/image/gray/test.jpg" + img = cv.imread(os.path.join(data_root, img_path)) + + data = dict(img=img) + + results = detectetor(data) + pred_bbox = results.pred_bbox + + assert len(pred_bbox) <= 8 + assert results['full_gray'] and results['box_info'] \ + and results['cropped_gray'] + + detectetor.stage = 'fusion' + results = detectetor(data) + index = len(results.pred_bbox) + assert results['full_rgb'] and results['cropped_rgb'] + assert results['cropped_gray_list'].shape == (index, 3, 256, 256) + + detectetor.stage = 'full' + results = detectetor(data) + assert results['rgb_img'] and results['gray_img'] + assert tensor2img(results['rgb_img']).shape == (3, 256, 256) + + def test_gen_maskrcnn_from_pred(self): + detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') + data_root = ".." + img_path = "data/image/gray/test.jpg" + img = cv.imread(os.path.join(data_root, img_path)) + + box_num_upbound = 4 + pred_bbox = detectetor.gen_maskrcnn_bbox_fromPred(img) + + assert len(pred_bbox) <= box_num_upbound + assert pred_bbox.shape[-1] == 4 + + def test_get_box_info(self): + detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') + data_root = ".." + img_path = "data/image/gray/test.jpg" + img = cv.imread(os.path.join(data_root, img_path)) + + pred_bbox = detectetor.gen_maskrcnn_bbox_fromPred(img) + + resize_startx = int(pred_bbox[0] / img.shape[0] * 256) + resize_starty = int(pred_bbox[1] / img.shape[1] * 256) + resize_endx = int(pred_bbox[2] / img.shape[0] * 256) + resize_endy = int(pred_bbox[3] / img.shape[1] * 256) + + box_info = detectetor.get_box_info(pred_bbox, img.shape) + + assert box_info[0] == resize_starty and box_info[1] == 256 - resize_endx \ + and box_info[2] == resize_starty and box_info[3] == 256 - resize_endy \ + and box_info[4] == resize_endx - resize_startx \ + and box_info[5] == resize_endy - resize_starty + + + diff --git a/tests/test_models/test_editors/test_inst_colorization/test_util.py b/tests/test_models/test_editors/test_inst_colorization/test_util.py index ef101fec61..876cffc38c 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_util.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_util.py @@ -1 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. + +from mmedit.models.utils import color_utils diff --git a/tests/test_models/test_losses/test_huber_loss.py b/tests/test_models/test_losses/test_huber_loss.py index ef101fec61..125eb7801f 100644 --- a/tests/test_models/test_losses/test_huber_loss.py +++ b/tests/test_models/test_losses/test_huber_loss.py @@ -1 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from mmedit.models.losses import HuberLoss + + +def test_huber_loss(): + pass \ No newline at end of file From 7a475fefd1a3775c5e8e8a08fb9b5ec265f35645 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Wed, 19 Oct 2022 10:25:59 +0800 Subject: [PATCH 09/32] update configs --- configs/inst_colorization/README.md | 2 +- configs/inst_colorization/README_zh-CN.md | 8 ++++---- .../inst-colorizatioon_cocostuff_256x256.py | 17 ++++++++++++++--- demo/colorization_demo.py | 2 -- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/configs/inst_colorization/README.md b/configs/inst_colorization/README.md index ab147570e7..29ec9dbc25 100644 --- a/configs/inst_colorization/README.md +++ b/configs/inst_colorization/README.md @@ -49,7 +49,7 @@ You can use the following commands to test a model with cpu or single/multiple G # CPU test CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization//inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth -# single-gpu test +# single-gpu demo python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth # multi-gpu test diff --git a/configs/inst_colorization/README_zh-CN.md b/configs/inst_colorization/README_zh-CN.md index 752872abcb..661dfe2442 100644 --- a/configs/inst_colorization/README_zh-CN.md +++ b/configs/inst_colorization/README_zh-CN.md @@ -47,13 +47,13 @@ python tools/train.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py ```shell # CPU上测试 -CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth +CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth -# 单个GPU上测试 -python demo/colorization_demo.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth +# 单个GPU上 demo +python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py work_dirs/checkpoints/instance_aware_cocostuff.pth work_dirs/colorization_example.jpg work_dirs/output_example.png # 多个GPU上测试 -./tools/dist_test.sh configs/insta/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 +./tools/dist_test.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 ``` 更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Test a pre-trained model** 部分。 diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index c6d2f3b087..c6f7ae2c19 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -1,3 +1,5 @@ +from logging import PlaceHolder + _base_ = ['../_base_/default_runtime.py'] exp_name = 'Instance-aware_full' @@ -15,11 +17,20 @@ generator=dict( type='InstColorizationGenerator', stage=stage, + detector=PlaceHolder, + full_model=dict( + type='InstanceGenerator', + input_nc=4, + output_nc=2, + norm_type='batch'), instance_model=dict( - type='InstanceGenerator', input_nc=4, output_nc=2, norm_type='batch'), + type='InstanceGenerator', + input_nc=4, + output_nc=2, + norm_type='batch'), fusion_model=dict( - type='FusionGenerator', input_nc=4, output_nc=2, norm_type='batch') - ), + type='FusionGenerator', input_nc=4, output_nc=2, + norm_type='batch')), insta_stage=stage, ngf=64, output_nc=2, diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py index cbfa00fe66..4cc3b0c4e7 100644 --- a/demo/colorization_demo.py +++ b/demo/colorization_demo.py @@ -15,8 +15,6 @@ def parse_args(): parser.add_argument('checkpoints', help='checkpoints file path') parser.add_argument('img_path', help='path to input image file') parser.add_argument('save_path', help='path to save generation result') - parser.add_argument( - '--unpaired-path', default=None, help='path to unpaired image file') parser.add_argument( '--imshow', action='store_true', help='whether show image with opencv') parser.add_argument('--device', type=int, default=0, help='CUDA device id') From 524f6f520474dc7633190fab3b80f2161f96a56b Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Wed, 19 Oct 2022 22:11:14 +0800 Subject: [PATCH 10/32] refactor networks --- ...colorizatioon_cocostuff-fusion_256x256.py} | 0 ...-colorizatioon_cocostuff-image_256x256.py} | 5 +- ...lorizatioon_cocostuff-instance_256x256.py} | 0 .../inst-colorizatioon_cocostuff_256x256.py | 18 +- .../editors/inst_colorization/__init__.py | 6 +- .../inst_colorization/colorization_net.py | 471 ++++++ .../editors/inst_colorization/fusion_net.py | 504 ++++++ .../inst_colorization/inst_colorization.py | 95 +- .../inst_colorization_generator.py | 99 -- .../inst_colorization_net.py | 1497 ----------------- .../editors/inst_colorization/weight_block.py | 99 ++ 11 files changed, 1132 insertions(+), 1662 deletions(-) rename configs/inst_colorization/{inst-colorizatioon_cocostuff_fusion_256x256.py => inst-colorizatioon_cocostuff-fusion_256x256.py} (100%) rename configs/inst_colorization/{inst-colorizatioon_cocostuff_full_256x256.py => inst-colorizatioon_cocostuff-image_256x256.py} (96%) rename configs/inst_colorization/{inst-colorizatioon_cocostuff_instance_256x256.py => inst-colorizatioon_cocostuff-instance_256x256.py} (100%) create mode 100644 mmedit/models/editors/inst_colorization/colorization_net.py create mode 100644 mmedit/models/editors/inst_colorization/fusion_net.py delete mode 100644 mmedit/models/editors/inst_colorization/inst_colorization_generator.py delete mode 100644 mmedit/models/editors/inst_colorization/inst_colorization_net.py create mode 100644 mmedit/models/editors/inst_colorization/weight_block.py diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_fusion_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff-fusion_256x256.py similarity index 100% rename from configs/inst_colorization/inst-colorizatioon_cocostuff_fusion_256x256.py rename to configs/inst_colorization/inst-colorizatioon_cocostuff-fusion_256x256.py diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff-image_256x256.py similarity index 96% rename from configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py rename to configs/inst_colorization/inst-colorizatioon_cocostuff-image_256x256.py index 3d0dcc58c0..88f7b65206 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff-image_256x256.py @@ -16,7 +16,10 @@ type='InstColorizationGenerator', stage=stage, instance_model=dict( - type='SIGGRAPHGenerator', input_nc=4, output_nc=2, norm_type='batch'), + type='SIGGRAPHGenerator', + input_nc=4, + output_nc=2, + norm_type='batch'), ), insta_stage=stage, ngf=64, diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_instance_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff-instance_256x256.py similarity index 100% rename from configs/inst_colorization/inst-colorizatioon_cocostuff_instance_256x256.py rename to configs/inst_colorization/inst-colorizatioon_cocostuff-instance_256x256.py diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index c6f7ae2c19..1b69951625 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -2,11 +2,12 @@ _base_ = ['../_base_/default_runtime.py'] -exp_name = 'Instance-aware_full' +exp_name = 'inst-colorization_cocostuff_256x256' save_dir = './' work_dir = '..' stage = 'test' + model = dict( type='InstColorization', data_preprocessor=dict( @@ -15,22 +16,17 @@ std=[127.5], ), generator=dict( - type='InstColorizationGenerator', + type='InstColorization', stage=stage, detector=PlaceHolder, - full_model=dict( - type='InstanceGenerator', - input_nc=4, - output_nc=2, + image_model=dict( + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), instance_model=dict( - type='InstanceGenerator', - input_nc=4, - output_nc=2, + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), fusion_model=dict( - type='FusionGenerator', input_nc=4, output_nc=2, - norm_type='batch')), + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch')), insta_stage=stage, ngf=64, output_nc=2, diff --git a/mmedit/models/editors/inst_colorization/__init__.py b/mmedit/models/editors/inst_colorization/__init__.py index dcf30b8d9e..3a3f5a43d5 100644 --- a/mmedit/models/editors/inst_colorization/__init__.py +++ b/mmedit/models/editors/inst_colorization/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .colorization_net import (FusionGenerator, InstanceGenerator, + SIGGRAPHGenerator) from .inst_colorization import InstColorization -from .inst_colorization_net import (FusionGenerator, InstanceGenerator, - SIGGRAPHGenerator) -from .inst_colorization_generator import InstColorizationGenerator +from .inst_colorization_net import InstColorizationGenerator __all__ = [ 'InstColorization', 'SIGGRAPHGenerator', 'InstanceGenerator', diff --git a/mmedit/models/editors/inst_colorization/colorization_net.py b/mmedit/models/editors/inst_colorization/colorization_net.py new file mode 100644 index 0000000000..9d6164903e --- /dev/null +++ b/mmedit/models/editors/inst_colorization/colorization_net.py @@ -0,0 +1,471 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +import torch.nn as nn +from .weight_block import get_norm_layer + +from mmengine.model import BaseModule +from mmedit.registry import MODULES + + + +@MODULES.register_module() +class ColorizationNet(BaseModule): + + def __init__(self, + input_nc, + output_nc, + norm_type, + use_tanh=True, + classification=True): + super(ColorizationNet, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.classification = classification + + norm_layer = get_norm_layer(norm_type) + + use_bias = True + + # Conv1 + # model1=[nn.ReflectionPad2d(1),] + model1 = [ + nn.Conv2d( + input_nc, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias), + ] + # model1+=[norm_layer(64),] + model1 += [ + nn.ReLU(True), + ] + # model1+=[nn.ReflectionPad2d(1),] + model1 += [ + nn.Conv2d( + 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + norm_layer(64), + ] + # add a subsampling operation + + # Conv2 + # model2=[nn.ReflectionPad2d(1),] + model2 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model2+=[norm_layer(128),] + model2 += [ + nn.ReLU(True), + ] + # model2+=[nn.ReflectionPad2d(1),] + model2 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + norm_layer(128), + ] + # add a subsampling layer operation + + # Conv3 + # model3=[nn.ReflectionPad2d(1),] + model3 = [ + nn.Conv2d( + 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + norm_layer(256), + ] + # add a subsampling layer operation + + # Conv4 + # model47=[nn.ReflectionPad2d(1),] + model4 = [ + nn.Conv2d( + 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + norm_layer(512), + ] + + # Conv5 + # model47+=[nn.ReflectionPad2d(2),] + model5 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + norm_layer(512), + ] + + # Conv6 + # model6+=[nn.ReflectionPad2d(2),] + model6 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + norm_layer(512), + ] + + # Conv7 + # model47+=[nn.ReflectionPad2d(1),] + model7 = [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + norm_layer(512), + ] + + # Conv7 + model8up = [ + nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) + ] + + # model3short8=[nn.ReflectionPad2d(1),] + model3short8 = [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + + # model47+=[norm_layer(256),] + model8 = [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model8+=[norm_layer(256),] + model8 += [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + norm_layer(256), + ] + + # Conv9 + model9up = [ + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model2short9=[nn.ReflectionPad2d(1),] + model2short9 = [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + # model9=[norm_layer(128),] + model9 = [ + nn.ReLU(True), + ] + # model9+=[nn.ReflectionPad2d(1),] + model9 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model9 += [ + nn.ReLU(True), + ] + model9 += [ + norm_layer(128), + ] + + # Conv10 + model10up = [ + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model1short10=[nn.ReflectionPad2d(1),] + model1short10 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + # model10=[norm_layer(128),] + model10 = [ + nn.ReLU(True), + ] + # model10+=[nn.ReflectionPad2d(1),] + model10 += [ + nn.Conv2d( + 128, + 128, + kernel_size=3, + dilation=1, + stride=1, + padding=1, + bias=use_bias), + ] + model10 += [ + nn.LeakyReLU(negative_slope=.2), + ] + + # classification output + model_class = [ + nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + + # regression output + model_out = [ + nn.Conv2d( + 128, + 2, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + if (use_tanh): + model_out += [nn.Tanh()] + + self.model1 = nn.Sequential(*model1) + self.model2 = nn.Sequential(*model2) + self.model3 = nn.Sequential(*model3) + self.model4 = nn.Sequential(*model4) + self.model5 = nn.Sequential(*model5) + self.model6 = nn.Sequential(*model6) + self.model7 = nn.Sequential(*model7) + self.model8up = nn.Sequential(*model8up) + self.model8 = nn.Sequential(*model8) + self.model9up = nn.Sequential(*model9up) + self.model9 = nn.Sequential(*model9) + self.model10up = nn.Sequential(*model10up) + self.model10 = nn.Sequential(*model10) + self.model3short8 = nn.Sequential(*model3short8) + self.model2short9 = nn.Sequential(*model2short9) + self.model1short10 = nn.Sequential(*model1short10) + + self.model_class = nn.Sequential(*model_class) + self.model_out = nn.Sequential(*model_out) + + self.upsample4 = nn.Sequential(*[ + nn.Upsample(scale_factor=4, mode='nearest'), + ]) + self.softmax = nn.Sequential(*[ + nn.Softmax(dim=1), + ]) + + def forward(self, input_A, input_B, mask_B): + conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) + conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) + conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) + conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) + conv5_3 = self.model5(conv4_3) + conv6_3 = self.model6(conv5_3) + conv7_3 = self.model7(conv6_3) + conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) + conv8_3 = self.model8(conv8_up) + + if (self.classification): + out_class = self.model_class(conv8_3) + conv9_up = self.model9up(conv8_3.detach()) + self.model2short9( + conv2_2.detach()) + conv9_3 = self.model9(conv9_up) + conv10_up = self.model10up(conv9_3) + self.model1short10( + conv1_2.detach()) + conv10_2 = self.model10(conv10_up) + out_reg = self.model_out(conv10_2) + else: + out_class = self.model_class(conv8_3.detach()) + + conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) + conv9_3 = self.model9(conv9_up) + conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) + conv10_2 = self.model10(conv10_up) + out_reg = self.model_out(conv10_2) + + feature_map = {} + feature_map['conv1_2'] = conv1_2 + feature_map['conv2_2'] = conv2_2 + feature_map['conv3_3'] = conv3_3 + feature_map['conv4_3'] = conv4_3 + feature_map['conv5_3'] = conv5_3 + feature_map['conv6_3'] = conv6_3 + feature_map['conv7_3'] = conv7_3 + feature_map['conv8_up'] = conv8_up + feature_map['conv8_3'] = conv8_3 + feature_map['conv9_up'] = conv9_up + feature_map['conv9_3'] = conv9_3 + feature_map['conv10_up'] = conv10_up + feature_map['conv10_2'] = conv10_2 + feature_map['out_reg'] = out_reg + + return (out_reg, feature_map) + + # return (out_class, out_reg) diff --git a/mmedit/models/editors/inst_colorization/fusion_net.py b/mmedit/models/editors/inst_colorization/fusion_net.py new file mode 100644 index 0000000000..bed83d8f02 --- /dev/null +++ b/mmedit/models/editors/inst_colorization/fusion_net.py @@ -0,0 +1,504 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch +import torch.nn as nn + +from mmedit.registry import MODULES + + +@MODULES.register_module() +class FusionNet(nn.Module): + + def __init__(self, + input_nc, + output_nc, + norm_type, + use_tanh=True, + classification=True): + super(FusionNet, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.classification = classification + use_bias = True + + norm_layer = get_norm_layer(norm_type) + + # Conv1 + # model1=[nn.ReflectionPad2d(1),] + model1 = [ + nn.Conv2d( + input_nc, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias), + ] + # model1+=[norm_layer(64),] + model1 += [ + nn.ReLU(True), + ] + # model1+=[nn.ReflectionPad2d(1),] + model1 += [ + nn.Conv2d( + 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + norm_layer(64), + ] + # add a subsampling operation + + self.weight_layer = WeightGenerator(64) + + # Conv2 + # model2=[nn.ReflectionPad2d(1),] + model2 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model2+=[norm_layer(128),] + model2 += [ + nn.ReLU(True), + ] + # model2+=[nn.ReflectionPad2d(1),] + model2 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + norm_layer(128), + ] + # add a subsampling layer operation + + self.weight_layer2 = WeightGenerator(128) + + # Conv3 + # model3=[nn.ReflectionPad2d(1),] + model3 = [ + nn.Conv2d( + 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model3+=[norm_layer(256),] + model3 += [ + nn.ReLU(True), + ] + # model3+=[nn.ReflectionPad2d(1),] + model3 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + norm_layer(256), + ] + # add a subsampling layer operation + + self.weight_layer3 = WeightGenerator(256) + + # Conv4 + # model47=[nn.ReflectionPad2d(1),] + model4 = [ + nn.Conv2d( + 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model4+=[norm_layer(512),] + model4 += [ + nn.ReLU(True), + ] + # model4+=[nn.ReflectionPad2d(1),] + model4 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + norm_layer(512), + ] + + self.weight_layer4 = WeightGenerator(512) + + # Conv5 + # model47+=[nn.ReflectionPad2d(2),] + model5 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model5+=[norm_layer(512),] + model5 += [ + nn.ReLU(True), + ] + # model5+=[nn.ReflectionPad2d(2),] + model5 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + norm_layer(512), + ] + + self.weight_layer5 = WeightGenerator(512) + + # Conv6 + # model6+=[nn.ReflectionPad2d(2),] + model6 = [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + # model6+=[norm_layer(512),] + model6 += [ + nn.ReLU(True), + ] + # model6+=[nn.ReflectionPad2d(2),] + model6 += [ + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + norm_layer(512), + ] + + self.weight_layer6 = WeightGenerator(512) + + # Conv7 + # model47+=[nn.ReflectionPad2d(1),] + model7 = [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model7+=[norm_layer(512),] + model7 += [ + nn.ReLU(True), + ] + # model7+=[nn.ReflectionPad2d(1),] + model7 += [ + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + norm_layer(512), + ] + + self.weight_layer7 = WeightGenerator(512) + + # Conv7 + model8up = [ + nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) + ] + + # model3short8=[nn.ReflectionPad2d(1),] + model3short8 = [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + + self.weight_layer8_1 = WeightGenerator(256) + + # model47+=[norm_layer(256),] + model8 = [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # model8+=[norm_layer(256),] + model8 += [ + nn.ReLU(True), + ] + # model8+=[nn.ReflectionPad2d(1),] + model8 += [ + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + norm_layer(256), + ] + + self.weight_layer8_2 = WeightGenerator(256) + + # Conv9 + model9up = [ + nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model2short9=[nn.ReflectionPad2d(1),] + model2short9 = [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + self.weight_layer9_1 = WeightGenerator(128) + + # model9=[norm_layer(128),] + model9 = [ + nn.ReLU(True), + ] + # model9+=[nn.ReflectionPad2d(1),] + model9 += [ + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + model9 += [ + nn.ReLU(True), + ] + model9 += [ + norm_layer(128), + ] + + self.weight_layer9_2 = WeightGenerator(128) + + # Conv10 + model10up = [ + nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), + ] + + # model1short10=[nn.ReflectionPad2d(1),] + model1short10 = [ + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + ] + # add the two feature maps above + + self.weight_layer10_1 = WeightGenerator(128) + + # model10=[norm_layer(128),] + model10 = [ + nn.ReLU(True), + ] + # model10+=[nn.ReflectionPad2d(1),] + model10 += [ + nn.Conv2d( + 128, + 128, + kernel_size=3, + dilation=1, + stride=1, + padding=1, + bias=use_bias), + ] + model10 += [ + nn.LeakyReLU(negative_slope=.2), + ] + + self.weight_layer10_2 = WeightGenerator(128) + + # classification output + model_class = [ + nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + + # regression output + model_out = [ + nn.Conv2d( + 128, + 2, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + if (use_tanh): + model_out += [nn.Tanh()] + + self.weight_layerout = WeightGenerator(2) + + self.model1 = nn.Sequential(*model1) + self.model2 = nn.Sequential(*model2) + self.model3 = nn.Sequential(*model3) + self.model4 = nn.Sequential(*model4) + self.model5 = nn.Sequential(*model5) + self.model6 = nn.Sequential(*model6) + self.model7 = nn.Sequential(*model7) + self.model8up = nn.Sequential(*model8up) + self.model8 = nn.Sequential(*model8) + self.model9up = nn.Sequential(*model9up) + self.model9 = nn.Sequential(*model9) + self.model10up = nn.Sequential(*model10up) + self.model10 = nn.Sequential(*model10) + self.model3short8 = nn.Sequential(*model3short8) + self.model2short9 = nn.Sequential(*model2short9) + self.model1short10 = nn.Sequential(*model1short10) + + self.model_class = nn.Sequential(*model_class) + self.model_out = nn.Sequential(*model_out) + + self.upsample4 = nn.Sequential(*[ + nn.Upsample(scale_factor=4, mode='nearest'), + ]) + self.softmax = nn.Sequential(*[ + nn.Softmax(dim=1), + ]) + + def forward(self, input_A, input_B, mask_B, instance_feature, + box_info_list): + conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) + conv1_2 = self.weight_layer(instance_feature['conv1_2'], conv1_2, + box_info_list[0]) + + conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) + conv2_2 = self.weight_layer2(instance_feature['conv2_2'], conv2_2, + box_info_list[1]) + + conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) + conv3_3 = self.weight_layer3(instance_feature['conv3_3'], conv3_3, + box_info_list[2]) + + conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) + conv4_3 = self.weight_layer4(instance_feature['conv4_3'], conv4_3, + box_info_list[3]) + + conv5_3 = self.model5(conv4_3) + conv5_3 = self.weight_layer5(instance_feature['conv5_3'], conv5_3, + box_info_list[3]) + + conv6_3 = self.model6(conv5_3) + conv6_3 = self.weight_layer6(instance_feature['conv6_3'], conv6_3, + box_info_list[3]) + + conv7_3 = self.model7(conv6_3) + conv7_3 = self.weight_layer7(instance_feature['conv7_3'], conv7_3, + box_info_list[3]) + + conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) + conv8_up = self.weight_layer8_1(instance_feature['conv8_up'], conv8_up, + box_info_list[2]) + + conv8_3 = self.model8(conv8_up) + conv8_3 = self.weight_layer8_2(instance_feature['conv8_3'], conv8_3, + box_info_list[2]) + + conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) + conv9_up = self.weight_layer9_1(instance_feature['conv9_up'], conv9_up, + box_info_list[1]) + + conv9_3 = self.model9(conv9_up) + conv9_3 = self.weight_layer9_2(instance_feature['conv9_3'], conv9_3, + box_info_list[1]) + + conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) + conv10_up = self.weight_layer10_1(instance_feature['conv10_up'], + conv10_up, box_info_list[0]) + + conv10_2 = self.model10(conv10_up) + conv10_2 = self.weight_layer10_2(instance_feature['conv10_2'], + conv10_2, box_info_list[0]) + + out_reg = self.model_out(conv10_2) + return out_reg diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 5abb5a5ec7..7cc242292b 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -1,20 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import OrderedDict -from typing import Union, List, Dict +from typing import Dict, List, Union import torch from mmengine.config import Config from mmengine.optim import OptimWrapperDict +from mmedit.models import BaseEditModel from mmedit.models.utils import (encode_ab_ind, generation_init_weights, get_colorization_data, lab2rgb) -from mmedit.structures import EditDataSample, PixelData -from mmedit.registry import BACKBONES, COMPONENTS -from ..srgan import SRGAN +from mmedit.registry import MODULES +from mmedit.structures import EditDataSample, PixelData -@BACKBONES.register_module() -class InstColorization(SRGAN): +@MODULES.register_module() +class InstColorization(BaseEditModel): def __init__(self, data_preprocessor: Union[dict, Config], @@ -51,9 +51,7 @@ def __init__(self, self.which_direction = which_direction self.encode_ab_opt = dict( - ab_norm=ab_norm, - ab_max=ab_max, - ab_quant=ab_quant) + ab_norm=ab_norm, ab_max=ab_max, ab_quant=ab_quant) self.colorization_data_opt = dict( ab_thresh=0, @@ -64,8 +62,7 @@ def __init__(self, mask_cent=mask_cent, ) - self.lab2rgb_opt = dict( - ab_norm=ab_norm, l_norm=l_norm, l_cent=l_cent) + self.lab2rgb_opt = dict(ab_norm=ab_norm, l_norm=l_norm, l_cent=l_cent) self.convert_params = dict( ab_thresh=0, @@ -94,8 +91,8 @@ def set_input(self, input): self.mask_B = input['mask_B'].to(self.device) self.mask_B_nc = self.mask_B + self.mask_cent - self.real_B_enc = encode_ab_ind( - self.real_B[:, :, ::4, ::4], **self.encode_ab_opt) + self.real_B_enc = encode_ab_ind(self.real_B[:, :, ::4, ::4], + **self.encode_ab_opt) def set_fusion_input(self, input, box_info): @@ -107,8 +104,8 @@ def set_fusion_input(self, input, box_info): self.full_mask_B = input['mask_B'].to(self.device) self.full_mask_B_nc = self.full_mask_B + self.mask_cent - self.full_real_B_enc = encode_ab_ind( - self.full_real_B[:, :, ::4, ::4], **self.encode_ab_opt) + self.full_real_B_enc = encode_ab_ind(self.full_real_B[:, :, ::4, ::4], + **self.encode_ab_opt) self.box_info_list = box_info def set_forward_without_box(self, input): @@ -155,10 +152,10 @@ def generator_loss(self): # float(...) works for both scalar tensor and float number self.avg_losses[name] = float(getattr( self, 'loss_' + - name)) + self.avg_loss_alpha * self.avg_losses[name] + name)) + self.avg_loss_alpha * self.avg_losses[name] errors_ret[name] = (1 - self.avg_loss_alpha) / ( - 1 - self.avg_loss_alpha ** # noqa - self.error_cnt) * self.avg_losses[name] + 1 - self.avg_loss_alpha** # noqa + self.error_cnt) * self.avg_losses[name] return errors_ret @@ -185,7 +182,8 @@ def train_step(self, data: List[dict], input_data['hint_B'] = gt_data['hint_B'] input_data['mask_B'] = gt_data['mask_B'] self.set_input(input_data) - self.fake_B_reg = self.generator(self.real_A, self.hint_B, self.mask_B) + self.fake_B_reg = self.generator(self.real_A, self.hint_B, + self.mask_B) elif self.insta_stage == 'fusion': box_info = data_samples.box_info @@ -195,10 +193,10 @@ def train_step(self, data: List[dict], cropped_input_data = get_colorization_data( data_samples.cropped_gray, **self.colorization_data_opt) - cropped_gt_data = get_colorization_data(data_samples.cropped_rgb, - **self.colorization_data_opt) - full_input_data = get_colorization_data(data_samples.full_gray, - **self.colorization_data_opt) + cropped_gt_data = get_colorization_data( + data_samples.cropped_rgb, **self.colorization_data_opt) + full_input_data = get_colorization_data( + data_samples.full_gray, **self.colorization_data_opt) full_gt_data = get_colorization_data(data_samples.full_rgb, **self.colorization_data_opt) @@ -210,10 +208,11 @@ def train_step(self, data: List[dict], full_input_data, [box_info, box_info_2x, box_info_4x, box_info_8x]) - self.fake_B_reg = self.generator( - self.real_A, self.hint_B, self.mask_B, self.full_real_A, self.full_hint_B, - self.full_mask_B, self.box_info_list - ) + self.fake_B_reg = self.generator(self.real_A, self.hint_B, + self.mask_B, self.full_real_A, + self.full_hint_B, + self.full_mask_B, + self.box_info_list) optimizer['generator'].zero_grad() @@ -259,26 +258,21 @@ def forward_tensor(self, inputs, data_samples, **kwargs): box_info_2x = data.box_info_2x box_info_4x = data.box_info_4x box_info_8x = data.box_info_8x - cropped_data = get_colorization_data( - cropped_img, - **self.convert_params - ) - full_img_data = get_colorization_data( - full_img, - **self.convert_params - ) + cropped_data = get_colorization_data(cropped_img, + **self.convert_params) + full_img_data = get_colorization_data(full_img, + **self.convert_params) self.set_input(cropped_data) self.set_fusion_input( full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x]) else: - full_img_data = get_colorization_data( - full_img, ab_thresh=0) + full_img_data = get_colorization_data(full_img, ab_thresh=0) self.set_forward_without_box(full_img_data) - self.fake_B_reg = self.generator( - self.real_A, self.hint_B, self.mask_B, self.full_real_A, - self.full_hint_B, self.full_mask_B, self.box_info_list) + self.fake_B_reg = self.generator(self.real_A, self.hint_B, self.mask_B, + self.full_real_A, self.full_hint_B, + self.full_mask_B, self.box_info_list) out_img = torch.clamp( lab2rgb( @@ -296,8 +290,7 @@ def forward_inference(self, inputs, data_samples=None, **kwargs): pred_img = PixelData(data=batch_tensor.to('cpu')) predictions.append( EditDataSample( - pred_img=pred_img, - metainfo=data_samples[idx].metainfo)) + pred_img=pred_img, metainfo=data_samples[idx].metainfo)) return predictions @@ -310,8 +303,8 @@ def get_current_visuals(self): visual_ret['gray'] = lab2rgb( torch.cat((self.real_A.type( torch.cuda.FloatTensor), torch.zeros_like( - self.real_B).type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) + self.real_B).type(torch.cuda.FloatTensor)), + dim=1), **self.lab2rgb_opt) visual_ret['real'] = lab2rgb( torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor)), @@ -329,19 +322,19 @@ def get_current_visuals(self): torch.cat((torch.zeros_like( self.real_A.type(torch.cuda.FloatTensor)), self.real_B.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) + dim=1), **self.lab2rgb_opt) visual_ret['fake_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) + dim=1), **self.lab2rgb_opt) elif self.insta_stage == 'fusion': visual_ret['gray'] = lab2rgb( torch.cat((self.full_real_A.type( torch.cuda.FloatTensor), torch.zeros_like( - self.full_real_B).type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) + self.full_real_B).type(torch.cuda.FloatTensor)), + dim=1), **self.lab2rgb_opt) visual_ret['real'] = lab2rgb( torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.full_real_B.type(torch.cuda.FloatTensor)), @@ -366,17 +359,17 @@ def get_current_visuals(self): torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.full_real_B.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) + dim=1), **self.lab2rgb_opt) visual_ret['comp_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.comp_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) + dim=1), **self.lab2rgb_opt) visual_ret['fake_ab_reg'] = lab2rgb( torch.cat((torch.zeros_like( self.full_real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) + dim=1), **self.lab2rgb_opt) else: print('Error! Wrong stage selection!') exit() diff --git a/mmedit/models/editors/inst_colorization/inst_colorization_generator.py b/mmedit/models/editors/inst_colorization/inst_colorization_generator.py deleted file mode 100644 index 5b9c3850cd..0000000000 --- a/mmedit/models/editors/inst_colorization/inst_colorization_generator.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn - -from mmedit.registry import BACKBONES, COMPONENTS -from mmedit.models.utils import generation_init_weights - - -@BACKBONES.register_module() -class InstColorizationGenerator(nn.Module): - - def __init__(self, - stage, - instance_model=None, - full_model=None, - fusion_model=None, - ): - - super(InstColorizationGenerator, self).__init__() - - self.stage = stage - - if self.stage == "test": - self.netG = COMPONENTS.build(instance_model) - generation_init_weights(self.netG) - - self.netGF = COMPONENTS.build(fusion_model) - generation_init_weights(self.netGF) - - elif self.stage == "instance" or stage == 'full': - self.netG = COMPONENTS.build(instance_model) - generation_init_weights(self.netG) - - elif self.stage == "fusion": - self.netG = COMPONENTS.build(instance_model) - generation_init_weights(self.netG) - self.netG.eval() - - self.netGF = COMPONENTS.build(fusion_model) - generation_init_weights(self.netGF) - self.netGF.eval() - - self.netGComp = COMPONENTS.build(full_model) - generation_init_weights(self.netGComp) - self.netGComp.eval() - - self.generator = \ - list(self.netGF.module.weight_layer.parameters()) + \ - list(self.netGF.module.weight_layer2.parameters()) + \ - list(self.netGF.module.weight_layer3.parameters()) + \ - list(self.netGF.module.weight_layer4.parameters()) + \ - list(self.netGF.module.weight_layer5.parameters()) + \ - list(self.netGF.module.weight_layer6.parameters()) + \ - list(self.netGF.module.weight_layer7.parameters()) + \ - list(self.netGF.module.weight_layer8_1.parameters()) + \ - list(self.netGF.module.weight_layer8_2.parameters()) + \ - list(self.netGF.module.weight_layer9_1.parameters()) + \ - list(self.netGF.module.weight_layer9_2.parameters()) + \ - list(self.netGF.module.weight_layer10_1.parameters()) + \ - list(self.netGF.module.weight_layer10_2.parameters()) + \ - list(self.netGF.module.model10.parameters()) + \ - list(self.netGF.module.model_out.parameters()) - else: - print('Error! Wrong stage selection!') - exit() - - def forward(self, - real_A, - hint_B, - mask_B, - full_real_A=None, - full_hint_B=None, - full_mask_B=None, - box_info_list=None - ): - if self.stage == 'test': - (_, feature_map) = self.netG(real_A, hint_B, mask_B) - fake_B_reg = self.netGF( - full_real_A, full_hint_B, full_mask_B, - feature_map, box_info_list - ) - - return fake_B_reg - - elif self.stage == 'full' or self.stage == 'instance': - (_, fake_B_reg) = self.netG(real_A, hint_B, mask_B) - - return fake_B_reg - - elif self.stage == 'fusion': - (_, self.comp_B_reg) = self.netGComp( - full_real_A, full_hint_B, full_mask_B) - - (_, feature_map) = self.netG(real_A, hint_B, mask_B) - - fake_B_reg = self.netGF( - full_real_A, full_hint_B, full_mask_B, - feature_map, box_info_list) - - return fake_B_reg \ No newline at end of file diff --git a/mmedit/models/editors/inst_colorization/inst_colorization_net.py b/mmedit/models/editors/inst_colorization/inst_colorization_net.py deleted file mode 100644 index 2dc207165a..0000000000 --- a/mmedit/models/editors/inst_colorization/inst_colorization_net.py +++ /dev/null @@ -1,1497 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import functools - -import torch -import torch.nn as nn - -from mmedit.registry import BACKBONES - - -def get_norm_layer(norm_type='instance'): - if norm_type == 'batch': - norm_layer = functools.partial(nn.BatchNorm2d, affine=True) - elif norm_type == 'instance': - norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) - elif norm_type == 'none': - norm_layer = None - else: - raise NotImplementedError('normalization layer [%s] is not found' % - norm_type) - return norm_layer - - -@BACKBONES.register_module() -class SIGGRAPHGenerator(nn.Module): - - def __init__(self, - input_nc, - output_nc, - norm_type, - use_tanh=True, - classification=True): - super(SIGGRAPHGenerator, self).__init__() - self.input_nc = input_nc - self.output_nc = output_nc - self.classification = classification - - norm_layer = get_norm_layer(norm_type) - - use_bias = True - - # Conv1 - # model1=[nn.ReflectionPad2d(1),] - model1 = [ - nn.Conv2d( - input_nc, - 64, - kernel_size=3, - stride=1, - padding=1, - bias=use_bias), - ] - # model1+=[norm_layer(64),] - model1 += [ - nn.ReLU(True), - ] - # model1+=[nn.ReflectionPad2d(1),] - model1 += [ - nn.Conv2d( - 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model1 += [ - nn.ReLU(True), - ] - model1 += [ - norm_layer(64), - ] - # add a subsampling operation - - # Conv2 - # model2=[nn.ReflectionPad2d(1),] - model2 = [ - nn.Conv2d( - 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model2+=[norm_layer(128),] - model2 += [ - nn.ReLU(True), - ] - # model2+=[nn.ReflectionPad2d(1),] - model2 += [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model2 += [ - nn.ReLU(True), - ] - model2 += [ - norm_layer(128), - ] - # add a subsampling layer operation - - # Conv3 - # model3=[nn.ReflectionPad2d(1),] - model3 = [ - nn.Conv2d( - 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ - nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ - nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model3 += [ - nn.ReLU(True), - ] - model3 += [ - norm_layer(256), - ] - # add a subsampling layer operation - - # Conv4 - # model47=[nn.ReflectionPad2d(1),] - model4 = [ - nn.Conv2d( - 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ - nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ - nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model4 += [ - nn.ReLU(True), - ] - model4 += [ - norm_layer(512), - ] - - # Conv5 - # model47+=[nn.ReflectionPad2d(2),] - model5 = [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ - nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ - nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - model5 += [ - nn.ReLU(True), - ] - model5 += [ - norm_layer(512), - ] - - # Conv6 - # model6+=[nn.ReflectionPad2d(2),] - model6 = [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ - nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ - nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - model6 += [ - nn.ReLU(True), - ] - model6 += [ - norm_layer(512), - ] - - # Conv7 - # model47+=[nn.ReflectionPad2d(1),] - model7 = [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ - nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ - nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model7 += [ - nn.ReLU(True), - ] - model7 += [ - norm_layer(512), - ] - - # Conv7 - model8up = [ - nn.ConvTranspose2d( - 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) - ] - - # model3short8=[nn.ReflectionPad2d(1),] - model3short8 = [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - - # model47+=[norm_layer(256),] - model8 = [ - nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model8+=[norm_layer(256),] - model8 += [ - nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model8 += [ - nn.ReLU(True), - ] - model8 += [ - norm_layer(256), - ] - - # Conv9 - model9up = [ - nn.ConvTranspose2d( - 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] - - # model2short9=[nn.ReflectionPad2d(1),] - model2short9 = [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # add the two feature maps above - - # model9=[norm_layer(128),] - model9 = [ - nn.ReLU(True), - ] - # model9+=[nn.ReflectionPad2d(1),] - model9 += [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model9 += [ - nn.ReLU(True), - ] - model9 += [ - norm_layer(128), - ] - - # Conv10 - model10up = [ - nn.ConvTranspose2d( - 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] - - # model1short10=[nn.ReflectionPad2d(1),] - model1short10 = [ - nn.Conv2d( - 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # add the two feature maps above - - # model10=[norm_layer(128),] - model10 = [ - nn.ReLU(True), - ] - # model10+=[nn.ReflectionPad2d(1),] - model10 += [ - nn.Conv2d( - 128, - 128, - kernel_size=3, - dilation=1, - stride=1, - padding=1, - bias=use_bias), - ] - model10 += [ - nn.LeakyReLU(negative_slope=.2), - ] - - # classification output - model_class = [ - nn.Conv2d( - 256, - 529, - kernel_size=1, - padding=0, - dilation=1, - stride=1, - bias=use_bias), - ] - - # regression output - model_out = [ - nn.Conv2d( - 128, - 2, - kernel_size=1, - padding=0, - dilation=1, - stride=1, - bias=use_bias), - ] - if (use_tanh): - model_out += [nn.Tanh()] - - self.model1 = nn.Sequential(*model1) - self.model2 = nn.Sequential(*model2) - self.model3 = nn.Sequential(*model3) - self.model4 = nn.Sequential(*model4) - self.model5 = nn.Sequential(*model5) - self.model6 = nn.Sequential(*model6) - self.model7 = nn.Sequential(*model7) - self.model8up = nn.Sequential(*model8up) - self.model8 = nn.Sequential(*model8) - self.model9up = nn.Sequential(*model9up) - self.model9 = nn.Sequential(*model9) - self.model10up = nn.Sequential(*model10up) - self.model10 = nn.Sequential(*model10) - self.model3short8 = nn.Sequential(*model3short8) - self.model2short9 = nn.Sequential(*model2short9) - self.model1short10 = nn.Sequential(*model1short10) - - self.model_class = nn.Sequential(*model_class) - self.model_out = nn.Sequential(*model_out) - - self.upsample4 = nn.Sequential(*[ - nn.Upsample(scale_factor=4, mode='nearest'), - ]) - self.softmax = nn.Sequential(*[ - nn.Softmax(dim=1), - ]) - - def forward(self, input_A, input_B, mask_B): - conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) - conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) - conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) - conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) - conv5_3 = self.model5(conv4_3) - conv6_3 = self.model6(conv5_3) - conv7_3 = self.model7(conv6_3) - conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) - conv8_3 = self.model8(conv8_up) - - if (self.classification): - out_class = self.model_class(conv8_3) - conv9_up = self.model9up(conv8_3.detach()) + self.model2short9( - conv2_2.detach()) - conv9_3 = self.model9(conv9_up) - conv10_up = self.model10up(conv9_3) + self.model1short10( - conv1_2.detach()) - conv10_2 = self.model10(conv10_up) - out_reg = self.model_out(conv10_2) - else: - out_class = self.model_class(conv8_3.detach()) - - conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) - conv9_3 = self.model9(conv9_up) - conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) - conv10_2 = self.model10(conv10_up) - out_reg = self.model_out(conv10_2) - - return (out_class, out_reg) - - -@BACKBONES.register_module() -class FusionGenerator(nn.Module): - - def __init__(self, - input_nc, - output_nc, - norm_type, - use_tanh=True, - classification=True): - super(FusionGenerator, self).__init__() - self.input_nc = input_nc - self.output_nc = output_nc - self.classification = classification - use_bias = True - - norm_layer = get_norm_layer(norm_type) - - # Conv1 - # model1=[nn.ReflectionPad2d(1),] - model1 = [ - nn.Conv2d( - input_nc, - 64, - kernel_size=3, - stride=1, - padding=1, - bias=use_bias), - ] - # model1+=[norm_layer(64),] - model1 += [ - nn.ReLU(True), - ] - # model1+=[nn.ReflectionPad2d(1),] - model1 += [ - nn.Conv2d( - 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model1 += [ - nn.ReLU(True), - ] - model1 += [ - norm_layer(64), - ] - # add a subsampling operation - - self.weight_layer = WeightGenerator(64) - - # Conv2 - # model2=[nn.ReflectionPad2d(1),] - model2 = [ - nn.Conv2d( - 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model2+=[norm_layer(128),] - model2 += [ - nn.ReLU(True), - ] - # model2+=[nn.ReflectionPad2d(1),] - model2 += [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model2 += [ - nn.ReLU(True), - ] - model2 += [ - norm_layer(128), - ] - # add a subsampling layer operation - - self.weight_layer2 = WeightGenerator(128) - - # Conv3 - # model3=[nn.ReflectionPad2d(1),] - model3 = [ - nn.Conv2d( - 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ - nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ - nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model3 += [ - nn.ReLU(True), - ] - model3 += [ - norm_layer(256), - ] - # add a subsampling layer operation - - self.weight_layer3 = WeightGenerator(256) - - # Conv4 - # model47=[nn.ReflectionPad2d(1),] - model4 = [ - nn.Conv2d( - 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ - nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ - nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model4 += [ - nn.ReLU(True), - ] - model4 += [ - norm_layer(512), - ] - - self.weight_layer4 = WeightGenerator(512) - - # Conv5 - # model47+=[nn.ReflectionPad2d(2),] - model5 = [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ - nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ - nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - model5 += [ - nn.ReLU(True), - ] - model5 += [ - norm_layer(512), - ] - - self.weight_layer5 = WeightGenerator(512) - - # Conv6 - # model6+=[nn.ReflectionPad2d(2),] - model6 = [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ - nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ - nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - model6 += [ - nn.ReLU(True), - ] - model6 += [ - norm_layer(512), - ] - - self.weight_layer6 = WeightGenerator(512) - - # Conv7 - # model47+=[nn.ReflectionPad2d(1),] - model7 = [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ - nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ - nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model7 += [ - nn.ReLU(True), - ] - model7 += [ - norm_layer(512), - ] - - self.weight_layer7 = WeightGenerator(512) - - # Conv7 - model8up = [ - nn.ConvTranspose2d( - 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) - ] - - # model3short8=[nn.ReflectionPad2d(1),] - model3short8 = [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - - self.weight_layer8_1 = WeightGenerator(256) - - # model47+=[norm_layer(256),] - model8 = [ - nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model8+=[norm_layer(256),] - model8 += [ - nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model8 += [ - nn.ReLU(True), - ] - model8 += [ - norm_layer(256), - ] - - self.weight_layer8_2 = WeightGenerator(256) - - # Conv9 - model9up = [ - nn.ConvTranspose2d( - 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] - - # model2short9=[nn.ReflectionPad2d(1),] - model2short9 = [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # add the two feature maps above - - self.weight_layer9_1 = WeightGenerator(128) - - # model9=[norm_layer(128),] - model9 = [ - nn.ReLU(True), - ] - # model9+=[nn.ReflectionPad2d(1),] - model9 += [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model9 += [ - nn.ReLU(True), - ] - model9 += [ - norm_layer(128), - ] - - self.weight_layer9_2 = WeightGenerator(128) - - # Conv10 - model10up = [ - nn.ConvTranspose2d( - 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] - - # model1short10=[nn.ReflectionPad2d(1),] - model1short10 = [ - nn.Conv2d( - 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # add the two feature maps above - - self.weight_layer10_1 = WeightGenerator(128) - - # model10=[norm_layer(128),] - model10 = [ - nn.ReLU(True), - ] - # model10+=[nn.ReflectionPad2d(1),] - model10 += [ - nn.Conv2d( - 128, - 128, - kernel_size=3, - dilation=1, - stride=1, - padding=1, - bias=use_bias), - ] - model10 += [ - nn.LeakyReLU(negative_slope=.2), - ] - - self.weight_layer10_2 = WeightGenerator(128) - - # classification output - model_class = [ - nn.Conv2d( - 256, - 529, - kernel_size=1, - padding=0, - dilation=1, - stride=1, - bias=use_bias), - ] - - # regression output - model_out = [ - nn.Conv2d( - 128, - 2, - kernel_size=1, - padding=0, - dilation=1, - stride=1, - bias=use_bias), - ] - if (use_tanh): - model_out += [nn.Tanh()] - - self.weight_layerout = WeightGenerator(2) - - self.model1 = nn.Sequential(*model1) - self.model2 = nn.Sequential(*model2) - self.model3 = nn.Sequential(*model3) - self.model4 = nn.Sequential(*model4) - self.model5 = nn.Sequential(*model5) - self.model6 = nn.Sequential(*model6) - self.model7 = nn.Sequential(*model7) - self.model8up = nn.Sequential(*model8up) - self.model8 = nn.Sequential(*model8) - self.model9up = nn.Sequential(*model9up) - self.model9 = nn.Sequential(*model9) - self.model10up = nn.Sequential(*model10up) - self.model10 = nn.Sequential(*model10) - self.model3short8 = nn.Sequential(*model3short8) - self.model2short9 = nn.Sequential(*model2short9) - self.model1short10 = nn.Sequential(*model1short10) - - self.model_class = nn.Sequential(*model_class) - self.model_out = nn.Sequential(*model_out) - - self.upsample4 = nn.Sequential(*[ - nn.Upsample(scale_factor=4, mode='nearest'), - ]) - self.softmax = nn.Sequential(*[ - nn.Softmax(dim=1), - ]) - - def forward(self, input_A, input_B, mask_B, instance_feature, - box_info_list): - conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) - conv1_2 = self.weight_layer(instance_feature['conv1_2'], conv1_2, - box_info_list[0]) - - conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) - conv2_2 = self.weight_layer2(instance_feature['conv2_2'], conv2_2, - box_info_list[1]) - - conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) - conv3_3 = self.weight_layer3(instance_feature['conv3_3'], conv3_3, - box_info_list[2]) - - conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) - conv4_3 = self.weight_layer4(instance_feature['conv4_3'], conv4_3, - box_info_list[3]) - - conv5_3 = self.model5(conv4_3) - conv5_3 = self.weight_layer5(instance_feature['conv5_3'], conv5_3, - box_info_list[3]) - - conv6_3 = self.model6(conv5_3) - conv6_3 = self.weight_layer6(instance_feature['conv6_3'], conv6_3, - box_info_list[3]) - - conv7_3 = self.model7(conv6_3) - conv7_3 = self.weight_layer7(instance_feature['conv7_3'], conv7_3, - box_info_list[3]) - - conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) - conv8_up = self.weight_layer8_1(instance_feature['conv8_up'], conv8_up, - box_info_list[2]) - - conv8_3 = self.model8(conv8_up) - conv8_3 = self.weight_layer8_2(instance_feature['conv8_3'], conv8_3, - box_info_list[2]) - - conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) - conv9_up = self.weight_layer9_1(instance_feature['conv9_up'], conv9_up, - box_info_list[1]) - - conv9_3 = self.model9(conv9_up) - conv9_3 = self.weight_layer9_2(instance_feature['conv9_3'], conv9_3, - box_info_list[1]) - - conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) - conv10_up = self.weight_layer10_1(instance_feature['conv10_up'], - conv10_up, box_info_list[0]) - - conv10_2 = self.model10(conv10_up) - conv10_2 = self.weight_layer10_2(instance_feature['conv10_2'], - conv10_2, box_info_list[0]) - - out_reg = self.model_out(conv10_2) - return out_reg - - -class WeightGenerator(nn.Module): - - def __init__(self, input_ch, inner_ch=16): - super(WeightGenerator, self).__init__() - self.simple_instance_conv = nn.Sequential( - nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - ) - - self.simple_bg_conv = nn.Sequential( - nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - ) - - self.normalize = nn.Softmax(1) - - def resize_and_pad(self, feauture_maps, info_array): - feauture_maps = torch.nn.functional.interpolate( - feauture_maps, - size=(info_array[5], info_array[4]), - mode='bilinear') - feauture_maps = torch.nn.functional.pad(feauture_maps, - (info_array[0], info_array[1], - info_array[2], info_array[3]), - 'constant', 0) - return feauture_maps - - def forward(self, instance_feature, bg_feature, box_info): - mask_list = [] - featur_map_list = [] - mask_sum_for_pred = torch.zeros_like(bg_feature)[:1, :1] - for i in range(instance_feature.shape[0]): - tmp_crop = torch.unsqueeze(instance_feature[i], 0) - conv_tmp_crop = self.simple_instance_conv(tmp_crop) - pred_mask = self.resize_and_pad(conv_tmp_crop, box_info[i]) - - tmp_crop = self.resize_and_pad(tmp_crop, box_info[i]) - - mask = torch.zeros_like(bg_feature)[:1, :1] - mask[0, 0, box_info[i][2]:box_info[i][2] + box_info[i][5], - box_info[i][0]:box_info[i][0] + box_info[i][4]] = 1.0 - device = mask.device - mask = mask.type(torch.FloatTensor).to(device) - - mask_sum_for_pred = torch.clamp(mask_sum_for_pred + mask, 0.0, 1.0) - - mask_list.append(pred_mask) - featur_map_list.append(tmp_crop) - - pred_bg_mask = self.simple_bg_conv(bg_feature) - mask_list.append(pred_bg_mask + (1 - mask_sum_for_pred) * 100000.0) - mask_list = self.normalize(torch.cat(mask_list, 1)) - - mask_list_maskout = mask_list.clone() - - # instance_mask = torch.clamp( - # torch.sum( - # mask_list_maskout[:, :instance_feature.shape[0]], - # 1, - # keepdim=True), 0.0, 1.0) - - featur_map_list.append(bg_feature) - featur_map_list = torch.cat(featur_map_list, 0) - mask_list_maskout = mask_list_maskout.permute(1, 0, 2, 3).contiguous() - out = featur_map_list * mask_list_maskout - out = torch.sum(out, 0, keepdim=True) - return out # , instance_mask, torch.clamp(mask_list, 0.0, 1.0) - - -@BACKBONES.register_module() -class InstanceGenerator(nn.Module): - - def __init__(self, - input_nc, - output_nc, - norm_type, - use_tanh=True, - classification=True): - super(InstanceGenerator, self).__init__() - self.input_nc = input_nc - self.output_nc = output_nc - self.classification = classification - use_bias = True - - norm_layer = get_norm_layer(norm_type) - - # Conv1 - # model1=[nn.ReflectionPad2d(1),] - model1 = [ - nn.Conv2d( - input_nc, - 64, - kernel_size=3, - stride=1, - padding=1, - bias=use_bias), - ] - # model1+=[norm_layer(64),] - model1 += [ - nn.ReLU(True), - ] - # model1+=[nn.ReflectionPad2d(1),] - model1 += [ - nn.Conv2d( - 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model1 += [ - nn.ReLU(True), - ] - model1 += [ - norm_layer(64), - ] - # add a subsampling operation - - # Conv2 - # model2=[nn.ReflectionPad2d(1),] - model2 = [ - nn.Conv2d( - 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model2+=[norm_layer(128),] - model2 += [ - nn.ReLU(True), - ] - # model2+=[nn.ReflectionPad2d(1),] - model2 += [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model2 += [ - nn.ReLU(True), - ] - model2 += [ - norm_layer(128), - ] - # add a subsampling layer operation - - # Conv3 - # model3=[nn.ReflectionPad2d(1),] - model3 = [ - nn.Conv2d( - 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ - nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ - nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model3 += [ - nn.ReLU(True), - ] - model3 += [ - norm_layer(256), - ] - # add a subsampling layer operation - - # Conv4 - # model47=[nn.ReflectionPad2d(1),] - model4 = [ - nn.Conv2d( - 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ - nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ - nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model4 += [ - nn.ReLU(True), - ] - model4 += [ - norm_layer(512), - ] - - # Conv5 - # model47+=[nn.ReflectionPad2d(2),] - model5 = [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ - nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ - nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - model5 += [ - nn.ReLU(True), - ] - model5 += [ - norm_layer(512), - ] - - # Conv6 - # model6+=[nn.ReflectionPad2d(2),] - model6 = [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ - nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ - nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ - nn.Conv2d( - 512, - 512, - kernel_size=3, - dilation=2, - stride=1, - padding=2, - bias=use_bias), - ] - model6 += [ - nn.ReLU(True), - ] - model6 += [ - norm_layer(512), - ] - - # Conv7 - # model47+=[nn.ReflectionPad2d(1),] - model7 = [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ - nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ - nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ - nn.Conv2d( - 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model7 += [ - nn.ReLU(True), - ] - model7 += [ - norm_layer(512), - ] - - # Conv7 - model8up = [ - nn.ConvTranspose2d( - 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) - ] - - # model3short8=[nn.ReflectionPad2d(1),] - model3short8 = [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - - # model47+=[norm_layer(256),] - model8 = [ - nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model8+=[norm_layer(256),] - model8 += [ - nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model8 += [ - nn.ReLU(True), - ] - model8 += [ - norm_layer(256), - ] - - # Conv9 - model9up = [ - nn.ConvTranspose2d( - 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] - - # model2short9=[nn.ReflectionPad2d(1),] - model2short9 = [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # add the two feature maps above - - # model9=[norm_layer(128),] - model9 = [ - nn.ReLU(True), - ] - # model9+=[nn.ReflectionPad2d(1),] - model9 += [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model9 += [ - nn.ReLU(True), - ] - model9 += [ - norm_layer(128), - ] - - # Conv10 - model10up = [ - nn.ConvTranspose2d( - 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] - - # model1short10=[nn.ReflectionPad2d(1),] - model1short10 = [ - nn.Conv2d( - 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # add the two feature maps above - - # model10=[norm_layer(128),] - model10 = [ - nn.ReLU(True), - ] - # model10+=[nn.ReflectionPad2d(1),] - model10 += [ - nn.Conv2d( - 128, - 128, - kernel_size=3, - dilation=1, - stride=1, - padding=1, - bias=use_bias), - ] - model10 += [ - nn.LeakyReLU(negative_slope=.2), - ] - - # classification output - model_class = [ - nn.Conv2d( - 256, - 529, - kernel_size=1, - padding=0, - dilation=1, - stride=1, - bias=use_bias), - ] - - # regression output - model_out = [ - nn.Conv2d( - 128, - 2, - kernel_size=1, - padding=0, - dilation=1, - stride=1, - bias=use_bias), - ] - if (use_tanh): - model_out += [nn.Tanh()] - - self.model1 = nn.Sequential(*model1) - self.model2 = nn.Sequential(*model2) - self.model3 = nn.Sequential(*model3) - self.model4 = nn.Sequential(*model4) - self.model5 = nn.Sequential(*model5) - self.model6 = nn.Sequential(*model6) - self.model7 = nn.Sequential(*model7) - self.model8up = nn.Sequential(*model8up) - self.model8 = nn.Sequential(*model8) - self.model9up = nn.Sequential(*model9up) - self.model9 = nn.Sequential(*model9) - self.model10up = nn.Sequential(*model10up) - self.model10 = nn.Sequential(*model10) - self.model3short8 = nn.Sequential(*model3short8) - self.model2short9 = nn.Sequential(*model2short9) - self.model1short10 = nn.Sequential(*model1short10) - - self.model_class = nn.Sequential(*model_class) - self.model_out = nn.Sequential(*model_out) - - self.upsample4 = nn.Sequential(*[ - nn.Upsample(scale_factor=4, mode='nearest'), - ]) - self.softmax = nn.Sequential(*[ - nn.Softmax(dim=1), - ]) - - def forward(self, input_A, input_B, mask_B): - conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) - conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) - conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) - conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) - conv5_3 = self.model5(conv4_3) - conv6_3 = self.model6(conv5_3) - conv7_3 = self.model7(conv6_3) - conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) - conv8_3 = self.model8(conv8_up) - - if (self.classification): - # out_class = self.model_class(conv8_3) - conv9_up = self.model9up(conv8_3.detach()) + self.model2short9( - conv2_2.detach()) - conv9_3 = self.model9(conv9_up) - conv10_up = self.model10up(conv9_3) + self.model1short10( - conv1_2.detach()) - conv10_2 = self.model10(conv10_up) - out_reg = self.model_out(conv10_2) - else: - # out_class = self.model_class(conv8_3.detach()) - - conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) - conv9_3 = self.model9(conv9_up) - conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) - conv10_2 = self.model10(conv10_up) - out_reg = self.model_out(conv10_2) - - feature_map = {} - feature_map['conv1_2'] = conv1_2 - feature_map['conv2_2'] = conv2_2 - feature_map['conv3_3'] = conv3_3 - feature_map['conv4_3'] = conv4_3 - feature_map['conv5_3'] = conv5_3 - feature_map['conv6_3'] = conv6_3 - feature_map['conv7_3'] = conv7_3 - feature_map['conv8_up'] = conv8_up - feature_map['conv8_3'] = conv8_3 - feature_map['conv9_up'] = conv9_up - feature_map['conv9_3'] = conv9_3 - feature_map['conv10_up'] = conv10_up - feature_map['conv10_2'] = conv10_2 - feature_map['out_reg'] = out_reg - - return (out_reg, feature_map) diff --git a/mmedit/models/editors/inst_colorization/weight_block.py b/mmedit/models/editors/inst_colorization/weight_block.py new file mode 100644 index 0000000000..8e9d56eed4 --- /dev/null +++ b/mmedit/models/editors/inst_colorization/weight_block.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmedit.registry import MODULES + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % + norm_type) + return norm_layer + + +@MODULES.register_module() +class WeightBlock(BaseModule): + + def __init__(self, input_ch, inner_ch=16): + super(WeightBlock, self).__init__() + self.simple_instance_conv = nn.Sequential( + nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + ) + + self.simple_bg_conv = nn.Sequential( + nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + ) + + self.normalize = nn.Softmax(1) + + def resize_and_pad(self, feauture_maps, info_array): + feauture_maps = torch.nn.functional.interpolate( + feauture_maps, + size=(info_array[5], info_array[4]), + mode='bilinear') + feauture_maps = torch.nn.functional.pad(feauture_maps, + (info_array[0], info_array[1], + info_array[2], info_array[3]), + 'constant', 0) + return feauture_maps + + def forward(self, instance_feature, bg_feature, box_info): + mask_list = [] + featur_map_list = [] + mask_sum_for_pred = torch.zeros_like(bg_feature)[:1, :1] + for i in range(instance_feature.shape[0]): + tmp_crop = torch.unsqueeze(instance_feature[i], 0) + conv_tmp_crop = self.simple_instance_conv(tmp_crop) + pred_mask = self.resize_and_pad(conv_tmp_crop, box_info[i]) + + tmp_crop = self.resize_and_pad(tmp_crop, box_info[i]) + + mask = torch.zeros_like(bg_feature)[:1, :1] + mask[0, 0, box_info[i][2]:box_info[i][2] + box_info[i][5], + box_info[i][0]:box_info[i][0] + box_info[i][4]] = 1.0 + device = mask.device + mask = mask.type(torch.FloatTensor).to(device) + + mask_sum_for_pred = torch.clamp(mask_sum_for_pred + mask, 0.0, 1.0) + + mask_list.append(pred_mask) + featur_map_list.append(tmp_crop) + + pred_bg_mask = self.simple_bg_conv(bg_feature) + mask_list.append(pred_bg_mask + (1 - mask_sum_for_pred) * 100000.0) + mask_list = self.normalize(torch.cat(mask_list, 1)) + + mask_list_maskout = mask_list.clone() + + # instance_mask = torch.clamp( + # torch.sum( + # mask_list_maskout[:, :instance_feature.shape[0]], + # 1, + # keepdim=True), 0.0, 1.0) + + featur_map_list.append(bg_feature) + featur_map_list = torch.cat(featur_map_list, 0) + mask_list_maskout = mask_list_maskout.permute(1, 0, 2, 3).contiguous() + out = featur_map_list * mask_list_maskout + out = torch.sum(out, 0, keepdim=True) + return out # , instance_mask, torch.clamp(mask_list, 0.0, 1.0) From 60f19d63bcbc3b601f5921301a69c5447018529e Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 20 Oct 2022 16:36:43 +0800 Subject: [PATCH 11/32] fix siggraphgenerator, i.e., colorization_net --- .../inst-colorizatioon_cocostuff_256x256.py | 23 +- .../editors/inst_colorization/__init__.py | 10 +- .../inst_colorization/colorization_net.py | 274 ++++-------------- .../editors/inst_colorization/fusion_net.py | 31 +- .../inst_colorization/inst_colorization.py | 270 ++++++++++++----- .../{weight_block.py => weight_layer.py} | 4 +- 6 files changed, 290 insertions(+), 322 deletions(-) rename mmedit/models/editors/inst_colorization/{weight_block.py => weight_layer.py} (98%) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index 1b69951625..55b4700f09 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -6,7 +6,7 @@ save_dir = './' work_dir = '..' -stage = 'test' +stage = 'full' model = dict( type='InstColorization', @@ -15,19 +15,14 @@ mean=[127.5], std=[127.5], ), - generator=dict( - type='InstColorization', - stage=stage, - detector=PlaceHolder, - image_model=dict( - type='ColorizationNet', input_nc=4, output_nc=2, - norm_type='batch'), - instance_model=dict( - type='ColorizationNet', input_nc=4, output_nc=2, - norm_type='batch'), - fusion_model=dict( - type='FusionNet', input_nc=4, output_nc=2, norm_type='batch')), - insta_stage=stage, + detector_cfg='COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml', + image_model=dict( + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), + instance_model=dict( + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), + fusion_model=dict( + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), + stage=stage, ngf=64, output_nc=2, avg_loss_alpha=.986, diff --git a/mmedit/models/editors/inst_colorization/__init__.py b/mmedit/models/editors/inst_colorization/__init__.py index 3a3f5a43d5..434ebe14d0 100644 --- a/mmedit/models/editors/inst_colorization/__init__.py +++ b/mmedit/models/editors/inst_colorization/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .colorization_net import (FusionGenerator, InstanceGenerator, - SIGGRAPHGenerator) +from .colorization_net import ColorizationNet +from .fusion_net import FusionNet from .inst_colorization import InstColorization -from .inst_colorization_net import InstColorizationGenerator __all__ = [ - 'InstColorization', 'SIGGRAPHGenerator', 'InstanceGenerator', - 'FusionGenerator', 'InstColorizationGenerator' + 'InstColorization', + 'ColorizationNet', + 'FusionNet', ] diff --git a/mmedit/models/editors/inst_colorization/colorization_net.py b/mmedit/models/editors/inst_colorization/colorization_net.py index 9d6164903e..f0c6635edb 100644 --- a/mmedit/models/editors/inst_colorization/colorization_net.py +++ b/mmedit/models/editors/inst_colorization/colorization_net.py @@ -2,15 +2,28 @@ import torch import torch.nn as nn -from .weight_block import get_norm_layer - from mmengine.model import BaseModule -from mmedit.registry import MODULES +from mmedit.registry import MODULES +from .weight_layer import get_norm_layer @MODULES.register_module() class ColorizationNet(BaseModule): + """Real-Time User-Guided Image Colorization with Learned Deep Priors. + + https://arxiv.org/abs/1705.02999 + + Codes adapted from 'https://github.com/ericsujw/InstColorization.git' + 'InstColorization/blob/master/models/networks.py#L108' + + Args: + input_nc: + output_nc: + norm_type: + use_tanh: + classification: + """ def __init__(self, input_nc, @@ -28,8 +41,7 @@ def __init__(self, use_bias = True # Conv1 - # model1=[nn.ReflectionPad2d(1),] - model1 = [ + self.model1 = nn.Sequential([ nn.Conv2d( input_nc, 64, @@ -37,113 +49,54 @@ def __init__(self, stride=1, padding=1, bias=use_bias), - ] - # model1+=[norm_layer(64),] - model1 += [ nn.ReLU(True), - ] - # model1+=[nn.ReflectionPad2d(1),] - model1 += [ nn.Conv2d( 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model1 += [ nn.ReLU(True), - ] - model1 += [ norm_layer(64), - ] - # add a subsampling operation + ]) # Conv2 - # model2=[nn.ReflectionPad2d(1),] - model2 = [ + self.model2 = nn.Sequential([ nn.Conv2d( 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model2+=[norm_layer(128),] - model2 += [ nn.ReLU(True), - ] - # model2+=[nn.ReflectionPad2d(1),] - model2 += [ nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model2 += [ nn.ReLU(True), - ] - model2 += [ norm_layer(128), - ] - # add a subsampling layer operation + ]) # Conv3 - # model3=[nn.ReflectionPad2d(1),] - model3 = [ + self.model3 = nn.Sequential([ nn.Conv2d( 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model3 += [ nn.ReLU(True), - ] - model3 += [ norm_layer(256), - ] - # add a subsampling layer operation + ]) # Conv4 - # model47=[nn.ReflectionPad2d(1),] - model4 = [ + self.model4 = nn.Sequential([ nn.Conv2d( 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model4 += [ nn.ReLU(True), - ] - model4 += [ norm_layer(512), - ] + ]) # Conv5 - # model47+=[nn.ReflectionPad2d(2),] - model5 = [ + self.model5 = nn.Sequential([ nn.Conv2d( 512, 512, @@ -152,13 +105,7 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ nn.Conv2d( 512, 512, @@ -167,13 +114,7 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ nn.Conv2d( 512, 512, @@ -182,17 +123,12 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - model5 += [ nn.ReLU(True), - ] - model5 += [ norm_layer(512), - ] + ]) # Conv6 - # model6+=[nn.ReflectionPad2d(2),] - model6 = [ + self.model6 = nn.Sequential([ nn.Conv2d( 512, 512, @@ -201,13 +137,7 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ nn.Conv2d( 512, 512, @@ -216,13 +146,7 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ nn.Conv2d( 512, 512, @@ -231,130 +155,77 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - model6 += [ nn.ReLU(True), - ] - model6 += [ norm_layer(512), - ] + ]) # Conv7 - # model47+=[nn.ReflectionPad2d(1),] - model7 = [ + self.model7 = nn.Sequential([ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model7 += [ nn.ReLU(True), - ] - model7 += [ norm_layer(512), - ] + ]) - # Conv7 - model8up = [ + # Conv8 + self.model8up = [ nn.ConvTranspose2d( 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) ] - # model3short8=[nn.ReflectionPad2d(1),] - model3short8 = [ + self.model3short8 = [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] - # model47+=[norm_layer(256),] - model8 = [ + self.model8 = nn.Sequential([ nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model8+=[norm_layer(256),] - model8 += [ nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model8 += [ nn.ReLU(True), - ] - model8 += [ norm_layer(256), - ] + ]) # Conv9 - model9up = [ + self.model9up = [ nn.ConvTranspose2d( 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ] - # model2short9=[nn.ReflectionPad2d(1),] - model2short9 = [ + self.model2short9 = [ nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] - # add the two feature maps above - - # model9=[norm_layer(128),] - model9 = [ + self.model9 = nn.Sequential([ nn.ReLU(True), - ] - # model9+=[nn.ReflectionPad2d(1),] - model9 += [ nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model9 += [ nn.ReLU(True), - ] - model9 += [ norm_layer(128), - ] + ]) # Conv10 - model10up = [ + self.model10up = [ nn.ConvTranspose2d( 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ] - # model1short10=[nn.ReflectionPad2d(1),] - model1short10 = [ + self.model1short10 = [ nn.Conv2d( 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] - # add the two feature maps above - # model10=[norm_layer(128),] - model10 = [ + self.model10 = nn.Sequential([ nn.ReLU(True), - ] - # model10+=[nn.ReflectionPad2d(1),] - model10 += [ nn.Conv2d( 128, 128, @@ -363,22 +234,18 @@ def __init__(self, stride=1, padding=1, bias=use_bias), - ] - model10 += [ nn.LeakyReLU(negative_slope=.2), - ] + ]) # classification output - model_class = [ - nn.Conv2d( - 256, - 529, - kernel_size=1, - padding=0, - dilation=1, - stride=1, - bias=use_bias), - ] + self.model_class = nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias) # regression output model_out = [ @@ -393,33 +260,10 @@ def __init__(self, ] if (use_tanh): model_out += [nn.Tanh()] - - self.model1 = nn.Sequential(*model1) - self.model2 = nn.Sequential(*model2) - self.model3 = nn.Sequential(*model3) - self.model4 = nn.Sequential(*model4) - self.model5 = nn.Sequential(*model5) - self.model6 = nn.Sequential(*model6) - self.model7 = nn.Sequential(*model7) - self.model8up = nn.Sequential(*model8up) - self.model8 = nn.Sequential(*model8) - self.model9up = nn.Sequential(*model9up) - self.model9 = nn.Sequential(*model9) - self.model10up = nn.Sequential(*model10up) - self.model10 = nn.Sequential(*model10) - self.model3short8 = nn.Sequential(*model3short8) - self.model2short9 = nn.Sequential(*model2short9) - self.model1short10 = nn.Sequential(*model1short10) - - self.model_class = nn.Sequential(*model_class) self.model_out = nn.Sequential(*model_out) - self.upsample4 = nn.Sequential(*[ - nn.Upsample(scale_factor=4, mode='nearest'), - ]) - self.softmax = nn.Sequential(*[ - nn.Softmax(dim=1), - ]) + self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest') + self.softmax = nn.Softmax(dim=1) def forward(self, input_A, input_B, mask_B): conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) @@ -439,16 +283,14 @@ def forward(self, input_A, input_B, mask_B): conv9_3 = self.model9(conv9_up) conv10_up = self.model10up(conv9_3) + self.model1short10( conv1_2.detach()) - conv10_2 = self.model10(conv10_up) - out_reg = self.model_out(conv10_2) else: out_class = self.model_class(conv8_3.detach()) - conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) conv9_3 = self.model9(conv9_up) conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) - conv10_2 = self.model10(conv10_up) - out_reg = self.model_out(conv10_2) + + conv10_2 = self.model10(conv10_up) + out_reg = self.model_out(conv10_2) feature_map = {} feature_map['conv1_2'] = conv1_2 @@ -466,6 +308,4 @@ def forward(self, input_A, input_B, mask_B): feature_map['conv10_2'] = conv10_2 feature_map['out_reg'] = out_reg - return (out_reg, feature_map) - - # return (out_class, out_reg) + return (out_class, out_reg, feature_map) diff --git a/mmedit/models/editors/inst_colorization/fusion_net.py b/mmedit/models/editors/inst_colorization/fusion_net.py index bed83d8f02..9ac98bc4e2 100644 --- a/mmedit/models/editors/inst_colorization/fusion_net.py +++ b/mmedit/models/editors/inst_colorization/fusion_net.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -import functools - import torch import torch.nn as nn from mmedit.registry import MODULES +from .weight_block import WeightBlock, get_norm_layer @MODULES.register_module() @@ -52,7 +51,7 @@ def __init__(self, ] # add a subsampling operation - self.weight_layer = WeightGenerator(64) + self.weight_layer = WeightBlock(64) # Conv2 # model2=[nn.ReflectionPad2d(1),] @@ -77,7 +76,7 @@ def __init__(self, ] # add a subsampling layer operation - self.weight_layer2 = WeightGenerator(128) + self.weight_layer2 = WeightBlock(128) # Conv3 # model3=[nn.ReflectionPad2d(1),] @@ -111,7 +110,7 @@ def __init__(self, ] # add a subsampling layer operation - self.weight_layer3 = WeightGenerator(256) + self.weight_layer3 = WeightBlock(256) # Conv4 # model47=[nn.ReflectionPad2d(1),] @@ -144,7 +143,7 @@ def __init__(self, norm_layer(512), ] - self.weight_layer4 = WeightGenerator(512) + self.weight_layer4 = WeightBlock(512) # Conv5 # model47+=[nn.ReflectionPad2d(2),] @@ -195,7 +194,7 @@ def __init__(self, norm_layer(512), ] - self.weight_layer5 = WeightGenerator(512) + self.weight_layer5 = WeightBlock(512) # Conv6 # model6+=[nn.ReflectionPad2d(2),] @@ -246,7 +245,7 @@ def __init__(self, norm_layer(512), ] - self.weight_layer6 = WeightGenerator(512) + self.weight_layer6 = WeightBlock(512) # Conv7 # model47+=[nn.ReflectionPad2d(1),] @@ -279,7 +278,7 @@ def __init__(self, norm_layer(512), ] - self.weight_layer7 = WeightGenerator(512) + self.weight_layer7 = WeightBlock(512) # Conv7 model8up = [ @@ -293,7 +292,7 @@ def __init__(self, 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] - self.weight_layer8_1 = WeightGenerator(256) + self.weight_layer8_1 = WeightBlock(256) # model47+=[norm_layer(256),] model8 = [ @@ -320,7 +319,7 @@ def __init__(self, norm_layer(256), ] - self.weight_layer8_2 = WeightGenerator(256) + self.weight_layer8_2 = WeightBlock(256) # Conv9 model9up = [ @@ -335,7 +334,7 @@ def __init__(self, ] # add the two feature maps above - self.weight_layer9_1 = WeightGenerator(128) + self.weight_layer9_1 = WeightBlock(128) # model9=[norm_layer(128),] model9 = [ @@ -353,7 +352,7 @@ def __init__(self, norm_layer(128), ] - self.weight_layer9_2 = WeightGenerator(128) + self.weight_layer9_2 = WeightBlock(128) # Conv10 model10up = [ @@ -368,7 +367,7 @@ def __init__(self, ] # add the two feature maps above - self.weight_layer10_1 = WeightGenerator(128) + self.weight_layer10_1 = WeightBlock(128) # model10=[norm_layer(128),] model10 = [ @@ -389,7 +388,7 @@ def __init__(self, nn.LeakyReLU(negative_slope=.2), ] - self.weight_layer10_2 = WeightGenerator(128) + self.weight_layer10_2 = WeightBlock(128) # classification output model_class = [ @@ -417,7 +416,7 @@ def __init__(self, if (use_tanh): model_out += [nn.Tanh()] - self.weight_layerout = WeightGenerator(2) + self.weight_layerout = WeightBlock(2) self.model1 = nn.Sequential(*model1) self.model2 = nn.Sequential(*model2) diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 7cc242292b..deed640741 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -1,12 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import OrderedDict -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import torch +from detectron2 import model_zoo +from detectron2.config import get_cfg +from detectron2.engine import DefaultPredictor from mmengine.config import Config +from mmengine.model import BaseModel from mmengine.optim import OptimWrapperDict -from mmedit.models import BaseEditModel from mmedit.models.utils import (encode_ab_ind, generation_init_weights, get_colorization_data, lab2rgb) from mmedit.registry import MODULES @@ -14,10 +17,15 @@ @MODULES.register_module() -class InstColorization(BaseEditModel): +class InstColorization(BaseModel): def __init__(self, data_preprocessor: Union[dict, Config], + detector_cfg, + image_model, + instance_model, + fusion_model, + stage, ngf, output_nc, avg_loss_alpha, @@ -28,58 +36,197 @@ def __init__(self, l_cent, sample_Ps, mask_cent, - insta_stage=None, which_direction='AtoB', - generator=None, loss=None, init_cfg=None, train_cfg=None, test_cfg=None): - super(InstColorization, self).__init__( - generator=generator, - data_preprocessor=data_preprocessor, - pixel_loss=loss, - init_cfg=init_cfg, - train_cfg=train_cfg, - test_cfg=test_cfg) - - self.ngf = ngf - self.output_nc = output_nc - self.avg_loss_alpha = avg_loss_alpha - self.mask_cent = mask_cent - self.which_direction = which_direction - - self.encode_ab_opt = dict( - ab_norm=ab_norm, ab_max=ab_max, ab_quant=ab_quant) - - self.colorization_data_opt = dict( - ab_thresh=0, - ab_norm=ab_norm, - l_norm=l_norm, - l_cent=l_cent, - sample_PS=sample_Ps, - mask_cent=mask_cent, - ) - - self.lab2rgb_opt = dict(ab_norm=ab_norm, l_norm=l_norm, l_cent=l_cent) - - self.convert_params = dict( - ab_thresh=0, - ab_norm=ab_norm, - l_norm=l_norm, - l_cent=l_cent, - sample_PS=sample_Ps, - mask_cent=mask_cent, - ) - - self.device = torch.device('cuda:{}'.format(0)) - - self.insta_stage = insta_stage - - if self.insta_stage == 'full' or self.insta_stage == 'instance': - self.training = False - self.setup_to_train() + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + # colorization networks + self.image_model = MODULES.build(image_model) + self.instance_model = MODULES.build(instance_model) + self.fusion_model = MODULES.build(fusion_model) + + # detector + cfg = get_cfg() + cfg.merge_from_file(model_zoo.get_config_file(detector_cfg)) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(detector_cfg) + self.detector = DefaultPredictor(cfg) + + self.stage = stage + + # self.train_cfg = train_cfg + # self.test_cfg = test_cfg + # self.ngf = ngf + # self.output_nc = output_nc + # self.avg_loss_alpha = avg_loss_alpha + # self.mask_cent = mask_cent + # self.which_direction = which_direction + + # self.encode_ab_opt = dict( + # ab_norm=ab_norm, ab_max=ab_max, ab_quant=ab_quant) + # self.colorization_data_opt = dict( + # ab_thresh=0, + # ab_norm=ab_norm, + # l_norm=l_norm, + # l_cent=l_cent, + # sample_PS=sample_Ps, + # mask_cent=mask_cent, + # ) + # self.lab2rgb_opt = dict(ab_norm=ab_norm, l_norm=l_norm, l_cent=l_cent) + # self.convert_params = dict( + # ab_thresh=0, + # ab_norm=ab_norm, + # l_norm=l_norm, + # l_cent=l_cent, + # sample_PS=sample_Ps, + # mask_cent=mask_cent, + # ) + + # # loss + # self.loss_names = ['G', 'L1'] + # self.criterionL1 = self.loss + # self.avg_losses = OrderedDict() + # self.error_cnt = 0 + # for loss_name in self.loss_names: + # self.avg_losses[loss_name] = 0 + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[EditDataSample]] = None, + mode: str = 'tensor', + **kwargs): + """Returns losses or predictions of training, validation, testing, and + simple inference process. + + ``forward`` method of BaseModel is an abstract method, its subclasses + must implement this method. + + Accepts ``inputs`` and ``data_samples`` processed by + :attr:`data_preprocessor`, and returns results according to mode + arguments. + + During non-distributed training, validation, and testing process, + ``forward`` will be called by ``BaseModel.train_step``, + ``BaseModel.val_step`` and ``BaseModel.val_step`` directly. + + During distributed data parallel training process, + ``MMSeparateDistributedDataParallel.train_step`` will first call + ``DistributedDataParallel.forward`` to enable automatic + gradient synchronization, and then call ``forward`` to get training + loss. + + Args: + inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + mode (str): mode should be one of ``loss``, ``predict`` and + ``tensor``. Default: 'tensor'. + + - ``loss``: Called by ``train_step`` and return loss ``dict`` + used for logging + - ``predict``: Called by ``val_step`` and ``test_step`` + and return list of ``BaseDataElement`` results used for + computing metric. + - ``tensor``: Called by custom use to get ``Tensor`` type + results. + + Returns: + ForwardResults: + + - If ``mode == loss``, return a ``dict`` of loss tensor used + for backward and logging. + - If ``mode == predict``, return a ``list`` of + :obj:`BaseDataElement` for computing metric + and getting inference result. + - If ``mode == tensor``, return a tensor or ``tuple`` of tensor + or ``dict`` or tensor for custom use. + """ + + if mode == 'tensor': + return self.forward_tensor(inputs, data_samples, **kwargs) + + elif mode == 'predict': + predictions = self.forward_inference(inputs, data_samples, + **kwargs) + predictions = self.convert_to_datasample(data_samples, predictions) + return predictions + + elif mode == 'loss': + return self.forward_train(inputs, data_samples, **kwargs) + + def convert_to_datasample(self, inputs, data_samples): + for data_sample, output in zip(inputs, data_samples): + data_sample.output = output + return inputs + + def forward_tensor(self, inputs, data_samples=None, **kwargs): + """Forward tensor. Returns result of simple forward. + + Args: + inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + + Returns: + Tensor: result of simple forward. + """ + + feats = self.generator(inputs, **kwargs) + + return feats + + def forward_inference(self, inputs, data_samples=None, **kwargs): + """Forward inference. Returns predictions of validation, testing, and + simple inference. + + Args: + inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + + Returns: + List[EditDataSample]: predictions. + """ + + feats = self.forward_tensor(inputs, data_samples, **kwargs) + feats = self.data_preprocessor.destructor(feats) + predictions = [] + for idx in range(feats.shape[0]): + predictions.append( + EditDataSample( + pred_img=PixelData(data=feats[idx].to('cpu')), + metainfo=data_samples[idx].metainfo)) + + return predictions + + def forward_train(self, inputs, data_samples=None, **kwargs): + """Forward training. Returns dict of losses of training. + + Args: + inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + + Returns: + dict: Dict of losses. + """ + + feats = self.forward_tensor(inputs, data_samples, **kwargs) + gt_imgs = [data_sample.gt_img.data for data_sample in data_samples] + batch_gt_data = torch.stack(gt_imgs) + + loss = self.pixel_loss(feats, batch_gt_data) + + return dict(loss=loss) def set_input(self, input): @@ -127,14 +274,14 @@ def set_forward_without_box(self, input): def generator_loss(self): - if self.insta_stage == 'full' or self.insta_stage == 'instance': + if self.stage == 'full' or self.stage == 'instance': self.loss_L1 = torch.mean( self.criterionL1( self.fake_B_reg.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor))) self.loss_G = 10 * self.loss_L1 - elif self.insta_stage == 'fusion': + elif self.stage == 'fusion': self.loss_L1 = torch.mean( self.criterionL1( self.fake_B_reg.type(torch.cuda.FloatTensor), @@ -168,7 +315,7 @@ def train_step(self, data: List[dict], log_vars = {} - if self.insta_stage == 'full' or self.insta_stage == 'instance': + if self.stage == 'full' or self.stage == 'instance': rgb_img = [data_samples.rgb_img] gray_img = [data_samples.gray_img] @@ -185,7 +332,7 @@ def train_step(self, data: List[dict], self.fake_B_reg = self.generator(self.real_A, self.hint_B, self.mask_B) - elif self.insta_stage == 'fusion': + elif self.stage == 'fusion': box_info = data_samples.box_info box_info_2x = data_samples.box_info_2x box_info_4x = data_samples.box_info_4x @@ -234,19 +381,6 @@ def train_step(self, data: List[dict], return output - def setup_to_train(self): - - self.loss_names = ['G', 'L1'] - - self.criterionL1 = self.loss - - # initialize average loss values - self.avg_losses = OrderedDict() - # self.avg_loss_alpha = self.avg_loss_alpha - self.error_cnt = 0 - for loss_name in self.loss_names: - self.avg_losses[loss_name] = 0 - def forward_tensor(self, inputs, data_samples, **kwargs): data = data_samples[0] @@ -298,7 +432,7 @@ def get_current_visuals(self): visual_ret = OrderedDict() - if self.insta_stage == 'full' or self.insta_stage == 'instance': + if self.stage == 'full' or self.stage == 'instance': visual_ret['gray'] = lab2rgb( torch.cat((self.real_A.type( @@ -329,7 +463,7 @@ def get_current_visuals(self): self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), **self.lab2rgb_opt) - elif self.insta_stage == 'fusion': + elif self.stage == 'fusion': visual_ret['gray'] = lab2rgb( torch.cat((self.full_real_A.type( torch.cuda.FloatTensor), torch.zeros_like( diff --git a/mmedit/models/editors/inst_colorization/weight_block.py b/mmedit/models/editors/inst_colorization/weight_layer.py similarity index 98% rename from mmedit/models/editors/inst_colorization/weight_block.py rename to mmedit/models/editors/inst_colorization/weight_layer.py index 8e9d56eed4..72b8a44d3e 100644 --- a/mmedit/models/editors/inst_colorization/weight_block.py +++ b/mmedit/models/editors/inst_colorization/weight_layer.py @@ -22,10 +22,10 @@ def get_norm_layer(norm_type='instance'): @MODULES.register_module() -class WeightBlock(BaseModule): +class WeightLayer(BaseModule): def __init__(self, input_ch, inner_ch=16): - super(WeightBlock, self).__init__() + super(WeightLayer, self).__init__() self.simple_instance_conv = nn.Sequential( nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), nn.ReLU(True), From 1fc4b74275983ef1ef7763d9be5f7508cebb2ed0 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 20 Oct 2022 19:55:41 +0800 Subject: [PATCH 12/32] rfactor networks --- .../inst-colorizatioon_cocostuff_256x256.py | 8 +- demo/colorization_demo.py | 1 - .../editors/inst_colorization/fusion_net.py | 295 +++++------------- .../inst_colorization/inst_colorization.py | 96 +----- .../editors/inst_colorization/weight_layer.py | 8 +- 5 files changed, 88 insertions(+), 320 deletions(-) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index 55b4700f09..7f8f4b346e 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -1,5 +1,3 @@ -from logging import PlaceHolder - _base_ = ['../_base_/default_runtime.py'] exp_name = 'inst-colorization_cocostuff_256x256' @@ -16,12 +14,10 @@ std=[127.5], ), detector_cfg='COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml', - image_model=dict( - type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), + full_model=dict( + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), instance_model=dict( type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), - fusion_model=dict( - type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), stage=stage, ngf=64, output_nc=2, diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py index 4cc3b0c4e7..037093c520 100644 --- a/demo/colorization_demo.py +++ b/demo/colorization_demo.py @@ -30,7 +30,6 @@ def main(): else: device = torch.device('cuda', args.device) - # model = init_model(args.config, args.checkpoints, device=device) output = colorization_inference(model, args.img_path) result = tensor2img(output)[..., ::-1] diff --git a/mmedit/models/editors/inst_colorization/fusion_net.py b/mmedit/models/editors/inst_colorization/fusion_net.py index 9ac98bc4e2..b73c2acaea 100644 --- a/mmedit/models/editors/inst_colorization/fusion_net.py +++ b/mmedit/models/editors/inst_colorization/fusion_net.py @@ -3,11 +3,25 @@ import torch.nn as nn from mmedit.registry import MODULES -from .weight_block import WeightBlock, get_norm_layer +from .weight_layer import get_norm_layer, WeightLayer @MODULES.register_module() class FusionNet(nn.Module): + """Instance-aware Image Colorization. + + https://arxiv.org/abs/2005.10825 + + Codes adapted from 'https://github.com/ericsujw/InstColorization.git' + 'InstColorization/blob/master/models/networks.py#L314' + + Args: + input_nc: + output_nc: + norm_type: + use_tanh: + classification: + """ def __init__(self, input_nc, @@ -19,13 +33,12 @@ def __init__(self, self.input_nc = input_nc self.output_nc = output_nc self.classification = classification - use_bias = True norm_layer = get_norm_layer(norm_type) - + use_bias = True + # Conv1 - # model1=[nn.ReflectionPad2d(1),] - model1 = [ + self.model1 = nn.Sequential([ nn.Conv2d( input_nc, 64, @@ -33,121 +46,63 @@ def __init__(self, stride=1, padding=1, bias=use_bias), - ] - # model1+=[norm_layer(64),] - model1 += [ nn.ReLU(True), - ] - # model1+=[nn.ReflectionPad2d(1),] - model1 += [ nn.Conv2d( 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model1 += [ nn.ReLU(True), - ] - model1 += [ norm_layer(64), - ] - # add a subsampling operation + ]) - self.weight_layer = WeightBlock(64) + self.weight_layer = WeightLayer(64) # Conv2 - # model2=[nn.ReflectionPad2d(1),] - model2 = [ + self.model2 = nn.Sequential([ nn.Conv2d( 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model2+=[norm_layer(128),] - model2 += [ nn.ReLU(True), - ] - # model2+=[nn.ReflectionPad2d(1),] - model2 += [ nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model2 += [ nn.ReLU(True), - ] - model2 += [ norm_layer(128), - ] - # add a subsampling layer operation + ]) - self.weight_layer2 = WeightBlock(128) + self.weight_layer2 = WeightLayer(128) + # Conv3 - # model3=[nn.ReflectionPad2d(1),] - model3 = [ + self.model3 = nn.Sequential([ nn.Conv2d( 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model3+=[norm_layer(256),] - model3 += [ nn.ReLU(True), - ] - # model3+=[nn.ReflectionPad2d(1),] - model3 += [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model3 += [ nn.ReLU(True), - ] - model3 += [ norm_layer(256), - ] - # add a subsampling layer operation - - self.weight_layer3 = WeightBlock(256) + ]) + self.weight_layer3 = WeightLayer(256) + # Conv4 - # model47=[nn.ReflectionPad2d(1),] - model4 = [ + self.model4 = nn.Sequential([ nn.Conv2d( 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model4+=[norm_layer(512),] - model4 += [ nn.ReLU(True), - ] - # model4+=[nn.ReflectionPad2d(1),] - model4 += [ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model4 += [ nn.ReLU(True), - ] - model4 += [ norm_layer(512), - ] + ]) - self.weight_layer4 = WeightBlock(512) + self.weight_layer4 = WeightLayer(512) # Conv5 - # model47+=[nn.ReflectionPad2d(2),] - model5 = [ + self.model5 = nn.Sequential([ nn.Conv2d( 512, 512, @@ -156,13 +111,7 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ nn.Conv2d( 512, 512, @@ -171,13 +120,7 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - # model5+=[norm_layer(512),] - model5 += [ nn.ReLU(True), - ] - # model5+=[nn.ReflectionPad2d(2),] - model5 += [ nn.Conv2d( 512, 512, @@ -186,19 +129,14 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - model5 += [ nn.ReLU(True), - ] - model5 += [ norm_layer(512), - ] + ]) - self.weight_layer5 = WeightBlock(512) + self.weight_layer5 = WeightLayer(512) # Conv6 - # model6+=[nn.ReflectionPad2d(2),] - model6 = [ + self.model6 = nn.Sequential([ nn.Conv2d( 512, 512, @@ -207,13 +145,7 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ nn.Conv2d( 512, 512, @@ -222,13 +154,7 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - # model6+=[norm_layer(512),] - model6 += [ nn.ReLU(True), - ] - # model6+=[nn.ReflectionPad2d(2),] - model6 += [ nn.Conv2d( 512, 512, @@ -237,144 +163,92 @@ def __init__(self, stride=1, padding=2, bias=use_bias), - ] - model6 += [ nn.ReLU(True), - ] - model6 += [ norm_layer(512), - ] + ]) - self.weight_layer6 = WeightBlock(512) + self.weight_layer6 = WeightLayer(512) # Conv7 - # model47+=[nn.ReflectionPad2d(1),] - model7 = [ + self.model7 = nn.Sequential([ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model7+=[norm_layer(512),] - model7 += [ nn.ReLU(True), - ] - # model7+=[nn.ReflectionPad2d(1),] - model7 += [ nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model7 += [ nn.ReLU(True), - ] - model7 += [ norm_layer(512), - ] + ]) - self.weight_layer7 = WeightBlock(512) + self.weight_layer7 = WeightLayer(512) - # Conv7 - model8up = [ + # Conv8 + self.model8up = [ nn.ConvTranspose2d( 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) ] - # model3short8=[nn.ReflectionPad2d(1),] - model3short8 = [ + self.model3short8 = [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ] - self.weight_layer8_1 = WeightBlock(256) + self.weight_layer8_1 = WeightLayer(256) - # model47+=[norm_layer(256),] - model8 = [ + self.model8 = nn.Sequential([ nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - # model8+=[norm_layer(256),] - model8 += [ nn.ReLU(True), - ] - # model8+=[nn.ReflectionPad2d(1),] - model8 += [ nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model8 += [ nn.ReLU(True), - ] - model8 += [ norm_layer(256), - ] + ]) - self.weight_layer8_2 = WeightBlock(256) + self.weight_layer8_2 = WeightLayer(256) # Conv9 - model9up = [ + self.model9up = [ nn.ConvTranspose2d( 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ] - # model2short9=[nn.ReflectionPad2d(1),] - model2short9 = [ + self.model2short9 = [ nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] - # add the two feature maps above - self.weight_layer9_1 = WeightBlock(128) + self.weight_layer9_1 = WeightLayer(128) - # model9=[norm_layer(128),] - model9 = [ + self.model9 = nn.Sequential([ nn.ReLU(True), - ] - # model9+=[nn.ReflectionPad2d(1),] - model9 += [ nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - model9 += [ nn.ReLU(True), - ] - model9 += [ norm_layer(128), - ] + ]) - self.weight_layer9_2 = WeightBlock(128) + self.weight_layer9_2 = WeightLayer(128) # Conv10 - model10up = [ + self.model10up = [ nn.ConvTranspose2d( 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ] - # model1short10=[nn.ReflectionPad2d(1),] - model1short10 = [ + self.model1short10 = [ nn.Conv2d( 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ] - # add the two feature maps above - self.weight_layer10_1 = WeightBlock(128) + self.weight_layer10_1 = WeightLayer(128) - # model10=[norm_layer(128),] - model10 = [ + self.model10 = nn.Sequential([ nn.ReLU(True), - ] - # model10+=[nn.ReflectionPad2d(1),] - model10 += [ nn.Conv2d( 128, 128, @@ -383,24 +257,20 @@ def __init__(self, stride=1, padding=1, bias=use_bias), - ] - model10 += [ nn.LeakyReLU(negative_slope=.2), - ] + ]) - self.weight_layer10_2 = WeightBlock(128) + self.weight_layer10_2 = WeightLayer(128) # classification output - model_class = [ - nn.Conv2d( - 256, - 529, - kernel_size=1, - padding=0, - dilation=1, - stride=1, - bias=use_bias), - ] + self.model_class = nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias) # regression output model_out = [ @@ -415,35 +285,12 @@ def __init__(self, ] if (use_tanh): model_out += [nn.Tanh()] - - self.weight_layerout = WeightBlock(2) - - self.model1 = nn.Sequential(*model1) - self.model2 = nn.Sequential(*model2) - self.model3 = nn.Sequential(*model3) - self.model4 = nn.Sequential(*model4) - self.model5 = nn.Sequential(*model5) - self.model6 = nn.Sequential(*model6) - self.model7 = nn.Sequential(*model7) - self.model8up = nn.Sequential(*model8up) - self.model8 = nn.Sequential(*model8) - self.model9up = nn.Sequential(*model9up) - self.model9 = nn.Sequential(*model9) - self.model10up = nn.Sequential(*model10up) - self.model10 = nn.Sequential(*model10) - self.model3short8 = nn.Sequential(*model3short8) - self.model2short9 = nn.Sequential(*model2short9) - self.model1short10 = nn.Sequential(*model1short10) - - self.model_class = nn.Sequential(*model_class) self.model_out = nn.Sequential(*model_out) - self.upsample4 = nn.Sequential(*[ - nn.Upsample(scale_factor=4, mode='nearest'), - ]) - self.softmax = nn.Sequential(*[ - nn.Softmax(dim=1), - ]) + self.weight_layerout = WeightLayer(2) + + self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest') + self.softmax = nn.Softmax(dim=1) def forward(self, input_A, input_B, mask_B, instance_feature, box_info_list): diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index deed640741..77955d92b7 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -10,8 +10,7 @@ from mmengine.model import BaseModel from mmengine.optim import OptimWrapperDict -from mmedit.models.utils import (encode_ab_ind, generation_init_weights, - get_colorization_data, lab2rgb) +from mmedit.models.utils import (encode_ab_ind, get_colorization_data, lab2rgb) from mmedit.registry import MODULES from mmedit.structures import EditDataSample, PixelData @@ -22,9 +21,8 @@ class InstColorization(BaseModel): def __init__(self, data_preprocessor: Union[dict, Config], detector_cfg, - image_model, + full_model, instance_model, - fusion_model, stage, ngf, output_nc, @@ -46,9 +44,11 @@ def __init__(self, init_cfg=init_cfg, data_preprocessor=data_preprocessor) # colorization networks - self.image_model = MODULES.build(image_model) + # Stage 1 & 3. fusion model intergrates the image model + self.full_model = MODULES.build(full_model) + + # Stage 2. instance model used for training instance colorization self.instance_model = MODULES.build(instance_model) - self.fusion_model = MODULES.build(fusion_model) # detector cfg = get_cfg() @@ -228,83 +228,7 @@ def forward_train(self, inputs, data_samples=None, **kwargs): return dict(loss=loss) - def set_input(self, input): - - AtoB = self.which_direction == 'AtoB' - self.real_A = input['A' if AtoB else 'B'].to(self.device) - self.real_B = input['B' if AtoB else 'A'].to(self.device) - self.hint_B = input['hint_B'].to(self.device) - - self.mask_B = input['mask_B'].to(self.device) - self.mask_B_nc = self.mask_B + self.mask_cent - - self.real_B_enc = encode_ab_ind(self.real_B[:, :, ::4, ::4], - **self.encode_ab_opt) - - def set_fusion_input(self, input, box_info): - - AtoB = self.which_direction == 'AtoB' - self.full_real_A = input['A' if AtoB else 'B'].to(self.device) - self.full_real_B = input['B' if AtoB else 'A'].to(self.device) - - self.full_hint_B = input['hint_B'].to(self.device) - self.full_mask_B = input['mask_B'].to(self.device) - - self.full_mask_B_nc = self.full_mask_B + self.mask_cent - self.full_real_B_enc = encode_ab_ind(self.full_real_B[:, :, ::4, ::4], - **self.encode_ab_opt) - self.box_info_list = box_info - - def set_forward_without_box(self, input): - - AtoB = self.which_direction == 'AtoB' - self.full_real_A = input['A' if AtoB else 'B'].to(self.device) - self.full_real_B = input['B' if AtoB else 'A'].to(self.device) - # self.image_paths = input['A_paths' if AtoB else 'B_paths'] - self.full_hint_B = input['hint_B'].to(self.device) - self.full_mask_B = input['mask_B'].to(self.device) - self.full_mask_B_nc = self.full_mask_B + self.mask_cent - self.full_real_B_enc = encode_ab_ind(self.full_real_B[:, :, ::4, ::4], - **self.encode_ab_opt) - - (_, self.comp_B_reg) = self.netGComp(self.full_real_A, - self.full_hint_B, - self.full_mask_B) - self.fake_B_reg = self.comp_B_reg - - def generator_loss(self): - - if self.stage == 'full' or self.stage == 'instance': - self.loss_L1 = torch.mean( - self.criterionL1( - self.fake_B_reg.type(torch.cuda.FloatTensor), - self.real_B.type(torch.cuda.FloatTensor))) - self.loss_G = 10 * self.loss_L1 - - elif self.stage == 'fusion': - self.loss_L1 = torch.mean( - self.criterionL1( - self.fake_B_reg.type(torch.cuda.FloatTensor), - self.full_real_B.type(torch.cuda.FloatTensor))) - self.loss_G = 10 * self.loss_L1 - - else: - print('Error! Wrong stage selection!') - exit() - - self.error_cnt += 1 - errors_ret = OrderedDict() - for name in self.loss_names: - if isinstance(name, str): - # float(...) works for both scalar tensor and float number - self.avg_losses[name] = float(getattr( - self, 'loss_' + - name)) + self.avg_loss_alpha * self.avg_losses[name] - errors_ret[name] = (1 - self.avg_loss_alpha) / ( - 1 - self.avg_loss_alpha** # noqa - self.error_cnt) * self.avg_losses[name] - return errors_ret def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: @@ -428,6 +352,14 @@ def forward_inference(self, inputs, data_samples=None, **kwargs): return predictions + + + + + + + + def get_current_visuals(self): visual_ret = OrderedDict() diff --git a/mmedit/models/editors/inst_colorization/weight_layer.py b/mmedit/models/editors/inst_colorization/weight_layer.py index 72b8a44d3e..ba4d050c4f 100644 --- a/mmedit/models/editors/inst_colorization/weight_layer.py +++ b/mmedit/models/editors/inst_colorization/weight_layer.py @@ -85,15 +85,9 @@ def forward(self, instance_feature, bg_feature, box_info): mask_list_maskout = mask_list.clone() - # instance_mask = torch.clamp( - # torch.sum( - # mask_list_maskout[:, :instance_feature.shape[0]], - # 1, - # keepdim=True), 0.0, 1.0) - featur_map_list.append(bg_feature) featur_map_list = torch.cat(featur_map_list, 0) mask_list_maskout = mask_list_maskout.permute(1, 0, 2, 3).contiguous() out = featur_map_list * mask_list_maskout out = torch.sum(out, 0, keepdim=True) - return out # , instance_mask, torch.clamp(mask_list, 0.0, 1.0) + return out From 9825d12bf340f2586c68896a75e25ee41db20f30 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Tue, 25 Oct 2022 11:56:07 +0800 Subject: [PATCH 13/32] fix loading weights --- .../inst-colorizatioon_cocostuff_256x256.py | 8 +- mmedit/datasets/__init__.py | 15 +--- mmedit/datasets/coco.py | 9 +- mmedit/datasets/transforms/formatting.py | 6 +- .../datasets/transforms/get_maskrcnn_bbox.py | 24 ++--- .../inst_colorization/colorization_net.py | 79 ++++++++--------- .../editors/inst_colorization/fusion_net.py | 87 ++++++++----------- .../inst_colorization/inst_colorization.py | 79 ++--------------- .../test_apis/test_colorization_inference.py | 6 +- tests/test_datasets/test_coco.py | 11 +-- .../test_get_gray_color_pil.py | 8 +- .../test_transforms/test_get_maskrcnn_bbox.py | 31 ++++--- .../test_inst_colorization/test_util.py | 2 - .../test_losses/test_huber_loss.py | 3 +- 14 files changed, 136 insertions(+), 232 deletions(-) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index 7f8f4b346e..0af72014e1 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -13,7 +13,6 @@ mean=[127.5], std=[127.5], ), - detector_cfg='COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml', full_model=dict( type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), instance_model=dict( @@ -32,9 +31,14 @@ which_direction='AtoB', loss=dict(type='HuberLoss', delta=.01)) +# yapf: disable test_pipeline = [ dict(type='LoadImageFromFile', key='img'), - dict(type='GenMaskRCNNBbox', stage=stage, finesize=256), + dict( + type='GenMaskRCNNBbox', + config_file='COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml', # noqa + stage=stage, + finesize=256), dict(type='Resize', keys=['img'], scale=(256, 256), keep_ratio=False), dict(type='PackEditInputs'), ] diff --git a/mmedit/datasets/__init__.py b/mmedit/datasets/__init__.py index 850f7787b7..b8b4ec3289 100644 --- a/mmedit/datasets/__init__.py +++ b/mmedit/datasets/__init__.py @@ -3,22 +3,15 @@ from .basic_frames_dataset import BasicFramesDataset from .basic_image_dataset import BasicImageDataset from .cifar10_dataset import CIFAR10 +from .coco import CocoDataset from .comp1k_dataset import AdobeComp1kDataset from .grow_scale_image_dataset import GrowScaleImgDataset from .imagenet_dataset import ImageNet from .paired_image_dataset import PairedImageDataset from .unpaired_image_dataset import UnpairedImageDataset -from .coco import CocoDataset __all__ = [ - 'AdobeComp1kDataset', - 'BasicImageDataset', - 'BasicFramesDataset', - 'BasicConditionalDataset', - 'UnpairedImageDataset', - 'PairedImageDataset', - 'ImageNet', - 'CIFAR10', - 'GrowScaleImgDataset', - 'CocoDataset' + 'AdobeComp1kDataset', 'BasicImageDataset', 'BasicFramesDataset', + 'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset', + 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'CocoDataset' ] diff --git a/mmedit/datasets/coco.py b/mmedit/datasets/coco.py index 7296a0d506..0f6f66d7c8 100644 --- a/mmedit/datasets/coco.py +++ b/mmedit/datasets/coco.py @@ -1,8 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import List, Callable, Optional, Union -from pathlib import Path -from typing import List, Union, Dict +from typing import List from mmengine.dataset import BaseDataset from mmengine.fileio import load @@ -38,7 +36,7 @@ def load_data_list(self) -> List[dict]: data_list.append(data_info) else: raise TypeError('data_info should be a dict or list of dict, ' - f'but got {type(data_info)}') + f'but got {type(data_info)}') return data_list @@ -48,6 +46,7 @@ def parse_data_info(self, raw_data_info: dict) -> dict: data_info = raw_data_info.copy() for key in raw_data_info: if 'path' in key: - data_info['gt_img_path'] = osp.join(self.data_root, data_info[key]) + data_info['gt_img_path'] = osp.join(self.data_root, + data_info[key]) return data_info diff --git a/mmedit/datasets/transforms/formatting.py b/mmedit/datasets/transforms/formatting.py index 727ef2fc2f..3a73adc643 100644 --- a/mmedit/datasets/transforms/formatting.py +++ b/mmedit/datasets/transforms/formatting.py @@ -200,12 +200,12 @@ def transform(self, results: dict) -> dict: gt_bg = results.pop('bg') gt_bg_tensor = images_to_tensor(gt_bg) data_sample.gt_bg = PixelData(data=gt_bg_tensor) - + if 'rgb_img' in results: gt_rgb = results.pop('rgb_img') gt_rgb_tensor = images_to_tensor(gt_rgb) data_sample.gt_rgb = PixelData(data=gt_rgb_tensor) - + if 'gray_img' in results: gray = results.pop('gray_img') gray_tensor = images_to_tensor(gray) @@ -243,7 +243,7 @@ def __init__(self, keys, to_float32=True): self.keys = keys self.to_float32 = to_float32 - + def _data_to_tensor(self, value): """Convert the value to tensor.""" is_image = check_if_image(value) diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py index cfeb00cde6..4110e4c540 100644 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -9,7 +9,6 @@ from detectron2.config import get_cfg from detectron2.engine import DefaultPredictor from PIL import Image -from skimage import color from mmedit.registry import TRANSFORMS @@ -17,7 +16,8 @@ @TRANSFORMS.register_module() class GenMaskRCNNBbox: - def __init__(self, key='img', stage='test', finesize=256): + def __init__(self, config_file, key='img', stage='test', finesize=256): + self.config_file = config_file self.key = key self.predictor = self.detectron() self.stage = stage @@ -108,7 +108,8 @@ def fusion(self, results, img, pred_bbox): if self.stage == 'fusion': rgb_img, gray_img = results['rgb_img'], results['gray_img'] - return self.get_instance_info(results, pred_bbox, gray_img, rgb_img) + return self.get_instance_info(results, pred_bbox, gray_img, + rgb_img) def get_instance_info(self, results, pred_bbox, gray_img, rgb_img=None): @@ -127,7 +128,8 @@ def get_instance_info(self, results, pred_bbox, gray_img, rgb_img=None): for i in index_list: startx, starty, endx, endy = pred_bbox[i] box_info[i] = np.array( - self.get_box_info(pred_bbox[i], gray_img.size, self.final_size)) + self.get_box_info(pred_bbox[i], gray_img.size, + self.final_size)) box_info_2x[i] = np.array( self.get_box_info(pred_bbox[i], gray_img.size, self.final_size // 2)) @@ -141,8 +143,12 @@ def get_instance_info(self, results, pred_bbox, gray_img, rgb_img=None): gray_img.crop((startx, starty, endx, endy))) cropped_gray_list.append(cropped_img) if rgb_img: - cropped_rgb_list.append(self.transforms(rgb_img.crop((startx, starty, endx, endy)))) - cropped_gray_list.append(self.transforms(gray_img.crop((startx, starty, endx, endy)))) + cropped_rgb_list.append( + self.transforms( + rgb_img.crop((startx, starty, endx, endy)))) + cropped_gray_list.append( + self.transforms( + gray_img.crop((startx, starty, endx, endy)))) results['full_gray'] = torch.stack(full_gray_list) if rgb_img: @@ -200,10 +206,8 @@ def __call__(self, results): def detectron(self): cfg = get_cfg() - cfg.merge_from_file( - model_zoo.get_config_file( - 'COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml')) + cfg.merge_from_file(model_zoo.get_config_file(self.config_file)) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 - cfg.MODEL.WEIGHTS = '/mnt/ruoning/model_final_2d9806.pkl' + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.config_file) predictor = DefaultPredictor(cfg) return predictor diff --git a/mmedit/models/editors/inst_colorization/colorization_net.py b/mmedit/models/editors/inst_colorization/colorization_net.py index f0c6635edb..f09db034c9 100644 --- a/mmedit/models/editors/inst_colorization/colorization_net.py +++ b/mmedit/models/editors/inst_colorization/colorization_net.py @@ -10,7 +10,8 @@ @MODULES.register_module() class ColorizationNet(BaseModule): - """Real-Time User-Guided Image Colorization with Learned Deep Priors. + """Real-Time User-Guided Image Colorization with Learned Deep Priors. The + backbone used for. https://arxiv.org/abs/1705.02999 @@ -41,7 +42,7 @@ def __init__(self, use_bias = True # Conv1 - self.model1 = nn.Sequential([ + self.model1 = nn.Sequential( nn.Conv2d( input_nc, 64, @@ -54,10 +55,10 @@ def __init__(self, 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(64), - ]) + ) # Conv2 - self.model2 = nn.Sequential([ + self.model2 = nn.Sequential( nn.Conv2d( 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), @@ -65,10 +66,10 @@ def __init__(self, 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(128), - ]) + ) # Conv3 - self.model3 = nn.Sequential([ + self.model3 = nn.Sequential( nn.Conv2d( 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), @@ -79,10 +80,10 @@ def __init__(self, 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(256), - ]) + ) # Conv4 - self.model4 = nn.Sequential([ + self.model4 = nn.Sequential( nn.Conv2d( 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), @@ -93,10 +94,10 @@ def __init__(self, 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(512), - ]) + ) # Conv5 - self.model5 = nn.Sequential([ + self.model5 = nn.Sequential( nn.Conv2d( 512, 512, @@ -125,10 +126,10 @@ def __init__(self, bias=use_bias), nn.ReLU(True), norm_layer(512), - ]) + ) # Conv6 - self.model6 = nn.Sequential([ + self.model6 = nn.Sequential( nn.Conv2d( 512, 512, @@ -157,10 +158,10 @@ def __init__(self, bias=use_bias), nn.ReLU(True), norm_layer(512), - ]) + ) # Conv7 - self.model7 = nn.Sequential([ + self.model7 = nn.Sequential( nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), @@ -171,20 +172,16 @@ def __init__(self, 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(512), - ]) + ) # Conv8 - self.model8up = [ - nn.ConvTranspose2d( - 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) - ] + self.model8up = nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) - self.model3short8 = [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] + self.model3short8 = nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias) - self.model8 = nn.Sequential([ + self.model8 = nn.Sequential( nn.ReLU(True), nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), @@ -193,38 +190,30 @@ def __init__(self, 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(256), - ]) + ) # Conv9 - self.model9up = [ - nn.ConvTranspose2d( - 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] + self.model9up = nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias) - self.model2short9 = [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] - self.model9 = nn.Sequential([ + self.model2short9 = nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias) + self.model9 = nn.Sequential( nn.ReLU(True), nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(128), - ]) + ) # Conv10 - self.model10up = [ - nn.ConvTranspose2d( - 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] + self.model10up = nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias) - self.model1short10 = [ - nn.Conv2d( - 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] + self.model1short10 = nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias) - self.model10 = nn.Sequential([ + self.model10 = nn.Sequential( nn.ReLU(True), nn.Conv2d( 128, @@ -235,7 +224,7 @@ def __init__(self, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2), - ]) + ) # classification output self.model_class = nn.Conv2d( diff --git a/mmedit/models/editors/inst_colorization/fusion_net.py b/mmedit/models/editors/inst_colorization/fusion_net.py index b73c2acaea..0ed6c5f733 100644 --- a/mmedit/models/editors/inst_colorization/fusion_net.py +++ b/mmedit/models/editors/inst_colorization/fusion_net.py @@ -1,19 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn +from mmengine.model import BaseModule from mmedit.registry import MODULES +from .weight_layer import WeightLayer, get_norm_layer -from .weight_layer import get_norm_layer, WeightLayer @MODULES.register_module() -class FusionNet(nn.Module): +class FusionNet(BaseModule): """Instance-aware Image Colorization. https://arxiv.org/abs/2005.10825 Codes adapted from 'https://github.com/ericsujw/InstColorization.git' 'InstColorization/blob/master/models/networks.py#L314' + FusionNet: the full image model with weight layer for fusion. Args: input_nc: @@ -36,9 +38,9 @@ def __init__(self, norm_layer = get_norm_layer(norm_type) use_bias = True - + # Conv1 - self.model1 = nn.Sequential([ + self.model1 = nn.Sequential( nn.Conv2d( input_nc, 64, @@ -51,12 +53,12 @@ def __init__(self, 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(64), - ]) + ) self.weight_layer = WeightLayer(64) # Conv2 - self.model2 = nn.Sequential([ + self.model2 = nn.Sequential( nn.Conv2d( 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), @@ -64,13 +66,12 @@ def __init__(self, 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(128), - ]) + ) self.weight_layer2 = WeightLayer(128) - # Conv3 - self.model3 = nn.Sequential([ + self.model3 = nn.Sequential( nn.Conv2d( 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), @@ -81,12 +82,12 @@ def __init__(self, 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(256), - ]) + ) self.weight_layer3 = WeightLayer(256) - + # Conv4 - self.model4 = nn.Sequential([ + self.model4 = nn.Sequential( nn.Conv2d( 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), @@ -97,12 +98,12 @@ def __init__(self, 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(512), - ]) + ) self.weight_layer4 = WeightLayer(512) # Conv5 - self.model5 = nn.Sequential([ + self.model5 = nn.Sequential( nn.Conv2d( 512, 512, @@ -131,12 +132,12 @@ def __init__(self, bias=use_bias), nn.ReLU(True), norm_layer(512), - ]) + ) self.weight_layer5 = WeightLayer(512) # Conv6 - self.model6 = nn.Sequential([ + self.model6 = nn.Sequential( nn.Conv2d( 512, 512, @@ -165,12 +166,12 @@ def __init__(self, bias=use_bias), nn.ReLU(True), norm_layer(512), - ]) + ) self.weight_layer6 = WeightLayer(512) # Conv7 - self.model7 = nn.Sequential([ + self.model7 = nn.Sequential( nn.Conv2d( 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), @@ -181,24 +182,20 @@ def __init__(self, 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(512), - ]) + ) self.weight_layer7 = WeightLayer(512) # Conv8 - self.model8up = [ - nn.ConvTranspose2d( - 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) - ] + self.model8up = nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) - self.model3short8 = [ - nn.Conv2d( - 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] + self.model3short8 = nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias) self.weight_layer8_1 = WeightLayer(256) - self.model8 = nn.Sequential([ + self.model8 = nn.Sequential( nn.ReLU(True), nn.Conv2d( 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), @@ -207,47 +204,39 @@ def __init__(self, 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(256), - ]) + ) self.weight_layer8_2 = WeightLayer(256) # Conv9 - self.model9up = [ - nn.ConvTranspose2d( - 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] + self.model9up = nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias) - self.model2short9 = [ - nn.Conv2d( - 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] + self.model2short9 = nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias) self.weight_layer9_1 = WeightLayer(128) - self.model9 = nn.Sequential([ + self.model9 = nn.Sequential( nn.ReLU(True), nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), nn.ReLU(True), norm_layer(128), - ]) + ) self.weight_layer9_2 = WeightLayer(128) # Conv10 - self.model10up = [ - nn.ConvTranspose2d( - 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), - ] + self.model10up = nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias) - self.model1short10 = [ - nn.Conv2d( - 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), - ] + self.model1short10 = nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias) self.weight_layer10_1 = WeightLayer(128) - self.model10 = nn.Sequential([ + self.model10 = nn.Sequential( nn.ReLU(True), nn.Conv2d( 128, @@ -258,7 +247,7 @@ def __init__(self, padding=1, bias=use_bias), nn.LeakyReLU(negative_slope=.2), - ]) + ) self.weight_layer10_2 = WeightLayer(128) diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 77955d92b7..2295de631e 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -3,14 +3,11 @@ from typing import Dict, List, Optional, Union import torch -from detectron2 import model_zoo -from detectron2.config import get_cfg -from detectron2.engine import DefaultPredictor from mmengine.config import Config from mmengine.model import BaseModel from mmengine.optim import OptimWrapperDict -from mmedit.models.utils import (encode_ab_ind, get_colorization_data, lab2rgb) +from mmedit.models.utils import get_colorization_data, lab2rgb from mmedit.registry import MODULES from mmedit.structures import EditDataSample, PixelData @@ -20,7 +17,6 @@ class InstColorization(BaseModel): def __init__(self, data_preprocessor: Union[dict, Config], - detector_cfg, full_model, instance_model, stage, @@ -44,19 +40,12 @@ def __init__(self, init_cfg=init_cfg, data_preprocessor=data_preprocessor) # colorization networks - # Stage 1 & 3. fusion model intergrates the image model + # Stage 1 & 3. fusion model intergrates the image model self.full_model = MODULES.build(full_model) - # Stage 2. instance model used for training instance colorization + # Stage 2. instance model used for training instance colorization self.instance_model = MODULES.build(instance_model) - # detector - cfg = get_cfg() - cfg.merge_from_file(model_zoo.get_config_file(detector_cfg)) - cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 - cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(detector_cfg) - self.detector = DefaultPredictor(cfg) - self.stage = stage # self.train_cfg = train_cfg @@ -77,7 +66,8 @@ def __init__(self, # sample_PS=sample_Ps, # mask_cent=mask_cent, # ) - # self.lab2rgb_opt = dict(ab_norm=ab_norm, l_norm=l_norm, l_cent=l_cent) + # self.lab2rgb_opt = dict( + # ab_norm=ab_norm, l_norm=l_norm, l_cent=l_cent) # self.convert_params = dict( # ab_thresh=0, # ab_norm=ab_norm, @@ -165,48 +155,6 @@ def convert_to_datasample(self, inputs, data_samples): data_sample.output = output return inputs - def forward_tensor(self, inputs, data_samples=None, **kwargs): - """Forward tensor. Returns result of simple forward. - - Args: - inputs (torch.Tensor): batch input tensor collated by - :attr:`data_preprocessor`. - data_samples (List[BaseDataElement], optional): - data samples collated by :attr:`data_preprocessor`. - - Returns: - Tensor: result of simple forward. - """ - - feats = self.generator(inputs, **kwargs) - - return feats - - def forward_inference(self, inputs, data_samples=None, **kwargs): - """Forward inference. Returns predictions of validation, testing, and - simple inference. - - Args: - inputs (torch.Tensor): batch input tensor collated by - :attr:`data_preprocessor`. - data_samples (List[BaseDataElement], optional): - data samples collated by :attr:`data_preprocessor`. - - Returns: - List[EditDataSample]: predictions. - """ - - feats = self.forward_tensor(inputs, data_samples, **kwargs) - feats = self.data_preprocessor.destructor(feats) - predictions = [] - for idx in range(feats.shape[0]): - predictions.append( - EditDataSample( - pred_img=PixelData(data=feats[idx].to('cpu')), - metainfo=data_samples[idx].metainfo)) - - return predictions - def forward_train(self, inputs, data_samples=None, **kwargs): """Forward training. Returns dict of losses of training. @@ -228,14 +176,11 @@ def forward_train(self, inputs, data_samples=None, **kwargs): return dict(loss=loss) - - def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: - g_optim_wrapper = optim_wrapper['generator'] data = self.data_preprocessor(data, True) - batch_inputs, data_samples = data['inputs'], data['data_samples'] + data_batch, data_samples = data['inputs'], data['data_samples'] log_vars = {} @@ -285,7 +230,7 @@ def train_step(self, data: List[dict], self.full_mask_B, self.box_info_list) - optimizer['generator'].zero_grad() + optim_wrapper['generator'].zero_grad() loss = self.generator_loss() @@ -294,7 +239,7 @@ def train_step(self, data: List[dict], loss_d.backward() - optimizer['generator'].step() + optim_wrapper['generator'].step() results = self.get_current_visuals() @@ -352,14 +297,6 @@ def forward_inference(self, inputs, data_samples=None, **kwargs): return predictions - - - - - - - - def get_current_visuals(self): visual_ret = OrderedDict() diff --git a/tests/test_apis/test_colorization_inference.py b/tests/test_apis/test_colorization_inference.py index 5c8545ef45..ad05a1d26b 100644 --- a/tests/test_apis/test_colorization_inference.py +++ b/tests/test_apis/test_colorization_inference.py @@ -19,7 +19,9 @@ def test_colorization_inference(): device = torch.device('cpu') data_root = '../../' - config = data_root + 'configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py' + config = osp.join( + data_root, + 'configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py') checkpoint = None @@ -36,4 +38,4 @@ def test_colorization_inference(): img_path = '../data/image/gray/test.jpg' result = colorization_inference(model, img_path) - assert tensor2img(result)[..., ::-1].shape == (256, 256, 3) \ No newline at end of file + assert tensor2img(result)[..., ::-1].shape == (256, 256, 3) diff --git a/tests/test_datasets/test_coco.py b/tests/test_datasets/test_coco.py index 08a5dc6436..7687af0db1 100644 --- a/tests/test_datasets/test_coco.py +++ b/tests/test_datasets/test_coco.py @@ -1,10 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. - -import os.path as osp -from pathlib import Path - from mmedit.registry import DATASETS -from mmedit.datasets import CocoDataset # todo 完成coco的单元测试编写 @@ -12,15 +7,14 @@ class TestCOCOStuff: DATASET_TYPE = 'CocoDataset' ann_file = 'test.json' - data_root = "../.." + data_root = '../..' DEFAULT_ARGS = dict( data_root=data_root + '/train2017', data_prefix=dict(gt='data_large'), ann_file=ann_file, pipeline=[], - test_mode=False - ) + test_mode=False) def test_load_data_list(self): dataset_class = DATASETS.get(self.DATASET_TYPE) @@ -32,4 +26,3 @@ def test_load_data_list(self): } # 对拿到的数据列表和数据进行判断 - \ No newline at end of file diff --git a/tests/test_datasets/test_transforms/test_get_gray_color_pil.py b/tests/test_datasets/test_transforms/test_get_gray_color_pil.py index 53cabdb3a7..04667a98a5 100644 --- a/tests/test_datasets/test_transforms/test_get_gray_color_pil.py +++ b/tests/test_datasets/test_transforms/test_get_gray_color_pil.py @@ -5,12 +5,10 @@ def test_get_gray_color_pil(): - img = cv.imread("../../data/image/gt/baboon.png") - test_class = GenGrayColorPil( - stage='test', keys=['rgb_img', 'gray_img'] - ) + img = cv.imread('../../data/image/gt/baboon.png') + test_class = GenGrayColorPil(stage='test', keys=['rgb_img', 'gray_img']) results = test_class.transform(dict(img=img)) assert 'rgb_img' in results.keys() and 'gray_img' in results.keys() - assert results['gray_img'].shape == img.shape \ No newline at end of file + assert results['gray_img'].shape == img.shape diff --git a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py index a1ee22a2f1..a89ae3bd8a 100644 --- a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py +++ b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py @@ -1,20 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. import os + import cv2 as cv + from mmedit.datasets.transforms import GenMaskRCNNBbox from mmedit.utils import tensor2img class TestMaskRCNNBbox: - DEFAULT_ARGS = dict( - key='img', finesize=256 - ) + DEFAULT_ARGS = dict(key='img', finesize=256) def test_maskrcnn_bbox(self): detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') - data_root = ".." - img_path = "data/image/gray/test.jpg" + data_root = '..' + img_path = 'data/image/gray/test.jpg' img = cv.imread(os.path.join(data_root, img_path)) data = dict(img=img) @@ -39,8 +39,8 @@ def test_maskrcnn_bbox(self): def test_gen_maskrcnn_from_pred(self): detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') - data_root = ".." - img_path = "data/image/gray/test.jpg" + data_root = '..' + img_path = 'data/image/gray/test.jpg' img = cv.imread(os.path.join(data_root, img_path)) box_num_upbound = 4 @@ -51,8 +51,8 @@ def test_gen_maskrcnn_from_pred(self): def test_get_box_info(self): detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') - data_root = ".." - img_path = "data/image/gray/test.jpg" + data_root = '..' + img_path = 'data/image/gray/test.jpg' img = cv.imread(os.path.join(data_root, img_path)) pred_bbox = detectetor.gen_maskrcnn_bbox_fromPred(img) @@ -64,10 +64,9 @@ def test_get_box_info(self): box_info = detectetor.get_box_info(pred_bbox, img.shape) - assert box_info[0] == resize_starty and box_info[1] == 256 - resize_endx \ - and box_info[2] == resize_starty and box_info[3] == 256 - resize_endy \ - and box_info[4] == resize_endx - resize_startx \ - and box_info[5] == resize_endy - resize_starty - - - + assert box_info[0] == resize_starty and \ + box_info[1] == 256 - resize_endx and \ + box_info[2] == resize_starty and \ + box_info[3] == 256 - resize_endy and \ + box_info[4] == resize_endx - resize_startx and \ + box_info[5] == resize_endy - resize_starty diff --git a/tests/test_models/test_editors/test_inst_colorization/test_util.py b/tests/test_models/test_editors/test_inst_colorization/test_util.py index 876cffc38c..ef101fec61 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_util.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_util.py @@ -1,3 +1 @@ # Copyright (c) OpenMMLab. All rights reserved. - -from mmedit.models.utils import color_utils diff --git a/tests/test_models/test_losses/test_huber_loss.py b/tests/test_models/test_losses/test_huber_loss.py index 125eb7801f..675963331e 100644 --- a/tests/test_models/test_losses/test_huber_loss.py +++ b/tests/test_models/test_losses/test_huber_loss.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmedit.models.losses import HuberLoss def test_huber_loss(): - pass \ No newline at end of file + pass From ffba00a9ba6358337aec76b10e11642e3ad8941b Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Tue, 25 Oct 2022 18:46:00 +0800 Subject: [PATCH 14/32] refactoring model forward --- .../inst-colorizatioon_cocostuff_256x256.py | 9 +- mmedit/datasets/transforms/__init__.py | 4 +- mmedit/datasets/transforms/aug_shape.py | 48 ++-- .../datasets/transforms/get_maskrcnn_bbox.py | 250 ++++++------------ .../inst_colorization/inst_colorization.py | 214 ++------------- 5 files changed, 144 insertions(+), 381 deletions(-) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index 0af72014e1..9a872442f9 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -35,10 +35,13 @@ test_pipeline = [ dict(type='LoadImageFromFile', key='img'), dict( - type='GenMaskRCNNBbox', + type='InstanceCrop', config_file='COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml', # noqa - stage=stage, finesize=256), - dict(type='Resize', keys=['img'], scale=(256, 256), keep_ratio=False), + dict( + type='Resize', + keys=['img', 'cropped_img'], + scale=(256, 256), + keep_ratio=False), dict(type='PackEditInputs'), ] diff --git a/mmedit/datasets/transforms/__init__.py b/mmedit/datasets/transforms/__init__.py index b5b5cb1bd8..6d7ff453be 100644 --- a/mmedit/datasets/transforms/__init__.py +++ b/mmedit/datasets/transforms/__init__.py @@ -18,7 +18,7 @@ GenerateSegmentIndices) from .get_gray_color_pil import GenGrayColorPil from .get_masked_image import GetMaskedImage -from .get_maskrcnn_bbox import GenMaskRCNNBbox +from .get_maskrcnn_bbox import InstanceCrop from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadMask, LoadPairedImageFromFile) from .matlab_like_resize import MATLABLikeResize @@ -47,6 +47,6 @@ 'GenerateSoftSeg', 'FormatTrimap', 'TransformTrimap', 'GenerateTrimap', 'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg', 'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile', - 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'GenMaskRCNNBbox', + 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop', 'GenGrayColorPil' ] diff --git a/mmedit/datasets/transforms/aug_shape.py b/mmedit/datasets/transforms/aug_shape.py index 6e62c3a4a9..61fa206d2e 100644 --- a/mmedit/datasets/transforms/aug_shape.py +++ b/mmedit/datasets/transforms/aug_shape.py @@ -319,24 +319,31 @@ def _resize(self, img): Returns: img (np.ndarray): Resized image. """ - - if self.keep_ratio: - img, self.scale_factor = mmcv.imrescale( - img, - self.scale, - return_scale=True, - interpolation=self.interpolation, - backend=self.backend) + if isinstance(img, list): + for i, image in enumerate(img): + size, img[i] = self._resize(image) + return size, img else: - img, w_scale, h_scale = mmcv.imresize( - img, - self.scale, - return_scale=True, - interpolation=self.interpolation, - backend=self.backend) - self.scale_factor = np.array((w_scale, h_scale), dtype=np.float32) - - return img + if self.keep_ratio: + img, self.scale_factor = mmcv.imrescale( + img, + self.scale, + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + else: + img, w_scale, h_scale = mmcv.imresize( + img, + self.scale, + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + self.scale_factor = np.array((w_scale, h_scale), + dtype=np.float32) + + if len(img.shape) == 2: + img = np.expand_dims(img, axis=2) + return img.shape, img def transform(self, results: Dict) -> Dict: """Transform function to resize images. @@ -358,11 +365,10 @@ def transform(self, results: Dict) -> Dict: new_w = min(self.max_size - (self.max_size % self.size_factor), new_w) self.scale = (new_w, new_h) + for key, out_key in zip(self.keys, self.output_keys): - results[out_key] = self._resize(results[key]) - if len(results[out_key].shape) == 2: - results[out_key] = np.expand_dims(results[out_key], axis=2) - results[f'{out_key}_shape'] = results[out_key].shape + size, results[out_key] = self._resize(results[key]) + results[f'{out_key}_shape'] = size results['scale_factor'] = self.scale_factor results['keep_ratio'] = self.keep_ratio diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py index 4110e4c540..610db6217d 100644 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -1,163 +1,88 @@ # Copyright (c) OpenMMLab. All rights reserved. -from random import sample - import cv2 as cv import numpy as np import torch -import torchvision.transforms as transforms from detectron2 import model_zoo from detectron2.config import get_cfg from detectron2.engine import DefaultPredictor -from PIL import Image +from mmcv.transforms import BaseTransform from mmedit.registry import TRANSFORMS @TRANSFORMS.register_module() -class GenMaskRCNNBbox: +class InstanceCrop(BaseTransform): + """## Arguments: + + - pred_data_path: Detectron2 predict results + - box_num_upbound: object bounding boxes number. + Default: -1 means use all the instances. + """ + + def __init__(self, + config_file, + key='img', + box_num_upbound=-1, + finesize=256): + # detector + cfg = get_cfg() + cfg.merge_from_file(model_zoo.get_config_file(config_file)) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file) + self.predictor = DefaultPredictor(cfg) - def __init__(self, config_file, key='img', stage='test', finesize=256): - self.config_file = config_file self.key = key - self.predictor = self.detectron() - self.stage = stage + self.box_num_upbound = box_num_upbound self.final_size = finesize - self.transforms = transforms.Compose([ - transforms.Resize((self.final_size, self.final_size), - interpolation=2), - transforms.ToTensor() - ]) - - def gen_maskrcnn_bbox_fromPred(self, - img, - bbox_path=None, - box_num_upbound=8): - """## Arguments: - - pred_data_path: Detectron2 predict results - - box_num_upbound: object bounding boxes number. - Default: -1 means use all the instances. - """ - if bbox_path: - pred_data = np.load(bbox_path) - pred_bbox = pred_data['bbox'].astype(np.int32) - pred_scores = pred_data['scores'] - else: - lab_image = cv.cvtColor(img, cv.COLOR_BGR2LAB) - l_channel, a_channel, b_channel = cv.split(lab_image) - l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) - outputs = self.predictor(l_stack) - pred_bbox = outputs['instances'].pred_boxes.to( - torch.device('cpu')).tensor.numpy() - pred_scores = outputs['instances'].scores.cpu().data.numpy() - - pred_bbox = pred_bbox.astype(np.int32) - if 0 < box_num_upbound < pred_bbox.shape[0]: - index_mask = np.argsort( - pred_scores, axis=0)[pred_scores.shape[0] - - box_num_upbound:pred_scores.shape[0]] + def transform(self, results: dict) -> dict: + + # get consistent box prediction based on L channel + full_img = results['img'] + # cv.imwrite('full_img.jpg', full_img) + full_img_size = results['ori_img_shape'][:-1][::-1] + lab_image = cv.cvtColor(full_img, cv.COLOR_BGR2LAB) + l_channel, a_channel, b_channel = cv.split(lab_image) + l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) + outputs = self.predictor(l_stack) + + # get the most confident boxes + pred_bbox = outputs['instances'].pred_boxes.to( + torch.device('cpu')).tensor.numpy() + pred_scores = outputs['instances'].scores.cpu().data.numpy() + pred_bbox = pred_bbox.astype(np.int32) + if self.box_num_upbound > 0 and pred_bbox.shape[ + 0] > self.box_num_upbound: + index_mask = np.argsort(pred_scores, axis=0) + index_mask = index_mask[pred_scores.shape[0] - + self.box_num_upbound:pred_scores.shape[0]] pred_bbox = pred_bbox[index_mask] - return pred_bbox - - @staticmethod - def read_to_pil(out_img): - ''' - return: pillow image object HxWx3 - ''' - out_img = Image.fromarray(out_img) - if len(np.asarray(out_img).shape) == 2: - out_img = np.stack([ - np.asarray(out_img), - np.asarray(out_img), - np.asarray(out_img) - ], 2) - out_img = Image.fromarray(out_img) - return out_img - - def get_box_info(self, pred_bbox, original_shape): - assert len(pred_bbox) == 4 - resize_startx = int(pred_bbox[0] / original_shape[0] * self.final_size) - resize_starty = int(pred_bbox[1] / original_shape[1] * self.final_size) - resize_endx = int(pred_bbox[2] / original_shape[0] * self.final_size) - resize_endy = int(pred_bbox[3] / original_shape[1] * self.final_size) - rh = resize_endx - resize_startx - rw = resize_endy - resize_starty - if rh < 1: - if self.final_size - resize_endx > 1: - resize_endx += 1 - else: - resize_startx -= 1 - rh = 1 - if rw < 1: - if self.final_size - resize_endy > 1: - resize_endy += 1 - else: - resize_starty -= 1 - rw = 1 - L_pad = resize_startx - R_pad = self.final_size - resize_endx - T_pad = resize_starty - B_pad = self.final_size - resize_endy - return [L_pad, R_pad, T_pad, B_pad, rh, rw] - - def fusion(self, results, img, pred_bbox): - if self.stage == 'test': - gray_img = self.read_to_pil(img) - return self.get_instance_info(results, pred_bbox, gray_img) - - if self.stage == 'fusion': - rgb_img, gray_img = results['rgb_img'], results['gray_img'] - return self.get_instance_info(results, pred_bbox, gray_img, - rgb_img) - - def get_instance_info(self, results, pred_bbox, gray_img, rgb_img=None): - - if not rgb_img: - full_gray_list = [self.transforms(gray_img)] - cropped_gray_list = [] - else: - full_rgb_list = [self.transforms(rgb_img)] - full_gray_list = [self.transforms(gray_img)] - cropped_rgb_list = [] - cropped_gray_list = [] - + # get cropped images and box info + cropped_img_list = [] index_list = range(len(pred_bbox)) box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros( (4, len(index_list), 6)) for i in index_list: startx, starty, endx, endy = pred_bbox[i] + cropped_img = full_img[starty:endy, startx:endx, :] + # cv.imwrite(f"crop_{i}.jpg", cropped_img) + cropped_img_list.append(cropped_img) box_info[i] = np.array( - self.get_box_info(pred_bbox[i], gray_img.size, - self.final_size)) + get_box_info(pred_bbox[i], full_img_size, self.final_size)) box_info_2x[i] = np.array( - self.get_box_info(pred_bbox[i], gray_img.size, - self.final_size // 2)) + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 2)) box_info_4x[i] = np.array( - self.get_box_info(pred_bbox[i], gray_img.size, - self.final_size // 4)) + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 4)) box_info_8x[i] = np.array( - self.get_box_info(pred_bbox[i], gray_img.size, - self.final_size // 8)) - cropped_img = self.transforms( - gray_img.crop((startx, starty, endx, endy))) - cropped_gray_list.append(cropped_img) - if rgb_img: - cropped_rgb_list.append( - self.transforms( - rgb_img.crop((startx, starty, endx, endy)))) - cropped_gray_list.append( - self.transforms( - gray_img.crop((startx, starty, endx, endy)))) - - results['full_gray'] = torch.stack(full_gray_list) - if rgb_img: - results['full_rgb'] = torch.stack(full_rgb_list) + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 8)) + # update results if len(pred_bbox) > 0: - results['cropped_gray'] = torch.stack(cropped_gray_list) - if rgb_img: - results['cropped_rgb'] = torch.stack(cropped_rgb_list) + results['cropped_img'] = cropped_img_list results['box_info'] = torch.from_numpy(box_info).type(torch.long) results['box_info_2x'] = torch.from_numpy(box_info_2x).type( torch.long) @@ -168,46 +93,31 @@ def get_instance_info(self, results, pred_bbox, gray_img, rgb_img=None): results['empty_box'] = False else: results['empty_box'] = True - - return results - - def train(self, results, pred_bbox): - - rgb_img, gray_img = results['rgb_img'], results['gray_img'] - index_list = range(len(pred_bbox)) - index_list = sample(index_list, 1) - startx, starty, endx, endy = pred_bbox[index_list[0]] - - results['rgb_img'] = self.transforms( - rgb_img.crop((startx, starty, endx, endy))) - results['gray_img'] = self.transforms( - gray_img.crop((startx, starty, endx, endy))) - return results - def __call__(self, results): - img = results['img'] - if 'bbox_path' in results.keys(): - pred_bbox = self.gen_maskrcnn_bbox_fromPred( - img, results['bbox_path']) - elif 'instance' in results.keys(): - pred_bbox = results['instance'] +def get_box_info(pred_bbox, original_shape, final_size): + assert len(pred_bbox) == 4 + resize_startx = int(pred_bbox[0] / original_shape[0] * final_size) + resize_starty = int(pred_bbox[1] / original_shape[1] * final_size) + resize_endx = int(pred_bbox[2] / original_shape[0] * final_size) + resize_endy = int(pred_bbox[3] / original_shape[1] * final_size) + rh = resize_endx - resize_startx + rw = resize_endy - resize_starty + if rh < 1: + if final_size - resize_endx > 1: + resize_endx += 1 else: - pred_bbox = self.gen_maskrcnn_bbox_fromPred(img) - - if self.stage == 'fusion' or self.stage == 'test': - results = self.fusion(results, img, pred_bbox) - - if self.stage == 'full' or self.stage == 'instance': - results = self.train(results, pred_bbox) - - return results - - def detectron(self): - cfg = get_cfg() - cfg.merge_from_file(model_zoo.get_config_file(self.config_file)) - cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 - cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.config_file) - predictor = DefaultPredictor(cfg) - return predictor + resize_startx -= 1 + rh = 1 + if rw < 1: + if final_size - resize_endy > 1: + resize_endy += 1 + else: + resize_starty -= 1 + rw = 1 + L_pad = resize_startx + R_pad = final_size - resize_endx + T_pad = resize_starty + B_pad = final_size - resize_endy + return [L_pad, R_pad, T_pad, B_pad, rh, rw] diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 2295de631e..23ebbf91b5 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from collections import OrderedDict from typing import Dict, List, Optional, Union import torch @@ -150,107 +149,40 @@ def forward(self, elif mode == 'loss': return self.forward_train(inputs, data_samples, **kwargs) - def convert_to_datasample(self, inputs, data_samples): - for data_sample, output in zip(inputs, data_samples): - data_sample.output = output - return inputs - def forward_train(self, inputs, data_samples=None, **kwargs): - """Forward training. Returns dict of losses of training. - - Args: - inputs (torch.Tensor): batch input tensor collated by - :attr:`data_preprocessor`. - data_samples (List[BaseDataElement], optional): - data samples collated by :attr:`data_preprocessor`. - - Returns: - dict: Dict of losses. - """ - - feats = self.forward_tensor(inputs, data_samples, **kwargs) - gt_imgs = [data_sample.gt_img.data for data_sample in data_samples] - batch_gt_data = torch.stack(gt_imgs) - - loss = self.pixel_loss(feats, batch_gt_data) - - return dict(loss=loss) + raise NotImplementedError( + 'Instance Colorization has not supported training.') def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: + raise NotImplementedError( + 'Instance Colorization has not supported training.') - data = self.data_preprocessor(data, True) - data_batch, data_samples = data['inputs'], data['data_samples'] - - log_vars = {} - - if self.stage == 'full' or self.stage == 'instance': - rgb_img = [data_samples.rgb_img] - gray_img = [data_samples.gray_img] - - input_data = get_colorization_data(gray_img, - **self.colorization_data_opt) - - gt_data = get_colorization_data(rgb_img, - **self.colorization_data_opt) - - input_data['B'] = gt_data['B'] - input_data['hint_B'] = gt_data['hint_B'] - input_data['mask_B'] = gt_data['mask_B'] - self.set_input(input_data) - self.fake_B_reg = self.generator(self.real_A, self.hint_B, - self.mask_B) - - elif self.stage == 'fusion': - box_info = data_samples.box_info - box_info_2x = data_samples.box_info_2x - box_info_4x = data_samples.box_info_4x - box_info_8x = data_samples.box_info_8x - - cropped_input_data = get_colorization_data( - data_samples.cropped_gray, **self.colorization_data_opt) - cropped_gt_data = get_colorization_data( - data_samples.cropped_rgb, **self.colorization_data_opt) - full_input_data = get_colorization_data( - data_samples.full_gray, **self.colorization_data_opt) - full_gt_data = get_colorization_data(data_samples.full_rgb, - **self.colorization_data_opt) - - cropped_input_data['B'] = cropped_gt_data['B'] - full_input_data['B'] = full_gt_data['B'] - - self.set_input(cropped_input_data) - self.set_fusion_input( - full_input_data, - [box_info, box_info_2x, box_info_4x, box_info_8x]) - - self.fake_B_reg = self.generator(self.real_A, self.hint_B, - self.mask_B, self.full_real_A, - self.full_hint_B, - self.full_mask_B, - self.box_info_list) - - optim_wrapper['generator'].zero_grad() - - loss = self.generator_loss() - - loss_d, log_vars_d = self.parse_losses(loss) - log_vars.update(log_vars_d) - - loss_d.backward() - - optim_wrapper['generator'].step() + def forward_inference(self, inputs, data_samples=None, **kwargs): + feats = self.forward_tensor(inputs, data_samples, **kwargs) + predictions = [] + for idx in range(feats.shape[0]): + batch_tensor = feats[idx] * 127.5 + 127.5 + pred_img = PixelData(data=batch_tensor.to('cpu')) + predictions.append( + EditDataSample( + pred_img=pred_img, metainfo=data_samples[idx].metainfo)) - results = self.get_current_visuals() + return predictions - output = dict( - log_vars=log_vars, - num_samples=len(data_batch['rgb_img']), - results=results) + def forward_tensor(self, inputs, data_samples): + """Forward function in tensor mode. - return output + Args: + inputs (torch.Tensor): Input tensor. + data_sample (dict): Dict contains data sample. - def forward_tensor(self, inputs, data_samples, **kwargs): + Returns: + dict: Dict contains output results. + """ + print(data_samples) + for dp in data_samples: + print(dp.keys()) data = data_samples[0] full_img = data.full_gray @@ -285,95 +217,7 @@ def forward_tensor(self, inputs, data_samples, **kwargs): return out_img - def forward_inference(self, inputs, data_samples=None, **kwargs): - feats = self.forward_tensor(inputs, data_samples, **kwargs) - predictions = [] - for idx in range(feats.shape[0]): - batch_tensor = feats[idx] * 127.5 + 127.5 - pred_img = PixelData(data=batch_tensor.to('cpu')) - predictions.append( - EditDataSample( - pred_img=pred_img, metainfo=data_samples[idx].metainfo)) - - return predictions - - def get_current_visuals(self): - - visual_ret = OrderedDict() - - if self.stage == 'full' or self.stage == 'instance': - - visual_ret['gray'] = lab2rgb( - torch.cat((self.real_A.type( - torch.cuda.FloatTensor), torch.zeros_like( - self.real_B).type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['real'] = lab2rgb( - torch.cat((self.real_A.type(torch.cuda.FloatTensor), - self.real_B.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['fake_reg'] = lab2rgb( - torch.cat((self.real_A.type(torch.cuda.FloatTensor), - self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - - visual_ret['hint'] = lab2rgb( - torch.cat((self.real_A.type(torch.cuda.FloatTensor), - self.hint_B.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['real_ab'] = lab2rgb( - torch.cat((torch.zeros_like( - self.real_A.type(torch.cuda.FloatTensor)), - self.real_B.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['fake_ab_reg'] = lab2rgb( - torch.cat((torch.zeros_like( - self.real_A.type(torch.cuda.FloatTensor)), - self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - - elif self.stage == 'fusion': - visual_ret['gray'] = lab2rgb( - torch.cat((self.full_real_A.type( - torch.cuda.FloatTensor), torch.zeros_like( - self.full_real_B).type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['real'] = lab2rgb( - torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), - self.full_real_B.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['comp_reg'] = lab2rgb( - torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), - self.comp_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['fake_reg'] = lab2rgb( - torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), - self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - - self.instance_mask = torch.nn.functional.interpolate( - torch.zeros([1, 1, 176, 176]), - size=visual_ret['gray'].shape[2:], - mode='bilinear').type(torch.cuda.FloatTensor) - visual_ret['box_mask'] = torch.cat( - (self.instance_mask, self.instance_mask, self.instance_mask), - 1) - visual_ret['real_ab'] = lab2rgb( - torch.cat((torch.zeros_like( - self.full_real_A.type(torch.cuda.FloatTensor)), - self.full_real_B.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['comp_ab_reg'] = lab2rgb( - torch.cat((torch.zeros_like( - self.full_real_A.type(torch.cuda.FloatTensor)), - self.comp_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - visual_ret['fake_ab_reg'] = lab2rgb( - torch.cat((torch.zeros_like( - self.full_real_A.type(torch.cuda.FloatTensor)), - self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt) - else: - print('Error! Wrong stage selection!') - exit() - return visual_ret + def convert_to_datasample(self, inputs, data_samples): + for data_sample, output in zip(inputs, data_samples): + data_sample.output = output + return inputs From a2fc3ceaa25119ca0ab3078c4b023bec8d445a56 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Wed, 26 Oct 2022 20:30:50 +0800 Subject: [PATCH 15/32] refactor packedit --- mmedit/datasets/transforms/formatting.py | 5 ++ mmedit/models/utils/color_utils.py | 104 +++++++++++------------ 2 files changed, 57 insertions(+), 52 deletions(-) diff --git a/mmedit/datasets/transforms/formatting.py b/mmedit/datasets/transforms/formatting.py index 3a73adc643..cc5aafaa24 100644 --- a/mmedit/datasets/transforms/formatting.py +++ b/mmedit/datasets/transforms/formatting.py @@ -211,6 +211,11 @@ def transform(self, results: dict) -> dict: gray_tensor = images_to_tensor(gray) data_sample.gray = PixelData(data=gray_tensor) + if 'cropped_img' in results: + cropped_img = results.pop('cropped_img') + cropped_img = images_to_tensor(cropped_img) + data_sample.cropped_img = PixelData(data=cropped_img) + metainfo = dict() for key in results: metainfo[key] = results[key] diff --git a/mmedit/models/utils/color_utils.py b/mmedit/models/utils/color_utils.py index 8430994f24..21f613c546 100644 --- a/mmedit/models/utils/color_utils.py +++ b/mmedit/models/utils/color_utils.py @@ -58,34 +58,6 @@ def xyz2rgb(xyz): return rgb -def xyz2lab(xyz): - # 0.95047, 1., 1.08883 # white - sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] - if (xyz.is_cuda): - sc = sc.cuda() - - xyz_scale = xyz / sc - - mask = (xyz_scale > .008856).type(torch.FloatTensor) - if (xyz_scale.is_cuda): - mask = mask.cuda() - - xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale + - 16. / 116.) * (1 - mask) - - L = 116. * xyz_int[:, 1, :, :] - 16. - a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :]) - b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :]) - out = torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), - dim=1) - - # if(torch.sum(torch.isnan(out))>0): - # print('xyz2lab') - # embed() - - return out - - def lab2xyz(lab): y_int = (lab[:, 0, :, :] + 16.) / 116. x_int = (lab[:, 1, :, :] / 500.) + y_int @@ -116,6 +88,58 @@ def lab2xyz(lab): return out +def lab2rgb(lab_rs, **kwargs): + L = lab_rs[:, [0], :, :] * kwargs['l_norm'] + kwargs['l_cent'] + AB = lab_rs[:, 1:, :, :] * kwargs['ab_norm'] + lab = torch.cat((L, AB), dim=1) + out = xyz2rgb(lab2xyz(lab)) + # if(torch.sum(torch.isnan(out))>0): + # print('lab2rgb') + # embed() + return out + + +def encode_ab_ind(data_ab, **kwargs): + # Encode ab value into an index + # INPUTS + # data_ab Nx2xHxW \in [-1,1] + # OUTPUTS + # data_q Nx1xHxW \in [0,Q) + A = 2 * kwargs['ab_max'] / kwargs['ab_quant'] + 1 + data_ab_rs = torch.round((data_ab * kwargs['ab_norm'] + kwargs['ab_max']) / + kwargs['ab_quant']) # normalized bin number + data_q = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] + return data_q + + +def xyz2lab(xyz): + # 0.95047, 1., 1.08883 # white + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + if (xyz.is_cuda): + sc = sc.cuda() + + xyz_scale = xyz / sc + + mask = (xyz_scale > .008856).type(torch.FloatTensor) + if (xyz_scale.is_cuda): + mask = mask.cuda() + + xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale + + 16. / 116.) * (1 - mask) + + L = 116. * xyz_int[:, 1, :, :] - 16. + a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :]) + b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :]) + out = torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), + dim=1) + + # if(torch.sum(torch.isnan(out))>0): + # print('xyz2lab') + # embed() + + return out + + def rgb2lab(rgb, **kwargs): lab = xyz2lab(rgb2xyz(rgb)) # print(lab[0, 0, 0, 0]) @@ -130,17 +154,6 @@ def rgb2lab(rgb, **kwargs): return out -def lab2rgb(lab_rs, **kwargs): - L = lab_rs[:, [0], :, :] * kwargs['l_norm'] + kwargs['l_cent'] - AB = lab_rs[:, 1:, :, :] * kwargs['ab_norm'] - lab = torch.cat((L, AB), dim=1) - out = xyz2rgb(lab2xyz(lab)) - # if(torch.sum(torch.isnan(out))>0): - # print('lab2rgb') - # embed() - return out - - def get_colorization_data(data_raw, ab_thresh=5., p=.125, @@ -240,16 +253,3 @@ def add_color_patches_rand_gt(data, data['mask_B'] -= kwargs['mask_cent'] return data - - -def encode_ab_ind(data_ab, **kwargs): - # Encode ab value into an index - # INPUTS - # data_ab Nx2xHxW \in [-1,1] - # OUTPUTS - # data_q Nx1xHxW \in [0,Q) - A = 2 * kwargs['ab_max'] / kwargs['ab_quant'] + 1 - data_ab_rs = torch.round((data_ab * kwargs['ab_norm'] + kwargs['ab_max']) / - kwargs['ab_quant']) # normalized bin number - data_q = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] - return data_q From ed48fd86cb0e4d9b5a500fba1332c79237aa1494 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 27 Oct 2022 15:23:18 +0800 Subject: [PATCH 16/32] fix color rendering --- .../inst-colorizatioon_cocostuff_256x256.py | 37 ++-- mmedit/apis/colorization_inference.py | 21 ++- .../inst_colorization}/color_utils.py | 140 ++++++--------- .../inst_colorization/inst_colorization.py | 166 ++++++++---------- mmedit/models/utils/__init__.py | 3 +- 5 files changed, 173 insertions(+), 194 deletions(-) rename mmedit/models/{utils => editors/inst_colorization}/color_utils.py (73%) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index 9a872442f9..9cc8dd55c0 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -13,21 +13,32 @@ mean=[127.5], std=[127.5], ), - full_model=dict( - type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), + image_model=dict( + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), instance_model=dict( type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), - stage=stage, - ngf=64, - output_nc=2, - avg_loss_alpha=.986, - ab_norm=110., - ab_max=110., - ab_quant=10., - l_norm=100., - l_cent=50., - sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], - mask_cent=.5, + fusion_model=dict( + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), + color_data_opt=dict( + ab_thresh=0, + p=1.0, + sample_PS=[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ], + ab_norm=110, + ab_max=110., + ab_quant=10., + l_norm=100., + l_cent=50., + mask_cent=0.5), which_direction='AtoB', loss=dict(type='HuberLoss', delta=.01)) diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py index 352d32f6ce..c2271dd5b8 100644 --- a/mmedit/apis/colorization_inference.py +++ b/mmedit/apis/colorization_inference.py @@ -13,11 +13,28 @@ def colorization_inference(model, img): test_pipeline = Compose(model.cfg.test_pipeline) # prepare data data = dict(img_path=img) - data = test_pipeline(data) + _data = test_pipeline(data) + data = dict() + data['inputs'] = _data['inputs'] / 255.0 data = collate([data]) - + data['data_samples'] = [_data['data_samples']] if 'cuda' in str(device): data = scatter(data, [device])[0] + data['data_samples'][0].cropped_img.data = scatter( + data['data_samples'][0].cropped_img.data, [device])[0] / 255.0 + + data['data_samples'][0].box_info.data = scatter( + data['data_samples'][0].box_info.data, [device])[0] + + data['data_samples'][0].box_info_2x.data = scatter( + data['data_samples'][0].box_info_2x.data, [device])[0] + + data['data_samples'][0].box_info_4x.data = scatter( + data['data_samples'][0].box_info_4x.data, [device])[0] + + data['data_samples'][0].box_info_8x.data = scatter( + data['data_samples'][0].box_info_8x.data, [device])[0] + # forward the model with torch.no_grad(): result = model(mode='tensor', **data) diff --git a/mmedit/models/utils/color_utils.py b/mmedit/models/editors/inst_colorization/color_utils.py similarity index 73% rename from mmedit/models/utils/color_utils.py rename to mmedit/models/editors/inst_colorization/color_utils.py index 21f613c546..1ea8464715 100644 --- a/mmedit/models/utils/color_utils.py +++ b/mmedit/models/editors/inst_colorization/color_utils.py @@ -3,38 +3,7 @@ import torch -# Color conversion code -def rgb2xyz(rgb): # rgb from [0,1] - # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], - # [0.212671, 0.715160, 0.072169], - # [0.019334, 0.119193, 0.950227]]) - - mask = (rgb > .04045).type(torch.FloatTensor) - if (rgb.is_cuda): - mask = mask.cuda() - - rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask) - - x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] \ - + .180423 * rgb[:, 2, :, :] - y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] \ - + .072169 * rgb[:, 2, :, :] - z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] \ - + .950227 * rgb[:, 2, :, :] - out = torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), - dim=1) - - # if(torch.sum(torch.isnan(out))>0): - # print('rgb2xyz') - # embed() - return out - - def xyz2rgb(xyz): - # array([[ 3.24048134, -1.53715152, -0.49853633], - # [-0.96925495, 1.87599 , 0.04155593], - # [ 0.05564664, -0.20404134, 1.05731107]]) - r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] \ - 0.49853633 * xyz[:, 2, :, :] g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] \ @@ -42,19 +11,16 @@ def xyz2rgb(xyz): b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] \ + 1.05731107 * xyz[:, 2, :, :] + # sometimes reaches a small negative number, which causes NaNs rgb = torch.cat((r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), dim=1) rgb = torch.max(rgb, torch.zeros_like(rgb)) - # sometimes reaches a small negative number, which causes NaNs mask = (rgb > .0031308).type(torch.FloatTensor) if rgb.is_cuda: mask = mask.cuda() rgb = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask) - # if(torch.sum(torch.isnan(rgb))>0): - # print('xyz2rgb') - # embed() return rgb @@ -80,38 +46,56 @@ def lab2xyz(lab): sc = sc.to(out.device) out = out * sc - - # if(torch.sum(torch.isnan(out))>0): - # print('lab2xyz') - # embed() - return out -def lab2rgb(lab_rs, **kwargs): - L = lab_rs[:, [0], :, :] * kwargs['l_norm'] + kwargs['l_cent'] - AB = lab_rs[:, 1:, :, :] * kwargs['ab_norm'] +def lab2rgb(lab_rs, color_data_opt): + L = lab_rs[:, + [0], :, :] * color_data_opt['l_norm'] + color_data_opt['l_cent'] + AB = lab_rs[:, 1:, :, :] * color_data_opt['ab_norm'] lab = torch.cat((L, AB), dim=1) out = xyz2rgb(lab2xyz(lab)) - # if(torch.sum(torch.isnan(out))>0): - # print('lab2rgb') - # embed() return out -def encode_ab_ind(data_ab, **kwargs): +def encode_ab_ind(data_ab, color_data_opt): # Encode ab value into an index # INPUTS # data_ab Nx2xHxW \in [-1,1] # OUTPUTS # data_q Nx1xHxW \in [0,Q) - A = 2 * kwargs['ab_max'] / kwargs['ab_quant'] + 1 - data_ab_rs = torch.round((data_ab * kwargs['ab_norm'] + kwargs['ab_max']) / - kwargs['ab_quant']) # normalized bin number + A = 2 * color_data_opt['ab_max'] / color_data_opt['ab_quant'] + 1 + data_ab_rs = torch.round( + (data_ab * color_data_opt['ab_norm'] + color_data_opt['ab_max']) / + color_data_opt['ab_quant']) # normalized bin number data_q = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] return data_q +# Color conversion code +def rgb2xyz(rgb): # rgb from [0,1] + # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], + # [0.212671, 0.715160, 0.072169], + # [0.019334, 0.119193, 0.950227]]) + + mask = (rgb > .04045).type(torch.FloatTensor) + if (rgb.is_cuda): + mask = mask.cuda() + + rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask) + + x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] \ + + .180423 * rgb[:, 2, :, :] + y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] \ + + .072169 * rgb[:, 2, :, :] + z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] \ + + .950227 * rgb[:, 2, :, :] + out = torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), + dim=1) + + return out + + def xyz2lab(xyz): # 0.95047, 1., 1.08883 # white sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] @@ -133,42 +117,28 @@ def xyz2lab(xyz): out = torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), dim=1) - # if(torch.sum(torch.isnan(out))>0): - # print('xyz2lab') - # embed() - return out -def rgb2lab(rgb, **kwargs): +def rgb2lab(rgb, color_opt): lab = xyz2lab(rgb2xyz(rgb)) - # print(lab[0, 0, 0, 0]) - # lab_0 = lab[:, [0], :, :] - l_rs = (lab[:, [0], :, :] - kwargs['l_cent']) / kwargs['l_norm'] - # print(l_rs[0, 0, 0, 0]) - ab_rs = lab[:, 1:, :, :] / kwargs['ab_norm'] + l_rs = (lab[:, [0], :, :] - color_opt['l_cent']) / color_opt['l_norm'] + ab_rs = lab[:, 1:, :, :] / color_opt['ab_norm'] out = torch.cat((l_rs, ab_rs), dim=1) - # if(torch.sum(torch.isnan(out))>0): - # print('rgb2lab') - # embed() return out -def get_colorization_data(data_raw, - ab_thresh=5., - p=.125, - num_points=None, - **kwargs): +def get_colorization_data(data_raw, color_opt, num_points=None): data = {} - - data_lab = rgb2lab(data_raw, **kwargs) + data_lab = rgb2lab(data_raw, color_opt) data['A'] = data_lab[:, [ 0, ], :, :] data['B'] = data_lab[:, 1:, :, :] - if ab_thresh > 0: # mask out grayscale images - thresh = 1. * ab_thresh / kwargs['ab_norm'] + # mask out grayscale images + if color_opt['ab_thresh'] > 0: + thresh = 1. * color_opt['ab_thresh'] / color_opt['ab_norm'] mask = torch.sum( torch.abs( torch.max(torch.max(data['B'], dim=3)[0], dim=2)[0] - @@ -176,20 +146,19 @@ def get_colorization_data(data_raw, dim=1) >= thresh data['A'] = data['A'][mask, :, :, :] data['B'] = data['B'][mask, :, :, :] - # print('Removed %i points'%torch.sum(mask==0).numpy()) if torch.sum(mask) == 0: return None return add_color_patches_rand_gt( - data, p=p, num_points=num_points, **kwargs) + data, color_opt, p=color_opt['p'], num_points=num_points) def add_color_patches_rand_gt(data, + color_opt, p=.125, num_points=None, use_avg=True, - samp='normal', - **kwargs): + samp='normal'): # Add random color points sampled from ground truth based on: # Number of points # - if num_points is 0, then sample from geometric distribution, @@ -207,18 +176,21 @@ def add_color_patches_rand_gt(data, pp = 0 cont_cond = True while cont_cond: - if num_points is None: # draw from geometric - # embed() + # draw from geometric + if num_points is None: cont_cond = np.random.rand() < (1 - p) - else: # add certain number of points + else: + # add certain number of points cont_cond = pp < num_points - if not cont_cond: # skip out of loop if condition not met + # skip out of loop if condition not met + if not cont_cond: continue - P = np.random.choice(kwargs['sample_PS']) # patch size + # patch size + P = np.random.choice(color_opt['sample_PS']) - # sample location - if samp == 'normal': # geometric distribution + # sample location: geometric distribution + if samp == 'normal': h = int( np.clip( np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), @@ -250,6 +222,6 @@ def add_color_patches_rand_gt(data, # increment counter pp += 1 - data['mask_B'] -= kwargs['mask_cent'] + data['mask_B'] -= color_opt['mask_cent'] return data diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 23ebbf91b5..c737e9923f 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -6,9 +6,9 @@ from mmengine.model import BaseModel from mmengine.optim import OptimWrapperDict -from mmedit.models.utils import get_colorization_data, lab2rgb from mmedit.registry import MODULES from mmedit.structures import EditDataSample, PixelData +from .color_utils import get_colorization_data, lab2rgb @MODULES.register_module() @@ -16,19 +16,10 @@ class InstColorization(BaseModel): def __init__(self, data_preprocessor: Union[dict, Config], - full_model, + image_model, instance_model, - stage, - ngf, - output_nc, - avg_loss_alpha, - ab_norm, - ab_max, - ab_quant, - l_norm, - l_cent, - sample_Ps, - mask_cent, + fusion_model, + color_data_opt, which_direction='AtoB', loss=None, init_cfg=None, @@ -39,50 +30,20 @@ def __init__(self, init_cfg=init_cfg, data_preprocessor=data_preprocessor) # colorization networks - # Stage 1 & 3. fusion model intergrates the image model - self.full_model = MODULES.build(full_model) + # image_model: used to colorize a single image + self.image_model = MODULES.build(image_model) - # Stage 2. instance model used for training instance colorization + # instance model: used to colorize cropped instance self.instance_model = MODULES.build(instance_model) - self.stage = stage - - # self.train_cfg = train_cfg - # self.test_cfg = test_cfg - # self.ngf = ngf - # self.output_nc = output_nc - # self.avg_loss_alpha = avg_loss_alpha - # self.mask_cent = mask_cent - # self.which_direction = which_direction - - # self.encode_ab_opt = dict( - # ab_norm=ab_norm, ab_max=ab_max, ab_quant=ab_quant) - # self.colorization_data_opt = dict( - # ab_thresh=0, - # ab_norm=ab_norm, - # l_norm=l_norm, - # l_cent=l_cent, - # sample_PS=sample_Ps, - # mask_cent=mask_cent, - # ) - # self.lab2rgb_opt = dict( - # ab_norm=ab_norm, l_norm=l_norm, l_cent=l_cent) - # self.convert_params = dict( - # ab_thresh=0, - # ab_norm=ab_norm, - # l_norm=l_norm, - # l_cent=l_cent, - # sample_PS=sample_Ps, - # mask_cent=mask_cent, - # ) - - # # loss - # self.loss_names = ['G', 'L1'] - # self.criterionL1 = self.loss - # self.avg_losses = OrderedDict() - # self.error_cnt = 0 - # for loss_name in self.loss_names: - # self.avg_losses[loss_name] = 0 + # fusion model: input a single image with related instance features + self.fusion_model = MODULES.build(fusion_model) + + self.color_data_opt = color_data_opt + self.which_direction = which_direction + + self.train_cfg = train_cfg + self.test_cfg = test_cfg def forward(self, inputs: torch.Tensor, @@ -149,6 +110,11 @@ def forward(self, elif mode == 'loss': return self.forward_train(inputs, data_samples, **kwargs) + def convert_to_datasample(self, inputs, data_samples): + for data_sample, output in zip(inputs, data_samples): + data_sample.output = output + return inputs + def forward_train(self, inputs, data_samples=None, **kwargs): raise NotImplementedError( 'Instance Colorization has not supported training.') @@ -180,44 +146,58 @@ def forward_tensor(self, inputs, data_samples): Returns: dict: Dict contains output results. """ - print(data_samples) - for dp in data_samples: - print(dp.keys()) - - data = data_samples[0] - full_img = data.full_gray - - if not data.empty_box: - cropped_img = data.cropped_gray - box_info = data.box_info - box_info_2x = data.box_info_2x - box_info_4x = data.box_info_4x - box_info_8x = data.box_info_8x - cropped_data = get_colorization_data(cropped_img, - **self.convert_params) - full_img_data = get_colorization_data(full_img, - **self.convert_params) - self.set_input(cropped_data) - self.set_fusion_input( - full_img_data, - [box_info, box_info_2x, box_info_4x, box_info_8x]) - else: - full_img_data = get_colorization_data(full_img, ab_thresh=0) - self.set_forward_without_box(full_img_data) - self.fake_B_reg = self.generator(self.real_A, self.hint_B, self.mask_B, - self.full_real_A, self.full_hint_B, - self.full_mask_B, self.box_info_list) + # prepare data + + assert len(data_samples), \ + 'fusion model supports only one image due to different numbers '\ + 'of instances of different images' + + cropped_img = data_samples[0].cropped_img.data + box_info_list = [ + data_samples[0].box_info, data_samples[0].box_info_2x, + data_samples[0].box_info_4x, data_samples[0].box_info_8x + ] + print('crop: ', torch.min(cropped_img), torch.max(cropped_img)) + print('full: ', torch.min(inputs), torch.max(inputs)) + cropped_data = get_colorization_data(cropped_img, self.color_data_opt) + full_img_data = get_colorization_data(inputs, self.color_data_opt) + AtoB = self.which_direction == 'AtoB' + + # preprocess input for a single image + full_real_A = full_img_data['A' if AtoB else 'B'] + # full_real_B = full_img_data['B' if AtoB else 'A'] + full_hint_B = full_img_data['hint_B'] + full_mask_B = full_img_data['mask_B'] + # full_mask_B_nc = full_mask_B + self.color_data_opt['mask_cent'] + # full_real_B_enc = encode_ab_ind(full_real_B[:, :, ::4, ::4], + # self.color_data_opt) + + if not data_samples[0].empty_box: + # preprocess instance input + real_A = cropped_data['A' if AtoB else 'B'] + # real_B = cropped_data['B' if AtoB else 'A'] + hint_B = cropped_data['hint_B'] + mask_B = cropped_data['mask_B'] + # mask_B_nc = mask_B + self.color_data_opt['mask_cent'] + # real_B_enc = encode_ab_ind(real_B[:, :, ::4, ::4], + # self.color_data_opt) + + # network forward + _, output, feature_map = self.instance_model( + real_A, hint_B, mask_B) + output = self.fusion_model(full_real_A, full_hint_B, full_mask_B, + feature_map, box_info_list) - out_img = torch.clamp( - lab2rgb( - torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), - self.fake_B_reg.type(torch.cuda.FloatTensor)), - dim=1), **self.lab2rgb_opt), 0.0, 1.0) - - return out_img - - def convert_to_datasample(self, inputs, data_samples): - for data_sample, output in zip(inputs, data_samples): - data_sample.output = output - return inputs + else: + _, output, _ = self.image_model(full_real_A, full_hint_B, + full_mask_B) + + output = [ + full_real_A.type(torch.cuda.FloatTensor), + output.type(torch.cuda.FloatTensor) + ] + output = torch.cat(output, dim=1) + print('output: ', torch.min(output), torch.max(output)) + output = torch.clamp(lab2rgb(output, self.color_data_opt), 0.0, 1.0) + return output diff --git a/mmedit/models/utils/__init__.py b/mmedit/models/utils/__init__.py index 4f6c6e1baa..6d98db512e 100644 --- a/mmedit/models/utils/__init__.py +++ b/mmedit/models/utils/__init__.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .bbox_utils import extract_around_bbox, extract_bbox_patch -from .color_utils import encode_ab_ind, get_colorization_data, lab2rgb from .flow_warp import flow_warp from .model_utils import (default_init_weights, generation_init_weights, get_module_device, get_valid_noise_size, @@ -14,5 +13,5 @@ 'generation_init_weights', 'set_requires_grad', 'extract_bbox_patch', 'extract_around_bbox', 'get_unknown_tensor', 'noise_sample_fn', 'label_sample_fn', 'get_valid_num_batches', 'get_valid_noise_size', - 'get_module_device', 'encode_ab_ind', 'get_colorization_data', 'lab2rgb' + 'get_module_device' ] From e61cc99937b54e14d7be7763dbefffda540c011e Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 27 Oct 2022 15:35:24 +0800 Subject: [PATCH 17/32] fix inference --- ...-colorizatioon_cocostuff-fusion_256x256.py | 0 ...t-colorizatioon_cocostuff-image_256x256.py | 134 ------------------ ...olorizatioon_cocostuff-instance_256x256.py | 94 ------------ .../inst-colorizatioon_cocostuff_256x256.py | 2 +- demo/colorization_demo.py | 2 +- .../inst_colorization/inst_colorization.py | 1 + 6 files changed, 3 insertions(+), 230 deletions(-) delete mode 100644 configs/inst_colorization/inst-colorizatioon_cocostuff-fusion_256x256.py delete mode 100644 configs/inst_colorization/inst-colorizatioon_cocostuff-image_256x256.py delete mode 100644 configs/inst_colorization/inst-colorizatioon_cocostuff-instance_256x256.py diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff-fusion_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff-fusion_256x256.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff-image_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff-image_256x256.py deleted file mode 100644 index 88f7b65206..0000000000 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff-image_256x256.py +++ /dev/null @@ -1,134 +0,0 @@ -_base_ = ['../_base_/default_runtime.py'] - -exp_name = 'Instance-aware_full' -save_dir = './' -work_dir = '..' - -stage = 'full' -model = dict( - type='INSTA', - data_preprocessor=dict( - type='EditDataPreprocessor', - mean=[127.5], - std=[127.5], - ), - generator=dict( - type='InstColorizationGenerator', - stage=stage, - instance_model=dict( - type='SIGGRAPHGenerator', - input_nc=4, - output_nc=2, - norm_type='batch'), - ), - insta_stage=stage, - ngf=64, - output_nc=2, - avg_loss_alpha=.986, - ab_norm=110., - ab_max=110., - ab_quant=10., - l_norm=100., - l_cent=50., - sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], - mask_cent=.5, - which_direction='AtoB', - loss=dict(type='HuberLoss', delta=.01), - pretrained='./checkpoints/pytorch_trained.pth') - -input_shape = (256, 256) - -train_pipeline = [ - dict(type='LoadImageFromFile', key='img'), - dict(type='GenGrayColorPil', stage='full', keys=['rgb_img', 'gray_img']), - dict( - type='Resize', - keys=['rgb_img', 'gray_img'], - scale=input_shape, - keep_ratio=False, - interpolation='nearest'), - dict(type='RescaleToZeroOne', keys=['rgb_img', 'gray_img']), - dict(type='PackEditInputs') -] - -test_pipeline = [ - dict(type='LoadImageFromFile', key='img'), - dict(type='GenMaskRCNNBbox', stage='test_fusion', finesize=256), - dict(type='Resize', keys=['gt'], scale=(256, 256), keep_ratio=False), - dict(type='PackEditInputs'), -] - -dataset_type = 'CocoDataset' -data_root = '/mnt/meng/cocos' -ann_file_path = '/mnt/ruoning/bbox' - -train_dataloader = dict( - batch_size=4, - num_workers=4, - persistent_workers=False, - sampler=dict(shuffle=False), - dataset=dict( - type=dataset_type, - data_root=data_root + '/train2017', - data_prefix=dict(gt='data_large'), - ann_file=f'{ann_file_path}/img_list.txt', - pipeline=train_pipeline, - test_mode=False)) - -test_dataloader = dict( - batch_size=1, - num_workers=1, - persistent_workers=False, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - data_root=data_root + '/train2017', - data_prefix=dict(gt='data_large'), - ann_file=f'{ann_file_path}/train_annotation.json', - pipeline=test_pipeline, - test_mode=False)) - -test_evaluator = [dict(type='PSNR'), dict(type='SSIM')] - -train_cfg = dict( - type='IterBasedTrainLoop', - max_iters=500002, - val_interval=50000, -) - -val_dataloader = test_dataloader -val_evaluator = test_evaluator - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') - -# optimizer -optim_wrapper = dict( - constructor='DefaultOptimWrapperConstructor', - generator=dict( - type='OptimWrapper', - optimizer=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9))), - disc=dict( - type='OptimWrapper', - optimizer=dict(type='Adam', lr=0.0001, betas=(0.0, 0.9)))) - -param_scheduler = dict( - # todo engine中暂时还没有这个 - type='LambdaLR', - by_epoch=False, -) - -vis_backends = [dict(type='LocalVisBackend')] - -visualizer = dict( - type='ConcatImageVisualizer', - vis_backends=vis_backends, - fn_key='gt_path', - img_keys=['gray', 'real', 'fake_reg', 'hint', 'real_ab', 'fake_ab_reg'], - bgr2rgb=False) - -env_cfg = dict( - cudnn_benchmark=False, - mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), - dist_cfg=dict(backend='nccl'), -) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff-instance_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff-instance_256x256.py deleted file mode 100644 index f7712efdf9..0000000000 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff-instance_256x256.py +++ /dev/null @@ -1,94 +0,0 @@ -ab_norm = 110. -model = dict( - type='FusionModel', - stage='instance', - ngf=64, - output_nc=2, - # avg_loss_alpha=.986, - ab_norm=ab_norm, - l_norm=100., - l_cent=50., - sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9], - mask_cent=.5, - init_type='normal', - fusion_weight_path='../checkpoints/coco_finetuned_mask_256_ffs', - which_direction='AtoB', - loss=dict(type='HuberLoss', delta=1. / ab_norm), - instance_model=dict( - type='SIGGRAPHGenerator', - input_nc=4, - output_nc=2, - )) - -train_cfg = dict(disc_step=1) -test_cfg = dict(metrics=['psnr', 'ssim']) -input_shape = (256, 256) - -train_pipeline = [ - dict(type='LoadImageFromFile', key='gt_img'), - dict(type='LoadBboxFromFile', key='instance', stage='instance'), - dict( - type='GenGrayColorPil', stage='instance', keys=['rgb_img', - 'gray_img']), - dict( - type='Resize', - keys=['rgb_img', 'gray_img'], - scale=input_shape, - keep_ratio=False, - interpolation='nearest'), - dict( - type='Collect', - keys=['instance', 'rgb_img', 'gray_img'], - meta_keys=['gt_img_path']), - dict(type='ImageToTensor', keys=['instance', 'rgb_img', 'gray_img']) -] - -dataset_type = 'COCOStuff_Instance_Dataset' -data_root = '/mnt/cache/share_data/zhangwenwei/data/coco/train2017' - -npz_root = '/mnt/cache/yuruoning.vendor/data' - -data = dict( - workers_per_gpu=2, - train_dataloader=dict(samples_per_gpu=1, drop_last=True), - val_dataloader=dict(samples_per_gpu=1), - test_dataloader=dict(samples_per_gpu=1), - train=dict( - type=dataset_type, - ann_file=f'{npz_root}/img_list.txt', - data_prefix=data_root, - npz_prefix=f'{npz_root}/train_bbox/train2017_bbox', - pipeline=train_pipeline, - test_mode=False)) - -optimizers = dict(generator=dict(type='Adam', lr=0.0001, betas=(0.9, 0.999)), ) -lr_config = dict(policy='Fixed', by_epoch=False) - -checkpoint_config = dict(by_epoch=False, interval=10000) - -log_config = dict( - interval=100, - hooks=[ - dict(type='TextLoggerHook', by_epoch=False), - dict(type='TensorboardLoggerHook'), - ]) - -visual_config = dict( - type='VisualizationHook', - output_dir='visual', - interval=100, - bgr2rgb=False, - res_name_list=[ - 'gray', 'real', 'fake_reg', 'hint', 'real_ab', 'fake_ab_reg' - ], -) - -total_iters = 500002 -dist_params = dict(backend='nccl') -load_from = None -resume_from = None -work_dir = '..' -log_level = 'INFO' -workflow = [('train', 10000)] -exp_name = 'Instance-aware' -find_unused_parameters = True diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index 9cc8dd55c0..4230f08192 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -44,7 +44,7 @@ # yapf: disable test_pipeline = [ - dict(type='LoadImageFromFile', key='img'), + dict(type='LoadImageFromFile', key='img', channel_order='rgb'), dict( type='InstanceCrop', config_file='COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml', # noqa diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py index 037093c520..926515f637 100644 --- a/demo/colorization_demo.py +++ b/demo/colorization_demo.py @@ -32,7 +32,7 @@ def main(): model = init_model(args.config, args.checkpoints, device=device) output = colorization_inference(model, args.img_path) - result = tensor2img(output)[..., ::-1] + result = tensor2img(output) mmcv.imwrite(result, args.save_path) if args.imshow: diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index c737e9923f..7462934e4e 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -158,6 +158,7 @@ def forward_tensor(self, inputs, data_samples): data_samples[0].box_info, data_samples[0].box_info_2x, data_samples[0].box_info_4x, data_samples[0].box_info_8x ] + print(data_samples[0]) print('crop: ', torch.min(cropped_img), torch.max(cropped_img)) print('full: ', torch.min(inputs), torch.max(inputs)) cropped_data = get_colorization_data(cropped_img, self.color_data_opt) From 20d14cb3d03f58c42259058c3218f6bf3c5282b3 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 27 Oct 2022 15:54:22 +0800 Subject: [PATCH 18/32] remove undesired files --- mmedit/datasets/__init__.py | 3 +- mmedit/datasets/coco.py | 52 ------------------- mmedit/datasets/transforms/__init__.py | 4 +- .../datasets/transforms/get_gray_color_pil.py | 30 ----------- tests/test_datasets/test_coco.py | 28 ---------- .../test_get_gray_color_pil.py | 14 ----- .../{test_insta.py => test_color_utils.py} | 0 ..._insta_net.py => test_colorization_net.py} | 0 .../{test_util.py => test_fusion_net.py} | 0 .../test_inst_colorization.py | 1 + .../test_weight_layer.py | 1 + 11 files changed, 4 insertions(+), 129 deletions(-) delete mode 100644 mmedit/datasets/coco.py delete mode 100644 mmedit/datasets/transforms/get_gray_color_pil.py delete mode 100644 tests/test_datasets/test_coco.py delete mode 100644 tests/test_datasets/test_transforms/test_get_gray_color_pil.py rename tests/test_models/test_editors/test_inst_colorization/{test_insta.py => test_color_utils.py} (100%) rename tests/test_models/test_editors/test_inst_colorization/{test_insta_net.py => test_colorization_net.py} (100%) rename tests/test_models/test_editors/test_inst_colorization/{test_util.py => test_fusion_net.py} (100%) create mode 100644 tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py create mode 100644 tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py diff --git a/mmedit/datasets/__init__.py b/mmedit/datasets/__init__.py index b8b4ec3289..1ee9b06b1b 100644 --- a/mmedit/datasets/__init__.py +++ b/mmedit/datasets/__init__.py @@ -3,7 +3,6 @@ from .basic_frames_dataset import BasicFramesDataset from .basic_image_dataset import BasicImageDataset from .cifar10_dataset import CIFAR10 -from .coco import CocoDataset from .comp1k_dataset import AdobeComp1kDataset from .grow_scale_image_dataset import GrowScaleImgDataset from .imagenet_dataset import ImageNet @@ -13,5 +12,5 @@ __all__ = [ 'AdobeComp1kDataset', 'BasicImageDataset', 'BasicFramesDataset', 'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset', - 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'CocoDataset' + 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset' ] diff --git a/mmedit/datasets/coco.py b/mmedit/datasets/coco.py deleted file mode 100644 index 0f6f66d7c8..0000000000 --- a/mmedit/datasets/coco.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -from typing import List - -from mmengine.dataset import BaseDataset -from mmengine.fileio import load - -from mmedit.registry import DATASETS - - -@DATASETS.register_module() -class CocoDataset(BaseDataset): - """Dataset for COCO.""" - - METAINFO = { - 'dataset_type': 'colorization_dataset', - 'task_name': 'colorization', - } - - def load_data_list(self) -> List[dict]: - - annotations = load(self.ann_file) - - assert annotations, f'annotation file "{self.ann_file}" is empty.' - - metainfo = annotations['metainfo'] - raw_data_list = annotations['data_list'] - - for k, v in metainfo.items(): - self._metainfo.setdefault(k, v) - - data_list = [] - for raw_data_info in raw_data_list: - data_info = self.parse_data_info(raw_data_info) - if isinstance(data_info, dict): - data_list.append(data_info) - else: - raise TypeError('data_info should be a dict or list of dict, ' - f'but got {type(data_info)}') - - return data_list - - def parse_data_info(self, raw_data_info: dict) -> dict: - """Join data_root to each path in data_info.""" - - data_info = raw_data_info.copy() - for key in raw_data_info: - if 'path' in key: - data_info['gt_img_path'] = osp.join(self.data_root, - data_info[key]) - - return data_info diff --git a/mmedit/datasets/transforms/__init__.py b/mmedit/datasets/transforms/__init__.py index 6d7ff453be..15a2ca0da4 100644 --- a/mmedit/datasets/transforms/__init__.py +++ b/mmedit/datasets/transforms/__init__.py @@ -16,7 +16,6 @@ from .generate_frame_indices import (GenerateFrameIndices, GenerateFrameIndiceswithPadding, GenerateSegmentIndices) -from .get_gray_color_pil import GenGrayColorPil from .get_masked_image import GetMaskedImage from .get_maskrcnn_bbox import InstanceCrop from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadMask, @@ -47,6 +46,5 @@ 'GenerateSoftSeg', 'FormatTrimap', 'TransformTrimap', 'GenerateTrimap', 'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg', 'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile', - 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop', - 'GenGrayColorPil' + 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop' ] diff --git a/mmedit/datasets/transforms/get_gray_color_pil.py b/mmedit/datasets/transforms/get_gray_color_pil.py deleted file mode 100644 index b99aa01cec..0000000000 --- a/mmedit/datasets/transforms/get_gray_color_pil.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import cv2 -import numpy as np -from mmcv.transforms.base import BaseTransform - -from mmedit.registry import TRANSFORMS - - -@TRANSFORMS.register_module() -class GenGrayColorPil(BaseTransform): - - def __init__(self, stage, keys): - self.stage = stage - self.keys = keys - - def transform(self, results): - - if self.stage == 'instance': - rgb_img = results['instance'] - else: - rgb_img = results['img'] - if len(rgb_img.shape) == 2: - rgb_img = np.stack([rgb_img, rgb_img, rgb_img], 2) - gray_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY) - gray_img = np.stack([gray_img, gray_img, gray_img], -1) - - results[self.keys[0]] = rgb_img - results[self.keys[1]] = gray_img - - return results diff --git a/tests/test_datasets/test_coco.py b/tests/test_datasets/test_coco.py deleted file mode 100644 index 7687af0db1..0000000000 --- a/tests/test_datasets/test_coco.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmedit.registry import DATASETS - - -# todo 完成coco的单元测试编写 -class TestCOCOStuff: - DATASET_TYPE = 'CocoDataset' - - ann_file = 'test.json' - data_root = '../..' - - DEFAULT_ARGS = dict( - data_root=data_root + '/train2017', - data_prefix=dict(gt='data_large'), - ann_file=ann_file, - pipeline=[], - test_mode=False) - - def test_load_data_list(self): - dataset_class = DATASETS.get(self.DATASET_TYPE) - dataset = dataset_class(**self.DEFAULT_ARGS) - - assert dataset.mateinfo == { - 'dataset_type': 'colorization_dataset', - 'task_name': 'colorization', - } - - # 对拿到的数据列表和数据进行判断 diff --git a/tests/test_datasets/test_transforms/test_get_gray_color_pil.py b/tests/test_datasets/test_transforms/test_get_gray_color_pil.py deleted file mode 100644 index 04667a98a5..0000000000 --- a/tests/test_datasets/test_transforms/test_get_gray_color_pil.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import cv2 as cv - -from mmedit.datasets.transforms import GenGrayColorPil - - -def test_get_gray_color_pil(): - img = cv.imread('../../data/image/gt/baboon.png') - test_class = GenGrayColorPil(stage='test', keys=['rgb_img', 'gray_img']) - - results = test_class.transform(dict(img=img)) - - assert 'rgb_img' in results.keys() and 'gray_img' in results.keys() - assert results['gray_img'].shape == img.shape diff --git a/tests/test_models/test_editors/test_inst_colorization/test_insta.py b/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py similarity index 100% rename from tests/test_models/test_editors/test_inst_colorization/test_insta.py rename to tests/test_models/test_editors/test_inst_colorization/test_color_utils.py diff --git a/tests/test_models/test_editors/test_inst_colorization/test_insta_net.py b/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py similarity index 100% rename from tests/test_models/test_editors/test_inst_colorization/test_insta_net.py rename to tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py diff --git a/tests/test_models/test_editors/test_inst_colorization/test_util.py b/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py similarity index 100% rename from tests/test_models/test_editors/test_inst_colorization/test_util.py rename to tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py diff --git a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. From d6b0c400165e3e68f9e25d87f068d5b008535c29 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 27 Oct 2022 20:24:55 +0800 Subject: [PATCH 19/32] clear code --- .../models/editors/inst_colorization/inst_colorization.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 7462934e4e..51308166eb 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -149,7 +149,7 @@ def forward_tensor(self, inputs, data_samples): # prepare data - assert len(data_samples), \ + assert len(data_samples) == 1, \ 'fusion model supports only one image due to different numbers '\ 'of instances of different images' @@ -158,9 +158,6 @@ def forward_tensor(self, inputs, data_samples): data_samples[0].box_info, data_samples[0].box_info_2x, data_samples[0].box_info_4x, data_samples[0].box_info_8x ] - print(data_samples[0]) - print('crop: ', torch.min(cropped_img), torch.max(cropped_img)) - print('full: ', torch.min(inputs), torch.max(inputs)) cropped_data = get_colorization_data(cropped_img, self.color_data_opt) full_img_data = get_colorization_data(inputs, self.color_data_opt) AtoB = self.which_direction == 'AtoB' @@ -199,6 +196,5 @@ def forward_tensor(self, inputs, data_samples): output.type(torch.cuda.FloatTensor) ] output = torch.cat(output, dim=1) - print('output: ', torch.min(output), torch.max(output)) output = torch.clamp(lab2rgb(output, self.color_data_opt), 0.0, 1.0) return output From 67631ee27d20b79e9f8388d0863f411cdb8f7dee Mon Sep 17 00:00:00 2001 From: ruoning Date: Sat, 29 Oct 2022 21:39:55 +0800 Subject: [PATCH 20/32] [Doc]: update docstring if instance-aware image colorization --- mmedit/apis/colorization_inference.py | 8 ++ .../datasets/transforms/get_maskrcnn_bbox.py | 34 ++++- .../editors/inst_colorization/color_utils.py | 122 +++++++++++++++--- .../inst_colorization/colorization_net.py | 23 +++- .../editors/inst_colorization/fusion_net.py | 24 +++- .../inst_colorization/inst_colorization.py | 52 ++++++-- .../editors/inst_colorization/weight_layer.py | 39 ++++++ 7 files changed, 257 insertions(+), 45 deletions(-) diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py index c2271dd5b8..f650dee925 100644 --- a/mmedit/apis/colorization_inference.py +++ b/mmedit/apis/colorization_inference.py @@ -6,7 +6,15 @@ def colorization_inference(model, img): + """Inference image with the model. + Args: + model (nn.Module): The loaded model. + img (str): Image file path. + + Returns: + Tensor: The predicted colorization result. + """ device = next(model.parameters()).device # build the data pipeline diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py index 610db6217d..175bd1833d 100644 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ b/mmedit/datasets/transforms/get_maskrcnn_bbox.py @@ -12,11 +12,16 @@ @TRANSFORMS.register_module() class InstanceCrop(BaseTransform): - """## Arguments: + """Use maskrcnn to detect instances on image. - - pred_data_path: Detectron2 predict results - - box_num_upbound: object bounding boxes number. - Default: -1 means use all the instances. + Mask R-CNN is used to detect the instance on the image + pred_bbox is used to segment the instance on the image + + Args: + config_file (str): config file name relative to detectron2's "configs/" + key (str): Unused + box_num_upbound (int):The upper limit on the number of instances + in the figure """ def __init__(self, @@ -36,10 +41,18 @@ def __init__(self, self.final_size = finesize def transform(self, results: dict) -> dict: + """The transform function of InstanceCrop. + + Args: + results (dict): A dict containing the necessary information and + data for Conversion + Returns: + results (dict): A dict containing the processed data + and information. + """ # get consistent box prediction based on L channel full_img = results['img'] - # cv.imwrite('full_img.jpg', full_img) full_img_size = results['ori_img_shape'][:-1][::-1] lab_image = cv.cvtColor(full_img, cv.COLOR_BGR2LAB) l_channel, a_channel, b_channel = cv.split(lab_image) @@ -66,7 +79,6 @@ def transform(self, results: dict) -> dict: for i in index_list: startx, starty, endx, endy = pred_bbox[i] cropped_img = full_img[starty:endy, startx:endx, :] - # cv.imwrite(f"crop_{i}.jpg", cropped_img) cropped_img_list.append(cropped_img) box_info[i] = np.array( get_box_info(pred_bbox[i], full_img_size, self.final_size)) @@ -97,6 +109,16 @@ def transform(self, results: dict) -> dict: def get_box_info(pred_bbox, original_shape, final_size): + """ + + Args: + pred_bbox: The bounding box for the instance + original_shape: Original image shape + final_size: Size of the final output + + Returns: + List: [L_pad, R_pad, T_pad, B_pad, rh, rw] + """ assert len(pred_bbox) == 4 resize_startx = int(pred_bbox[0] / original_shape[0] * final_size) resize_starty = int(pred_bbox[1] / original_shape[1] * final_size) diff --git a/mmedit/models/editors/inst_colorization/color_utils.py b/mmedit/models/editors/inst_colorization/color_utils.py index 1ea8464715..78b536755f 100644 --- a/mmedit/models/editors/inst_colorization/color_utils.py +++ b/mmedit/models/editors/inst_colorization/color_utils.py @@ -4,6 +4,14 @@ def xyz2rgb(xyz): + """Conversion images from lab to xyz. + + Args: + xyz (tensor): The images to be conversion + + Returns: + out (tensor): The converted image + """ r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] \ - 0.49853633 * xyz[:, 2, :, :] g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] \ @@ -25,6 +33,14 @@ def xyz2rgb(xyz): def lab2xyz(lab): + """Conversion images from lab to xyz. + + Args: + lab (tensor): The images to be conversion + + Returns: + out (tensor): The converted image + """ y_int = (lab[:, 0, :, :] + 16.) / 116. x_int = (lab[:, 1, :, :] / 500.) + y_int z_int = y_int - (lab[:, 2, :, :] / 200.) @@ -50,6 +66,16 @@ def lab2xyz(lab): def lab2rgb(lab_rs, color_data_opt): + """Conversion images from lab to rgb. + + Args: + lab_rs (tensor): The images to be conversion + color_data_opt (dict): Config for image colorspace transformation. + Include: l_norm, ab_norm, l_cent + + Returns: + out (tensor): The converted image + """ L = lab_rs[:, [0], :, :] * color_data_opt['l_norm'] + color_data_opt['l_cent'] AB = lab_rs[:, 1:, :, :] * color_data_opt['ab_norm'] @@ -59,11 +85,15 @@ def lab2rgb(lab_rs, color_data_opt): def encode_ab_ind(data_ab, color_data_opt): - # Encode ab value into an index - # INPUTS - # data_ab Nx2xHxW \in [-1,1] - # OUTPUTS - # data_q Nx1xHxW \in [0,Q) + """Encode ab value into an index. + + Args: + data_ab: Nx2xHxW from [-1,1] + color_data_opt: Config for image colorspace transformation. + ab_max, ab_quant, ab_norm, ab_quant + Returns: + Nx1xHxW from [0,Q) + """ A = 2 * color_data_opt['ab_max'] / color_data_opt['ab_quant'] + 1 data_ab_rs = torch.round( (data_ab * color_data_opt['ab_norm'] + color_data_opt['ab_max']) / @@ -72,12 +102,19 @@ def encode_ab_ind(data_ab, color_data_opt): return data_q -# Color conversion code -def rgb2xyz(rgb): # rgb from [0,1] - # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], - # [0.212671, 0.715160, 0.072169], - # [0.019334, 0.119193, 0.950227]]) +def rgb2xyz(rgb): + """Conversion images from rgb to xyz + rgb from [0,1] + xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227]]) + Args: + rgb (Tensor): image in rgb colorspace + + Returns: + xyz (Tensor): image in xyz colorspace + """ mask = (rgb > .04045).type(torch.FloatTensor) if (rgb.is_cuda): mask = mask.cuda() @@ -97,7 +134,16 @@ def rgb2xyz(rgb): # rgb from [0,1] def xyz2lab(xyz): - # 0.95047, 1., 1.08883 # white + """Conversion images from xyz to lab + xyz from [0,1] + factors: 0.95047, 1., 1.08883 + + Args: + xyz (Tensor): image in xyz colorspace + + Returns: + out (Tensor): Image in lab colorspace + """ sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] if (xyz.is_cuda): sc = sc.cuda() @@ -121,6 +167,16 @@ def xyz2lab(xyz): def rgb2lab(rgb, color_opt): + """Conversion images from rgb to lab. + + Args: + data_raw (tensor): The images to be conversion + color_opt (dict): Config for image colorspace transformation. + Include: ab_thresh, ab_norm, sample_PS, mask_cent + + Returns: + out (tensor): The converted image + """ lab = xyz2lab(rgb2xyz(rgb)) l_rs = (lab[:, [0], :, :] - color_opt['l_cent']) / color_opt['l_norm'] ab_rs = lab[:, 1:, :, :] / color_opt['ab_norm'] @@ -129,6 +185,16 @@ def rgb2lab(rgb, color_opt): def get_colorization_data(data_raw, color_opt, num_points=None): + """Conversion images from rgb to lab. + + Args: + data_raw (tensor): The images to be conversion + color_opt (dict): Config for image colorspace transformation. + Include: ab_thresh, ab_norm, sample_PS, mask_cent + + Returns: + results (dict): Output in add_color_patches_rand_gt + """ data = {} data_lab = rgb2lab(data_raw, color_opt) data['A'] = data_lab[:, [ @@ -159,14 +225,30 @@ def add_color_patches_rand_gt(data, num_points=None, use_avg=True, samp='normal'): - # Add random color points sampled from ground truth based on: - # Number of points - # - if num_points is 0, then sample from geometric distribution, - # drawn from probability p - # - if num_points > 0, then sample that number of points - # Location of points - # - if samp is 'normal', draw from N(0.5, 0.25) of image - # - otherwise, draw from U[0, 1] of image + """Add random color points sampled from ground truth based on: Number of + points. + + - if num_points is 0, then sample from geometric distribution, + drawn from probability p + - if num_points > 0, then sample that number of points + Location of points + - if samp is 'normal', draw from N(0.5, 0.25) of image + - otherwise, draw from U[0, 1] of image + + Args: + data (tensor): The images to be conversion + color_opt (dict): Config for image colorspace transformation + Include: ab_thresh, ab_norm, sample_PS, mask_cent + p (float): Sampling geometric distribution, 1.0 means no hints + num_points (int): Certain number of points + use_avg (bool): Whether to use the mean when add color point + Default: True. + samp (str): Geometric distribution or uniform distribution when + sample location. Default: normal. + + Returns: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + """ N, C, H, W = data['B'].shape data['hint_B'] = torch.zeros_like(data['B']) @@ -188,7 +270,6 @@ def add_color_patches_rand_gt(data, # patch size P = np.random.choice(color_opt['sample_PS']) - # sample location: geometric distribution if samp == 'normal': h = int( @@ -205,7 +286,6 @@ def add_color_patches_rand_gt(data, # add color point if use_avg: - # embed() data['hint_B'][nn, :, h:h + P, w:w + P] = torch.mean( torch.mean( data['B'][nn, :, h:h + P, w:w + P], diff --git a/mmedit/models/editors/inst_colorization/colorization_net.py b/mmedit/models/editors/inst_colorization/colorization_net.py index f09db034c9..85d2d03039 100644 --- a/mmedit/models/editors/inst_colorization/colorization_net.py +++ b/mmedit/models/editors/inst_colorization/colorization_net.py @@ -19,11 +19,12 @@ class ColorizationNet(BaseModule): 'InstColorization/blob/master/models/networks.py#L108' Args: - input_nc: - output_nc: - norm_type: - use_tanh: - classification: + input_nc (int): input image channels + output_nc (int): output image channels + norm_type (str): instance normalization or batch normalization + use_tanh (bool): Whether to use nn.Tanh() Default: True. + classification (bool): backprop trunk using classification, + otherwise use regression. Default: True """ def __init__(self, @@ -255,6 +256,18 @@ def __init__(self, self.softmax = nn.Softmax(dim=1) def forward(self, input_A, input_B, mask_B): + """Forward function. + + Args: + input_A (tensor): Channel of the image in lab color space + input_B (tensor): Color patch + mask_B (tensor): Color patch mask + + Returns: + out_class (tensor): Classification output + out_reg (tensor): Regression output + feature_map (dict): The full-image feature + """ conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) diff --git a/mmedit/models/editors/inst_colorization/fusion_net.py b/mmedit/models/editors/inst_colorization/fusion_net.py index 0ed6c5f733..6996cfe448 100644 --- a/mmedit/models/editors/inst_colorization/fusion_net.py +++ b/mmedit/models/editors/inst_colorization/fusion_net.py @@ -18,11 +18,12 @@ class FusionNet(BaseModule): FusionNet: the full image model with weight layer for fusion. Args: - input_nc: - output_nc: - norm_type: - use_tanh: - classification: + input_nc (int): input image channels + output_nc (int): output image channels + norm_type (str): instance normalization or batch normalization + use_tanh (bool): Whether to use nn.Tanh() Default: True. + classification (bool): backprop trunk using classification, + otherwise use regression. Default: True """ def __init__(self, @@ -283,6 +284,19 @@ def __init__(self, def forward(self, input_A, input_B, mask_B, instance_feature, box_info_list): + """Forward function. + + Args: + input_A (tensor): Channel of the image in lab color space + input_B (tensor): Color patch + mask_B (tensor): Color patch mask + instance_feature (dict): A bunch of instance features + box_info_list (list): Bounding box information corresponding + to the instance + + Returns: + out_reg (tensor): Regression output + """ conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) conv1_2 = self.weight_layer(instance_feature['conv1_2'], conv1_2, box_info_list[0]) diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 51308166eb..29734b1689 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -13,6 +13,28 @@ @MODULES.register_module() class InstColorization(BaseModel): + """Colorization InstColorization method. + + This Colorization is implemented according to the paper: + Instance-aware Image Colorization, CVPR 2020 + + Adapted from 'https://github.com/ericsujw/InstColorization.git' + 'InstColorization/models/train_model' + Copyright (c) 2020, Su, under MIT License. + + Args: + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + image_model (dict): Config for single image model + instance_model (dict): Config for instance model + fusion_model (dict): Config for fusion model + color_data_opt (dict): Option for colorspace conversion + which_direction (str): AtoB or BtoA + loss (dict): Config for loss. + init_cfg (str): Initialization config dict. Default: None. + train_cfg (dict): Config for training. Default: None. + test_cfg (dict): Config for testing. Default: None. + """ def __init__(self, data_preprocessor: Union[dict, Config], @@ -116,15 +138,37 @@ def convert_to_datasample(self, inputs, data_samples): return inputs def forward_train(self, inputs, data_samples=None, **kwargs): + """Forward function for training.""" raise NotImplementedError( 'Instance Colorization has not supported training.') def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: + """Train step function. + + Args: + data (List[dict]): Batch of data as input. + optim_wrapper (dict[torch.optim.Optimizer]): Dict with optimizers + for generator and discriminator (if have). + Returns: + dict: Dict with loss, information for logger, the number of + samples and results for visualization. + """ raise NotImplementedError( 'Instance Colorization has not supported training.') def forward_inference(self, inputs, data_samples=None, **kwargs): + """Forward inference. Returns predictions of validation, testing. + + Args: + inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + + Returns: + List[EditDataSample]: predictions. + """ feats = self.forward_tensor(inputs, data_samples, **kwargs) predictions = [] for idx in range(feats.shape[0]): @@ -164,22 +208,14 @@ def forward_tensor(self, inputs, data_samples): # preprocess input for a single image full_real_A = full_img_data['A' if AtoB else 'B'] - # full_real_B = full_img_data['B' if AtoB else 'A'] full_hint_B = full_img_data['hint_B'] full_mask_B = full_img_data['mask_B'] - # full_mask_B_nc = full_mask_B + self.color_data_opt['mask_cent'] - # full_real_B_enc = encode_ab_ind(full_real_B[:, :, ::4, ::4], - # self.color_data_opt) if not data_samples[0].empty_box: # preprocess instance input real_A = cropped_data['A' if AtoB else 'B'] - # real_B = cropped_data['B' if AtoB else 'A'] hint_B = cropped_data['hint_B'] mask_B = cropped_data['mask_B'] - # mask_B_nc = mask_B + self.color_data_opt['mask_cent'] - # real_B_enc = encode_ab_ind(real_B[:, :, ::4, ::4], - # self.color_data_opt) # network forward _, output, feature_map = self.instance_model( diff --git a/mmedit/models/editors/inst_colorization/weight_layer.py b/mmedit/models/editors/inst_colorization/weight_layer.py index ba4d050c4f..76f3c88243 100644 --- a/mmedit/models/editors/inst_colorization/weight_layer.py +++ b/mmedit/models/editors/inst_colorization/weight_layer.py @@ -9,6 +9,15 @@ def get_norm_layer(norm_type='instance'): + """Gets the normalization layer. + + Args: + norm_type (str): Type of the normalization layer. + + Returns: + norm_layer (BatchNorm2d or InstanceNorm2d or None): + normalization layer. Default: instance + """ if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': @@ -23,6 +32,15 @@ def get_norm_layer(norm_type='instance'): @MODULES.register_module() class WeightLayer(BaseModule): + """Weight layer of the fusion_net. A small neural network with three + convolutional layers to predict full-image weight map and perinstance + weight map. + + Args: + input_ch (int): Number of channels in the input image. + inner_ch (int): Number of channels produced by the convolution. + Default: True + """ def __init__(self, input_ch, inner_ch=16): super(WeightLayer, self).__init__() @@ -47,6 +65,16 @@ def __init__(self, input_ch, inner_ch=16): self.normalize = nn.Softmax(1) def resize_and_pad(self, feauture_maps, info_array): + """Resize the instance feature as well as the weight map to match the + size of full-image and do zero padding on both of them. + + Args: + feauture_maps (tensor): Feature map + info_array (tensor): The bounding box + + Returns: + feauture_maps (tensor): Feature maps after resize and padding + """ feauture_maps = torch.nn.functional.interpolate( feauture_maps, size=(info_array[5], info_array[4]), @@ -58,6 +86,17 @@ def resize_and_pad(self, feauture_maps, info_array): return feauture_maps def forward(self, instance_feature, bg_feature, box_info): + """Forward function. + + Args: + instance_feature (tensor): Instance feature obtained from the + colorization_net + bg_feature (tensor): full-image feature + box_info (tensor): The bounding box corresponding to the instance + + Returns: + out (tensor): Fused feature + """ mask_list = [] featur_map_list = [] mask_sum_for_pred = torch.zeros_like(bg_feature)[:1, :1] From 0201cb6daea019506012fedc01494b30a52ce011 Mon Sep 17 00:00:00 2001 From: ruoning Date: Mon, 31 Oct 2022 05:26:36 +0800 Subject: [PATCH 21/32] [Enhancement]: add unit test of instance_aware_colorization --- .../test_transforms/test_get_maskrcnn_bbox.py | 8 +- .../test_color_utils.py | 152 ++++++++++++++++++ .../test_colorization_net.py | 41 +++++ .../test_inst_colorization/test_fusion_net.py | 80 +++++++++ .../test_inst_colorization.py | 110 +++++++++++++ .../test_weight_layer.py | 23 +++ 6 files changed, 410 insertions(+), 4 deletions(-) diff --git a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py index a89ae3bd8a..e7c35951aa 100644 --- a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py +++ b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py @@ -3,7 +3,7 @@ import cv2 as cv -from mmedit.datasets.transforms import GenMaskRCNNBbox +from mmedit.datasets.transforms import InstanceCrop from mmedit.utils import tensor2img @@ -12,7 +12,7 @@ class TestMaskRCNNBbox: DEFAULT_ARGS = dict(key='img', finesize=256) def test_maskrcnn_bbox(self): - detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') + detectetor = InstanceCrop(**self.DEFAULT_ARGS, stage='test') data_root = '..' img_path = 'data/image/gray/test.jpg' img = cv.imread(os.path.join(data_root, img_path)) @@ -38,7 +38,7 @@ def test_maskrcnn_bbox(self): assert tensor2img(results['rgb_img']).shape == (3, 256, 256) def test_gen_maskrcnn_from_pred(self): - detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') + detectetor = InstanceCrop(**self.DEFAULT_ARGS, stage='test') data_root = '..' img_path = 'data/image/gray/test.jpg' img = cv.imread(os.path.join(data_root, img_path)) @@ -50,7 +50,7 @@ def test_gen_maskrcnn_from_pred(self): assert pred_bbox.shape[-1] == 4 def test_get_box_info(self): - detectetor = GenMaskRCNNBbox(**self.DEFAULT_ARGS, stage='test') + detectetor = InstanceCrop(**self.DEFAULT_ARGS, stage='test') data_root = '..' img_path = 'data/image/gray/test.jpg' img = cv.imread(os.path.join(data_root, img_path)) diff --git a/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py b/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py index ef101fec61..ff113d0917 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py @@ -1 +1,153 @@ # Copyright (c) OpenMMLab. All rights reserved. + +import torch + +from mmedit.models.editors.inst_colorization import color_utils + + +class TestColorUtils: + color_data_opt = dict( + ab_thresh=0, + p=1.0, + sample_PS=[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ], + ab_norm=110, + ab_max=110., + ab_quant=10., + l_norm=100., + l_cent=50., + mask_cent=0.5) + + def test_xyz2lab(self): + xyz = torch.rand(1, 3, 8, 8) + lab = color_utils.xyz2lab(xyz) + + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + xyz_scale = xyz / sc + mask = (xyz_scale > .008856).type(torch.FloatTensor) + + xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale + + 16. / 116.) * (1 - mask) + L = 116. * xyz_int[:, 1, :, :] - 16. + a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :]) + b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :]) + + assert lab.shape == (1, 3, 8, 8) + assert lab.equal( + torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), + dim=1)) + + def test_rgb2xyz(self): + rgb = torch.rand(1, 3, 8, 8) + xyz = color_utils.rgb2xyz(rgb) + + mask = (rgb > .04045).type(torch.FloatTensor) + rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask) + + x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] \ + + .180423 * rgb[:, 2, :, :] + y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] \ + + .072169 * rgb[:, 2, :, :] + z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] \ + + .950227 * rgb[:, 2, :, :] + + assert xyz.shape == (1, 3, 8, 8) + assert xyz.equal( + torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), + dim=1)) + + def test_rgb2lab(self): + rgb = torch.rand(1, 3, 8, 8) + lab = color_utils.rgb2lab(rgb, self.color_data_opt) + _lab = color_utils.xyz2lab(color_utils.rgb2xyz(rgb)) + + l_rs = (_lab[:, [0], :, :] - + self.color_data_opt['l_cent']) / self.color_data_opt['l_norm'] + ab_rs = _lab[:, 1:, :, :] / self.color_data_opt['ab_norm'] + + assert lab.shape == (1, 3, 8, 8) + assert lab.equal(torch.cat((l_rs, ab_rs), dim=1)) + + def test_lab2rgb(self): + lab = torch.rand(1, 3, 8, 8) + rgb = color_utils.lab2rgb(lab, self.color_data_opt) + + L = lab[:, [0], :, :] * self.color_data_opt[ + 'l_norm'] + self.color_data_opt['l_cent'] + AB = lab[:, 1:, :, :] * self.color_data_opt['ab_norm'] + + lab = torch.cat((L, AB), dim=1) + + assert rgb.shape == (1, 3, 8, 8) + assert rgb.equal(color_utils.xyz2rgb(color_utils.lab2xyz(lab))) + + def test_lab2xyz(self): + lab = torch.rand(1, 3, 8, 8) + xyz = color_utils.lab2xyz(lab) + y_int = (lab[:, 0, :, :] + 16.) / 116. + x_int = (lab[:, 1, :, :] / 500.) + y_int + z_int = y_int - (lab[:, 2, :, :] / 200.) + z_int = torch.max(torch.Tensor((0, )), z_int) + + out = torch.cat( + (x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), + dim=1) + mask = (out > .2068966).type(torch.FloatTensor) + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask) + target = sc * out + assert xyz.shape == (1, 3, 8, 8) + assert xyz.equal(target) + + def test_xyz2rgb(self): + xyz = torch.rand(1, 3, 8, 8) + + rgb = color_utils.xyz2rgb(xyz) + + r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] \ + - 0.49853633 * xyz[:, 2, :, :] + g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] \ + + .04155593 * xyz[:, 2, :, :] + b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] \ + + 1.05731107 * xyz[:, 2, :, :] + + _rgb = torch.cat( + (r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), dim=1) + _rgb = torch.max(_rgb, torch.zeros_like(_rgb)) + + mask = (_rgb > .0031308).type(torch.FloatTensor) + + assert rgb.shape == (1, 3, 8, 8) and mask.shape == (1, 3, 8, 8) + assert rgb.equal((1.055 * (_rgb**(1. / 2.4)) - 0.055) * mask + + 12.92 * _rgb * (1 - mask)) + + def test_get_colorization_data(self): + data_raw = torch.rand(1, 3, 8, 8) + + res = color_utils.get_colorization_data(data_raw, self.color_data_opt) + + assert isinstance(res, dict) + assert 'A' in res.keys() and 'B' in res.keys() \ + and 'hint_B' in res.keys() and 'mask_B' in res.keys() + assert res['A'].shape == res['mask_B'].shape == (1, 1, 8, 8) + assert res['hint_B'].shape == res['B'].shape == (1, 2, 8, 8) + + def test_encode_ab_ind(self): + data_ab = torch.rand(1, 2, 8, 8) + data_q = color_utils.encode_ab_ind(data_ab, self.color_data_opt) + A = 2 * 110. / 10. + 1 + + data_ab_rs = torch.round((data_ab * 110 + 110.) / 10.) + + assert data_q.shape == (1, 1, 8, 8) + assert data_q.equal(data_ab_rs[:, [0], :, :] * A + + data_ab_rs[:, [1], :, :]) diff --git a/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py b/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py index ef101fec61..c6d4454cab 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py @@ -1 +1,42 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmedit.registry import MODULES + + +def test_colorization_net(): + + model_cfg = dict( + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch') + + # build model + model = MODULES.build(model_cfg) + + # test attributes + assert model.__class__.__name__ == 'ColorizationNet' + + # prepare data + input_A = torch.rand(1, 1, 256, 256) + input_B = torch.rand(1, 2, 256, 256) + mask_B = torch.rand(1, 1, 256, 256) + + target_shape = (1, 2, 256, 256) + + # test on cpu + (out_class, out_reg, feature_map) = model(input_A, input_B, mask_B) + assert isinstance(feature_map, dict) + assert feature_map['conv1_2'].shape == (1, 64, 256, 256) \ + and feature_map['out_reg'].shape == target_shape + + # test on gpu + if torch.cuda.is_available(): + model = model.cuda() + input_A = input_A.cuda() + input_B = input_B.cuda() + mask_B = mask_B.cuda() + (out_class, out_reg, feature_map) = \ + model(input_A, input_B, mask_B) + + assert isinstance(feature_map, dict) + for item in feature_map.keys(): + assert torch.is_tensor(feature_map[item]) diff --git a/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py b/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py index ef101fec61..929c9c8eb1 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py @@ -1 +1,81 @@ # Copyright (c) OpenMMLab. All rights reserved. + +import torch + +from mmedit.registry import MODULES + + +def test_fusion_net(): + + model_cfg = dict( + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch') + + # build model + model = MODULES.build(model_cfg) + + # test attributes + assert model.__class__.__name__ == 'FusionNet' + + # prepare data + input_A = torch.rand(1, 1, 256, 256) + input_B = torch.rand(1, 2, 256, 256) + mask_B = torch.rand(1, 1, 256, 256) + + instance_feature = dict( + conv1_2=torch.rand(1, 64, 256, 256), + conv2_2=torch.rand(1, 128, 128, 128), + conv3_3=torch.rand(1, 256, 64, 64), + conv4_3=torch.rand(1, 512, 32, 32), + conv5_3=torch.rand(1, 512, 32, 32), + conv6_3=torch.rand(1, 512, 32, 32), + conv7_3=torch.rand(1, 512, 32, 32), + conv8_up=torch.rand(1, 256, 64, 64), + conv8_3=torch.rand(1, 256, 64, 64), + conv9_up=torch.rand(1, 128, 128, 128), + conv9_3=torch.rand(1, 128, 128, 128), + conv10_up=torch.rand(1, 128, 256, 256), + conv10_2=torch.rand(1, 128, 256, 256), + ) + + target_shape = (1, 2, 256, 256) + + box_info_box = [ + torch.tensor([[175, 29, 96, 54, 52, 106], [14, 191, 84, 61, 51, 111], + [117, 64, 115, 46, 75, 95], [41, 165, 121, 47, 50, 88], + [46, 136, 94, 45, 74, 117], [79, 124, 62, 115, 53, 79], + [156, 64, 77, 138, 36, 41], [200, 48, 114, 131, 8, 11], + [115, 78, 92, 81, 63, 83]]), + torch.tensor([[87, 15, 48, 27, 26, 53], [7, 96, 42, 31, 25, 55], + [58, 32, 57, 23, 38, 48], [20, 83, 60, 24, 25, 44], + [23, 68, 47, 23, 37, 58], [39, 62, 31, 58, 27, 39], + [78, 32, 38, 69, 18, 21], [100, 24, 57, 66, 4, 5], + [57, 39, 46, 41, 32, 41]]), + torch.tensor([[43, 8, 24, 14, 13, 26], [3, 48, 21, 16, 13, 27], + [29, 16, 28, 12, 19, 24], [10, 42, 30, 12, 12, 22], + [11, 34, 23, 12, 19, 29], [19, 31, 15, 29, 14, 20], + [39, 16, 19, 35, 9, 10], [50, 12, 28, 33, 2, 3], + [28, 20, 23, 21, 16, 20]]), + torch.tensor([[21, 4, 12, 7, 7, 13], [1, 24, 10, 8, 7, 14], + [14, 8, 14, 6, 10, 12], [5, 21, 15, 6, 6, 11], + [5, 17, 11, 6, 10, 15], [9, 16, 7, 15, 7, 10], + [19, 8, 9, 18, 5, 5], [25, 6, 14, 17, 1, 1], + [14, 10, 11, 11, 8, 10]]) + ] + + # test on cpu + out = model(input_A, input_B, mask_B, instance_feature, box_info_box) + assert torch.is_tensor(out) + assert out.shape == target_shape + + # test on gpu + if torch.cuda.is_available(): + model = model.cuda() + input_A = input_A.cuda() + input_B = input_B.cuda() + mask_B = mask_B.cuda() + for item in instance_feature.keys(): + instance_feature[item] = instance_feature[item].cuda() + box_info_box = [i.cuda() for i in box_info_box] + output = model(input_A, input_B, mask_B, instance_feature, + box_info_box) + assert torch.is_tensor(output) diff --git a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py index ef101fec61..1b77725604 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py @@ -1 +1,111 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmedit.registry import BACKBONES +from mmedit.structures import EditDataSample, PixelData +from mmedit.utils import register_all_modules + + +class TestInstColorization: + + def test_inst_colorization(self): + register_all_modules() + model_cfg = dict( + type='InstColorization', + data_preprocessor=dict( + type='EditDataPreprocessor', + mean=[127.5], + std=[127.5], + ), + image_model=dict( + type='ColorizationNet', + input_nc=4, + output_nc=2, + norm_type='batch'), + instance_model=dict( + type='ColorizationNet', + input_nc=4, + output_nc=2, + norm_type='batch'), + fusion_model=dict( + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), + color_data_opt=dict( + ab_thresh=0, + p=1.0, + sample_PS=[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ], + ab_norm=110, + ab_max=110., + ab_quant=10., + l_norm=100., + l_cent=50., + mask_cent=0.5), + which_direction='AtoB', + loss=dict(type='HuberLoss', delta=.01)) + + model = BACKBONES.build(model_cfg) + + # test attributes + assert model.__class__.__name__ == 'InstColorization' + + # prepare data + inputs = torch.rand(1, 3, 256, 256) + target_shape = (1, 3, 256, 256) + + data_sample = EditDataSample(gt_img=PixelData(data=inputs)) + metainfo = dict( + cropped_img=PixelData(data=torch.rand(9, 3, 256, 256)), + box_info=torch.tensor([[175, 29, 96, 54, 52, 106], + [14, 191, 84, 61, 51, 111], + [117, 64, 115, 46, 75, 95], + [41, 165, 121, 47, 50, 88], + [46, 136, 94, 45, 74, 117], + [79, 124, 62, 115, 53, 79], + [156, 64, 77, 138, 36, 41], + [200, 48, 114, 131, 8, 11], + [115, 78, 92, 81, 63, 83]]), + box_info_2x=torch.tensor([[87, 15, 48, 27, 26, 53], + [7, 96, 42, 31, 25, 55], + [58, 32, 57, 23, 38, 48], + [20, 83, 60, 24, 25, 44], + [23, 68, 47, 23, 37, 58], + [39, 62, 31, 58, 27, 39], + [78, 32, 38, 69, 18, 21], + [100, 24, 57, 66, 4, 5], + [57, 39, 46, 41, 32, 41]]), + box_info_4x=torch.tensor([[43, 8, 24, 14, 13, 26], + [3, 48, 21, 16, 13, 27], + [29, 16, 28, 12, 19, 24], + [10, 42, 30, 12, 12, 22], + [11, 34, 23, 12, 19, 29], + [19, 31, 15, 29, 14, 20], + [39, 16, 19, 35, 9, 10], + [50, 12, 28, 33, 2, 3], + [28, 20, 23, 21, 16, 20]]), + box_info_8x=torch.tensor([[21, 4, 12, 7, 7, 13], + [1, 24, 10, 8, 7, 14], + [14, 8, 14, 6, 10, 12], + [5, 21, 15, 6, 6, 11], + [5, 17, 11, 6, 10, 15], + [9, 16, 7, 15, 7, 10], + [19, 8, 9, 18, 5, 5], + [25, 6, 14, 17, 1, 1], + [14, 10, 11, 11, 8, 10]]), + empty_box=False) + data_sample.set_metainfo(metainfo=metainfo) + + data = dict(inputs=inputs, data_samples=[data_sample]) + + res = model(mode='tensor', **data) + + assert torch.is_tensor(res) + assert res.shape == target_shape diff --git a/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py index ef101fec61..e009af981f 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py @@ -1 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmedit.models.editors.inst_colorization.weight_layer import WeightLayer + + +def test_weight_layer(): + + weight_layer = WeightLayer(64) + + instance_feature_conv1_2 = torch.rand(1, 64, 256, 256) + conv1_2 = torch.rand(1, 64, 256, 256) + box_info = torch.tensor([[175, 29, 96, 54, 52, 106], + [14, 191, 84, 61, 51, 111], + [117, 64, 115, 46, 75, 95], + [41, 165, 121, 47, 50, 88], + [46, 136, 94, 45, 74, 117], + [79, 124, 62, 115, 53, 79], + [156, 64, 77, 138, 36, 41], + [200, 48, 114, 131, 8, 11], + [115, 78, 92, 81, 63, 83]]) + conv1_2 = weight_layer(instance_feature_conv1_2, conv1_2, box_info) + + assert conv1_2.shape == instance_feature_conv1_2.shape From 75804335f8f9aaff2c1d169f39615d71ba4d27ae Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Tue, 1 Nov 2022 21:08:15 +0800 Subject: [PATCH 22/32] use mmdet --- .../inst-colorizatioon_cocostuff_256x256.py | 2 +- mmedit/datasets/transforms/__init__.py | 6 +- mmedit/datasets/transforms/crop.py | 99 +++++++++++- .../datasets/transforms/get_maskrcnn_bbox.py | 145 ------------------ mmedit/utils/__init__.py | 4 +- mmedit/utils/img_utils.py | 37 +++++ 6 files changed, 141 insertions(+), 152 deletions(-) delete mode 100644 mmedit/datasets/transforms/get_maskrcnn_bbox.py diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py index 4230f08192..3cbaf2ad04 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py @@ -47,7 +47,7 @@ dict(type='LoadImageFromFile', key='img', channel_order='rgb'), dict( type='InstanceCrop', - config_file='COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml', # noqa + config_file='mmdet::/mask_rcnn_X_101_32x8d_FPN_3x.yaml', # noqa finesize=256), dict( type='Resize', diff --git a/mmedit/datasets/transforms/__init__.py b/mmedit/datasets/transforms/__init__.py index 15a2ca0da4..39543085f1 100644 --- a/mmedit/datasets/transforms/__init__.py +++ b/mmedit/datasets/transforms/__init__.py @@ -6,8 +6,9 @@ from .aug_shape import (Flip, NumpyPad, RandomRotation, RandomTransposeHW, Resize) from .crop import (CenterCropLongEdge, Crop, CropAroundCenter, CropAroundFg, - CropAroundUnknown, CropLike, FixedCrop, ModCrop, - PairedRandomCrop, RandomCropLongEdge, RandomResizedCrop) + CropAroundUnknown, CropLike, FixedCrop, InstanceCrop, + ModCrop, PairedRandomCrop, RandomCropLongEdge, + RandomResizedCrop) from .fgbg import (CompositeFg, MergeFgAndBg, PerturbBg, RandomJitter, RandomLoadResizeBg) from .formatting import PackEditInputs, ToTensor @@ -17,7 +18,6 @@ GenerateFrameIndiceswithPadding, GenerateSegmentIndices) from .get_masked_image import GetMaskedImage -from .get_maskrcnn_bbox import InstanceCrop from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadMask, LoadPairedImageFromFile) from .matlab_like_resize import MATLABLikeResize diff --git a/mmedit/datasets/transforms/crop.py b/mmedit/datasets/transforms/crop.py index 906955254f..d15043a134 100644 --- a/mmedit/datasets/transforms/crop.py +++ b/mmedit/datasets/transforms/crop.py @@ -2,14 +2,17 @@ import math import random +import cv2 as cv import mmcv import numpy as np +import torch from mmcv.transforms import BaseTransform +from mmengine.hub import get_model from mmengine.utils import is_list_of, is_tuple_of from torch.nn.modules.utils import _pair from mmedit.registry import TRANSFORMS -from mmedit.utils import random_choose_unknown +from mmedit.utils import get_box_info, random_choose_unknown @TRANSFORMS.register_module() @@ -916,3 +919,97 @@ def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(keys={self.keys})') return repr_str + + +@TRANSFORMS.register_module() +class InstanceCrop(BaseTransform): + """Use maskrcnn to detect instances on image. + + Mask R-CNN is used to detect the instance on the image + pred_bbox is used to segment the instance on the image + + Args: + config_file (str): config file name relative to detectron2's "configs/" + key (str): Unused + box_num_upbound (int):The upper limit on the number of instances + in the figure + """ + + def __init__(self, + config_file, + key='img', + box_num_upbound=-1, + finesize=256): + # detector + self.predictor = get_model(config_file, pretrained=True) + + self.key = key + self.box_num_upbound = box_num_upbound + self.final_size = finesize + + def transform(self, results: dict) -> dict: + """The transform function of InstanceCrop. + + Args: + results (dict): A dict containing the necessary information and + data for Conversion + + Returns: + results (dict): A dict containing the processed data + and information. + """ + # get consistent box prediction based on L channel + full_img = results['img'] + full_img_size = results['ori_img_shape'][:-1][::-1] + lab_image = cv.cvtColor(full_img, cv.COLOR_BGR2LAB) + l_channel, a_channel, b_channel = cv.split(lab_image) + l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) + outputs = self.predictor(l_stack) + + # get the most confident boxes + pred_bbox = outputs['instances'].pred_boxes.to( + torch.device('cpu')).tensor.numpy() + pred_scores = outputs['instances'].scores.cpu().data.numpy() + pred_bbox = pred_bbox.astype(np.int32) + if self.box_num_upbound > 0 and pred_bbox.shape[ + 0] > self.box_num_upbound: + index_mask = np.argsort(pred_scores, axis=0) + index_mask = index_mask[pred_scores.shape[0] - + self.box_num_upbound:pred_scores.shape[0]] + pred_bbox = pred_bbox[index_mask] + + # get cropped images and box info + cropped_img_list = [] + index_list = range(len(pred_bbox)) + box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros( + (4, len(index_list), 6)) + for i in index_list: + startx, starty, endx, endy = pred_bbox[i] + cropped_img = full_img[starty:endy, startx:endx, :] + cropped_img_list.append(cropped_img) + box_info[i] = np.array( + get_box_info(pred_bbox[i], full_img_size, self.final_size)) + box_info_2x[i] = np.array( + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 2)) + box_info_4x[i] = np.array( + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 4)) + box_info_8x[i] = np.array( + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 8)) + + # update results + if len(pred_bbox) > 0: + results['cropped_img'] = cropped_img_list + results['box_info'] = torch.from_numpy(box_info).type(torch.long) + results['box_info_2x'] = torch.from_numpy(box_info_2x).type( + torch.long) + results['box_info_4x'] = torch.from_numpy(box_info_4x).type( + torch.long) + results['box_info_8x'] = torch.from_numpy(box_info_8x).type( + torch.long) + results['empty_box'] = False + else: + results['empty_box'] = True + return results diff --git a/mmedit/datasets/transforms/get_maskrcnn_bbox.py b/mmedit/datasets/transforms/get_maskrcnn_bbox.py deleted file mode 100644 index 175bd1833d..0000000000 --- a/mmedit/datasets/transforms/get_maskrcnn_bbox.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import cv2 as cv -import numpy as np -import torch -from detectron2 import model_zoo -from detectron2.config import get_cfg -from detectron2.engine import DefaultPredictor -from mmcv.transforms import BaseTransform - -from mmedit.registry import TRANSFORMS - - -@TRANSFORMS.register_module() -class InstanceCrop(BaseTransform): - """Use maskrcnn to detect instances on image. - - Mask R-CNN is used to detect the instance on the image - pred_bbox is used to segment the instance on the image - - Args: - config_file (str): config file name relative to detectron2's "configs/" - key (str): Unused - box_num_upbound (int):The upper limit on the number of instances - in the figure - """ - - def __init__(self, - config_file, - key='img', - box_num_upbound=-1, - finesize=256): - # detector - cfg = get_cfg() - cfg.merge_from_file(model_zoo.get_config_file(config_file)) - cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 - cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file) - self.predictor = DefaultPredictor(cfg) - - self.key = key - self.box_num_upbound = box_num_upbound - self.final_size = finesize - - def transform(self, results: dict) -> dict: - """The transform function of InstanceCrop. - - Args: - results (dict): A dict containing the necessary information and - data for Conversion - - Returns: - results (dict): A dict containing the processed data - and information. - """ - # get consistent box prediction based on L channel - full_img = results['img'] - full_img_size = results['ori_img_shape'][:-1][::-1] - lab_image = cv.cvtColor(full_img, cv.COLOR_BGR2LAB) - l_channel, a_channel, b_channel = cv.split(lab_image) - l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) - outputs = self.predictor(l_stack) - - # get the most confident boxes - pred_bbox = outputs['instances'].pred_boxes.to( - torch.device('cpu')).tensor.numpy() - pred_scores = outputs['instances'].scores.cpu().data.numpy() - pred_bbox = pred_bbox.astype(np.int32) - if self.box_num_upbound > 0 and pred_bbox.shape[ - 0] > self.box_num_upbound: - index_mask = np.argsort(pred_scores, axis=0) - index_mask = index_mask[pred_scores.shape[0] - - self.box_num_upbound:pred_scores.shape[0]] - pred_bbox = pred_bbox[index_mask] - - # get cropped images and box info - cropped_img_list = [] - index_list = range(len(pred_bbox)) - box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros( - (4, len(index_list), 6)) - for i in index_list: - startx, starty, endx, endy = pred_bbox[i] - cropped_img = full_img[starty:endy, startx:endx, :] - cropped_img_list.append(cropped_img) - box_info[i] = np.array( - get_box_info(pred_bbox[i], full_img_size, self.final_size)) - box_info_2x[i] = np.array( - get_box_info(pred_bbox[i], full_img_size, - self.final_size // 2)) - box_info_4x[i] = np.array( - get_box_info(pred_bbox[i], full_img_size, - self.final_size // 4)) - box_info_8x[i] = np.array( - get_box_info(pred_bbox[i], full_img_size, - self.final_size // 8)) - - # update results - if len(pred_bbox) > 0: - results['cropped_img'] = cropped_img_list - results['box_info'] = torch.from_numpy(box_info).type(torch.long) - results['box_info_2x'] = torch.from_numpy(box_info_2x).type( - torch.long) - results['box_info_4x'] = torch.from_numpy(box_info_4x).type( - torch.long) - results['box_info_8x'] = torch.from_numpy(box_info_8x).type( - torch.long) - results['empty_box'] = False - else: - results['empty_box'] = True - return results - - -def get_box_info(pred_bbox, original_shape, final_size): - """ - - Args: - pred_bbox: The bounding box for the instance - original_shape: Original image shape - final_size: Size of the final output - - Returns: - List: [L_pad, R_pad, T_pad, B_pad, rh, rw] - """ - assert len(pred_bbox) == 4 - resize_startx = int(pred_bbox[0] / original_shape[0] * final_size) - resize_starty = int(pred_bbox[1] / original_shape[1] * final_size) - resize_endx = int(pred_bbox[2] / original_shape[0] * final_size) - resize_endy = int(pred_bbox[3] / original_shape[1] * final_size) - rh = resize_endx - resize_startx - rw = resize_endy - resize_starty - if rh < 1: - if final_size - resize_endx > 1: - resize_endx += 1 - else: - resize_startx -= 1 - rh = 1 - if rw < 1: - if final_size - resize_endy > 1: - resize_endy += 1 - else: - resize_starty -= 1 - rw = 1 - L_pad = resize_startx - R_pad = final_size - resize_endx - T_pad = resize_starty - B_pad = final_size - resize_endy - return [L_pad, R_pad, T_pad, B_pad, rh, rw] diff --git a/mmedit/utils/__init__.py b/mmedit/utils/__init__.py index 3fdaa607f2..1533e91fc7 100644 --- a/mmedit/utils/__init__.py +++ b/mmedit/utils/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cli import modify_args -from .img_utils import reorder_image, tensor2img, to_numpy +from .img_utils import get_box_info, reorder_image, tensor2img, to_numpy from .io_utils import MMEDIT_CACHE_DIR, download_from_url # TODO replace with engine's API from .logger import print_colored_log @@ -17,5 +17,5 @@ 'download_from_url', 'get_sampler', 'tensor2img', 'random_choose_unknown', 'add_gaussian_noise', 'adjust_gamma', 'make_coord', 'bbox2mask', 'brush_stroke_mask', 'get_irregular_mask', 'random_bbox', 'reorder_image', - 'to_numpy' + 'to_numpy', 'get_box_info' ] diff --git a/mmedit/utils/img_utils.py b/mmedit/utils/img_utils.py index bf420910b5..ff5b2c1f03 100644 --- a/mmedit/utils/img_utils.py +++ b/mmedit/utils/img_utils.py @@ -125,3 +125,40 @@ def to_numpy(img, dtype=np.float64): img = img.astype(dtype) return img + + +def get_box_info(pred_bbox, original_shape, final_size): + """ + + Args: + pred_bbox: The bounding box for the instance + original_shape: Original image shape + final_size: Size of the final output + + Returns: + List: [L_pad, R_pad, T_pad, B_pad, rh, rw] + """ + assert len(pred_bbox) == 4 + resize_startx = int(pred_bbox[0] / original_shape[0] * final_size) + resize_starty = int(pred_bbox[1] / original_shape[1] * final_size) + resize_endx = int(pred_bbox[2] / original_shape[0] * final_size) + resize_endy = int(pred_bbox[3] / original_shape[1] * final_size) + rh = resize_endx - resize_startx + rw = resize_endy - resize_starty + if rh < 1: + if final_size - resize_endx > 1: + resize_endx += 1 + else: + resize_startx -= 1 + rh = 1 + if rw < 1: + if final_size - resize_endy > 1: + resize_endy += 1 + else: + resize_starty -= 1 + rw = 1 + L_pad = resize_startx + R_pad = final_size - resize_endx + T_pad = resize_starty + B_pad = final_size - resize_endy + return [L_pad, R_pad, T_pad, B_pad, rh, rw] From a67a2c6473e61d873409ab8cca6ec960721e8815 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Wed, 2 Nov 2022 00:10:23 +0800 Subject: [PATCH 23/32] merge get_mask_rcnn into crop as InstanceCrop --- ...atioon_full_official_cocostuff-256x256.py} | 2 +- mmedit/datasets/transforms/crop.py | 50 ++++++++++++++++--- 2 files changed, 43 insertions(+), 9 deletions(-) rename configs/inst_colorization/{inst-colorizatioon_cocostuff_256x256.py => inst-colorizatioon_full_official_cocostuff-256x256.py} (93%) diff --git a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py similarity index 93% rename from configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py rename to configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py index 3cbaf2ad04..30399b723e 100644 --- a/configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py @@ -47,7 +47,7 @@ dict(type='LoadImageFromFile', key='img', channel_order='rgb'), dict( type='InstanceCrop', - config_file='mmdet::/mask_rcnn_X_101_32x8d_FPN_3x.yaml', # noqa + config_file='mmdet::mask_rcnn/mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py', # noqa finesize=256), dict( type='Resize', diff --git a/mmedit/datasets/transforms/crop.py b/mmedit/datasets/transforms/crop.py index d15043a134..d9d6966c88 100644 --- a/mmedit/datasets/transforms/crop.py +++ b/mmedit/datasets/transforms/crop.py @@ -6,8 +6,10 @@ import mmcv import numpy as np import torch -from mmcv.transforms import BaseTransform -from mmengine.hub import get_model +from mmcv.ops import RoIPool +from mmcv.transforms import BaseTransform, Compose +from mmdet.utils import register_all_modules +from mmengine.hub import get_config, get_model from mmengine.utils import is_list_of, is_tuple_of from torch.nn.modules.utils import _pair @@ -940,13 +942,37 @@ def __init__(self, key='img', box_num_upbound=-1, finesize=256): - # detector - self.predictor = get_model(config_file, pretrained=True) + + self.predictor = self.set_model(config_file) + self.pipeline = self.set_pipeline(config_file) self.key = key self.box_num_upbound = box_num_upbound self.final_size = finesize + def set_model(self, config_file): + model = get_model(config_file, pretrained=True) + if model.data_preprocessor.device.type == 'cpu': + for m in model.modules(): + assert not isinstance( + m, RoIPool + ), 'CPU inference with RoIPool is not supported currently.' + return model + + def set_pipeline(self, config_file): + register_all_modules() + cfg = get_config(config_file) + test_pipeline = cfg.test_dataloader.dataset.pipeline + test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' + new_test_pipeline = [] + for pipeline in test_pipeline: + if pipeline['type'] != 'LoadAnnotations' and pipeline[ + 'type'] != 'LoadPanopticAnnotations': + new_test_pipeline.append(pipeline) + + test_pipeline = Compose(new_test_pipeline) + return test_pipeline + def transform(self, results: dict) -> dict: """The transform function of InstanceCrop. @@ -961,10 +987,7 @@ def transform(self, results: dict) -> dict: # get consistent box prediction based on L channel full_img = results['img'] full_img_size = results['ori_img_shape'][:-1][::-1] - lab_image = cv.cvtColor(full_img, cv.COLOR_BGR2LAB) - l_channel, a_channel, b_channel = cv.split(lab_image) - l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) - outputs = self.predictor(l_stack) + outputs = self.predict_bbox(full_img) # get the most confident boxes pred_bbox = outputs['instances'].pred_boxes.to( @@ -1013,3 +1036,14 @@ def transform(self, results: dict) -> dict: else: results['empty_box'] = True return results + + def predict_bbox(self, image): + lab_image = cv.cvtColor(image, cv.COLOR_BGR2LAB) + l_channel, a_channel, b_channel = cv.split(lab_image) + l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) + data_ = dict(img=l_stack, img_id=0) + data_ = self.pipeline(data_) + data_['inputs'] = [data_['inputs']] + data_['data_samples'] = [data_['data_samples']] + results = self.predictor.test_step(data_)[0] + return results From 22982d5df497990258ddecd264311b23fc61676b Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Wed, 2 Nov 2022 16:58:55 +0800 Subject: [PATCH 24/32] update readme --- configs/inst_colorization/README.md | 46 ++++---------- configs/inst_colorization/README_zh-CN.md | 47 ++++---------- ...zatioon_full_official_cocostuff-256x256.py | 3 +- configs/inst_colorization/metafile.yml | 12 +++- mmedit/apis/colorization_inference.py | 21 ++++--- mmedit/datasets/transforms/aug_shape.py | 5 +- mmedit/datasets/transforms/crop.py | 62 ++++++------------- .../inst_colorization/inst_colorization.py | 14 +++-- 8 files changed, 81 insertions(+), 129 deletions(-) diff --git a/configs/inst_colorization/README.md b/configs/inst_colorization/README.md index 29ec9dbc25..fdfdfb9e50 100644 --- a/configs/inst_colorization/README.md +++ b/configs/inst_colorization/README.md @@ -12,51 +12,31 @@ Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization. -## Results and models - -## Quick Start - -**Train** + -
-Train Instructions - -You can use the following commands to train a model with cpu or single/multiple GPUs. - -```shell -# CPU train -CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py +
+ +
-# single-gpu train -python tools/train.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py - -# multi-gpu train -./tools/dist_train.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py 8 -``` +## Results and models -For more details, you can refer to **Train a model** part in [train_test.md](/docs/en/user_guides/train_test.md#Train-a-model-in-MMEditing). +| Method | Download | +| :-------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------: | +| [instance_aware_colorization_officiial](/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth) | -
- -**Test** +## Quick Start
-Test Instructions +Colorization demo -You can use the following commands to test a model with cpu or single/multiple GPUs. +You can use the following commands to colorize an image. ```shell -# CPU test -CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization//inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth - -# single-gpu demo -python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth -# multi-gpu test -./tools/dist_test.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 +python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth input.jpg output.jpg ``` -For more details, you can refer to **Test a pre-trained model** part in [train_test.md](/docs/en/user_guides/train_test.md#Test-a-pre-trained-model-in-MMEditing). +For more demos, you can refer to [Tutorial 3: inference with pre-trained models](https://mmediting.readthedocs.io/en/1.x/user_guides/3_inference.html).
diff --git a/configs/inst_colorization/README_zh-CN.md b/configs/inst_colorization/README_zh-CN.md index 661dfe2442..19e59c64fc 100644 --- a/configs/inst_colorization/README_zh-CN.md +++ b/configs/inst_colorization/README_zh-CN.md @@ -12,51 +12,30 @@ Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization. -## 结果和模型 - -## 快速开始 - -**训练** - -
-训练说明 - -您可以使用以下命令来训练模型。 - -```shell -# CPU上训练 -CUDA_VISIBLE_DEVICES=-1 python tools/train.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py - -# 单个GPU上训练 -python tools/train.py configs/insta/inst-colorizatioon_cocostuff_full_256x256.py + -# 多个GPU上训练 -./tools/dist_train.sh configs/insta/inst-colorizatioon_cocostuff_full_256x256.py 8 -``` +
+ +
-更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Train a model** 部分。 +## 结果和模型 -
+| Method | Download | +| :-------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------: | +| [instance_aware_colorization_officiial](/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth) | -**测试** +## 快速开始
-测试说明 +图像上色模型 -您可以使用以下命令来测试模型。 +您可以使用以下命令来对一张图像进行上色。 ```shell -# CPU上测试 -CUDA_VISIBLE_DEVICES=-1 python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth - -# 单个GPU上 demo -python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py work_dirs/checkpoints/instance_aware_cocostuff.pth work_dirs/colorization_example.jpg work_dirs/output_example.png - -# 多个GPU上测试 -./tools/dist_test.sh configs/inst_colorization/inst-colorizatioon_cocostuff_full_256x256.py ../checkpoints/instance_aware_cocostuff.pth 8 +python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth input.jpg output.jpg ``` -更多细节可以参考 [train_test.md](/docs/zh_cn/user_guides/train_test.md) 中的 **Test a pre-trained model** 部分。 +更多细节可以参考 [Tutorial 3: inference with pre-trained models](https://mmediting.readthedocs.io/en/1.x/user_guides/3_inference.html)。
diff --git a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py index 30399b723e..76ccf2bcaa 100644 --- a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py @@ -48,7 +48,8 @@ dict( type='InstanceCrop', config_file='mmdet::mask_rcnn/mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py', # noqa - finesize=256), + finesize=256, + box_num_upbound=5), dict( type='Resize', keys=['img', 'cropped_img'], diff --git a/configs/inst_colorization/metafile.yml b/configs/inst_colorization/metafile.yml index 54bf9ccebc..c13dabfb11 100644 --- a/configs/inst_colorization/metafile.yml +++ b/configs/inst_colorization/metafile.yml @@ -6,4 +6,14 @@ Collections: Paper: - https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html README: configs/inst_colorization/README.md -Models: [] +Models: +- Config: configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py + In Collection: Instance-aware Image Colorization + Metadata: + Training Data: Others + Name: inst-colorizatioon_full_official_cocostuff-256x256 + Results: + - Dataset: Others + Metrics: {} + Task: Inst_colorization + Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py index f650dee925..ddef7ef587 100644 --- a/mmedit/apis/colorization_inference.py +++ b/mmedit/apis/colorization_inference.py @@ -28,20 +28,21 @@ def colorization_inference(model, img): data['data_samples'] = [_data['data_samples']] if 'cuda' in str(device): data = scatter(data, [device])[0] - data['data_samples'][0].cropped_img.data = scatter( - data['data_samples'][0].cropped_img.data, [device])[0] / 255.0 + if not data['data_samples'][0].empty_box: + data['data_samples'][0].cropped_img.data = scatter( + data['data_samples'][0].cropped_img.data, [device])[0] / 255.0 - data['data_samples'][0].box_info.data = scatter( - data['data_samples'][0].box_info.data, [device])[0] + data['data_samples'][0].box_info.data = scatter( + data['data_samples'][0].box_info.data, [device])[0] - data['data_samples'][0].box_info_2x.data = scatter( - data['data_samples'][0].box_info_2x.data, [device])[0] + data['data_samples'][0].box_info_2x.data = scatter( + data['data_samples'][0].box_info_2x.data, [device])[0] - data['data_samples'][0].box_info_4x.data = scatter( - data['data_samples'][0].box_info_4x.data, [device])[0] + data['data_samples'][0].box_info_4x.data = scatter( + data['data_samples'][0].box_info_4x.data, [device])[0] - data['data_samples'][0].box_info_8x.data = scatter( - data['data_samples'][0].box_info_8x.data, [device])[0] + data['data_samples'][0].box_info_8x.data = scatter( + data['data_samples'][0].box_info_8x.data, [device])[0] # forward the model with torch.no_grad(): diff --git a/mmedit/datasets/transforms/aug_shape.py b/mmedit/datasets/transforms/aug_shape.py index 61fa206d2e..fb50fe1134 100644 --- a/mmedit/datasets/transforms/aug_shape.py +++ b/mmedit/datasets/transforms/aug_shape.py @@ -367,8 +367,9 @@ def transform(self, results: Dict) -> Dict: self.scale = (new_w, new_h) for key, out_key in zip(self.keys, self.output_keys): - size, results[out_key] = self._resize(results[key]) - results[f'{out_key}_shape'] = size + if key in results: + size, results[out_key] = self._resize(results[key]) + results[f'{out_key}_shape'] = size results['scale_factor'] = self.scale_factor results['keep_ratio'] = self.keep_ratio diff --git a/mmedit/datasets/transforms/crop.py b/mmedit/datasets/transforms/crop.py index d9d6966c88..2bf8714739 100644 --- a/mmedit/datasets/transforms/crop.py +++ b/mmedit/datasets/transforms/crop.py @@ -6,10 +6,10 @@ import mmcv import numpy as np import torch -from mmcv.ops import RoIPool -from mmcv.transforms import BaseTransform, Compose -from mmdet.utils import register_all_modules -from mmengine.hub import get_config, get_model +from mmcv.transforms import BaseTransform +from mmdet.apis import inference_detector, init_detector +from mmengine.hub import get_config +from mmengine.registry import DefaultScope from mmengine.utils import is_list_of, is_tuple_of from torch.nn.modules.utils import _pair @@ -943,36 +943,14 @@ def __init__(self, box_num_upbound=-1, finesize=256): - self.predictor = self.set_model(config_file) - self.pipeline = self.set_pipeline(config_file) + cfg = get_config(config_file, pretrained=True) + with DefaultScope.overwrite_default_scope('mmdet'): + self.predictor = init_detector(cfg, cfg.model_path) self.key = key self.box_num_upbound = box_num_upbound self.final_size = finesize - def set_model(self, config_file): - model = get_model(config_file, pretrained=True) - if model.data_preprocessor.device.type == 'cpu': - for m in model.modules(): - assert not isinstance( - m, RoIPool - ), 'CPU inference with RoIPool is not supported currently.' - return model - - def set_pipeline(self, config_file): - register_all_modules() - cfg = get_config(config_file) - test_pipeline = cfg.test_dataloader.dataset.pipeline - test_pipeline[0].type = 'mmdet.LoadImageFromNDArray' - new_test_pipeline = [] - for pipeline in test_pipeline: - if pipeline['type'] != 'LoadAnnotations' and pipeline[ - 'type'] != 'LoadPanopticAnnotations': - new_test_pipeline.append(pipeline) - - test_pipeline = Compose(new_test_pipeline) - return test_pipeline - def transform(self, results: dict) -> dict: """The transform function of InstanceCrop. @@ -987,13 +965,8 @@ def transform(self, results: dict) -> dict: # get consistent box prediction based on L channel full_img = results['img'] full_img_size = results['ori_img_shape'][:-1][::-1] - outputs = self.predict_bbox(full_img) + pred_bbox, pred_scores = self.predict_bbox(full_img) - # get the most confident boxes - pred_bbox = outputs['instances'].pred_boxes.to( - torch.device('cpu')).tensor.numpy() - pred_scores = outputs['instances'].scores.cpu().data.numpy() - pred_bbox = pred_bbox.astype(np.int32) if self.box_num_upbound > 0 and pred_bbox.shape[ 0] > self.box_num_upbound: index_mask = np.argsort(pred_scores, axis=0) @@ -1039,11 +1012,16 @@ def transform(self, results: dict) -> dict: def predict_bbox(self, image): lab_image = cv.cvtColor(image, cv.COLOR_BGR2LAB) - l_channel, a_channel, b_channel = cv.split(lab_image) + l_channel, _, _ = cv.split(lab_image) l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) - data_ = dict(img=l_stack, img_id=0) - data_ = self.pipeline(data_) - data_['inputs'] = [data_['inputs']] - data_['data_samples'] = [data_['data_samples']] - results = self.predictor.test_step(data_)[0] - return results + + with DefaultScope.overwrite_default_scope('mmdet'): + with torch.no_grad(): + results = inference_detector(self.predictor, l_stack) + + bboxes = results.pred_instances.bboxes.cpu().numpy().astype(np.int32) + scores = results.pred_instances.scores.cpu().numpy() + index_mask = [i for i, x in enumerate(scores) if x >= 0.7] + scores = np.array(scores[index_mask]) + bboxes = np.array(bboxes[index_mask]) + return bboxes, scores diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py index 29734b1689..4c63aac225 100644 --- a/mmedit/models/editors/inst_colorization/inst_colorization.py +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -197,12 +197,6 @@ def forward_tensor(self, inputs, data_samples): 'fusion model supports only one image due to different numbers '\ 'of instances of different images' - cropped_img = data_samples[0].cropped_img.data - box_info_list = [ - data_samples[0].box_info, data_samples[0].box_info_2x, - data_samples[0].box_info_4x, data_samples[0].box_info_8x - ] - cropped_data = get_colorization_data(cropped_img, self.color_data_opt) full_img_data = get_colorization_data(inputs, self.color_data_opt) AtoB = self.which_direction == 'AtoB' @@ -213,6 +207,14 @@ def forward_tensor(self, inputs, data_samples): if not data_samples[0].empty_box: # preprocess instance input + cropped_img = data_samples[0].cropped_img.data + box_info_list = [ + data_samples[0].box_info, data_samples[0].box_info_2x, + data_samples[0].box_info_4x, data_samples[0].box_info_8x + ] + cropped_data = get_colorization_data(cropped_img, + self.color_data_opt) + real_A = cropped_data['A' if AtoB else 'B'] hint_B = cropped_data['hint_B'] mask_B = cropped_data['mask_B'] From db0fed74fc26e24766423f5f2a287ed3ff7ed203 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Wed, 2 Nov 2022 22:15:48 +0800 Subject: [PATCH 25/32] fix ut --- ...zatioon_full_official_cocostuff-256x256.py | 6 +- mmedit/datasets/transforms/crop.py | 1 + .../editors/inst_colorization/color_utils.py | 2 +- .../inst_colorization/colorization_net.py | 2 +- .../editors/inst_colorization/fusion_net.py | 2 +- .../editors/inst_colorization/weight_layer.py | 2 +- mmedit/models/losses/__init__.py | 42 ++++++++--- mmedit/models/losses/huber_loss.py | 22 ------ .../test_apis/test_colorization_inference.py | 11 +-- .../test_transforms/test_crop.py | 29 +++++++- .../test_transforms/test_get_maskrcnn_bbox.py | 72 ------------------- .../test_losses/test_huber_loss.py | 5 -- 12 files changed, 73 insertions(+), 123 deletions(-) delete mode 100644 mmedit/models/losses/huber_loss.py delete mode 100644 tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py delete mode 100644 tests/test_models/test_losses/test_huber_loss.py diff --git a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py index 76ccf2bcaa..952bc74cda 100644 --- a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py +++ b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py @@ -1,8 +1,8 @@ _base_ = ['../_base_/default_runtime.py'] -exp_name = 'inst-colorization_cocostuff_256x256' -save_dir = './' -work_dir = '..' +experiment_name = 'inst-colorization_full_official_cocostuff_256x256' +work_dir = f'./work_dirs/{experiment_name}' +save_dir = './work_dirs/' stage = 'full' diff --git a/mmedit/datasets/transforms/crop.py b/mmedit/datasets/transforms/crop.py index 2bf8714739..8554b72018 100644 --- a/mmedit/datasets/transforms/crop.py +++ b/mmedit/datasets/transforms/crop.py @@ -963,6 +963,7 @@ def transform(self, results: dict) -> dict: and information. """ # get consistent box prediction based on L channel + full_img = results['img'] full_img_size = results['ori_img_shape'][:-1][::-1] pred_bbox, pred_scores = self.predict_bbox(full_img) diff --git a/mmedit/models/editors/inst_colorization/color_utils.py b/mmedit/models/editors/inst_colorization/color_utils.py index 78b536755f..6ecc57b72f 100644 --- a/mmedit/models/editors/inst_colorization/color_utils.py +++ b/mmedit/models/editors/inst_colorization/color_utils.py @@ -4,7 +4,7 @@ def xyz2rgb(xyz): - """Conversion images from lab to xyz. + """Conversion images from xyz to rgb. Args: xyz (tensor): The images to be conversion diff --git a/mmedit/models/editors/inst_colorization/colorization_net.py b/mmedit/models/editors/inst_colorization/colorization_net.py index 85d2d03039..6d62209e07 100644 --- a/mmedit/models/editors/inst_colorization/colorization_net.py +++ b/mmedit/models/editors/inst_colorization/colorization_net.py @@ -33,7 +33,7 @@ def __init__(self, norm_type, use_tanh=True, classification=True): - super(ColorizationNet, self).__init__() + super().__init__() self.input_nc = input_nc self.output_nc = output_nc self.classification = classification diff --git a/mmedit/models/editors/inst_colorization/fusion_net.py b/mmedit/models/editors/inst_colorization/fusion_net.py index 6996cfe448..10c5732680 100644 --- a/mmedit/models/editors/inst_colorization/fusion_net.py +++ b/mmedit/models/editors/inst_colorization/fusion_net.py @@ -32,7 +32,7 @@ def __init__(self, norm_type, use_tanh=True, classification=True): - super(FusionNet, self).__init__() + super().__init__() self.input_nc = input_nc self.output_nc = output_nc self.classification = classification diff --git a/mmedit/models/editors/inst_colorization/weight_layer.py b/mmedit/models/editors/inst_colorization/weight_layer.py index 76f3c88243..c2f05b34f0 100644 --- a/mmedit/models/editors/inst_colorization/weight_layer.py +++ b/mmedit/models/editors/inst_colorization/weight_layer.py @@ -43,7 +43,7 @@ class WeightLayer(BaseModule): """ def __init__(self, input_ch, inner_ch=16): - super(WeightLayer, self).__init__() + super().__init__() self.simple_instance_conv = nn.Sequential( nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), nn.ReLU(True), diff --git a/mmedit/models/losses/__init__.py b/mmedit/models/losses/__init__.py index 4e388acc47..df66126d39 100644 --- a/mmedit/models/losses/__init__.py +++ b/mmedit/models/losses/__init__.py @@ -9,7 +9,6 @@ gen_path_regularizer, gradient_penalty_loss, r1_gradient_penalty_loss) from .gradient_loss import GradientLoss -from .huber_loss import HuberLoss from .loss_comps import (CLIPLossComps, DiscShiftLossComps, FaceIdLossComps, GANLossComps, GeneratorPathRegularizerComps, GradientPenaltyLossComps, R1GradientPenaltyComps) @@ -19,14 +18,35 @@ from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss __all__ = [ - 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss', - 'MSECompositionLoss', 'CharbonnierCompLoss', 'GANLoss', 'GaussianBlur', - 'GradientPenaltyLoss', 'PerceptualLoss', 'PerceptualVGG', 'reduce_loss', - 'mask_reduce_loss', 'DiscShiftLoss', 'MaskedTVLoss', 'GradientLoss', - 'TransferalPerceptualLoss', 'LightCNNFeatureLoss', 'gradient_penalty_loss', - 'r1_gradient_penalty_loss', 'gen_path_regularizer', 'FaceIdLoss', - 'CLIPLoss', 'CLIPLossComps', 'DiscShiftLossComps', 'FaceIdLossComps', - 'GANLossComps', 'GeneratorPathRegularizerComps', - 'GradientPenaltyLossComps', 'R1GradientPenaltyComps', 'disc_shift_loss', - 'HuberLoss' + 'L1Loss', + 'MSELoss', + 'CharbonnierLoss', + 'L1CompositionLoss', + 'MSECompositionLoss', + 'CharbonnierCompLoss', + 'GANLoss', + 'GaussianBlur', + 'GradientPenaltyLoss', + 'PerceptualLoss', + 'PerceptualVGG', + 'reduce_loss', + 'mask_reduce_loss', + 'DiscShiftLoss', + 'MaskedTVLoss', + 'GradientLoss', + 'TransferalPerceptualLoss', + 'LightCNNFeatureLoss', + 'gradient_penalty_loss', + 'r1_gradient_penalty_loss', + 'gen_path_regularizer', + 'FaceIdLoss', + 'CLIPLoss', + 'CLIPLossComps', + 'DiscShiftLossComps', + 'FaceIdLossComps', + 'GANLossComps', + 'GeneratorPathRegularizerComps', + 'GradientPenaltyLossComps', + 'R1GradientPenaltyComps', + 'disc_shift_loss', ] diff --git a/mmedit/models/losses/huber_loss.py b/mmedit/models/losses/huber_loss.py deleted file mode 100644 index 187a47bef9..0000000000 --- a/mmedit/models/losses/huber_loss.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn - -from mmedit.registry import LOSSES - - -@LOSSES.register_module() -class HuberLoss(nn.Module): - - def __init__(self, delta=.01): - super(HuberLoss, self).__init__() - self.delta = delta - - def forward(self, in0, in1): - mask = torch.zeros_like(in0) - mann = torch.abs(in0 - in1) - eucl = .5 * (mann**2) - mask[...] = mann < self.delta - - loss = eucl * mask / self.delta + (mann - .5 * self.delta) * (1 - mask) - return torch.sum(loss, dim=1, keepdim=True) diff --git a/tests/test_apis/test_colorization_inference.py b/tests/test_apis/test_colorization_inference.py index ad05a1d26b..8772a31fc2 100644 --- a/tests/test_apis/test_colorization_inference.py +++ b/tests/test_apis/test_colorization_inference.py @@ -18,11 +18,11 @@ def test_colorization_inference(): else: device = torch.device('cpu') - data_root = '../../' config = osp.join( - data_root, - 'configs/inst_colorization/inst-colorizatioon_cocostuff_256x256.py') - + osp.dirname(__file__), + '../..', + 'configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py' # noqa + ) checkpoint = None cfg = Config.fromfile(config) @@ -35,7 +35,8 @@ def test_colorization_inference(): model.to(device) model.eval() - img_path = '../data/image/gray/test.jpg' + img_path = osp.join( + osp.dirname(__file__), '..', 'data/image/img_root/horse/horse.jpeg') result = colorization_inference(model, img_path) assert tensor2img(result)[..., ::-1].shape == (256, 256, 3) diff --git a/tests/test_datasets/test_transforms/test_crop.py b/tests/test_datasets/test_transforms/test_crop.py index b755e264ff..c763a679cc 100644 --- a/tests/test_datasets/test_transforms/test_crop.py +++ b/tests/test_datasets/test_transforms/test_crop.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import os.path as osp +import cv2 import numpy as np import pytest -from mmedit.datasets.transforms import (Crop, CropLike, FixedCrop, ModCrop, +from mmedit.datasets.transforms import (Crop, CropLike, FixedCrop, + InstanceCrop, ModCrop, PairedRandomCrop, RandomResizedCrop) @@ -350,3 +353,27 @@ def test_crop_like(): assert results['gt'].shape == (512, 512) sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512, 0])) assert sum_diff < 1e-6 + + +def test_instance_crop(): + + croper = InstanceCrop( + key='img', + finesize=256, + box_num_upbound=2, + config_file='mmdet::mask_rcnn/' + 'mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py') # noqa + + img_path = osp.join( + osp.dirname(__file__), '..', '..', + 'data/image/img_root/horse/horse.jpeg') + img = cv2.imread(img_path) + data = dict(img=img, ori_img_shape=img.shape, img_channel_order='rgb') + + results = croper(data) + + assert 'empty_box' in results + if results['empty_box']: + cropped_img = results['cropped_img'] + assert len(cropped_img) == 0 + assert len(cropped_img) <= 2 diff --git a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py b/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py deleted file mode 100644 index e7c35951aa..0000000000 --- a/tests/test_datasets/test_transforms/test_get_maskrcnn_bbox.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os - -import cv2 as cv - -from mmedit.datasets.transforms import InstanceCrop -from mmedit.utils import tensor2img - - -class TestMaskRCNNBbox: - - DEFAULT_ARGS = dict(key='img', finesize=256) - - def test_maskrcnn_bbox(self): - detectetor = InstanceCrop(**self.DEFAULT_ARGS, stage='test') - data_root = '..' - img_path = 'data/image/gray/test.jpg' - img = cv.imread(os.path.join(data_root, img_path)) - - data = dict(img=img) - - results = detectetor(data) - pred_bbox = results.pred_bbox - - assert len(pred_bbox) <= 8 - assert results['full_gray'] and results['box_info'] \ - and results['cropped_gray'] - - detectetor.stage = 'fusion' - results = detectetor(data) - index = len(results.pred_bbox) - assert results['full_rgb'] and results['cropped_rgb'] - assert results['cropped_gray_list'].shape == (index, 3, 256, 256) - - detectetor.stage = 'full' - results = detectetor(data) - assert results['rgb_img'] and results['gray_img'] - assert tensor2img(results['rgb_img']).shape == (3, 256, 256) - - def test_gen_maskrcnn_from_pred(self): - detectetor = InstanceCrop(**self.DEFAULT_ARGS, stage='test') - data_root = '..' - img_path = 'data/image/gray/test.jpg' - img = cv.imread(os.path.join(data_root, img_path)) - - box_num_upbound = 4 - pred_bbox = detectetor.gen_maskrcnn_bbox_fromPred(img) - - assert len(pred_bbox) <= box_num_upbound - assert pred_bbox.shape[-1] == 4 - - def test_get_box_info(self): - detectetor = InstanceCrop(**self.DEFAULT_ARGS, stage='test') - data_root = '..' - img_path = 'data/image/gray/test.jpg' - img = cv.imread(os.path.join(data_root, img_path)) - - pred_bbox = detectetor.gen_maskrcnn_bbox_fromPred(img) - - resize_startx = int(pred_bbox[0] / img.shape[0] * 256) - resize_starty = int(pred_bbox[1] / img.shape[1] * 256) - resize_endx = int(pred_bbox[2] / img.shape[0] * 256) - resize_endy = int(pred_bbox[3] / img.shape[1] * 256) - - box_info = detectetor.get_box_info(pred_bbox, img.shape) - - assert box_info[0] == resize_starty and \ - box_info[1] == 256 - resize_endx and \ - box_info[2] == resize_starty and \ - box_info[3] == 256 - resize_endy and \ - box_info[4] == resize_endx - resize_startx and \ - box_info[5] == resize_endy - resize_starty diff --git a/tests/test_models/test_losses/test_huber_loss.py b/tests/test_models/test_losses/test_huber_loss.py deleted file mode 100644 index 675963331e..0000000000 --- a/tests/test_models/test_losses/test_huber_loss.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - - -def test_huber_loss(): - pass From 5e4b9d7291ae0e58b5413589df6ebc068eb4dfa2 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Wed, 2 Nov 2022 22:28:06 +0800 Subject: [PATCH 26/32] update workflow --- .circleci/test.yml | 5 ++++- .github/workflows/merge_stage_test.yml | 5 +++++ .github/workflows/pr_stage_test.yml | 4 ++++ requirements/optional.txt | 1 + 4 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 requirements/optional.txt diff --git a/.circleci/test.yml b/.circleci/test.yml index d1a3489802..11e7189571 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -62,6 +62,7 @@ jobs: pip install git+https://github.com/open-mmlab/mmengine.git@main pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install -r requirements/tests.txt - run: name: Build and install @@ -94,17 +95,19 @@ jobs: name: Clone Repos command: | git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine + git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection - run: name: Build Docker image command: | docker build .circleci/docker -t mmedit:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >> - docker run --gpus all -t -d -v /home/circleci/project:/mmedit -v /home/circleci/mmengine:/mmengine -w /mmedit --name mmedit mmedit:gpu + docker run --gpus all -t -d -v /home/circleci/project:/mmedit -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmdetection:/mmdetection -w /mmedit --name mmedit mmedit:gpu - run: name: Install mmedit dependencies command: | docker exec mmedit pip install -e /mmengine docker exec mmedit pip install -U openmim docker exec mmedit mim install 'mmcv >= 2.0.0rc1' + docker exec mmedit pip install -e /mmdetection docker exec mmedit pip install -r requirements/tests.txt - run: name: Build and install diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index e22dcb528a..be938eb1dd 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -45,6 +45,8 @@ jobs: run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + - name: Install MMDet + run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -92,6 +94,7 @@ jobs: run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -145,6 +148,7 @@ jobs: - name: Install mmediting dependencies run: | pip install git+https://github.com/open-mmlab/mmengine.git@main + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install -U openmim mim install 'mmcv >= 2.0.0rc1' pip install -r requirements/tests.txt @@ -175,6 +179,7 @@ jobs: - name: Install mmediting dependencies run: | python -m pip install git+https://github.com/open-mmlab/mmengine.git@main + python -m pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x python -m pip install -U openmim mim install 'mmcv >= 2.0.0rc1' python -m pip install -r requirements/tests.txt diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index b55f810642..f7a4db6aa5 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -39,6 +39,8 @@ jobs: run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + - name: Install MMDet + run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -92,6 +94,7 @@ jobs: - name: Install mmedit dependencies run: | pip install git+https://github.com/open-mmlab/mmengine.git@main + pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install -U openmim mim install 'mmcv >= 2.0.0rc1' pip install -r requirements/tests.txt @@ -125,6 +128,7 @@ jobs: - name: Install mmedit dependencies run: | python -m pip install git+https://github.com/open-mmlab/mmengine.git@main + python -m pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x python -m pip install -U openmim mim install 'mmcv >= 2.0.0rc1' python -m pip install -r requirements/tests.txt diff --git a/requirements/optional.txt b/requirements/optional.txt new file mode 100644 index 0000000000..95688066cc --- /dev/null +++ b/requirements/optional.txt @@ -0,0 +1 @@ +mmdet From e1caef1560c66b420a269844da22b0aff2f937ab Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Wed, 2 Nov 2022 23:12:29 +0800 Subject: [PATCH 27/32] update changelog --- README.md | 43 +++++++++++++++++++------------ README_zh-CN.md | 49 ++++++++++++++++++++++-------------- docs/en/3_model_zoo.md | 29 +++++++++++++-------- docs/en/notes/3_changelog.md | 28 +++++++++++++++++++++ mmedit/version.py | 2 +- 5 files changed, 105 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index f79f5b4551..9ace05994b 100644 --- a/README.md +++ b/README.md @@ -104,22 +104,26 @@ hope MMEditing could provide better experience. ## What's New -- \[2022-09-13\] 🎉[MMGeneration](https://github.com/open-mmlab/mmgeneration/tree/1.x) was merged into MMEditing! And we are calling for your [suggestion](https://github.com/open-mmlab/mmediting/discussions/1108)! -- \[2022-08-31\] v1.0.0rc0 was released. This release introduced a brand new and flexible training & test engine, but it's still in progress. Welcome - to try according to [the documentation](https://mmediting.readthedocs.io/en/1.x/). -- \[2022-06-01\] v0.15.0 was released. - - Support FLAVR - - Support AOT-GAN - - Support CAIN with ReduceLROnPlateau Scheduler -- \[2022-04-01\] v0.14.0 was released. - - Support TOFlow in video frame interpolation -- \[2022-03-01\] v0.13.0 was released. - - Support CAIN - - Support EDVR-L - - Support running in Windows -- \[2022-02-11\] Switch to **PyTorch 1.5+**. The compatibility to earlier versions of PyTorch will no longer be guaranteed. - -Please refer to [changelog.md](docs/en/notes/3_changelog.md) for details and release history. +### 🌟 Preview of 1.x version + +A brand new version of [**MMEditing v1.0.0rc2**](https://github.com/open-mmlab/mmediting/releases/tag/v1.0.0rc2) was released in 02/11/2022: + +- Support all the tasks, models, metrics, and losses in [MMGeneration](https://github.com/open-mmlab/mmgeneration) 😍。 +- Unifies interfaces of all components based on [MMEngine](https://github.com/open-mmlab/mmengine). +- Support patch-based and slider-based image and video comparison viewer. +- Support image colorization. + +Find more new features in [1.x branch](https://github.com/open-mmlab/mmediting/tree/1.x). Issues and PRs are welcome! + +### 💎 Stable version + +**0.16.0** was released in 31/10/2022: + +- `VisualizationHook` is deprecated. Users should use `MMEditVisualizationHook` instead. +- Fix FLAVR register. +- Fix the number of channels in RDB. + +Please refer to [changelog.md](docs/en/changelog.md) for details and release history. ## Installation @@ -215,6 +219,13 @@ Supported algorithms: +
+Image Colorization + +- ✅ [InstColorization](configs/inst_colorization/README.md) (CVPR'2020) + +
+
Unconditional GANs diff --git a/README_zh-CN.md b/README_zh-CN.md index 7f820b5456..64fca3222b 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -37,7 +37,7 @@ [English](/README.md) | 简体中文 -## Introduction +## 介绍 MMEditing 是基于 PyTorch 的图像&视频编辑和生成开源工具箱。是 [OpenMMLab](https://openmmlab.com/) 项目的成员之一。 @@ -101,24 +101,28 @@ https://user-images.githubusercontent.com/12756472/158972813-d8d0f19c-f49c-4618- 需要注意的是 **MMSR** 已作为 MMEditing 的一部分并入本仓库。 MMEditing 缜密地设计新的框架并将其精心实现,希望能够为您带来更好的体验。 -## 最新消息 - -- \[2022-09-13\] 🎉 [MMGeneration](<(https://github.com/open-mmlab/mmgeneration/tree/1.x)>) 合入 MMEditing! 对于该合入计划,我们期待您的 [建议](https://github.com/open-mmlab/mmediting/discussions/1108)! -- \[2022-08-31\] v1.0.0rc0 版本发布 - 这个版本引入一个全新的,可扩展性强的训练和测试引擎,但目前仍在开发中。欢迎根据[文档](https://mmediting.readthedocs.io/en/1.x/)进行试用。 -- \[2022-06-01\] v0.15.0 版本发布 - - 支持 FLAVR - - 支持 AOT-GAN - - 新版 CAIN,支持 ReduceLROnPlateau 策略 -- \[2022-04-01\] v0.14.0 版本发布 - - 支持视频插帧算法 TOFlow -- \[2022-03-01\] v0.13.0 版本发布 - - 支持 CAIN - - 支持 EDVR-L - - 支持在 Windows 系统中运行 -- \[2022-02-11\] 切换到 **PyTorch 1.5+**. 将不再保证与早期版本的 PyTorch 的兼容性 - -请查看 [changelog.md](docs/zh_cn/notes/3_changelog.md) 以获取更多细节与发版记录 +## 最新进展 + +### 🌟 1.x 预览版本 + +全新的 [**MMEditing v1.0.0rc2**](https://github.com/open-mmlab/mmediting/releases/tag/v1.0.0rc2) 已经在 02/11/2022 发布: + +- 支持[MMGeneration](https://github.com/open-mmlab/mmgeneration)中的全量任务、模型、优化函数和评价指标 😍。 +- 基于[MMEngine](https://github.com/open-mmlab/mmengine)统一了各组件接口。 +- 支持基于图像子块以及滑动条的图像和视频比较可视化工具。 +- 支持图像上色任务。 + +在[1.x 分支](https://github.com/open-mmlab/mmediting/tree/1.x)中发现更多特性!欢迎提 Issues 和 PRs! + +### 💎 稳定版本 + +最新的 **0.16.0** 版本已经在 31/10/2022 发布: + +- `VisualizationHook` 将被启用,建议用户使用 `MMEditVisualizationHook`。 +- 修复 FLAVR 的注册问题。 +- 修正 RDB 模型中的通道数。 + +如果像了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/changelog.md)。 ## 安装 @@ -213,6 +217,13 @@ pip3 install -e .
+
+图像上色 + +- ✅ [InstColorization](configs/inst_colorization/README.md) (CVPR'2020) + +
+
Unconditional GANs diff --git a/docs/en/3_model_zoo.md b/docs/en/3_model_zoo.md index 7f2d8716b6..af3547569a 100644 --- a/docs/en/3_model_zoo.md +++ b/docs/en/3_model_zoo.md @@ -1,19 +1,20 @@ # Overview -- Number of checkpoints: 168 -- Number of configs: 168 -- Number of papers: 41 - - ALGORITHM: 42 +- Number of checkpoints: 169 +- Number of configs: 169 +- Number of papers: 42 + - ALGORITHM: 43 - Tasks: - - unconditional gans - image2image translation - - conditional gans - - matting + - video interpolation + - unconditional gans - image super-resolution - - video super-resolution - - inpainting - internal learning - - video interpolation + - conditional gans + - inpainting + - video super-resolution + - colorization + - matting For supported datasets, see [datasets overview](dataset_zoo/0_overview.md). @@ -185,6 +186,14 @@ For supported datasets, see [datasets overview](dataset_zoo/0_overview.md). - Number of papers: 1 - \[ALGORITHM\] Indices Matter: Learning to Index for Deep Image Matting ([⇨](https://github.com/open-mmlab/mmediting/blob/1.x/configs/indexnet/README.md#citation)) +## Instance-aware Image Colorization (CVPR'2020) + +- Tasks: colorization +- Number of checkpoints: 1 +- Number of configs: 1 +- Number of papers: 1 + - \[ALGORITHM\] Instance-Aware Image Colorization ([⇨](https://github.com/open-mmlab/mmediting/blob/1.x/configs/inst_colorization/README.md#quick-start)) + ## LIIF (CVPR'2021) - Tasks: image super-resolution diff --git a/docs/en/notes/3_changelog.md b/docs/en/notes/3_changelog.md index be177b9f25..258664940d 100644 --- a/docs/en/notes/3_changelog.md +++ b/docs/en/notes/3_changelog.md @@ -1,5 +1,33 @@ # Changelog +## v1.0.0rc2 (02/11/2022) + +**Highlights** +We are excited to announce the release of PyTorch 1.11. This release is composed of over 3,300 commits since 1.10, made by 434 contributors. Along with 1.11, we are releasing beta versions of TorchData and functorch. We want to sincerely thank our community for continuously improving PyTorch. + +- TorchData is a new library for common modular data loading primitives for easily constructing flexible and performant data pipelines. View it on GitHub. +- functorch, a library that adds composable function transforms to PyTorch, is now available in beta. View it on GitHub. +- Distributed Data Parallel (DDP) static graph optimizations available in stable. + +You can check the blogpost that shows the new features here. + +**New Features & Improvements** + +- Improve arguments type in `preprocess_div2k_dataset.py`. (#1381) +- Update docstring of RDN. (#1326) +- Update the introduction in readme. (#) + +**Bug Fixes** + +- Fix FLAVR register in `mmedit/models/video_interpolators` when importing `FLAVR`. (#1186) +- Fix data path processing in `restoration_video_inference.py`. (#1262) +- Fix the number of channels in RDB. (#1292, #1311) + +**Contributors** + +A total of 5 developers contributed to this release. +Thanks @LeoXing1996, @Z-Fran, @zengyh1900, @ryanxingql, @ruoningYu. + ## v1.0.0rc1(23/9/2022) MMEditing 1.0.0rc1 has merged MMGeneration 1.x. diff --git a/mmedit/version.py b/mmedit/version.py index b367889550..da34a5e738 100644 --- a/mmedit/version.py +++ b/mmedit/version.py @@ -1,6 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. -__version__ = '1.0.0rc1' +__version__ = '1.0.0rc2' def parse_version_info(version_str): From 2a490a05d3117ebc7374a5f9b0c66d33c9f9fa90 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 3 Nov 2022 10:00:16 +0800 Subject: [PATCH 28/32] update changelog --- docs/en/notes/3_changelog.md | 46 +++++++++++++------ .../test_inst_colorization.py | 6 +++ .../test_weight_layer.py | 6 +++ .../test_editors/test_liif/test_liif_net.py | 6 +++ 4 files changed, 51 insertions(+), 13 deletions(-) diff --git a/docs/en/notes/3_changelog.md b/docs/en/notes/3_changelog.md index 258664940d..189f8ed46a 100644 --- a/docs/en/notes/3_changelog.md +++ b/docs/en/notes/3_changelog.md @@ -3,30 +3,50 @@ ## v1.0.0rc2 (02/11/2022) **Highlights** -We are excited to announce the release of PyTorch 1.11. This release is composed of over 3,300 commits since 1.10, made by 434 contributors. Along with 1.11, we are releasing beta versions of TorchData and functorch. We want to sincerely thank our community for continuously improving PyTorch. -- TorchData is a new library for common modular data loading primitives for easily constructing flexible and performant data pipelines. View it on GitHub. -- functorch, a library that adds composable function transforms to PyTorch, is now available in beta. View it on GitHub. -- Distributed Data Parallel (DDP) static graph optimizations available in stable. +We are excited to announce the release of MMEditing 1.0.0rc2. This release supports 43+ models, 170+ configs and 169+ checkpoints in MMGeneration and MMEditing. We highlight the following new features -You can check the blogpost that shows the new features here. +- patch-based and slider-based image and video comparison viewer. +- image colorization. + +We want to sincerely thank our community for continuously improving MMEditing. **New Features & Improvements** -- Improve arguments type in `preprocess_div2k_dataset.py`. (#1381) -- Update docstring of RDN. (#1326) -- Update the introduction in readme. (#) +- Support qualitative comparison tools. (#1303) +- Support instance aware colorization. (#1370) +- Support multi-metrics with different sample-model. (#1171) +- Improve the implementation + - refactoring evaluation metrics. (#1164) + - Save gt images in PGGAN's `forward`. (#1332) + - Improve type and change default number of `preprocess_div2k_dataset.py`. (#1380) + - Support pixel value clip in visualizer. (#1365) + - Support SinGAN Dataset and SinGAN demo. (#1363) + - Avoid cast int and float in GenDataPreprocessor. (#1385) +- Improve the documentation + - Update a menu switcher. (#1162) + - Fix TTSR's README. (#1325) + - Revise docs (change `PackGenInputs` and `GenDataSample`). (#1382) **Bug Fixes** -- Fix FLAVR register in `mmedit/models/video_interpolators` when importing `FLAVR`. (#1186) -- Fix data path processing in `restoration_video_inference.py`. (#1262) -- Fix the number of channels in RDB. (#1292, #1311) +- Fix PPL bug. (#1172) +- Fix RDN number of channels. (#1328) +- Fix types of exceptions in demos. (#1372) +- Fix realesrgan ema. (#1341) +- Improve the assertion to ensuer `GenerateFacialHeatmap` as `np.float32`. (#1310) +- Fix sampling behavior of `unpaired_dataset.py` and urls in cyclegan's README. (#1308) +- Fix vsr models in pytorch2onnx. (#1300) +- Fix incorrect settings in configs. (#1167,#1200,#1236,#1293,#1302,#1304,#1319,#1331,#1336,#1349,#1352,#1353,#1358,#1364,#1367,#1384,#1386,#1391,#1392,#1393) + +**New Contributors** + +- @gaoyang07 made their first contribution in https://github.com/open-mmlab/mmediting/pull/1372 **Contributors** -A total of 5 developers contributed to this release. -Thanks @LeoXing1996, @Z-Fran, @zengyh1900, @ryanxingql, @ruoningYu. +A total of 7 developers contributed to this release. +Thanks @LeoXing1996, @Z-Fran, @zengyh1900, @plyfager, @ryanxingql, @ruoningYu, @gaoyang07. ## v1.0.0rc1(23/9/2022) diff --git a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py index 1b77725604..1c810bf8bf 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest import torch from mmedit.registry import BACKBONES @@ -6,6 +9,9 @@ from mmedit.utils import register_all_modules +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') class TestInstColorization: def test_inst_colorization(self): diff --git a/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py index e009af981f..8293eb116e 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py @@ -1,9 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest import torch from mmedit.models.editors.inst_colorization.weight_layer import WeightLayer +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') def test_weight_layer(): weight_layer = WeightLayer(64) diff --git a/tests/test_models/test_editors/test_liif/test_liif_net.py b/tests/test_models/test_editors/test_liif/test_liif_net.py index c1642a06f6..ab8f409adb 100644 --- a/tests/test_models/test_editors/test_liif/test_liif_net.py +++ b/tests/test_models/test_editors/test_liif/test_liif_net.py @@ -1,9 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest import torch from mmedit.registry import BACKBONES +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') def test_liif_edsr_net(): model_cfg = dict( From 136c9a339bc8756328427d687aa401e5d8b6036a Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 3 Nov 2022 11:03:05 +0800 Subject: [PATCH 29/32] enable tmate --- .github/workflows/pr_stage_test.yml | 18 +++++++++--------- docs/en/notes/3_changelog.md | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index f7a4db6aa5..4f0d9d3136 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -59,9 +59,9 @@ jobs: env_vars: OS,PYTHON name: codecov-umbrella fail_ci_if_error: false - # - name: Setup tmate session - # if: ${{ failure() }} - # uses: mxschmitt/action-tmate@v3 + - name: Setup tmate session + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 build_cu102: runs-on: ubuntu-18.04 @@ -102,9 +102,9 @@ jobs: run: | python setup.py check -m -s TORCH_CUDA_ARCH_LIST=7.0 pip install -e . - # - name: Setup tmate session - # if: ${{ failure() }} - # uses: mxschmitt/action-tmate@v3 + - name: Setup tmate session + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 build_windows: runs-on: ${{ matrix.os }} @@ -138,6 +138,6 @@ jobs: - name: Run unittests and generate coverage report run: | pytest tests/ - # - name: Setup tmate session - # if: ${{ failure() }} - # uses: mxschmitt/action-tmate@v3 + - name: Setup tmate session + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 diff --git a/docs/en/notes/3_changelog.md b/docs/en/notes/3_changelog.md index 189f8ed46a..8a618a5788 100644 --- a/docs/en/notes/3_changelog.md +++ b/docs/en/notes/3_changelog.md @@ -46,7 +46,7 @@ We want to sincerely thank our community for continuously improving MMEditing. **Contributors** A total of 7 developers contributed to this release. -Thanks @LeoXing1996, @Z-Fran, @zengyh1900, @plyfager, @ryanxingql, @ruoningYu, @gaoyang07. +Thanks @LeoXing1996, @Z-Fran, @zengyh1900, @plyfager, @ryanxingql, @ruoningYu, @gaoyang07. ## v1.0.0rc1(23/9/2022) From 3e4eeae1a468c03f681b40a339a1a396b63b4d92 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 3 Nov 2022 12:12:16 +0800 Subject: [PATCH 30/32] fix ut --- .circleci/test.yml | 5 ++-- .github/workflows/merge_stage_test.yml | 12 ++++----- .github/workflows/pr_stage_test.yml | 27 +++++++++---------- .../test_transforms/test_crop.py | 3 +++ 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/.circleci/test.yml b/.circleci/test.yml index 11e7189571..b6974c6456 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -62,7 +62,7 @@ jobs: pip install git+https://github.com/open-mmlab/mmengine.git@main pip install -U openmim mim install 'mmcv >= 2.0.0rc1' - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet >= 3.0.0rc2' pip install -r requirements/tests.txt - run: name: Build and install @@ -95,7 +95,6 @@ jobs: name: Clone Repos command: | git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine - git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection - run: name: Build Docker image command: | @@ -107,7 +106,7 @@ jobs: docker exec mmedit pip install -e /mmengine docker exec mmedit pip install -U openmim docker exec mmedit mim install 'mmcv >= 2.0.0rc1' - docker exec mmedit pip install -e /mmdetection + docker exec mmedit mim install 'mmdet >= 3.0.0rc2' docker exec mmedit pip install -r requirements/tests.txt - run: name: Build and install diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index be938eb1dd..f5bf7d8c0d 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -41,12 +41,11 @@ jobs: run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install MMEngine run: pip install git+https://github.com/open-mmlab/mmengine.git@main - - name: Install MMCV + - name: Install MMCV and MMDet run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' - - name: Install MMDet - run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet >= 3.0.0rc2' - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -90,11 +89,11 @@ jobs: run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install MMEngine run: pip install git+https://github.com/open-mmlab/mmengine.git@main - - name: Install MMCV + - name: Install MMCV and MMDet run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet >= 3.0.0rc2' - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -148,7 +147,6 @@ jobs: - name: Install mmediting dependencies run: | pip install git+https://github.com/open-mmlab/mmengine.git@main - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install -U openmim mim install 'mmcv >= 2.0.0rc1' pip install -r requirements/tests.txt @@ -179,9 +177,9 @@ jobs: - name: Install mmediting dependencies run: | python -m pip install git+https://github.com/open-mmlab/mmengine.git@main - python -m pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x python -m pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' python -m pip install -r requirements/tests.txt - name: Build and install run: | diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index 4f0d9d3136..e27703b379 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -35,12 +35,11 @@ jobs: run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install MMEngine run: pip install git+https://github.com/open-mmlab/mmengine.git@main - - name: Install MMCV + - name: Install MMCV and MMDet run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' - - name: Install MMDet - run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + mim install 'mmdet >= 3.0.0rc2' - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -59,9 +58,9 @@ jobs: env_vars: OS,PYTHON name: codecov-umbrella fail_ci_if_error: false - - name: Setup tmate session - if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 + # - name: Setup tmate session + # if: ${{ failure() }} + # uses: mxschmitt/action-tmate@v3 build_cu102: runs-on: ubuntu-18.04 @@ -94,17 +93,17 @@ jobs: - name: Install mmedit dependencies run: | pip install git+https://github.com/open-mmlab/mmengine.git@main - pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' pip install -r requirements/tests.txt - name: Build and install run: | python setup.py check -m -s TORCH_CUDA_ARCH_LIST=7.0 pip install -e . - - name: Setup tmate session - if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 + # - name: Setup tmate session + # if: ${{ failure() }} + # uses: mxschmitt/action-tmate@v3 build_windows: runs-on: ${{ matrix.os }} @@ -128,9 +127,9 @@ jobs: - name: Install mmedit dependencies run: | python -m pip install git+https://github.com/open-mmlab/mmengine.git@main - python -m pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x python -m pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' python -m pip install -r requirements/tests.txt - name: Build and install run: | @@ -138,6 +137,6 @@ jobs: - name: Run unittests and generate coverage report run: | pytest tests/ - - name: Setup tmate session - if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 + # - name: Setup tmate session + # if: ${{ failure() }} + # uses: mxschmitt/action-tmate@v3 diff --git a/tests/test_datasets/test_transforms/test_crop.py b/tests/test_datasets/test_transforms/test_crop.py index c763a679cc..8d54a183a0 100644 --- a/tests/test_datasets/test_transforms/test_crop.py +++ b/tests/test_datasets/test_transforms/test_crop.py @@ -5,6 +5,7 @@ import cv2 import numpy as np import pytest +import torch from mmedit.datasets.transforms import (Crop, CropLike, FixedCrop, InstanceCrop, ModCrop, @@ -355,6 +356,8 @@ def test_crop_like(): assert sum_diff < 1e-6 +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='require cuda support') def test_instance_crop(): croper = InstanceCrop( From 831896d274db07f69f61503064cf599f64eff95a Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 3 Nov 2022 14:20:19 +0800 Subject: [PATCH 31/32] fix ut --- tests/test_datasets/test_transforms/test_crop.py | 7 +++++-- .../test_inst_colorization/test_inst_colorization.py | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_datasets/test_transforms/test_crop.py b/tests/test_datasets/test_transforms/test_crop.py index 8d54a183a0..703e7657c6 100644 --- a/tests/test_datasets/test_transforms/test_crop.py +++ b/tests/test_datasets/test_transforms/test_crop.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import os.path as osp +import unittest import cv2 import numpy as np @@ -356,10 +357,12 @@ def test_crop_like(): assert sum_diff < 1e-6 -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='require cuda support') def test_instance_crop(): + if not torch.cuda.is_available(): + # RoI pooling only support in GPU + return unittest.skip('test requires GPU and torch+cuda') + croper = InstanceCrop( key='img', finesize=256, diff --git a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py index 1c810bf8bf..5d769b2134 100644 --- a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py +++ b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import platform +import unittest import pytest import torch @@ -15,6 +16,10 @@ class TestInstColorization: def test_inst_colorization(self): + if not torch.cuda.is_available(): + # RoI pooling only support in GPU + return unittest.skip('test requires GPU and torch+cuda') + register_all_modules() model_cfg = dict( type='InstColorization', From f1c6c80c4a222296c74e69a4f9a2b481af09cfd7 Mon Sep 17 00:00:00 2001 From: zenggyh1900 Date: Thu, 3 Nov 2022 15:01:38 +0800 Subject: [PATCH 32/32] skip workflow on windows-cu111 due to limited ram --- tests/test_apis/test_colorization_inference.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_apis/test_colorization_inference.py b/tests/test_apis/test_colorization_inference.py index 8772a31fc2..3e574633bb 100644 --- a/tests/test_apis/test_colorization_inference.py +++ b/tests/test_apis/test_colorization_inference.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +import platform +import unittest +import pytest import torch from mmengine import Config from mmengine.runner import load_checkpoint @@ -10,9 +13,16 @@ from mmedit.utils import register_all_modules, tensor2img +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') def test_colorization_inference(): register_all_modules() + if not torch.cuda.is_available(): + # RoI pooling only support in GPU + return unittest.skip('test requires GPU and torch+cuda') + if torch.cuda.is_available(): device = torch.device('cuda', 0) else: