Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Siamese net #11

Merged
merged 2 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ indoor_cat_names = coco_api.get_sub_category_names('indoor')

We primarily train detection models using the tools and model definitions from `torchvision` as described in the [Torchvision Object Detection Finetuning Tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html). Guidelines for training and using models can be found in [`docs/object_detection.md`](docs/object_detection.md)

## Image Comparison Model Training

We use Siamese networks for object/person comparison and subsequent recognition. Instructions for training and using a comparison model are given in [`docs/image_comparison.md`](docs/image_comparison.md).

## Image augmentation

We aim to ease the process of generating data for object detection. Using the green box in the picture below,
Expand Down
64 changes: 64 additions & 0 deletions docs/image_comparison.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Image Comparison: Dataset Format, Training, and Inference

We use a Siamese network for comparing images (e.g. objects or faces). As in the case of object detection, we use a PyTorch-based model for this purpose.

# Dataset Format

The Siamese network training script expects a data directory with multiple subdirectories - one per object - where each such directory contains images corresponding to the same object. The expected directory structure is shown below:

```
train_dataset_dir
| obj_1
| | 1.jpg
| | ...
| |____n.jpg
| obj_2
| | 1.jpg
| | ...
| |____n.jpg
| ...
|____obj_n
| 1.jpg
| ...
|____n.jpg
```

# Training a Comparison Model

For training a comparison model, the `train_siamese_net.py` script, which can be found under `scripts`, should be used. The script accepts the following arguments:

* `-d --data_path`: Directory containing training (default `/home/lucy/data`)
* `-m --model_path`: Path to a directory where the trained models (one per epoch) should be saved (default `/home/lucy/models`)
* `-e --num_epochs`: Number of training epochs (default `10`)
* `-lr --learning_rate`: Initial learning rate (default `1e-4`)
* `-b --training_batch_size`: Training batch size (default `1`)
* `-l --train_loss_file_path`: Path to a file in which training losses will be saved (default `/home/lucy/data/train_loss.log`)

An example call is given below:

```
./train_siamese_net.py \
-d /home/lucy/images/ \
-m /home/lucy/models \
-e 10 \
-lr 0.0001 \
-b 64 \
-l /home/lucy/training_loss.log
```

# Using a Comparison Model

The script `test_siamese_net.py`, also included under `scripts`, illustrates the use of a trained model for calculating the difference between two images. This script takes the following arguments:

* `-i1 --image1_path`: Path to an image (default `''`)
* `-i2 --image2_path`: Path to an image to be compared (default `''`)
* `-m --model_path`: Path to a trained model (default `/home/lucy/data/model.pt`)

An example call is given below:

```
./test_siamese_net.py \
-i1 /home/lucy/img1.jpg \
-i2 /home/lucy/img2.jpg \
-m /home/lucy/models/model.pt
```
45 changes: 45 additions & 0 deletions scripts/test_siamese_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python3
import argparse

import sys
try:
sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
except:
pass

import torch
import torch.nn as nn

from dataset_interface.siamese_net.model import SiameseNetwork
from dataset_interface.siamese_net.utils import get_grayscale_image_tensor

if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('-i1', '--image1_path', type=str,
help='Path to an anchor image',
default='')
argparser.add_argument('-i2', '--image2_path', type=str,
help='Path to a target image',
default='')
argparser.add_argument('-m', '--model_path', type=str,
help='Path to a trained model',
default='/home/lucy/data/model.pt')

args = argparser.parse_args()
img1_path = args.image1_path
img2_path = args.image2_path
model_path = args.model_path

model = SiameseNetwork()
model.load_state_dict(torch.load(model_path))
model.eval()

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

img0 = get_grayscale_image_tensor(img1_path)
img1 = get_grayscale_image_tensor(img2_path)

out1, out2 = model(img0, img1)
distance = nn.functional.pairwise_distance(out1, out2)
print('Distance: {0}'.format(distance.item()))
117 changes: 117 additions & 0 deletions scripts/train_siamese_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/env python3
import os
import argparse

import numpy as np
import sys
try:
sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
except:
pass

import torch
import torchvision
from torch.utils.data import DataLoader

from dataset_interface.siamese_net.model import SiameseNetwork
from dataset_interface.siamese_net.dataset import SiameseNetworkDataset
from dataset_interface.siamese_net.loss import ContrastiveLoss
from dataset_interface.siamese_net.utils import get_transforms

if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('-d', '--data_path', type=str,
help='Directory containing training data',
default='/home/lucy/data/')
argparser.add_argument('-m', '--model_path', type=str,
help='Path to a directory where the trained models (one per epoch) should be saved',
default='/home/lucy/models')
argparser.add_argument('-e', '--num_epochs', type=int,
help='Number of training epochs',
default=10)
argparser.add_argument('-lr', '--learning_rate', type=float,
help='Initial learning rate',
default=1e-4)
argparser.add_argument('-b', '--training_batch_size', type=int,
help='Training batch size',
default=1)
argparser.add_argument('-l', '--train_loss_file_path', type=str,
help='Path to a file in which training losses will be saved',
default='/home/lucy/data/train_loss.log')

# we read all arguments
args = argparser.parse_args()
data_path = args.data_path
model_path = args.model_path
train_loss_file_path = args.train_loss_file_path
num_epochs = args.num_epochs
training_batch_size = args.training_batch_size
learning_rate = args.learning_rate

