Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support mscoco dataset #1520

Merged
merged 8 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,14 @@
from .comp1k_dataset import AdobeComp1kDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .imagenet_dataset import ImageNet
from .mscoco_dataset import MSCoCoDataset
from .paired_image_dataset import PairedImageDataset
from .singan_dataset import SinGANDataset
from .unpaired_image_dataset import UnpairedImageDataset

__all__ = [
'AdobeComp1kDataset',
'BasicImageDataset',
'BasicFramesDataset',
'BasicConditionalDataset',
'UnpairedImageDataset',
'PairedImageDataset',
'ImageNet',
'CIFAR10',
'GrowScaleImgDataset',
'SinGANDataset',
'AdobeComp1kDataset', 'BasicImageDataset', 'BasicFramesDataset',
'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset',
'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset',
'MSCoCoDataset'
]
101 changes: 101 additions & 0 deletions mmedit/datasets/mscoco_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
from typing import Optional, Sequence, Union

import mmengine
from mmengine import FileClient

from mmedit.registry import DATASETS
from .basic_conditional_dataset import BasicConditionalDataset


@DATASETS.register_module()
@DATASETS.register_module('MSCOCO')
class MSCoCoDataset(BasicConditionalDataset):
"""MSCoCo 2014 dataset.

Args:
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
drop_caption_rate (float, optional): Rate of dropping caption,
used for training. Defaults to 0.0.
phase (str, optional): Subdataset used for certain phase, can be set
to `train`, `test` and `val`. Defaults to 'train'.
year (int, optional): Version of CoCo dataset, can be set to 2014
and 2017. Defaults to 2014.
data_prefix (str | dict): Prefix for the data. Defaults to ''.
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
lazy_init (bool): Whether to load annotation during instantiation.
In some cases, such as visualization, only the meta information of
the dataset is needed, which is not necessary to load annotation
file. ``Basedataset`` can skip load annotations to save time by set
``lazy_init=False``. Defaults to False.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
METAINFO = dict(dataset_type='text_image_dataset', task_name='editing')

def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: str = '',
drop_caption_rate=0.0,
phase='train',
year=2014,
data_prefix: Union[str, dict] = '',
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
'.bmp', '.pgm', '.tif'),
lazy_init: bool = False,
classes: Union[str, Sequence[str], None] = None,
**kwargs):
ann_file = os.path.join('annotations', 'captions_' + phase +
f'{year}.json') if ann_file == '' else ann_file
self.image_prename = 'COCO_' + phase + f'{year}_'
self.phase = phase
self.drop_rate = drop_caption_rate
self.year = year
assert self.year == 2014, 'We only support CoCo2014 now.'

super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
extensions=extensions,
lazy_init=lazy_init,
classes=classes,
**kwargs)

def load_data_list(self):
"""Load image paths and gt_labels."""
if self.img_prefix:
file_client = FileClient.infer_client(uri=self.img_prefix)
json_file = mmengine.fileio.io.load(self.ann_file)

def add_prefix(filename, prefix=''):
if not prefix:
return filename
else:
return file_client.join_path(prefix, filename)

data_list = []
for item in json_file['annotations']:
image_name = self.image_prename + str(
item['image_id']).zfill(12) + '.jpg'
img_path = add_prefix(
os.path.join(self.phase + str(self.year), image_name),
self.img_prefix)
caption = item['caption'].lower()
info = {
'img_path':
img_path,
'gt_label':
caption if (self.phase != 'train' or self.drop_rate < 1e-6
or random.random() >= self.drop_rate) else ''
}
data_list.append(info)
return data_list
3 changes: 3 additions & 0 deletions tests/data/coco/annotations/captions_train2014.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"annotations": [{"image_id": 9, "caption": "a good meal"}]
}
3 changes: 3 additions & 0 deletions tests/data/coco/annotations/captions_val2014.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"annotations": [{"image_id": 42, "caption": "a pair of slippers"}]
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 31 additions & 0 deletions tests/test_datasets/test_mscoco_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from pathlib import Path

from mmedit.datasets import MSCoCoDataset


class TestMSCoCoDatasets:

@classmethod
def setup_class(cls):
cls.data_root = Path(__file__).parent.parent / 'data' / 'coco'

def test_mscoco(self):

# test basic usage
dataset = MSCoCoDataset(data_root=self.data_root, pipeline=[])
assert dataset[0] == dict(
gt_label='a good meal',
img_path=os.path.join(self.data_root, 'train2014',
'COCO_train2014_000000000009.jpg'),
sample_idx=0)

# test with different phase
dataset = MSCoCoDataset(
data_root=self.data_root, phase='val', pipeline=[])
assert dataset[0] == dict(
gt_label='a pair of slippers',
img_path=os.path.join(self.data_root, 'val2014',
'COCO_val2014_000000000042.jpg'),
sample_idx=0)