Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] support controlnet inferencer #1891

Merged
merged 1 commit into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions configs/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ This model has several weights including vae, unet and clip. You should download

| Model | Dataset | Download |
| :---------------------------------------------: | :-----: | :----------------------------------------------------------------------------------------------: |
| [ControlNet-Demo](./controlnet-1xb1-fill50k.py) | - | - |
| [ControlNet-Canny](./controlnet-canny.py) | - | [model](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth) |
| [ControlNet-Segmentation](./controlnet-seg.py) | - | [model](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_seg.pth) |
| [ControlNet-Pose](./controlnet-pose.py) | - | [model](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_openpose.pth) |
| [ControlNet-Demo](./controlnet-1xb1-fill50k.py) | - | - |
| [ControlNet-Segmentation](./controlnet-seg.py) | - | [model](https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_seg.pth) |

Noted that, [ControlNet-Demo](./controlnet-1xb1-demo_dataset.py) is a demo config to train ControlNet with toy dataset named Fill50K.

Expand Down Expand Up @@ -159,6 +159,35 @@ for idx, control in enumerate(controls):
</thead>
</table>

### Using MMInferencer

You can only use several lines of codes to play controlnet by MMagic!

```python
from mmagic.apis import MMagicInferencer

# controlnet-canny
controlnet_canny_inferencer = MMagicInferencer(model_name='controlnet', model_setting=1)
text_prompts = 'Room with blue walls and a yellow ceiling.'
control = 'https://user-images.githubusercontent.com/28132635/230297033-4f5c32df-365c-4cf4-8e4f-1b76a4cbb0b7.png'
result_out_dir = 'controlnet_canny_res.png'
controlnet_canny_inferencer.infer(text=text_prompts, control=control, result_out_dir=result_out_dir)

# controlnet-pose
controlnet_pose_inferencer = MMagicInferencer(model_name='controlnet', model_setting=2)
text_prompts = 'masterpiece, best quality, sky, black hair, skirt, sailor collar, looking at viewer, short hair, building, bangs, neckerchief, long sleeves, cloudy sky, power lines, shirt, cityscape, pleated skirt, scenery, blunt bangs, city, night, black sailor collar, closed mouth'
control = 'https://user-images.githubusercontent.com/28132635/230380893-2eae68af-d610-4f7f-aa68-c2f22c2abf7e.png'
result_out_dir = 'controlnet_pose_res.png'
controlnet_pose_inferencer.infer(text=text_prompts, control=control, result_out_dir=result_out_dir)

# controlnet-seg
controlnet_seg_inferencer = MMagicInferencer(model_name='controlnet', model_setting=3)
text_prompts = 'black house, blue sky'
control = 'https://github-production-user-asset-6210df.s3.amazonaws.com/49083766/243599897-553a4c46-c61d-46df-b820-59a49aaf6678.png'
result_out_dir = 'controlnet_seg_res.png'
controlnet_seg_inferencer.infer(text=text_prompts, control=control, result_out_dir=result_out_dir)
```

## Train your own ControlNet!

You can start training your own ControlNet with the toy dataset [Fill50K](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip) with the following command:
Expand Down
6 changes: 3 additions & 3 deletions configs/controlnet/controlnet-pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
type='ControlStableDiffusion',
vae=dict(
type='AutoencoderKL',
from_pretrained=stable_diffusion_v15_url,
from_pretrained='gsdf/Counterfeit-V2.5',
subfolder='vae'),
unet=dict(
type='UNet2DConditionModel',
subfolder='unet',
from_pretrained=stable_diffusion_v15_url),
from_pretrained='gsdf/Counterfeit-V2.5'),
text_encoder=dict(
type='ClipWrapper',
clip_type='huggingface',
Expand All @@ -29,4 +29,4 @@
from_pretrained=stable_diffusion_v15_url,
subfolder='scheduler'),
data_preprocessor=dict(type='DataPreprocessor'),
init_cfg=dict(type='init_from_unet'))
init_cfg=dict(type='convert_from_unet'))
16 changes: 8 additions & 8 deletions configs/controlnet/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,21 @@ Collections:
- text2image
Year: 2023
Models:
- Config: configs/controlnet/controlnet-canny.py
- Config: configs/controlnet/controlnet-1xb1-fill50k.py
In Collection: Control Net
Name: controlnet-canny
Name: controlnet-1xb1-fill50k
Results:
- Dataset: '-'
Metrics: {}
Task: Text2Image
Weights: https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth
- Config: configs/controlnet/controlnet-seg.py
- Config: configs/controlnet/controlnet-canny.py
In Collection: Control Net
Name: controlnet-seg
Name: controlnet-canny
Results:
- Dataset: '-'
Metrics: {}
Task: Text2Image
Weights: https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_seg.pth
Weights: https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_canny.pth
- Config: configs/controlnet/controlnet-pose.py
In Collection: Control Net
Name: controlnet-pose
Expand All @@ -32,10 +31,11 @@ Models:
Metrics: {}
Task: Text2Image
Weights: https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_openpose.pth
- Config: configs/controlnet/controlnet-1xb1-fill50k.py
- Config: configs/controlnet/controlnet-seg.py
In Collection: Control Net
Name: controlnet-1xb1-fill50k
Name: controlnet-seg
Results:
- Dataset: '-'
Metrics: {}
Task: Text2Image
Weights: https://huggingface.co/lllyasviel/ControlNet/blob/main/models/control_sd15_seg.pth
35 changes: 35 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,48 @@ python mmagic_inference_demo.py \

