diff --git a/.circleci/test.yml b/.circleci/test.yml index d1a3489802..b6974c6456 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -62,6 +62,7 @@ jobs: pip install git+https://github.com/open-mmlab/mmengine.git@main pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' pip install -r requirements/tests.txt - run: name: Build and install @@ -98,13 +99,14 @@ jobs: name: Build Docker image command: | docker build .circleci/docker -t mmedit:gpu --build-arg PYTORCH=<< parameters.torch >> --build-arg CUDA=<< parameters.cuda >> --build-arg CUDNN=<< parameters.cudnn >> - docker run --gpus all -t -d -v /home/circleci/project:/mmedit -v /home/circleci/mmengine:/mmengine -w /mmedit --name mmedit mmedit:gpu + docker run --gpus all -t -d -v /home/circleci/project:/mmedit -v /home/circleci/mmengine:/mmengine -v /home/circleci/mmdetection:/mmdetection -w /mmedit --name mmedit mmedit:gpu - run: name: Install mmedit dependencies command: | docker exec mmedit pip install -e /mmengine docker exec mmedit pip install -U openmim docker exec mmedit mim install 'mmcv >= 2.0.0rc1' + docker exec mmedit mim install 'mmdet >= 3.0.0rc2' docker exec mmedit pip install -r requirements/tests.txt - run: name: Build and install diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index e22dcb528a..f5bf7d8c0d 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -41,10 +41,11 @@ jobs: run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install MMEngine run: pip install git+https://github.com/open-mmlab/mmengine.git@main - - name: Install MMCV + - name: Install MMCV and MMDet run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -88,10 +89,11 @@ jobs: run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install MMEngine run: pip install git+https://github.com/open-mmlab/mmengine.git@main - - name: Install MMCV + - name: Install MMCV and MMDet run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -177,6 +179,7 @@ jobs: python -m pip install git+https://github.com/open-mmlab/mmengine.git@main python -m pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' python -m pip install -r requirements/tests.txt - name: Build and install run: | diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index b55f810642..e27703b379 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -35,10 +35,11 @@ jobs: run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install MMEngine run: pip install git+https://github.com/open-mmlab/mmengine.git@main - - name: Install MMCV + - name: Install MMCV and MMDet run: | pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' - name: Install other dependencies run: pip install -r requirements/tests.txt - name: Build and install @@ -94,6 +95,7 @@ jobs: pip install git+https://github.com/open-mmlab/mmengine.git@main pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' pip install -r requirements/tests.txt - name: Build and install run: | @@ -127,6 +129,7 @@ jobs: python -m pip install git+https://github.com/open-mmlab/mmengine.git@main python -m pip install -U openmim mim install 'mmcv >= 2.0.0rc1' + mim install 'mmdet >= 3.0.0rc2' python -m pip install -r requirements/tests.txt - name: Build and install run: | diff --git a/README.md b/README.md index f79f5b4551..9ace05994b 100644 --- a/README.md +++ b/README.md @@ -104,22 +104,26 @@ hope MMEditing could provide better experience. ## What's New -- \[2022-09-13\] 🎉[MMGeneration](https://github.com/open-mmlab/mmgeneration/tree/1.x) was merged into MMEditing! And we are calling for your [suggestion](https://github.com/open-mmlab/mmediting/discussions/1108)! -- \[2022-08-31\] v1.0.0rc0 was released. This release introduced a brand new and flexible training & test engine, but it's still in progress. Welcome - to try according to [the documentation](https://mmediting.readthedocs.io/en/1.x/). -- \[2022-06-01\] v0.15.0 was released. - - Support FLAVR - - Support AOT-GAN - - Support CAIN with ReduceLROnPlateau Scheduler -- \[2022-04-01\] v0.14.0 was released. - - Support TOFlow in video frame interpolation -- \[2022-03-01\] v0.13.0 was released. - - Support CAIN - - Support EDVR-L - - Support running in Windows -- \[2022-02-11\] Switch to **PyTorch 1.5+**. The compatibility to earlier versions of PyTorch will no longer be guaranteed. - -Please refer to [changelog.md](docs/en/notes/3_changelog.md) for details and release history. +### 🌟 Preview of 1.x version + +A brand new version of [**MMEditing v1.0.0rc2**](https://github.com/open-mmlab/mmediting/releases/tag/v1.0.0rc2) was released in 02/11/2022: + +- Support all the tasks, models, metrics, and losses in [MMGeneration](https://github.com/open-mmlab/mmgeneration) 😍。 +- Unifies interfaces of all components based on [MMEngine](https://github.com/open-mmlab/mmengine). +- Support patch-based and slider-based image and video comparison viewer. +- Support image colorization. + +Find more new features in [1.x branch](https://github.com/open-mmlab/mmediting/tree/1.x). Issues and PRs are welcome! + +### 💎 Stable version + +**0.16.0** was released in 31/10/2022: + +- `VisualizationHook` is deprecated. Users should use `MMEditVisualizationHook` instead. +- Fix FLAVR register. +- Fix the number of channels in RDB. + +Please refer to [changelog.md](docs/en/changelog.md) for details and release history. ## Installation @@ -215,6 +219,13 @@ Supported algorithms: +
+Image Colorization + +- ✅ [InstColorization](configs/inst_colorization/README.md) (CVPR'2020) + +
+
Unconditional GANs diff --git a/README_zh-CN.md b/README_zh-CN.md index 7f820b5456..64fca3222b 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -37,7 +37,7 @@ [English](/README.md) | 简体中文 -## Introduction +## 介绍 MMEditing 是基于 PyTorch 的图像&视频编辑和生成开源工具箱。是 [OpenMMLab](https://openmmlab.com/) 项目的成员之一。 @@ -101,24 +101,28 @@ https://user-images.githubusercontent.com/12756472/158972813-d8d0f19c-f49c-4618- 需要注意的是 **MMSR** 已作为 MMEditing 的一部分并入本仓库。 MMEditing 缜密地设计新的框架并将其精心实现,希望能够为您带来更好的体验。 -## 最新消息 - -- \[2022-09-13\] 🎉 [MMGeneration](<(https://github.com/open-mmlab/mmgeneration/tree/1.x)>) 合入 MMEditing! 对于该合入计划,我们期待您的 [建议](https://github.com/open-mmlab/mmediting/discussions/1108)! -- \[2022-08-31\] v1.0.0rc0 版本发布 - 这个版本引入一个全新的,可扩展性强的训练和测试引擎,但目前仍在开发中。欢迎根据[文档](https://mmediting.readthedocs.io/en/1.x/)进行试用。 -- \[2022-06-01\] v0.15.0 版本发布 - - 支持 FLAVR - - 支持 AOT-GAN - - 新版 CAIN,支持 ReduceLROnPlateau 策略 -- \[2022-04-01\] v0.14.0 版本发布 - - 支持视频插帧算法 TOFlow -- \[2022-03-01\] v0.13.0 版本发布 - - 支持 CAIN - - 支持 EDVR-L - - 支持在 Windows 系统中运行 -- \[2022-02-11\] 切换到 **PyTorch 1.5+**. 将不再保证与早期版本的 PyTorch 的兼容性 - -请查看 [changelog.md](docs/zh_cn/notes/3_changelog.md) 以获取更多细节与发版记录 +## 最新进展 + +### 🌟 1.x 预览版本 + +全新的 [**MMEditing v1.0.0rc2**](https://github.com/open-mmlab/mmediting/releases/tag/v1.0.0rc2) 已经在 02/11/2022 发布: + +- 支持[MMGeneration](https://github.com/open-mmlab/mmgeneration)中的全量任务、模型、优化函数和评价指标 😍。 +- 基于[MMEngine](https://github.com/open-mmlab/mmengine)统一了各组件接口。 +- 支持基于图像子块以及滑动条的图像和视频比较可视化工具。 +- 支持图像上色任务。 + +在[1.x 分支](https://github.com/open-mmlab/mmediting/tree/1.x)中发现更多特性!欢迎提 Issues 和 PRs! + +### 💎 稳定版本 + +最新的 **0.16.0** 版本已经在 31/10/2022 发布: + +- `VisualizationHook` 将被启用,建议用户使用 `MMEditVisualizationHook`。 +- 修复 FLAVR 的注册问题。 +- 修正 RDB 模型中的通道数。 + +如果像了解更多版本更新细节和历史信息,请阅读[更新日志](docs/en/changelog.md)。 ## 安装 @@ -213,6 +217,13 @@ pip3 install -e .
+
+图像上色 + +- ✅ [InstColorization](configs/inst_colorization/README.md) (CVPR'2020) + +
+
Unconditional GANs diff --git a/configs/inst_colorization/README.md b/configs/inst_colorization/README.md new file mode 100644 index 0000000000..fdfdfb9e50 --- /dev/null +++ b/configs/inst_colorization/README.md @@ -0,0 +1,55 @@ +# Instance-aware Image Colorization (CVPR'2020) + +> [Instance-Aware Image Colorization](https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html) + +> **Task**: Colorization + + + +## Abstract + + + +Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization. + + + +
+ +
+ +## Results and models + +| Method | Download | +| :-------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------: | +| [instance_aware_colorization_officiial](/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth) | + +## Quick Start + +
+Colorization demo + +You can use the following commands to colorize an image. + +```shell + +python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth input.jpg output.jpg +``` + +For more demos, you can refer to [Tutorial 3: inference with pre-trained models](https://mmediting.readthedocs.io/en/1.x/user_guides/3_inference.html). + +
+ +
+Instance-aware Image Colorization (CVPR'2020) + +```bibtex +@inproceedings{Su-CVPR-2020, + author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin}, + title = {Instance-aware Image Colorization}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2020} +} +``` + +
diff --git a/configs/inst_colorization/README_zh-CN.md b/configs/inst_colorization/README_zh-CN.md new file mode 100644 index 0000000000..19e59c64fc --- /dev/null +++ b/configs/inst_colorization/README_zh-CN.md @@ -0,0 +1,54 @@ +# Instance-aware Image Colorization (CVPR'2020) + +> [Instance-Aware Image Colorization](https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html) + +> **任务**: 图像上色 + + + +## 摘要 + + + +Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization. + + + +
+ +
+ +## 结果和模型 + +| Method | Download | +| :-------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------: | +| [instance_aware_colorization_officiial](/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth) | + +## 快速开始 + +
+图像上色模型 + +您可以使用以下命令来对一张图像进行上色。 + +```shell +python demo/colorization_demo.py configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth input.jpg output.jpg +``` + +更多细节可以参考 [Tutorial 3: inference with pre-trained models](https://mmediting.readthedocs.io/en/1.x/user_guides/3_inference.html)。 + +
+ +
+Instance-aware Image Colorization (CVPR'2020) + +```bibtex +@inproceedings{Su-CVPR-2020, + author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin}, + title = {Instance-aware Image Colorization}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2020} +} +``` + +
diff --git a/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py new file mode 100644 index 0000000000..952bc74cda --- /dev/null +++ b/configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py @@ -0,0 +1,59 @@ +_base_ = ['../_base_/default_runtime.py'] + +experiment_name = 'inst-colorization_full_official_cocostuff_256x256' +work_dir = f'./work_dirs/{experiment_name}' +save_dir = './work_dirs/' + +stage = 'full' + +model = dict( + type='InstColorization', + data_preprocessor=dict( + type='EditDataPreprocessor', + mean=[127.5], + std=[127.5], + ), + image_model=dict( + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), + instance_model=dict( + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch'), + fusion_model=dict( + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), + color_data_opt=dict( + ab_thresh=0, + p=1.0, + sample_PS=[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ], + ab_norm=110, + ab_max=110., + ab_quant=10., + l_norm=100., + l_cent=50., + mask_cent=0.5), + which_direction='AtoB', + loss=dict(type='HuberLoss', delta=.01)) + +# yapf: disable +test_pipeline = [ + dict(type='LoadImageFromFile', key='img', channel_order='rgb'), + dict( + type='InstanceCrop', + config_file='mmdet::mask_rcnn/mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py', # noqa + finesize=256, + box_num_upbound=5), + dict( + type='Resize', + keys=['img', 'cropped_img'], + scale=(256, 256), + keep_ratio=False), + dict(type='PackEditInputs'), +] diff --git a/configs/inst_colorization/metafile.yml b/configs/inst_colorization/metafile.yml new file mode 100644 index 0000000000..c13dabfb11 --- /dev/null +++ b/configs/inst_colorization/metafile.yml @@ -0,0 +1,19 @@ +Collections: +- Metadata: + Architecture: + - Instance-aware Image Colorization + Name: Instance-aware Image Colorization + Paper: + - https://openaccess.thecvf.com/content_CVPR_2020/html/Su_Instance-Aware_Image_Colorization_CVPR_2020_paper.html + README: configs/inst_colorization/README.md +Models: +- Config: configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py + In Collection: Instance-aware Image Colorization + Metadata: + Training Data: Others + Name: inst-colorizatioon_full_official_cocostuff-256x256 + Results: + - Dataset: Others + Metrics: {} + Task: Inst_colorization + Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmediting/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256-5b9d4eee.pth diff --git a/demo/colorization_demo.py b/demo/colorization_demo.py new file mode 100644 index 0000000000..926515f637 --- /dev/null +++ b/demo/colorization_demo.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import mmcv +import torch + +from mmedit.apis import colorization_inference, init_model +from mmedit.utils import modify_args, tensor2img + + +def parse_args(): + modify_args() + parser = argparse.ArgumentParser(description='Colorzation demo') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoints', help='checkpoints file path') + parser.add_argument('img_path', help='path to input image file') + parser.add_argument('save_path', help='path to save generation result') + parser.add_argument( + '--imshow', action='store_true', help='whether show image with opencv') + parser.add_argument('--device', type=int, default=0, help='CUDA device id') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.device < 0 or not torch.cuda.is_available(): + device = torch.device('cpu') + else: + device = torch.device('cuda', args.device) + + model = init_model(args.config, args.checkpoints, device=device) + output = colorization_inference(model, args.img_path) + result = tensor2img(output) + mmcv.imwrite(result, args.save_path) + + if args.imshow: + mmcv.imshow(output, 'predicted generation result') + + +if __name__ == '__main__': + main() diff --git a/docs/en/3_model_zoo.md b/docs/en/3_model_zoo.md index 7f2d8716b6..af3547569a 100644 --- a/docs/en/3_model_zoo.md +++ b/docs/en/3_model_zoo.md @@ -1,19 +1,20 @@ # Overview -- Number of checkpoints: 168 -- Number of configs: 168 -- Number of papers: 41 - - ALGORITHM: 42 +- Number of checkpoints: 169 +- Number of configs: 169 +- Number of papers: 42 + - ALGORITHM: 43 - Tasks: - - unconditional gans - image2image translation - - conditional gans - - matting + - video interpolation + - unconditional gans - image super-resolution - - video super-resolution - - inpainting - internal learning - - video interpolation + - conditional gans + - inpainting + - video super-resolution + - colorization + - matting For supported datasets, see [datasets overview](dataset_zoo/0_overview.md). @@ -185,6 +186,14 @@ For supported datasets, see [datasets overview](dataset_zoo/0_overview.md). - Number of papers: 1 - \[ALGORITHM\] Indices Matter: Learning to Index for Deep Image Matting ([⇨](https://github.com/open-mmlab/mmediting/blob/1.x/configs/indexnet/README.md#citation)) +## Instance-aware Image Colorization (CVPR'2020) + +- Tasks: colorization +- Number of checkpoints: 1 +- Number of configs: 1 +- Number of papers: 1 + - \[ALGORITHM\] Instance-Aware Image Colorization ([⇨](https://github.com/open-mmlab/mmediting/blob/1.x/configs/inst_colorization/README.md#quick-start)) + ## LIIF (CVPR'2021) - Tasks: image super-resolution diff --git a/docs/en/notes/3_changelog.md b/docs/en/notes/3_changelog.md index be177b9f25..8a618a5788 100644 --- a/docs/en/notes/3_changelog.md +++ b/docs/en/notes/3_changelog.md @@ -1,5 +1,53 @@ # Changelog +## v1.0.0rc2 (02/11/2022) + +**Highlights** + +We are excited to announce the release of MMEditing 1.0.0rc2. This release supports 43+ models, 170+ configs and 169+ checkpoints in MMGeneration and MMEditing. We highlight the following new features + +- patch-based and slider-based image and video comparison viewer. +- image colorization. + +We want to sincerely thank our community for continuously improving MMEditing. + +**New Features & Improvements** + +- Support qualitative comparison tools. (#1303) +- Support instance aware colorization. (#1370) +- Support multi-metrics with different sample-model. (#1171) +- Improve the implementation + - refactoring evaluation metrics. (#1164) + - Save gt images in PGGAN's `forward`. (#1332) + - Improve type and change default number of `preprocess_div2k_dataset.py`. (#1380) + - Support pixel value clip in visualizer. (#1365) + - Support SinGAN Dataset and SinGAN demo. (#1363) + - Avoid cast int and float in GenDataPreprocessor. (#1385) +- Improve the documentation + - Update a menu switcher. (#1162) + - Fix TTSR's README. (#1325) + - Revise docs (change `PackGenInputs` and `GenDataSample`). (#1382) + +**Bug Fixes** + +- Fix PPL bug. (#1172) +- Fix RDN number of channels. (#1328) +- Fix types of exceptions in demos. (#1372) +- Fix realesrgan ema. (#1341) +- Improve the assertion to ensuer `GenerateFacialHeatmap` as `np.float32`. (#1310) +- Fix sampling behavior of `unpaired_dataset.py` and urls in cyclegan's README. (#1308) +- Fix vsr models in pytorch2onnx. (#1300) +- Fix incorrect settings in configs. (#1167,#1200,#1236,#1293,#1302,#1304,#1319,#1331,#1336,#1349,#1352,#1353,#1358,#1364,#1367,#1384,#1386,#1391,#1392,#1393) + +**New Contributors** + +- @gaoyang07 made their first contribution in https://github.com/open-mmlab/mmediting/pull/1372 + +**Contributors** + +A total of 7 developers contributed to this release. +Thanks @LeoXing1996, @Z-Fran, @zengyh1900, @plyfager, @ryanxingql, @ruoningYu, @gaoyang07. + ## v1.0.0rc1(23/9/2022) MMEditing 1.0.0rc1 has merged MMGeneration 1.x. diff --git a/mmedit/apis/__init__.py b/mmedit/apis/__init__.py index 63989da131..75033721b0 100644 --- a/mmedit/apis/__init__.py +++ b/mmedit/apis/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .colorization_inference import colorization_inference from .gan_inference import sample_conditional_model, sample_unconditional_model from .inference import delete_cfg, init_model, set_random_seed from .inpainting_inference import inpainting_inference @@ -10,16 +11,10 @@ from .video_interpolation_inference import video_interpolation_inference __all__ = [ - 'init_model', - 'delete_cfg', - 'set_random_seed', - 'matting_inference', - 'inpainting_inference', - 'restoration_inference', - 'restoration_video_inference', - 'restoration_face_inference', - 'video_interpolation_inference', - 'sample_conditional_model', - 'sample_unconditional_model', - 'sample_img2img_model', + 'init_model', 'delete_cfg', 'set_random_seed', 'matting_inference', + 'inpainting_inference', 'restoration_inference', + 'restoration_video_inference', 'restoration_face_inference', + 'video_interpolation_inference', 'sample_conditional_model', + 'sample_unconditional_model', 'sample_img2img_model', + 'colorization_inference' ] diff --git a/mmedit/apis/colorization_inference.py b/mmedit/apis/colorization_inference.py new file mode 100644 index 0000000000..ddef7ef587 --- /dev/null +++ b/mmedit/apis/colorization_inference.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import Compose +from mmengine.dataset.utils import default_collate as collate +from torch.nn.parallel import scatter + + +def colorization_inference(model, img): + """Inference image with the model. + + Args: + model (nn.Module): The loaded model. + img (str): Image file path. + + Returns: + Tensor: The predicted colorization result. + """ + device = next(model.parameters()).device + + # build the data pipeline + test_pipeline = Compose(model.cfg.test_pipeline) + # prepare data + data = dict(img_path=img) + _data = test_pipeline(data) + data = dict() + data['inputs'] = _data['inputs'] / 255.0 + data = collate([data]) + data['data_samples'] = [_data['data_samples']] + if 'cuda' in str(device): + data = scatter(data, [device])[0] + if not data['data_samples'][0].empty_box: + data['data_samples'][0].cropped_img.data = scatter( + data['data_samples'][0].cropped_img.data, [device])[0] / 255.0 + + data['data_samples'][0].box_info.data = scatter( + data['data_samples'][0].box_info.data, [device])[0] + + data['data_samples'][0].box_info_2x.data = scatter( + data['data_samples'][0].box_info_2x.data, [device])[0] + + data['data_samples'][0].box_info_4x.data = scatter( + data['data_samples'][0].box_info_4x.data, [device])[0] + + data['data_samples'][0].box_info_8x.data = scatter( + data['data_samples'][0].box_info_8x.data, [device])[0] + + # forward the model + with torch.no_grad(): + result = model(mode='tensor', **data) + + return result diff --git a/mmedit/datasets/transforms/__init__.py b/mmedit/datasets/transforms/__init__.py index f5eb2a02c6..39543085f1 100644 --- a/mmedit/datasets/transforms/__init__.py +++ b/mmedit/datasets/transforms/__init__.py @@ -6,8 +6,9 @@ from .aug_shape import (Flip, NumpyPad, RandomRotation, RandomTransposeHW, Resize) from .crop import (CenterCropLongEdge, Crop, CropAroundCenter, CropAroundFg, - CropAroundUnknown, CropLike, FixedCrop, ModCrop, - PairedRandomCrop, RandomCropLongEdge, RandomResizedCrop) + CropAroundUnknown, CropLike, FixedCrop, InstanceCrop, + ModCrop, PairedRandomCrop, RandomCropLongEdge, + RandomResizedCrop) from .fgbg import (CompositeFg, MergeFgAndBg, PerturbBg, RandomJitter, RandomLoadResizeBg) from .formatting import PackEditInputs, ToTensor @@ -45,5 +46,5 @@ 'GenerateSoftSeg', 'FormatTrimap', 'TransformTrimap', 'GenerateTrimap', 'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg', 'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile', - 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad' + 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop' ] diff --git a/mmedit/datasets/transforms/aug_shape.py b/mmedit/datasets/transforms/aug_shape.py index 6e62c3a4a9..fb50fe1134 100644 --- a/mmedit/datasets/transforms/aug_shape.py +++ b/mmedit/datasets/transforms/aug_shape.py @@ -319,24 +319,31 @@ def _resize(self, img): Returns: img (np.ndarray): Resized image. """ - - if self.keep_ratio: - img, self.scale_factor = mmcv.imrescale( - img, - self.scale, - return_scale=True, - interpolation=self.interpolation, - backend=self.backend) + if isinstance(img, list): + for i, image in enumerate(img): + size, img[i] = self._resize(image) + return size, img else: - img, w_scale, h_scale = mmcv.imresize( - img, - self.scale, - return_scale=True, - interpolation=self.interpolation, - backend=self.backend) - self.scale_factor = np.array((w_scale, h_scale), dtype=np.float32) - - return img + if self.keep_ratio: + img, self.scale_factor = mmcv.imrescale( + img, + self.scale, + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + else: + img, w_scale, h_scale = mmcv.imresize( + img, + self.scale, + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + self.scale_factor = np.array((w_scale, h_scale), + dtype=np.float32) + + if len(img.shape) == 2: + img = np.expand_dims(img, axis=2) + return img.shape, img def transform(self, results: Dict) -> Dict: """Transform function to resize images. @@ -358,11 +365,11 @@ def transform(self, results: Dict) -> Dict: new_w = min(self.max_size - (self.max_size % self.size_factor), new_w) self.scale = (new_w, new_h) + for key, out_key in zip(self.keys, self.output_keys): - results[out_key] = self._resize(results[key]) - if len(results[out_key].shape) == 2: - results[out_key] = np.expand_dims(results[out_key], axis=2) - results[f'{out_key}_shape'] = results[out_key].shape + if key in results: + size, results[out_key] = self._resize(results[key]) + results[f'{out_key}_shape'] = size results['scale_factor'] = self.scale_factor results['keep_ratio'] = self.keep_ratio diff --git a/mmedit/datasets/transforms/crop.py b/mmedit/datasets/transforms/crop.py index 906955254f..8554b72018 100644 --- a/mmedit/datasets/transforms/crop.py +++ b/mmedit/datasets/transforms/crop.py @@ -2,14 +2,19 @@ import math import random +import cv2 as cv import mmcv import numpy as np +import torch from mmcv.transforms import BaseTransform +from mmdet.apis import inference_detector, init_detector +from mmengine.hub import get_config +from mmengine.registry import DefaultScope from mmengine.utils import is_list_of, is_tuple_of from torch.nn.modules.utils import _pair from mmedit.registry import TRANSFORMS -from mmedit.utils import random_choose_unknown +from mmedit.utils import get_box_info, random_choose_unknown @TRANSFORMS.register_module() @@ -916,3 +921,108 @@ def __repr__(self): repr_str = self.__class__.__name__ repr_str += (f'(keys={self.keys})') return repr_str + + +@TRANSFORMS.register_module() +class InstanceCrop(BaseTransform): + """Use maskrcnn to detect instances on image. + + Mask R-CNN is used to detect the instance on the image + pred_bbox is used to segment the instance on the image + + Args: + config_file (str): config file name relative to detectron2's "configs/" + key (str): Unused + box_num_upbound (int):The upper limit on the number of instances + in the figure + """ + + def __init__(self, + config_file, + key='img', + box_num_upbound=-1, + finesize=256): + + cfg = get_config(config_file, pretrained=True) + with DefaultScope.overwrite_default_scope('mmdet'): + self.predictor = init_detector(cfg, cfg.model_path) + + self.key = key + self.box_num_upbound = box_num_upbound + self.final_size = finesize + + def transform(self, results: dict) -> dict: + """The transform function of InstanceCrop. + + Args: + results (dict): A dict containing the necessary information and + data for Conversion + + Returns: + results (dict): A dict containing the processed data + and information. + """ + # get consistent box prediction based on L channel + + full_img = results['img'] + full_img_size = results['ori_img_shape'][:-1][::-1] + pred_bbox, pred_scores = self.predict_bbox(full_img) + + if self.box_num_upbound > 0 and pred_bbox.shape[ + 0] > self.box_num_upbound: + index_mask = np.argsort(pred_scores, axis=0) + index_mask = index_mask[pred_scores.shape[0] - + self.box_num_upbound:pred_scores.shape[0]] + pred_bbox = pred_bbox[index_mask] + + # get cropped images and box info + cropped_img_list = [] + index_list = range(len(pred_bbox)) + box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros( + (4, len(index_list), 6)) + for i in index_list: + startx, starty, endx, endy = pred_bbox[i] + cropped_img = full_img[starty:endy, startx:endx, :] + cropped_img_list.append(cropped_img) + box_info[i] = np.array( + get_box_info(pred_bbox[i], full_img_size, self.final_size)) + box_info_2x[i] = np.array( + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 2)) + box_info_4x[i] = np.array( + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 4)) + box_info_8x[i] = np.array( + get_box_info(pred_bbox[i], full_img_size, + self.final_size // 8)) + + # update results + if len(pred_bbox) > 0: + results['cropped_img'] = cropped_img_list + results['box_info'] = torch.from_numpy(box_info).type(torch.long) + results['box_info_2x'] = torch.from_numpy(box_info_2x).type( + torch.long) + results['box_info_4x'] = torch.from_numpy(box_info_4x).type( + torch.long) + results['box_info_8x'] = torch.from_numpy(box_info_8x).type( + torch.long) + results['empty_box'] = False + else: + results['empty_box'] = True + return results + + def predict_bbox(self, image): + lab_image = cv.cvtColor(image, cv.COLOR_BGR2LAB) + l_channel, _, _ = cv.split(lab_image) + l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) + + with DefaultScope.overwrite_default_scope('mmdet'): + with torch.no_grad(): + results = inference_detector(self.predictor, l_stack) + + bboxes = results.pred_instances.bboxes.cpu().numpy().astype(np.int32) + scores = results.pred_instances.scores.cpu().numpy() + index_mask = [i for i, x in enumerate(scores) if x >= 0.7] + scores = np.array(scores[index_mask]) + bboxes = np.array(bboxes[index_mask]) + return bboxes, scores diff --git a/mmedit/datasets/transforms/formatting.py b/mmedit/datasets/transforms/formatting.py index d375b56c47..5df741884d 100644 --- a/mmedit/datasets/transforms/formatting.py +++ b/mmedit/datasets/transforms/formatting.py @@ -229,6 +229,21 @@ def transform(self, results: dict) -> dict: gt_bg_tensor = images_to_tensor(gt_bg) data_sample.gt_bg = PixelData(data=gt_bg_tensor) + if 'rgb_img' in results: + gt_rgb = results.pop('rgb_img') + gt_rgb_tensor = images_to_tensor(gt_rgb) + data_sample.gt_rgb = PixelData(data=gt_rgb_tensor) + + if 'gray_img' in results: + gray = results.pop('gray_img') + gray_tensor = images_to_tensor(gray) + data_sample.gray = PixelData(data=gray_tensor) + + if 'cropped_img' in results: + cropped_img = results.pop('cropped_img') + cropped_img = images_to_tensor(cropped_img) + data_sample.cropped_img = PixelData(data=cropped_img) + metainfo = dict() for key in results: metainfo[key] = results[key] diff --git a/mmedit/models/base_models/__init__.py b/mmedit/models/base_models/__init__.py index ff1107fb67..0ec81d6d5a 100644 --- a/mmedit/models/base_models/__init__.py +++ b/mmedit/models/base_models/__init__.py @@ -10,7 +10,14 @@ from .two_stage import TwoStageInpaintor __all__ = [ - 'BaseEditModel', 'BaseGAN', 'BaseConditionalGAN', 'BaseMattor', - 'BasicInterpolator', 'BaseTranslationModel', 'OneStageInpaintor', - 'TwoStageInpaintor', 'ExponentialMovingAverage', 'RampUpEMA' + 'BaseEditModel', + 'BaseGAN', + 'BaseConditionalGAN', + 'BaseMattor', + 'BasicInterpolator', + 'BaseTranslationModel', + 'OneStageInpaintor', + 'TwoStageInpaintor', + 'ExponentialMovingAverage', + 'RampUpEMA', ] diff --git a/mmedit/models/editors/__init__.py b/mmedit/models/editors/__init__.py index e201db961c..2ecb5668e6 100644 --- a/mmedit/models/editors/__init__.py +++ b/mmedit/models/editors/__init__.py @@ -28,6 +28,7 @@ from .indexnet import (DepthwiseIndexBlock, HolisticIndexBlock, IndexedUpsample, IndexNet, IndexNetDecoder, IndexNetEncoder) +from .inst_colorization import InstColorization from .liif import LIIF, MLPRefiner from .lsgan import LSGAN from .mspie import MSPIEStyleGAN2, PESinGAN @@ -73,5 +74,5 @@ 'FBADecoder', 'WGANGP', 'CycleGAN', 'SAGAN', 'LSGAN', 'GGAN', 'Pix2Pix', 'StyleGAN1', 'StyleGAN2', 'StyleGAN3', 'BigGAN', 'DCGAN', 'ProgressiveGrowingGAN', 'SinGAN', 'IDLossModel', 'PESinGAN', - 'MSPIEStyleGAN2', 'StyleGAN3Generator' + 'MSPIEStyleGAN2', 'StyleGAN3Generator', 'InstColorization' ] diff --git a/mmedit/models/editors/inst_colorization/__init__.py b/mmedit/models/editors/inst_colorization/__init__.py new file mode 100644 index 0000000000..434ebe14d0 --- /dev/null +++ b/mmedit/models/editors/inst_colorization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .colorization_net import ColorizationNet +from .fusion_net import FusionNet +from .inst_colorization import InstColorization + +__all__ = [ + 'InstColorization', + 'ColorizationNet', + 'FusionNet', +] diff --git a/mmedit/models/editors/inst_colorization/color_utils.py b/mmedit/models/editors/inst_colorization/color_utils.py new file mode 100644 index 0000000000..6ecc57b72f --- /dev/null +++ b/mmedit/models/editors/inst_colorization/color_utils.py @@ -0,0 +1,307 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + + +def xyz2rgb(xyz): + """Conversion images from xyz to rgb. + + Args: + xyz (tensor): The images to be conversion + + Returns: + out (tensor): The converted image + """ + r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] \ + - 0.49853633 * xyz[:, 2, :, :] + g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] \ + + .04155593 * xyz[:, 2, :, :] + b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] \ + + 1.05731107 * xyz[:, 2, :, :] + + # sometimes reaches a small negative number, which causes NaNs + rgb = torch.cat((r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), + dim=1) + rgb = torch.max(rgb, torch.zeros_like(rgb)) + + mask = (rgb > .0031308).type(torch.FloatTensor) + if rgb.is_cuda: + mask = mask.cuda() + + rgb = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask) + return rgb + + +def lab2xyz(lab): + """Conversion images from lab to xyz. + + Args: + lab (tensor): The images to be conversion + + Returns: + out (tensor): The converted image + """ + y_int = (lab[:, 0, :, :] + 16.) / 116. + x_int = (lab[:, 1, :, :] / 500.) + y_int + z_int = y_int - (lab[:, 2, :, :] / 200.) + if (z_int.is_cuda): + z_int = torch.max(torch.Tensor((0, )).cuda(), z_int) + else: + z_int = torch.max(torch.Tensor((0, )), z_int) + + out = torch.cat( + (x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), + dim=1) + mask = (out > .2068966).type(torch.FloatTensor) + if (out.is_cuda): + mask = mask.cuda() + + out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask) + + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + sc = sc.to(out.device) + + out = out * sc + return out + + +def lab2rgb(lab_rs, color_data_opt): + """Conversion images from lab to rgb. + + Args: + lab_rs (tensor): The images to be conversion + color_data_opt (dict): Config for image colorspace transformation. + Include: l_norm, ab_norm, l_cent + + Returns: + out (tensor): The converted image + """ + L = lab_rs[:, + [0], :, :] * color_data_opt['l_norm'] + color_data_opt['l_cent'] + AB = lab_rs[:, 1:, :, :] * color_data_opt['ab_norm'] + lab = torch.cat((L, AB), dim=1) + out = xyz2rgb(lab2xyz(lab)) + return out + + +def encode_ab_ind(data_ab, color_data_opt): + """Encode ab value into an index. + + Args: + data_ab: Nx2xHxW from [-1,1] + color_data_opt: Config for image colorspace transformation. + ab_max, ab_quant, ab_norm, ab_quant + Returns: + Nx1xHxW from [0,Q) + """ + A = 2 * color_data_opt['ab_max'] / color_data_opt['ab_quant'] + 1 + data_ab_rs = torch.round( + (data_ab * color_data_opt['ab_norm'] + color_data_opt['ab_max']) / + color_data_opt['ab_quant']) # normalized bin number + data_q = data_ab_rs[:, [0], :, :] * A + data_ab_rs[:, [1], :, :] + return data_q + + +def rgb2xyz(rgb): + """Conversion images from rgb to xyz + rgb from [0,1] + xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227]]) + Args: + rgb (Tensor): image in rgb colorspace + + Returns: + xyz (Tensor): image in xyz colorspace + + """ + mask = (rgb > .04045).type(torch.FloatTensor) + if (rgb.is_cuda): + mask = mask.cuda() + + rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask) + + x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] \ + + .180423 * rgb[:, 2, :, :] + y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] \ + + .072169 * rgb[:, 2, :, :] + z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] \ + + .950227 * rgb[:, 2, :, :] + out = torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), + dim=1) + + return out + + +def xyz2lab(xyz): + """Conversion images from xyz to lab + xyz from [0,1] + factors: 0.95047, 1., 1.08883 + + Args: + xyz (Tensor): image in xyz colorspace + + Returns: + out (Tensor): Image in lab colorspace + """ + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + if (xyz.is_cuda): + sc = sc.cuda() + + xyz_scale = xyz / sc + + mask = (xyz_scale > .008856).type(torch.FloatTensor) + if (xyz_scale.is_cuda): + mask = mask.cuda() + + xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale + + 16. / 116.) * (1 - mask) + + L = 116. * xyz_int[:, 1, :, :] - 16. + a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :]) + b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :]) + out = torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), + dim=1) + + return out + + +def rgb2lab(rgb, color_opt): + """Conversion images from rgb to lab. + + Args: + data_raw (tensor): The images to be conversion + color_opt (dict): Config for image colorspace transformation. + Include: ab_thresh, ab_norm, sample_PS, mask_cent + + Returns: + out (tensor): The converted image + """ + lab = xyz2lab(rgb2xyz(rgb)) + l_rs = (lab[:, [0], :, :] - color_opt['l_cent']) / color_opt['l_norm'] + ab_rs = lab[:, 1:, :, :] / color_opt['ab_norm'] + out = torch.cat((l_rs, ab_rs), dim=1) + return out + + +def get_colorization_data(data_raw, color_opt, num_points=None): + """Conversion images from rgb to lab. + + Args: + data_raw (tensor): The images to be conversion + color_opt (dict): Config for image colorspace transformation. + Include: ab_thresh, ab_norm, sample_PS, mask_cent + + Returns: + results (dict): Output in add_color_patches_rand_gt + """ + data = {} + data_lab = rgb2lab(data_raw, color_opt) + data['A'] = data_lab[:, [ + 0, + ], :, :] + data['B'] = data_lab[:, 1:, :, :] + + # mask out grayscale images + if color_opt['ab_thresh'] > 0: + thresh = 1. * color_opt['ab_thresh'] / color_opt['ab_norm'] + mask = torch.sum( + torch.abs( + torch.max(torch.max(data['B'], dim=3)[0], dim=2)[0] - + torch.min(torch.min(data['B'], dim=3)[0], dim=2)[0]), + dim=1) >= thresh + data['A'] = data['A'][mask, :, :, :] + data['B'] = data['B'][mask, :, :, :] + if torch.sum(mask) == 0: + return None + + return add_color_patches_rand_gt( + data, color_opt, p=color_opt['p'], num_points=num_points) + + +def add_color_patches_rand_gt(data, + color_opt, + p=.125, + num_points=None, + use_avg=True, + samp='normal'): + """Add random color points sampled from ground truth based on: Number of + points. + + - if num_points is 0, then sample from geometric distribution, + drawn from probability p + - if num_points > 0, then sample that number of points + Location of points + - if samp is 'normal', draw from N(0.5, 0.25) of image + - otherwise, draw from U[0, 1] of image + + Args: + data (tensor): The images to be conversion + color_opt (dict): Config for image colorspace transformation + Include: ab_thresh, ab_norm, sample_PS, mask_cent + p (float): Sampling geometric distribution, 1.0 means no hints + num_points (int): Certain number of points + use_avg (bool): Whether to use the mean when add color point + Default: True. + samp (str): Geometric distribution or uniform distribution when + sample location. Default: normal. + + Returns: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + """ + N, C, H, W = data['B'].shape + + data['hint_B'] = torch.zeros_like(data['B']) + data['mask_B'] = torch.zeros_like(data['A']) + + for nn in range(N): + pp = 0 + cont_cond = True + while cont_cond: + # draw from geometric + if num_points is None: + cont_cond = np.random.rand() < (1 - p) + else: + # add certain number of points + cont_cond = pp < num_points + # skip out of loop if condition not met + if not cont_cond: + continue + + # patch size + P = np.random.choice(color_opt['sample_PS']) + # sample location: geometric distribution + if samp == 'normal': + h = int( + np.clip( + np.random.normal((H - P + 1) / 2., (H - P + 1) / 4.), + 0, H - P)) + w = int( + np.clip( + np.random.normal((W - P + 1) / 2., (W - P + 1) / 4.), + 0, W - P)) + else: # uniform distribution + h = np.random.randint(H - P + 1) + w = np.random.randint(W - P + 1) + + # add color point + if use_avg: + data['hint_B'][nn, :, h:h + P, w:w + P] = torch.mean( + torch.mean( + data['B'][nn, :, h:h + P, w:w + P], + dim=2, + keepdim=True), + dim=1, + keepdim=True).view(1, C, 1, 1) + else: + data['hint_B'][nn, :, h:h + P, w:w + P] = \ + data['B'][nn, :, h:h + P, w:w + P] + + data['mask_B'][nn, :, h:h + P, w:w + P] = 1 + + # increment counter + pp += 1 + + data['mask_B'] -= color_opt['mask_cent'] + + return data diff --git a/mmedit/models/editors/inst_colorization/colorization_net.py b/mmedit/models/editors/inst_colorization/colorization_net.py new file mode 100644 index 0000000000..6d62209e07 --- /dev/null +++ b/mmedit/models/editors/inst_colorization/colorization_net.py @@ -0,0 +1,313 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmedit.registry import MODULES +from .weight_layer import get_norm_layer + + +@MODULES.register_module() +class ColorizationNet(BaseModule): + """Real-Time User-Guided Image Colorization with Learned Deep Priors. The + backbone used for. + + https://arxiv.org/abs/1705.02999 + + Codes adapted from 'https://github.com/ericsujw/InstColorization.git' + 'InstColorization/blob/master/models/networks.py#L108' + + Args: + input_nc (int): input image channels + output_nc (int): output image channels + norm_type (str): instance normalization or batch normalization + use_tanh (bool): Whether to use nn.Tanh() Default: True. + classification (bool): backprop trunk using classification, + otherwise use regression. Default: True + """ + + def __init__(self, + input_nc, + output_nc, + norm_type, + use_tanh=True, + classification=True): + super().__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.classification = classification + + norm_layer = get_norm_layer(norm_type) + + use_bias = True + + # Conv1 + self.model1 = nn.Sequential( + nn.Conv2d( + input_nc, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(64), + ) + + # Conv2 + self.model2 = nn.Sequential( + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(128), + ) + + # Conv3 + self.model3 = nn.Sequential( + nn.Conv2d( + 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(256), + ) + + # Conv4 + self.model4 = nn.Sequential( + nn.Conv2d( + 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(512), + ) + + # Conv5 + self.model5 = nn.Sequential( + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + norm_layer(512), + ) + + # Conv6 + self.model6 = nn.Sequential( + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + norm_layer(512), + ) + + # Conv7 + self.model7 = nn.Sequential( + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(512), + ) + + # Conv8 + self.model8up = nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) + + self.model3short8 = nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias) + + self.model8 = nn.Sequential( + nn.ReLU(True), + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(256), + ) + + # Conv9 + self.model9up = nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias) + + self.model2short9 = nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias) + self.model9 = nn.Sequential( + nn.ReLU(True), + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(128), + ) + + # Conv10 + self.model10up = nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias) + + self.model1short10 = nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias) + + self.model10 = nn.Sequential( + nn.ReLU(True), + nn.Conv2d( + 128, + 128, + kernel_size=3, + dilation=1, + stride=1, + padding=1, + bias=use_bias), + nn.LeakyReLU(negative_slope=.2), + ) + + # classification output + self.model_class = nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias) + + # regression output + model_out = [ + nn.Conv2d( + 128, + 2, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + if (use_tanh): + model_out += [nn.Tanh()] + self.model_out = nn.Sequential(*model_out) + + self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest') + self.softmax = nn.Softmax(dim=1) + + def forward(self, input_A, input_B, mask_B): + """Forward function. + + Args: + input_A (tensor): Channel of the image in lab color space + input_B (tensor): Color patch + mask_B (tensor): Color patch mask + + Returns: + out_class (tensor): Classification output + out_reg (tensor): Regression output + feature_map (dict): The full-image feature + """ + conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) + conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) + conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) + conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) + conv5_3 = self.model5(conv4_3) + conv6_3 = self.model6(conv5_3) + conv7_3 = self.model7(conv6_3) + conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) + conv8_3 = self.model8(conv8_up) + + if (self.classification): + out_class = self.model_class(conv8_3) + conv9_up = self.model9up(conv8_3.detach()) + self.model2short9( + conv2_2.detach()) + conv9_3 = self.model9(conv9_up) + conv10_up = self.model10up(conv9_3) + self.model1short10( + conv1_2.detach()) + else: + out_class = self.model_class(conv8_3.detach()) + conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) + conv9_3 = self.model9(conv9_up) + conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) + + conv10_2 = self.model10(conv10_up) + out_reg = self.model_out(conv10_2) + + feature_map = {} + feature_map['conv1_2'] = conv1_2 + feature_map['conv2_2'] = conv2_2 + feature_map['conv3_3'] = conv3_3 + feature_map['conv4_3'] = conv4_3 + feature_map['conv5_3'] = conv5_3 + feature_map['conv6_3'] = conv6_3 + feature_map['conv7_3'] = conv7_3 + feature_map['conv8_up'] = conv8_up + feature_map['conv8_3'] = conv8_3 + feature_map['conv9_up'] = conv9_up + feature_map['conv9_3'] = conv9_3 + feature_map['conv10_up'] = conv10_up + feature_map['conv10_2'] = conv10_2 + feature_map['out_reg'] = out_reg + + return (out_class, out_reg, feature_map) diff --git a/mmedit/models/editors/inst_colorization/fusion_net.py b/mmedit/models/editors/inst_colorization/fusion_net.py new file mode 100644 index 0000000000..10c5732680 --- /dev/null +++ b/mmedit/models/editors/inst_colorization/fusion_net.py @@ -0,0 +1,353 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmedit.registry import MODULES +from .weight_layer import WeightLayer, get_norm_layer + + +@MODULES.register_module() +class FusionNet(BaseModule): + """Instance-aware Image Colorization. + + https://arxiv.org/abs/2005.10825 + + Codes adapted from 'https://github.com/ericsujw/InstColorization.git' + 'InstColorization/blob/master/models/networks.py#L314' + FusionNet: the full image model with weight layer for fusion. + + Args: + input_nc (int): input image channels + output_nc (int): output image channels + norm_type (str): instance normalization or batch normalization + use_tanh (bool): Whether to use nn.Tanh() Default: True. + classification (bool): backprop trunk using classification, + otherwise use regression. Default: True + """ + + def __init__(self, + input_nc, + output_nc, + norm_type, + use_tanh=True, + classification=True): + super().__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.classification = classification + + norm_layer = get_norm_layer(norm_type) + use_bias = True + + # Conv1 + self.model1 = nn.Sequential( + nn.Conv2d( + input_nc, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(64), + ) + + self.weight_layer = WeightLayer(64) + + # Conv2 + self.model2 = nn.Sequential( + nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(128), + ) + + self.weight_layer2 = WeightLayer(128) + + # Conv3 + self.model3 = nn.Sequential( + nn.Conv2d( + 128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(256), + ) + + self.weight_layer3 = WeightLayer(256) + + # Conv4 + self.model4 = nn.Sequential( + nn.Conv2d( + 256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(512), + ) + + self.weight_layer4 = WeightLayer(512) + + # Conv5 + self.model5 = nn.Sequential( + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + norm_layer(512), + ) + + self.weight_layer5 = WeightLayer(512) + + # Conv6 + self.model6 = nn.Sequential( + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, + 512, + kernel_size=3, + dilation=2, + stride=1, + padding=2, + bias=use_bias), + nn.ReLU(True), + norm_layer(512), + ) + + self.weight_layer6 = WeightLayer(512) + + # Conv7 + self.model7 = nn.Sequential( + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(512), + ) + + self.weight_layer7 = WeightLayer(512) + + # Conv8 + self.model8up = nn.ConvTranspose2d( + 512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias) + + self.model3short8 = nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias) + + self.weight_layer8_1 = WeightLayer(256) + + self.model8 = nn.Sequential( + nn.ReLU(True), + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + nn.Conv2d( + 256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(256), + ) + + self.weight_layer8_2 = WeightLayer(256) + + # Conv9 + self.model9up = nn.ConvTranspose2d( + 256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias) + + self.model2short9 = nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias) + + self.weight_layer9_1 = WeightLayer(128) + + self.model9 = nn.Sequential( + nn.ReLU(True), + nn.Conv2d( + 128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), + nn.ReLU(True), + norm_layer(128), + ) + + self.weight_layer9_2 = WeightLayer(128) + + # Conv10 + self.model10up = nn.ConvTranspose2d( + 128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias) + + self.model1short10 = nn.Conv2d( + 64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias) + + self.weight_layer10_1 = WeightLayer(128) + + self.model10 = nn.Sequential( + nn.ReLU(True), + nn.Conv2d( + 128, + 128, + kernel_size=3, + dilation=1, + stride=1, + padding=1, + bias=use_bias), + nn.LeakyReLU(negative_slope=.2), + ) + + self.weight_layer10_2 = WeightLayer(128) + + # classification output + self.model_class = nn.Conv2d( + 256, + 529, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias) + + # regression output + model_out = [ + nn.Conv2d( + 128, + 2, + kernel_size=1, + padding=0, + dilation=1, + stride=1, + bias=use_bias), + ] + if (use_tanh): + model_out += [nn.Tanh()] + self.model_out = nn.Sequential(*model_out) + + self.weight_layerout = WeightLayer(2) + + self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest') + self.softmax = nn.Softmax(dim=1) + + def forward(self, input_A, input_B, mask_B, instance_feature, + box_info_list): + """Forward function. + + Args: + input_A (tensor): Channel of the image in lab color space + input_B (tensor): Color patch + mask_B (tensor): Color patch mask + instance_feature (dict): A bunch of instance features + box_info_list (list): Bounding box information corresponding + to the instance + + Returns: + out_reg (tensor): Regression output + """ + conv1_2 = self.model1(torch.cat((input_A, input_B, mask_B), dim=1)) + conv1_2 = self.weight_layer(instance_feature['conv1_2'], conv1_2, + box_info_list[0]) + + conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) + conv2_2 = self.weight_layer2(instance_feature['conv2_2'], conv2_2, + box_info_list[1]) + + conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) + conv3_3 = self.weight_layer3(instance_feature['conv3_3'], conv3_3, + box_info_list[2]) + + conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) + conv4_3 = self.weight_layer4(instance_feature['conv4_3'], conv4_3, + box_info_list[3]) + + conv5_3 = self.model5(conv4_3) + conv5_3 = self.weight_layer5(instance_feature['conv5_3'], conv5_3, + box_info_list[3]) + + conv6_3 = self.model6(conv5_3) + conv6_3 = self.weight_layer6(instance_feature['conv6_3'], conv6_3, + box_info_list[3]) + + conv7_3 = self.model7(conv6_3) + conv7_3 = self.weight_layer7(instance_feature['conv7_3'], conv7_3, + box_info_list[3]) + + conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) + conv8_up = self.weight_layer8_1(instance_feature['conv8_up'], conv8_up, + box_info_list[2]) + + conv8_3 = self.model8(conv8_up) + conv8_3 = self.weight_layer8_2(instance_feature['conv8_3'], conv8_3, + box_info_list[2]) + + conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) + conv9_up = self.weight_layer9_1(instance_feature['conv9_up'], conv9_up, + box_info_list[1]) + + conv9_3 = self.model9(conv9_up) + conv9_3 = self.weight_layer9_2(instance_feature['conv9_3'], conv9_3, + box_info_list[1]) + + conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) + conv10_up = self.weight_layer10_1(instance_feature['conv10_up'], + conv10_up, box_info_list[0]) + + conv10_2 = self.model10(conv10_up) + conv10_2 = self.weight_layer10_2(instance_feature['conv10_2'], + conv10_2, box_info_list[0]) + + out_reg = self.model_out(conv10_2) + return out_reg diff --git a/mmedit/models/editors/inst_colorization/inst_colorization.py b/mmedit/models/editors/inst_colorization/inst_colorization.py new file mode 100644 index 0000000000..4c63aac225 --- /dev/null +++ b/mmedit/models/editors/inst_colorization/inst_colorization.py @@ -0,0 +1,238 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +from mmengine.config import Config +from mmengine.model import BaseModel +from mmengine.optim import OptimWrapperDict + +from mmedit.registry import MODULES +from mmedit.structures import EditDataSample, PixelData +from .color_utils import get_colorization_data, lab2rgb + + +@MODULES.register_module() +class InstColorization(BaseModel): + """Colorization InstColorization method. + + This Colorization is implemented according to the paper: + Instance-aware Image Colorization, CVPR 2020 + + Adapted from 'https://github.com/ericsujw/InstColorization.git' + 'InstColorization/models/train_model' + Copyright (c) 2020, Su, under MIT License. + + Args: + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + image_model (dict): Config for single image model + instance_model (dict): Config for instance model + fusion_model (dict): Config for fusion model + color_data_opt (dict): Option for colorspace conversion + which_direction (str): AtoB or BtoA + loss (dict): Config for loss. + init_cfg (str): Initialization config dict. Default: None. + train_cfg (dict): Config for training. Default: None. + test_cfg (dict): Config for testing. Default: None. + """ + + def __init__(self, + data_preprocessor: Union[dict, Config], + image_model, + instance_model, + fusion_model, + color_data_opt, + which_direction='AtoB', + loss=None, + init_cfg=None, + train_cfg=None, + test_cfg=None): + + super().__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor) + + # colorization networks + # image_model: used to colorize a single image + self.image_model = MODULES.build(image_model) + + # instance model: used to colorize cropped instance + self.instance_model = MODULES.build(instance_model) + + # fusion model: input a single image with related instance features + self.fusion_model = MODULES.build(fusion_model) + + self.color_data_opt = color_data_opt + self.which_direction = which_direction + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[EditDataSample]] = None, + mode: str = 'tensor', + **kwargs): + """Returns losses or predictions of training, validation, testing, and + simple inference process. + + ``forward`` method of BaseModel is an abstract method, its subclasses + must implement this method. + + Accepts ``inputs`` and ``data_samples`` processed by + :attr:`data_preprocessor`, and returns results according to mode + arguments. + + During non-distributed training, validation, and testing process, + ``forward`` will be called by ``BaseModel.train_step``, + ``BaseModel.val_step`` and ``BaseModel.val_step`` directly. + + During distributed data parallel training process, + ``MMSeparateDistributedDataParallel.train_step`` will first call + ``DistributedDataParallel.forward`` to enable automatic + gradient synchronization, and then call ``forward`` to get training + loss. + + Args: + inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + mode (str): mode should be one of ``loss``, ``predict`` and + ``tensor``. Default: 'tensor'. + + - ``loss``: Called by ``train_step`` and return loss ``dict`` + used for logging + - ``predict``: Called by ``val_step`` and ``test_step`` + and return list of ``BaseDataElement`` results used for + computing metric. + - ``tensor``: Called by custom use to get ``Tensor`` type + results. + + Returns: + ForwardResults: + + - If ``mode == loss``, return a ``dict`` of loss tensor used + for backward and logging. + - If ``mode == predict``, return a ``list`` of + :obj:`BaseDataElement` for computing metric + and getting inference result. + - If ``mode == tensor``, return a tensor or ``tuple`` of tensor + or ``dict`` or tensor for custom use. + """ + + if mode == 'tensor': + return self.forward_tensor(inputs, data_samples, **kwargs) + + elif mode == 'predict': + predictions = self.forward_inference(inputs, data_samples, + **kwargs) + predictions = self.convert_to_datasample(data_samples, predictions) + return predictions + + elif mode == 'loss': + return self.forward_train(inputs, data_samples, **kwargs) + + def convert_to_datasample(self, inputs, data_samples): + for data_sample, output in zip(inputs, data_samples): + data_sample.output = output + return inputs + + def forward_train(self, inputs, data_samples=None, **kwargs): + """Forward function for training.""" + raise NotImplementedError( + 'Instance Colorization has not supported training.') + + def train_step(self, data: List[dict], + optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: + """Train step function. + + Args: + data (List[dict]): Batch of data as input. + optim_wrapper (dict[torch.optim.Optimizer]): Dict with optimizers + for generator and discriminator (if have). + Returns: + dict: Dict with loss, information for logger, the number of + samples and results for visualization. + """ + raise NotImplementedError( + 'Instance Colorization has not supported training.') + + def forward_inference(self, inputs, data_samples=None, **kwargs): + """Forward inference. Returns predictions of validation, testing. + + Args: + inputs (torch.Tensor): batch input tensor collated by + :attr:`data_preprocessor`. + data_samples (List[BaseDataElement], optional): + data samples collated by :attr:`data_preprocessor`. + + Returns: + List[EditDataSample]: predictions. + """ + feats = self.forward_tensor(inputs, data_samples, **kwargs) + predictions = [] + for idx in range(feats.shape[0]): + batch_tensor = feats[idx] * 127.5 + 127.5 + pred_img = PixelData(data=batch_tensor.to('cpu')) + predictions.append( + EditDataSample( + pred_img=pred_img, metainfo=data_samples[idx].metainfo)) + + return predictions + + def forward_tensor(self, inputs, data_samples): + """Forward function in tensor mode. + + Args: + inputs (torch.Tensor): Input tensor. + data_sample (dict): Dict contains data sample. + + Returns: + dict: Dict contains output results. + """ + + # prepare data + + assert len(data_samples) == 1, \ + 'fusion model supports only one image due to different numbers '\ + 'of instances of different images' + + full_img_data = get_colorization_data(inputs, self.color_data_opt) + AtoB = self.which_direction == 'AtoB' + + # preprocess input for a single image + full_real_A = full_img_data['A' if AtoB else 'B'] + full_hint_B = full_img_data['hint_B'] + full_mask_B = full_img_data['mask_B'] + + if not data_samples[0].empty_box: + # preprocess instance input + cropped_img = data_samples[0].cropped_img.data + box_info_list = [ + data_samples[0].box_info, data_samples[0].box_info_2x, + data_samples[0].box_info_4x, data_samples[0].box_info_8x + ] + cropped_data = get_colorization_data(cropped_img, + self.color_data_opt) + + real_A = cropped_data['A' if AtoB else 'B'] + hint_B = cropped_data['hint_B'] + mask_B = cropped_data['mask_B'] + + # network forward + _, output, feature_map = self.instance_model( + real_A, hint_B, mask_B) + output = self.fusion_model(full_real_A, full_hint_B, full_mask_B, + feature_map, box_info_list) + + else: + _, output, _ = self.image_model(full_real_A, full_hint_B, + full_mask_B) + + output = [ + full_real_A.type(torch.cuda.FloatTensor), + output.type(torch.cuda.FloatTensor) + ] + output = torch.cat(output, dim=1) + output = torch.clamp(lab2rgb(output, self.color_data_opt), 0.0, 1.0) + return output diff --git a/mmedit/models/editors/inst_colorization/weight_layer.py b/mmedit/models/editors/inst_colorization/weight_layer.py new file mode 100644 index 0000000000..c2f05b34f0 --- /dev/null +++ b/mmedit/models/editors/inst_colorization/weight_layer.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmedit.registry import MODULES + + +def get_norm_layer(norm_type='instance'): + """Gets the normalization layer. + + Args: + norm_type (str): Type of the normalization layer. + + Returns: + norm_layer (BatchNorm2d or InstanceNorm2d or None): + normalization layer. Default: instance + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % + norm_type) + return norm_layer + + +@MODULES.register_module() +class WeightLayer(BaseModule): + """Weight layer of the fusion_net. A small neural network with three + convolutional layers to predict full-image weight map and perinstance + weight map. + + Args: + input_ch (int): Number of channels in the input image. + inner_ch (int): Number of channels produced by the convolution. + Default: True + """ + + def __init__(self, input_ch, inner_ch=16): + super().__init__() + self.simple_instance_conv = nn.Sequential( + nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + ) + + self.simple_bg_conv = nn.Sequential( + nn.Conv2d(input_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, inner_ch, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(inner_ch, 1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + ) + + self.normalize = nn.Softmax(1) + + def resize_and_pad(self, feauture_maps, info_array): + """Resize the instance feature as well as the weight map to match the + size of full-image and do zero padding on both of them. + + Args: + feauture_maps (tensor): Feature map + info_array (tensor): The bounding box + + Returns: + feauture_maps (tensor): Feature maps after resize and padding + """ + feauture_maps = torch.nn.functional.interpolate( + feauture_maps, + size=(info_array[5], info_array[4]), + mode='bilinear') + feauture_maps = torch.nn.functional.pad(feauture_maps, + (info_array[0], info_array[1], + info_array[2], info_array[3]), + 'constant', 0) + return feauture_maps + + def forward(self, instance_feature, bg_feature, box_info): + """Forward function. + + Args: + instance_feature (tensor): Instance feature obtained from the + colorization_net + bg_feature (tensor): full-image feature + box_info (tensor): The bounding box corresponding to the instance + + Returns: + out (tensor): Fused feature + """ + mask_list = [] + featur_map_list = [] + mask_sum_for_pred = torch.zeros_like(bg_feature)[:1, :1] + for i in range(instance_feature.shape[0]): + tmp_crop = torch.unsqueeze(instance_feature[i], 0) + conv_tmp_crop = self.simple_instance_conv(tmp_crop) + pred_mask = self.resize_and_pad(conv_tmp_crop, box_info[i]) + + tmp_crop = self.resize_and_pad(tmp_crop, box_info[i]) + + mask = torch.zeros_like(bg_feature)[:1, :1] + mask[0, 0, box_info[i][2]:box_info[i][2] + box_info[i][5], + box_info[i][0]:box_info[i][0] + box_info[i][4]] = 1.0 + device = mask.device + mask = mask.type(torch.FloatTensor).to(device) + + mask_sum_for_pred = torch.clamp(mask_sum_for_pred + mask, 0.0, 1.0) + + mask_list.append(pred_mask) + featur_map_list.append(tmp_crop) + + pred_bg_mask = self.simple_bg_conv(bg_feature) + mask_list.append(pred_bg_mask + (1 - mask_sum_for_pred) * 100000.0) + mask_list = self.normalize(torch.cat(mask_list, 1)) + + mask_list_maskout = mask_list.clone() + + featur_map_list.append(bg_feature) + featur_map_list = torch.cat(featur_map_list, 0) + mask_list_maskout = mask_list_maskout.permute(1, 0, 2, 3).contiguous() + out = featur_map_list * mask_list_maskout + out = torch.sum(out, 0, keepdim=True) + return out diff --git a/mmedit/models/losses/__init__.py b/mmedit/models/losses/__init__.py index 4027c72013..df66126d39 100644 --- a/mmedit/models/losses/__init__.py +++ b/mmedit/models/losses/__init__.py @@ -18,13 +18,35 @@ from .pixelwise_loss import CharbonnierLoss, L1Loss, MaskedTVLoss, MSELoss __all__ = [ - 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'L1CompositionLoss', - 'MSECompositionLoss', 'CharbonnierCompLoss', 'GANLoss', 'GaussianBlur', - 'GradientPenaltyLoss', 'PerceptualLoss', 'PerceptualVGG', 'reduce_loss', - 'mask_reduce_loss', 'DiscShiftLoss', 'MaskedTVLoss', 'GradientLoss', - 'TransferalPerceptualLoss', 'LightCNNFeatureLoss', 'gradient_penalty_loss', - 'r1_gradient_penalty_loss', 'gen_path_regularizer', 'FaceIdLoss', - 'CLIPLoss', 'CLIPLossComps', 'DiscShiftLossComps', 'FaceIdLossComps', - 'GANLossComps', 'GeneratorPathRegularizerComps', - 'GradientPenaltyLossComps', 'R1GradientPenaltyComps', 'disc_shift_loss' + 'L1Loss', + 'MSELoss', + 'CharbonnierLoss', + 'L1CompositionLoss', + 'MSECompositionLoss', + 'CharbonnierCompLoss', + 'GANLoss', + 'GaussianBlur', + 'GradientPenaltyLoss', + 'PerceptualLoss', + 'PerceptualVGG', + 'reduce_loss', + 'mask_reduce_loss', + 'DiscShiftLoss', + 'MaskedTVLoss', + 'GradientLoss', + 'TransferalPerceptualLoss', + 'LightCNNFeatureLoss', + 'gradient_penalty_loss', + 'r1_gradient_penalty_loss', + 'gen_path_regularizer', + 'FaceIdLoss', + 'CLIPLoss', + 'CLIPLossComps', + 'DiscShiftLossComps', + 'FaceIdLossComps', + 'GANLossComps', + 'GeneratorPathRegularizerComps', + 'GradientPenaltyLossComps', + 'R1GradientPenaltyComps', + 'disc_shift_loss', ] diff --git a/mmedit/models/utils/__init__.py b/mmedit/models/utils/__init__.py index b579869d60..6d98db512e 100644 --- a/mmedit/models/utils/__init__.py +++ b/mmedit/models/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. + from .bbox_utils import extract_around_bbox, extract_bbox_patch from .flow_warp import flow_warp from .model_utils import (default_init_weights, generation_init_weights, @@ -8,17 +9,9 @@ from .tensor_utils import get_unknown_tensor __all__ = [ - 'default_init_weights', - 'make_layer', - 'flow_warp', - 'generation_init_weights', - 'set_requires_grad', - 'extract_bbox_patch', - 'extract_around_bbox', - 'get_unknown_tensor', - 'noise_sample_fn', - 'label_sample_fn', - 'get_valid_num_batches', - 'get_valid_noise_size', - 'get_module_device', + 'default_init_weights', 'make_layer', 'flow_warp', + 'generation_init_weights', 'set_requires_grad', 'extract_bbox_patch', + 'extract_around_bbox', 'get_unknown_tensor', 'noise_sample_fn', + 'label_sample_fn', 'get_valid_num_batches', 'get_valid_noise_size', + 'get_module_device' ] diff --git a/mmedit/utils/__init__.py b/mmedit/utils/__init__.py index 3fdaa607f2..1533e91fc7 100644 --- a/mmedit/utils/__init__.py +++ b/mmedit/utils/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cli import modify_args -from .img_utils import reorder_image, tensor2img, to_numpy +from .img_utils import get_box_info, reorder_image, tensor2img, to_numpy from .io_utils import MMEDIT_CACHE_DIR, download_from_url # TODO replace with engine's API from .logger import print_colored_log @@ -17,5 +17,5 @@ 'download_from_url', 'get_sampler', 'tensor2img', 'random_choose_unknown', 'add_gaussian_noise', 'adjust_gamma', 'make_coord', 'bbox2mask', 'brush_stroke_mask', 'get_irregular_mask', 'random_bbox', 'reorder_image', - 'to_numpy' + 'to_numpy', 'get_box_info' ] diff --git a/mmedit/utils/img_utils.py b/mmedit/utils/img_utils.py index bf420910b5..ff5b2c1f03 100644 --- a/mmedit/utils/img_utils.py +++ b/mmedit/utils/img_utils.py @@ -125,3 +125,40 @@ def to_numpy(img, dtype=np.float64): img = img.astype(dtype) return img + + +def get_box_info(pred_bbox, original_shape, final_size): + """ + + Args: + pred_bbox: The bounding box for the instance + original_shape: Original image shape + final_size: Size of the final output + + Returns: + List: [L_pad, R_pad, T_pad, B_pad, rh, rw] + """ + assert len(pred_bbox) == 4 + resize_startx = int(pred_bbox[0] / original_shape[0] * final_size) + resize_starty = int(pred_bbox[1] / original_shape[1] * final_size) + resize_endx = int(pred_bbox[2] / original_shape[0] * final_size) + resize_endy = int(pred_bbox[3] / original_shape[1] * final_size) + rh = resize_endx - resize_startx + rw = resize_endy - resize_starty + if rh < 1: + if final_size - resize_endx > 1: + resize_endx += 1 + else: + resize_startx -= 1 + rh = 1 + if rw < 1: + if final_size - resize_endy > 1: + resize_endy += 1 + else: + resize_starty -= 1 + rw = 1 + L_pad = resize_startx + R_pad = final_size - resize_endx + T_pad = resize_starty + B_pad = final_size - resize_endy + return [L_pad, R_pad, T_pad, B_pad, rh, rw] diff --git a/mmedit/version.py b/mmedit/version.py index b367889550..da34a5e738 100644 --- a/mmedit/version.py +++ b/mmedit/version.py @@ -1,6 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. -__version__ = '1.0.0rc1' +__version__ = '1.0.0rc2' def parse_version_info(version_str): diff --git a/model-index.yml b/model-index.yml index 6373a6dc4d..b0ce511cac 100644 --- a/model-index.yml +++ b/model-index.yml @@ -20,6 +20,7 @@ Import: - configs/global_local/metafile.yml - configs/iconvsr/metafile.yml - configs/indexnet/metafile.yml +- configs/inst_colorization/metafile.yml - configs/liif/metafile.yml - configs/lsgan/metafile.yml - configs/partial_conv/metafile.yml diff --git a/requirements/optional.txt b/requirements/optional.txt new file mode 100644 index 0000000000..95688066cc --- /dev/null +++ b/requirements/optional.txt @@ -0,0 +1 @@ +mmdet diff --git a/tests/test_apis/test_colorization_inference.py b/tests/test_apis/test_colorization_inference.py new file mode 100644 index 0000000000..3e574633bb --- /dev/null +++ b/tests/test_apis/test_colorization_inference.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform +import unittest + +import pytest +import torch +from mmengine import Config +from mmengine.runner import load_checkpoint + +from mmedit.apis import colorization_inference +from mmedit.registry import MODELS +from mmedit.utils import register_all_modules, tensor2img + + +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') +def test_colorization_inference(): + register_all_modules() + + if not torch.cuda.is_available(): + # RoI pooling only support in GPU + return unittest.skip('test requires GPU and torch+cuda') + + if torch.cuda.is_available(): + device = torch.device('cuda', 0) + else: + device = torch.device('cpu') + + config = osp.join( + osp.dirname(__file__), + '../..', + 'configs/inst_colorization/inst-colorizatioon_full_official_cocostuff-256x256.py' # noqa + ) + checkpoint = None + + cfg = Config.fromfile(config) + model = MODELS.build(cfg.model) + + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint) + + model.cfg = cfg + model.to(device) + model.eval() + + img_path = osp.join( + osp.dirname(__file__), '..', 'data/image/img_root/horse/horse.jpeg') + + result = colorization_inference(model, img_path) + assert tensor2img(result)[..., ::-1].shape == (256, 256, 3) diff --git a/tests/test_datasets/test_transforms/test_crop.py b/tests/test_datasets/test_transforms/test_crop.py index b755e264ff..703e7657c6 100644 --- a/tests/test_datasets/test_transforms/test_crop.py +++ b/tests/test_datasets/test_transforms/test_crop.py @@ -1,10 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import os.path as osp +import unittest +import cv2 import numpy as np import pytest +import torch -from mmedit.datasets.transforms import (Crop, CropLike, FixedCrop, ModCrop, +from mmedit.datasets.transforms import (Crop, CropLike, FixedCrop, + InstanceCrop, ModCrop, PairedRandomCrop, RandomResizedCrop) @@ -350,3 +355,31 @@ def test_crop_like(): assert results['gt'].shape == (512, 512) sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512, 0])) assert sum_diff < 1e-6 + + +def test_instance_crop(): + + if not torch.cuda.is_available(): + # RoI pooling only support in GPU + return unittest.skip('test requires GPU and torch+cuda') + + croper = InstanceCrop( + key='img', + finesize=256, + box_num_upbound=2, + config_file='mmdet::mask_rcnn/' + 'mask-rcnn_x101-32x8d_fpn_ms-poly-3x_coco.py') # noqa + + img_path = osp.join( + osp.dirname(__file__), '..', '..', + 'data/image/img_root/horse/horse.jpeg') + img = cv2.imread(img_path) + data = dict(img=img, ori_img_shape=img.shape, img_channel_order='rgb') + + results = croper(data) + + assert 'empty_box' in results + if results['empty_box']: + cropped_img = results['cropped_img'] + assert len(cropped_img) == 0 + assert len(cropped_img) <= 2 diff --git a/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py b/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py new file mode 100644 index 0000000000..ff113d0917 --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_color_utils.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch + +from mmedit.models.editors.inst_colorization import color_utils + + +class TestColorUtils: + color_data_opt = dict( + ab_thresh=0, + p=1.0, + sample_PS=[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ], + ab_norm=110, + ab_max=110., + ab_quant=10., + l_norm=100., + l_cent=50., + mask_cent=0.5) + + def test_xyz2lab(self): + xyz = torch.rand(1, 3, 8, 8) + lab = color_utils.xyz2lab(xyz) + + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + xyz_scale = xyz / sc + mask = (xyz_scale > .008856).type(torch.FloatTensor) + + xyz_int = xyz_scale**(1 / 3.) * mask + (7.787 * xyz_scale + + 16. / 116.) * (1 - mask) + L = 116. * xyz_int[:, 1, :, :] - 16. + a = 500. * (xyz_int[:, 0, :, :] - xyz_int[:, 1, :, :]) + b = 200. * (xyz_int[:, 1, :, :] - xyz_int[:, 2, :, :]) + + assert lab.shape == (1, 3, 8, 8) + assert lab.equal( + torch.cat((L[:, None, :, :], a[:, None, :, :], b[:, None, :, :]), + dim=1)) + + def test_rgb2xyz(self): + rgb = torch.rand(1, 3, 8, 8) + xyz = color_utils.rgb2xyz(rgb) + + mask = (rgb > .04045).type(torch.FloatTensor) + rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask) + + x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] \ + + .180423 * rgb[:, 2, :, :] + y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] \ + + .072169 * rgb[:, 2, :, :] + z = .019334 * rgb[:, 0, :, :] + .119193 * rgb[:, 1, :, :] \ + + .950227 * rgb[:, 2, :, :] + + assert xyz.shape == (1, 3, 8, 8) + assert xyz.equal( + torch.cat((x[:, None, :, :], y[:, None, :, :], z[:, None, :, :]), + dim=1)) + + def test_rgb2lab(self): + rgb = torch.rand(1, 3, 8, 8) + lab = color_utils.rgb2lab(rgb, self.color_data_opt) + _lab = color_utils.xyz2lab(color_utils.rgb2xyz(rgb)) + + l_rs = (_lab[:, [0], :, :] - + self.color_data_opt['l_cent']) / self.color_data_opt['l_norm'] + ab_rs = _lab[:, 1:, :, :] / self.color_data_opt['ab_norm'] + + assert lab.shape == (1, 3, 8, 8) + assert lab.equal(torch.cat((l_rs, ab_rs), dim=1)) + + def test_lab2rgb(self): + lab = torch.rand(1, 3, 8, 8) + rgb = color_utils.lab2rgb(lab, self.color_data_opt) + + L = lab[:, [0], :, :] * self.color_data_opt[ + 'l_norm'] + self.color_data_opt['l_cent'] + AB = lab[:, 1:, :, :] * self.color_data_opt['ab_norm'] + + lab = torch.cat((L, AB), dim=1) + + assert rgb.shape == (1, 3, 8, 8) + assert rgb.equal(color_utils.xyz2rgb(color_utils.lab2xyz(lab))) + + def test_lab2xyz(self): + lab = torch.rand(1, 3, 8, 8) + xyz = color_utils.lab2xyz(lab) + y_int = (lab[:, 0, :, :] + 16.) / 116. + x_int = (lab[:, 1, :, :] / 500.) + y_int + z_int = y_int - (lab[:, 2, :, :] / 200.) + z_int = torch.max(torch.Tensor((0, )), z_int) + + out = torch.cat( + (x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), + dim=1) + mask = (out > .2068966).type(torch.FloatTensor) + sc = torch.Tensor((0.95047, 1., 1.08883))[None, :, None, None] + out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask) + target = sc * out + assert xyz.shape == (1, 3, 8, 8) + assert xyz.equal(target) + + def test_xyz2rgb(self): + xyz = torch.rand(1, 3, 8, 8) + + rgb = color_utils.xyz2rgb(xyz) + + r = 3.24048134 * xyz[:, 0, :, :] - 1.53715152 * xyz[:, 1, :, :] \ + - 0.49853633 * xyz[:, 2, :, :] + g = -0.96925495 * xyz[:, 0, :, :] + 1.87599 * xyz[:, 1, :, :] \ + + .04155593 * xyz[:, 2, :, :] + b = .05564664 * xyz[:, 0, :, :] - .20404134 * xyz[:, 1, :, :] \ + + 1.05731107 * xyz[:, 2, :, :] + + _rgb = torch.cat( + (r[:, None, :, :], g[:, None, :, :], b[:, None, :, :]), dim=1) + _rgb = torch.max(_rgb, torch.zeros_like(_rgb)) + + mask = (_rgb > .0031308).type(torch.FloatTensor) + + assert rgb.shape == (1, 3, 8, 8) and mask.shape == (1, 3, 8, 8) + assert rgb.equal((1.055 * (_rgb**(1. / 2.4)) - 0.055) * mask + + 12.92 * _rgb * (1 - mask)) + + def test_get_colorization_data(self): + data_raw = torch.rand(1, 3, 8, 8) + + res = color_utils.get_colorization_data(data_raw, self.color_data_opt) + + assert isinstance(res, dict) + assert 'A' in res.keys() and 'B' in res.keys() \ + and 'hint_B' in res.keys() and 'mask_B' in res.keys() + assert res['A'].shape == res['mask_B'].shape == (1, 1, 8, 8) + assert res['hint_B'].shape == res['B'].shape == (1, 2, 8, 8) + + def test_encode_ab_ind(self): + data_ab = torch.rand(1, 2, 8, 8) + data_q = color_utils.encode_ab_ind(data_ab, self.color_data_opt) + A = 2 * 110. / 10. + 1 + + data_ab_rs = torch.round((data_ab * 110 + 110.) / 10.) + + assert data_q.shape == (1, 1, 8, 8) + assert data_q.equal(data_ab_rs[:, [0], :, :] * A + + data_ab_rs[:, [1], :, :]) diff --git a/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py b/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py new file mode 100644 index 0000000000..c6d4454cab --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_colorization_net.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmedit.registry import MODULES + + +def test_colorization_net(): + + model_cfg = dict( + type='ColorizationNet', input_nc=4, output_nc=2, norm_type='batch') + + # build model + model = MODULES.build(model_cfg) + + # test attributes + assert model.__class__.__name__ == 'ColorizationNet' + + # prepare data + input_A = torch.rand(1, 1, 256, 256) + input_B = torch.rand(1, 2, 256, 256) + mask_B = torch.rand(1, 1, 256, 256) + + target_shape = (1, 2, 256, 256) + + # test on cpu + (out_class, out_reg, feature_map) = model(input_A, input_B, mask_B) + assert isinstance(feature_map, dict) + assert feature_map['conv1_2'].shape == (1, 64, 256, 256) \ + and feature_map['out_reg'].shape == target_shape + + # test on gpu + if torch.cuda.is_available(): + model = model.cuda() + input_A = input_A.cuda() + input_B = input_B.cuda() + mask_B = mask_B.cuda() + (out_class, out_reg, feature_map) = \ + model(input_A, input_B, mask_B) + + assert isinstance(feature_map, dict) + for item in feature_map.keys(): + assert torch.is_tensor(feature_map[item]) diff --git a/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py b/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py new file mode 100644 index 0000000000..929c9c8eb1 --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_fusion_net.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch + +from mmedit.registry import MODULES + + +def test_fusion_net(): + + model_cfg = dict( + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch') + + # build model + model = MODULES.build(model_cfg) + + # test attributes + assert model.__class__.__name__ == 'FusionNet' + + # prepare data + input_A = torch.rand(1, 1, 256, 256) + input_B = torch.rand(1, 2, 256, 256) + mask_B = torch.rand(1, 1, 256, 256) + + instance_feature = dict( + conv1_2=torch.rand(1, 64, 256, 256), + conv2_2=torch.rand(1, 128, 128, 128), + conv3_3=torch.rand(1, 256, 64, 64), + conv4_3=torch.rand(1, 512, 32, 32), + conv5_3=torch.rand(1, 512, 32, 32), + conv6_3=torch.rand(1, 512, 32, 32), + conv7_3=torch.rand(1, 512, 32, 32), + conv8_up=torch.rand(1, 256, 64, 64), + conv8_3=torch.rand(1, 256, 64, 64), + conv9_up=torch.rand(1, 128, 128, 128), + conv9_3=torch.rand(1, 128, 128, 128), + conv10_up=torch.rand(1, 128, 256, 256), + conv10_2=torch.rand(1, 128, 256, 256), + ) + + target_shape = (1, 2, 256, 256) + + box_info_box = [ + torch.tensor([[175, 29, 96, 54, 52, 106], [14, 191, 84, 61, 51, 111], + [117, 64, 115, 46, 75, 95], [41, 165, 121, 47, 50, 88], + [46, 136, 94, 45, 74, 117], [79, 124, 62, 115, 53, 79], + [156, 64, 77, 138, 36, 41], [200, 48, 114, 131, 8, 11], + [115, 78, 92, 81, 63, 83]]), + torch.tensor([[87, 15, 48, 27, 26, 53], [7, 96, 42, 31, 25, 55], + [58, 32, 57, 23, 38, 48], [20, 83, 60, 24, 25, 44], + [23, 68, 47, 23, 37, 58], [39, 62, 31, 58, 27, 39], + [78, 32, 38, 69, 18, 21], [100, 24, 57, 66, 4, 5], + [57, 39, 46, 41, 32, 41]]), + torch.tensor([[43, 8, 24, 14, 13, 26], [3, 48, 21, 16, 13, 27], + [29, 16, 28, 12, 19, 24], [10, 42, 30, 12, 12, 22], + [11, 34, 23, 12, 19, 29], [19, 31, 15, 29, 14, 20], + [39, 16, 19, 35, 9, 10], [50, 12, 28, 33, 2, 3], + [28, 20, 23, 21, 16, 20]]), + torch.tensor([[21, 4, 12, 7, 7, 13], [1, 24, 10, 8, 7, 14], + [14, 8, 14, 6, 10, 12], [5, 21, 15, 6, 6, 11], + [5, 17, 11, 6, 10, 15], [9, 16, 7, 15, 7, 10], + [19, 8, 9, 18, 5, 5], [25, 6, 14, 17, 1, 1], + [14, 10, 11, 11, 8, 10]]) + ] + + # test on cpu + out = model(input_A, input_B, mask_B, instance_feature, box_info_box) + assert torch.is_tensor(out) + assert out.shape == target_shape + + # test on gpu + if torch.cuda.is_available(): + model = model.cuda() + input_A = input_A.cuda() + input_B = input_B.cuda() + mask_B = mask_B.cuda() + for item in instance_feature.keys(): + instance_feature[item] = instance_feature[item].cuda() + box_info_box = [i.cuda() for i in box_info_box] + output = model(input_A, input_B, mask_B, instance_feature, + box_info_box) + assert torch.is_tensor(output) diff --git a/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py new file mode 100644 index 0000000000..5d769b2134 --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_inst_colorization.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform +import unittest + +import pytest +import torch + +from mmedit.registry import BACKBONES +from mmedit.structures import EditDataSample, PixelData +from mmedit.utils import register_all_modules + + +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') +class TestInstColorization: + + def test_inst_colorization(self): + if not torch.cuda.is_available(): + # RoI pooling only support in GPU + return unittest.skip('test requires GPU and torch+cuda') + + register_all_modules() + model_cfg = dict( + type='InstColorization', + data_preprocessor=dict( + type='EditDataPreprocessor', + mean=[127.5], + std=[127.5], + ), + image_model=dict( + type='ColorizationNet', + input_nc=4, + output_nc=2, + norm_type='batch'), + instance_model=dict( + type='ColorizationNet', + input_nc=4, + output_nc=2, + norm_type='batch'), + fusion_model=dict( + type='FusionNet', input_nc=4, output_nc=2, norm_type='batch'), + color_data_opt=dict( + ab_thresh=0, + p=1.0, + sample_PS=[ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ], + ab_norm=110, + ab_max=110., + ab_quant=10., + l_norm=100., + l_cent=50., + mask_cent=0.5), + which_direction='AtoB', + loss=dict(type='HuberLoss', delta=.01)) + + model = BACKBONES.build(model_cfg) + + # test attributes + assert model.__class__.__name__ == 'InstColorization' + + # prepare data + inputs = torch.rand(1, 3, 256, 256) + target_shape = (1, 3, 256, 256) + + data_sample = EditDataSample(gt_img=PixelData(data=inputs)) + metainfo = dict( + cropped_img=PixelData(data=torch.rand(9, 3, 256, 256)), + box_info=torch.tensor([[175, 29, 96, 54, 52, 106], + [14, 191, 84, 61, 51, 111], + [117, 64, 115, 46, 75, 95], + [41, 165, 121, 47, 50, 88], + [46, 136, 94, 45, 74, 117], + [79, 124, 62, 115, 53, 79], + [156, 64, 77, 138, 36, 41], + [200, 48, 114, 131, 8, 11], + [115, 78, 92, 81, 63, 83]]), + box_info_2x=torch.tensor([[87, 15, 48, 27, 26, 53], + [7, 96, 42, 31, 25, 55], + [58, 32, 57, 23, 38, 48], + [20, 83, 60, 24, 25, 44], + [23, 68, 47, 23, 37, 58], + [39, 62, 31, 58, 27, 39], + [78, 32, 38, 69, 18, 21], + [100, 24, 57, 66, 4, 5], + [57, 39, 46, 41, 32, 41]]), + box_info_4x=torch.tensor([[43, 8, 24, 14, 13, 26], + [3, 48, 21, 16, 13, 27], + [29, 16, 28, 12, 19, 24], + [10, 42, 30, 12, 12, 22], + [11, 34, 23, 12, 19, 29], + [19, 31, 15, 29, 14, 20], + [39, 16, 19, 35, 9, 10], + [50, 12, 28, 33, 2, 3], + [28, 20, 23, 21, 16, 20]]), + box_info_8x=torch.tensor([[21, 4, 12, 7, 7, 13], + [1, 24, 10, 8, 7, 14], + [14, 8, 14, 6, 10, 12], + [5, 21, 15, 6, 6, 11], + [5, 17, 11, 6, 10, 15], + [9, 16, 7, 15, 7, 10], + [19, 8, 9, 18, 5, 5], + [25, 6, 14, 17, 1, 1], + [14, 10, 11, 11, 8, 10]]), + empty_box=False) + data_sample.set_metainfo(metainfo=metainfo) + + data = dict(inputs=inputs, data_samples=[data_sample]) + + res = model(mode='tensor', **data) + + assert torch.is_tensor(res) + assert res.shape == target_shape diff --git a/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py new file mode 100644 index 0000000000..8293eb116e --- /dev/null +++ b/tests/test_models/test_editors/test_inst_colorization/test_weight_layer.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmedit.models.editors.inst_colorization.weight_layer import WeightLayer + + +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') +def test_weight_layer(): + + weight_layer = WeightLayer(64) + + instance_feature_conv1_2 = torch.rand(1, 64, 256, 256) + conv1_2 = torch.rand(1, 64, 256, 256) + box_info = torch.tensor([[175, 29, 96, 54, 52, 106], + [14, 191, 84, 61, 51, 111], + [117, 64, 115, 46, 75, 95], + [41, 165, 121, 47, 50, 88], + [46, 136, 94, 45, 74, 117], + [79, 124, 62, 115, 53, 79], + [156, 64, 77, 138, 36, 41], + [200, 48, 114, 131, 8, 11], + [115, 78, 92, 81, 63, 83]]) + conv1_2 = weight_layer(instance_feature_conv1_2, conv1_2, box_info) + + assert conv1_2.shape == instance_feature_conv1_2.shape diff --git a/tests/test_models/test_editors/test_liif/test_liif_net.py b/tests/test_models/test_editors/test_liif/test_liif_net.py index c1642a06f6..ab8f409adb 100644 --- a/tests/test_models/test_editors/test_liif/test_liif_net.py +++ b/tests/test_models/test_editors/test_liif/test_liif_net.py @@ -1,9 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest import torch from mmedit.registry import BACKBONES +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') def test_liif_edsr_net(): model_cfg = dict(