print('\nThe following arguments were read:')
print('------------------------------------')
print('data_path: {0}'.format(data_path))
print('model_path: {0}'.format(model_path))
print('train_loss_file_path: {0}'.format(train_loss_file_path))
print('num_epochs: {0}'.format(num_epochs))
print('training_batch_size: {0}'.format(training_batch_size))
print('learning_rate: {0}'.format(learning_rate))
print('------------------------------------')
print('Proceed with training (y/n)')
proceed = input()
if proceed != 'y':
print('Aborting training')
sys.exit(1)

# we create a data loader by instantiating an appropriate
# dataset class depending on the annotation type
folder_dataset = torchvision.datasets.ImageFolder(root=data_path)
siamese_dataset = SiameseNetworkDataset(image_folder_dataset=folder_dataset,
transform=get_transforms(),
should_invert=False)

train_dataloader = DataLoader(siamese_dataset,
shuffle=True,
num_workers=8,
batch_size=training_batch_size)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# we move the model to the correct device before training
model.to(device)

# we create the model path directory if it doesn't exist
if not os.path.isdir(model_path):
print('Creating model directory {0}'.format(model_path))
os.mkdir(model_path)

# we clear the files in which the training and validation losses are saved
open(train_loss_file_path, 'w').close()

print('Training model for {0} epochs'.format(num_epochs))
for epoch in range(num_epochs):
losses = []
for i, data in enumerate(train_dataloader):
img0, img1, label = data
optimizer.zero_grad()
output1, output2 = model(img0, img1)
loss_contrastive = criterion(output1, output2, label)
loss_contrastive.backward()
optimizer.step()
if i % 10 == 0:
print('Epoch number {}\n Current loss {}\n'.format(epoch,
loss_contrastive.item()))
losses.append(loss_contrastive.item())
avg_loss = np.mean(losses)

if train_loss_file_path:
with open(train_loss_file_path, 'a+') as loss_file:
loss_file.write(str(avg_loss).split(' ')[0] + '\n')

lr_scheduler.step()
torch.save(model.state_dict(), os.path.join(model_path, 'model_{0}.pt'.format(epoch)))
print('Training done')
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
setup(
name='dataset_interface',
packages=['pycocotools', 'dataset_interface', 'dataset_interface.common', 'dataset_interface.coco',
'dataset_interface.augmentation', 'dataset_interface.object_detection'],
'dataset_interface.augmentation', 'dataset_interface.object_detection',
'dataset_interface.siamese_net'],
package_dir={
'dataset_interface': '.',
'pycocotools': COCO_API_PATH + 'PythonAPI/pycocotools'
Expand Down
Empty file added siamese_net/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions siamese_net/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import random
from PIL import Image, ImageOps
import numpy as np

import torch
from torch.utils.data import Dataset

class SiameseNetworkDataset(Dataset):
"""
Author: Alan Preciado
"""
def __init__(self, image_folder_dataset, transform=None, should_invert=True):
self.image_folder_dataset = image_folder_dataset
self.transform = transform
self.should_invert = should_invert

def __getitem__(self, index):
img0_tuple = random.choice(self.image_folder_dataset.imgs)

# we need to make sure approx 50% of images are in the same class
should_get_same_class = random.randint(0, 1)
if should_get_same_class:
while True:
# keep looping till the same class image is found
img1_tuple = random.choice(self.image_folder_dataset.imgs)
if img0_tuple[1] == img1_tuple[1]:
break
else:
while True:
# keep looping till a different class image is found
img1_tuple = random.choice(self.image_folder_dataset.imgs)
if img0_tuple[1] != img1_tuple[1]:
break

img0 = Image.open(img0_tuple[0])
img1 = Image.open(img1_tuple[0])
img0 = img0.convert("L")
img1 = img1.convert("L")

if self.should_invert:
img0 = ImageOps.invert(img0)
img1 = ImageOps.invert(img1)

if self.transform is not None:
img0 = self.transform(img0)
img1 = self.transform(img1)

return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1] != img0_tuple[1])],
dtype=np.float32))

def __len__(self):
return len(self.image_folder_dataset.imgs)
19 changes: 19 additions & 0 deletions siamese_net/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import torch.nn as nn

class ContrastiveLoss(nn.Module):
"""
Author: Alan Preciado

Contrastive loss function based on
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin

def forward(self, output1, output2, label):
euclidean_distance = nn.functional.pairwise_distance(output1, output2, keepdim=True)
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.), 2))
return loss_contrastive
39 changes: 39 additions & 0 deletions siamese_net/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch.nn as nn

class SiameseNetwork(nn.Module):
"""
Author: Alan Preciado
"""
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn1 = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(1, 4, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),

nn.ReflectionPad2d(1),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),

nn.ReflectionPad2d(1),
nn.Conv2d(8, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8))

self.fc1 = nn.Sequential(nn.Linear(8*100*100, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5))

def forward_once(self, x):
output = self.cnn1(x)
output = output.view(output.size()[0], -1)
output = self.fc1(output)
return output

def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
return output1, output2
14 changes: 14 additions & 0 deletions siamese_net/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from PIL import Image
import torchvision.transforms as transforms

def get_transforms():
transform = transforms.Compose([transforms.Resize((100, 100)),
transforms.ToTensor()])
return transform

def get_grayscale_image_tensor(img_path):
img = Image.open(img_path)
img = img.convert("L")
img = get_transforms()(img)
img.unsqueeze_(0)
return img