#### 2.2.9 Text-to-Image

stable diffusion

```shell
python mmagic_inference_demo.py \
--model-name stable_diffusion \
--text "A panda is having dinner at KFC" \
--result-out-dir ../resources/output/text2image/demo_text2image_stable_diffusion_res.png
```

controlnet-canny

```shell
python mmagic_inference_demo.py \
--model-name controlnet \
--model-setting 1 \
--text "Room with blue walls and a yellow ceiling." \
--control 'https://user-images.githubusercontent.com/28132635/230297033-4f5c32df-365c-4cf4-8e4f-1b76a4cbb0b7.png' \
--result-out-dir demo_text2image_controlnet_canny_res.png
```

controlnet-pose

```shell
python mmagic_inference_demo.py \
--model-name controlnet \
--model-setting 2 \
--text "masterpiece, best quality, sky, black hair, skirt, sailor collar, looking at viewer, short hair, building, bangs, neckerchief, long sleeves, cloudy sky, power lines, shirt, cityscape, pleated skirt, scenery, blunt bangs, city, night, black sailor collar, closed mouth" \
--control 'https://user-images.githubusercontent.com/28132635/230380893-2eae68af-d610-4f7f-aa68-c2f22c2abf7e.png' \
--result-out-dir demo_text2image_controlnet_pose_res.png
```

controlnet-seg

```shell
python mmagic_inference_demo.py \
--model-name controlnet \
--model-setting 3 \
--text "black house, blue sky" \
--control 'https://github-production-user-asset-6210df.s3.amazonaws.com/49083766/243599897-553a4c46-c61d-46df-b820-59a49aaf6678.png' \
--result-out-dir demo_text2image_controlnet_seg_res.png
```

#### 2.2.10 3D-aware Generation

```shell
Expand Down
3 changes: 3 additions & 0 deletions mmagic/apis/inferencers/base_mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def _init_model(self, cfg: Union[ConfigType, str], ckpt: Optional[str],
model = MODELS.build(cfg.model)
if ckpt is not None and ckpt != '':
ckpt = load_checkpoint(model, ckpt, map_location='cpu')
if cfg.model.get(
'init_cfg') and cfg.model.init_cfg.type == 'convert_from_unet':
model.init_weights()
model.cfg = cfg
model.to(device)
model.eval()
Expand Down
13 changes: 10 additions & 3 deletions mmagic/apis/inferencers/text2image_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os
from typing import Dict, List

import mmcv
import numpy as np
from mmengine import mkdir_or_exist
from PIL.Image import Image
from PIL.Image import Image, fromarray
from torchvision.utils import save_image

from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType
Expand All @@ -14,14 +15,14 @@ class Text2ImageInferencer(BaseMMagicInferencer):
"""inferencer that predicts with text2image models."""

func_kwargs = dict(
preprocess=['text'],
preprocess=['text', 'control'],
forward=[],
visualize=['result_out_dir'],
postprocess=[])

extra_parameters = dict(height=None, width=None, seed=1)

def preprocess(self, text: InputsType) -> Dict:
def preprocess(self, text: InputsType, control: str = None) -> Dict:
"""Process the inputs into a model-feedable format.

Args:
Expand All @@ -36,6 +37,12 @@ def preprocess(self, text: InputsType) -> Dict:
else:
result['prompt'] = text

if control:
control_img = mmcv.imread(control)
control_img = fromarray(control_img)
result['control'] = control_img
result.pop('seed', None)

return result

def forward(self, inputs: InputsType) -> PredType:
Expand Down
3 changes: 3 additions & 0 deletions mmagic/apis/mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class MMagicInferencer:
'restormer',

# text2image models
'controlnet',
'disco_diffusion',
'stable_diffusion',

Expand Down Expand Up @@ -163,6 +164,8 @@ def _get_inferencer_kwargs(self, model_name: Optional[str],
osp.dirname(__file__), '..', '.mim', config_dir)
if 'Weights' in cfgs['settings'][setting_to_use].keys():
kwargs['ckpt'] = cfgs['settings'][setting_to_use]['Weights']
if model_name == 'controlnet':
kwargs['ckpt'] = None

if model_config is not None:
if kwargs.get('config', None) is not None:
Expand Down