from __future__ import division, print_function, absolute_import
import copy
import numpy as np
import os.path as osp
import tarfile
import zipfile
import torch
from torchreid.utils import read_image, download_url, mkdir_if_missing
[docs]class Dataset(object):
"""An abstract class representing a Dataset.
This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args:
train (list): contains tuples of (img_path(s), pid, camid).
query (list): contains tuples of (img_path(s), pid, camid).
gallery (list): contains tuples of (img_path(s), pid, camid).
transform: transform function.
k_tfm (int): number of times to apply augmentation to an image
independently. If k_tfm > 1, the transform function will be
applied k_tfm times to an image. This variable will only be
useful for training and is currently valid for image datasets only.
mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a
dataset for training.
verbose (bool): show information.
"""
# junk_pids contains useless person IDs, e.g. background,
# false detections, distractors. These IDs will be ignored
# when combining all images in a dataset for training, i.e.
# combineall=True
_junk_pids = []
# Some datasets are only used for training, like CUHK-SYSU
# In this case, "combineall=True" is not used for them
_train_only = False
def __init__(
self,
train,
query,
gallery,
transform=None,
k_tfm=1,
mode='train',
combineall=False,
verbose=True,
**kwargs
):
# extend 3-tuple (img_path(s), pid, camid) to
# 4-tuple (img_path(s), pid, camid, dsetid) by
# adding a dataset indicator "dsetid"
if len(train[0]) == 3:
train = [(*items, 0) for items in train]
if len(query[0]) == 3:
query = [(*items, 0) for items in query]
if len(gallery[0]) == 3:
gallery = [(*items, 0) for items in gallery]
self.train = train
self.query = query
self.gallery = gallery
self.transform = transform
self.k_tfm = k_tfm
self.mode = mode
self.combineall = combineall
self.verbose = verbose
self.num_train_pids = self.get_num_pids(self.train)
self.num_train_cams = self.get_num_cams(self.train)
self.num_datasets = self.get_num_datasets(self.train)
if self.combineall:
self.combine_all()
if self.mode == 'train':
self.data = self.train
elif self.mode == 'query':
self.data = self.query
elif self.mode == 'gallery':
self.data = self.gallery
else:
raise ValueError(
'Invalid mode. Got {}, but expected to be '
'one of [train | query | gallery]'.format(self.mode)
)
if self.verbose:
self.show_summary()
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
return len(self.data)
def __add__(self, other):
"""Adds two datasets together (only the train set)."""
train = copy.deepcopy(self.train)
for img_path, pid, camid, dsetid in other.train:
pid += self.num_train_pids
camid += self.num_train_cams
dsetid += self.num_datasets
train.append((img_path, pid, camid, dsetid))
###################################
# Note that
# 1. set verbose=False to avoid unnecessary print
# 2. set combineall=False because combineall would have been applied
# if it was True for a specific dataset; setting it to True will
# create new IDs that should have already been included
###################################
if isinstance(train[0][0], str):
return ImageDataset(
train,
self.query,
self.gallery,
transform=self.transform,
mode=self.mode,
combineall=False,
verbose=False
)
else:
return VideoDataset(
train,
self.query,
self.gallery,
transform=self.transform,
mode=self.mode,
combineall=False,
verbose=False,
seq_len=self.seq_len,
sample_method=self.sample_method
)
def __radd__(self, other):
"""Supports sum([dataset1, dataset2, dataset3])."""
if other == 0:
return self
else:
return self.__add__(other)
[docs] def get_num_pids(self, data):
"""Returns the number of training person identities.
Each tuple in data contains (img_path(s), pid, camid, dsetid).
"""
pids = set()
for items in data:
pid = items[1]
pids.add(pid)
return len(pids)
[docs] def get_num_cams(self, data):
"""Returns the number of training cameras.
Each tuple in data contains (img_path(s), pid, camid, dsetid).
"""
cams = set()
for items in data:
camid = items[2]
cams.add(camid)
return len(cams)
[docs] def get_num_datasets(self, data):
"""Returns the number of datasets included.
Each tuple in data contains (img_path(s), pid, camid, dsetid).
"""
dsets = set()
for items in data:
dsetid = items[3]
dsets.add(dsetid)
return len(dsets)
[docs] def show_summary(self):
"""Shows dataset statistics."""
pass
[docs] def combine_all(self):
"""Combines train, query and gallery in a dataset for training."""
if self._train_only:
return
combined = copy.deepcopy(self.train)
# relabel pids in gallery (query shares the same scope)
g_pids = set()
for items in self.gallery:
pid = items[1]
if pid in self._junk_pids:
continue
g_pids.add(pid)
pid2label = {pid: i for i, pid in enumerate(g_pids)}
def _combine_data(data):
for img_path, pid, camid, dsetid in data:
if pid in self._junk_pids:
continue
pid = pid2label[pid] + self.num_train_pids
combined.append((img_path, pid, camid, dsetid))
_combine_data(self.query)
_combine_data(self.gallery)
self.train = combined
self.num_train_pids = self.get_num_pids(self.train)
[docs] def download_dataset(self, dataset_dir, dataset_url):
"""Downloads and extracts dataset.
Args:
dataset_dir (str): dataset directory.
dataset_url (str): url to download dataset.
"""
if osp.exists(dataset_dir):
return
if dataset_url is None:
raise RuntimeError(
'{} dataset needs to be manually '
'prepared, please follow the '
'document to prepare this dataset'.format(
self.__class__.__name__
)
)
print('Creating directory "{}"'.format(dataset_dir))
mkdir_if_missing(dataset_dir)
fpath = osp.join(dataset_dir, osp.basename(dataset_url))
print(
'Downloading {} dataset to "{}"'.format(
self.__class__.__name__, dataset_dir
)
)
download_url(dataset_url, fpath)
print('Extracting "{}"'.format(fpath))
try:
tar = tarfile.open(fpath)
tar.extractall(path=dataset_dir)
tar.close()
except:
zip_ref = zipfile.ZipFile(fpath, 'r')
zip_ref.extractall(dataset_dir)
zip_ref.close()
print('{} dataset is ready'.format(self.__class__.__name__))
[docs] def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
Args:
required_files (str or list): string file name(s).
"""
if isinstance(required_files, str):
required_files = [required_files]
for fpath in required_files:
if not osp.exists(fpath):
raise RuntimeError('"{}" is not found'.format(fpath))
def __repr__(self):
num_train_pids = self.get_num_pids(self.train)
num_train_cams = self.get_num_cams(self.train)
num_query_pids = self.get_num_pids(self.query)
num_query_cams = self.get_num_cams(self.query)
num_gallery_pids = self.get_num_pids(self.gallery)
num_gallery_cams = self.get_num_cams(self.gallery)
msg = ' ----------------------------------------\n' \
' subset | # ids | # items | # cameras\n' \
' ----------------------------------------\n' \
' train | {:5d} | {:7d} | {:9d}\n' \
' query | {:5d} | {:7d} | {:9d}\n' \
' gallery | {:5d} | {:7d} | {:9d}\n' \
' ----------------------------------------\n' \
' items: images/tracklets for image/video dataset\n'.format(
num_train_pids, len(self.train), num_train_cams,
num_query_pids, len(self.query), num_query_cams,
num_gallery_pids, len(self.gallery), num_gallery_cams
)
return msg
def _transform_image(self, tfm, k_tfm, img0):
"""Transforms a raw image (img0) k_tfm times with
the transform function tfm.
"""
img_list = []
for k in range(k_tfm):
img_list.append(tfm(img0))
img = img_list
if len(img) == 1:
img = img[0]
return img
[docs]class ImageDataset(Dataset):
"""A base class representing ImageDataset.
All other image datasets should subclass it.
``__getitem__`` returns an image given index.
It will return ``img``, ``pid``, ``camid`` and ``img_path``
where ``img`` has shape (channel, height, width). As a result,
data in each batch has shape (batch_size, channel, height, width).
"""
def __init__(self, train, query, gallery, **kwargs):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def __getitem__(self, index):
img_path, pid, camid, dsetid = self.data[index]
img = read_image(img_path)
if self.transform is not None:
img = self._transform_image(self.transform, self.k_tfm, img)
item = {
'img': img,
'pid': pid,
'camid': camid,
'impath': img_path,
'dsetid': dsetid
}
return item
[docs] def show_summary(self):
num_train_pids = self.get_num_pids(self.train)
num_train_cams = self.get_num_cams(self.train)
num_query_pids = self.get_num_pids(self.query)
num_query_cams = self.get_num_cams(self.query)
num_gallery_pids = self.get_num_pids(self.gallery)
num_gallery_cams = self.get_num_cams(self.gallery)
print('=> Loaded {}'.format(self.__class__.__name__))
print(' ----------------------------------------')
print(' subset | # ids | # images | # cameras')
print(' ----------------------------------------')
print(
' train | {:5d} | {:8d} | {:9d}'.format(
num_train_pids, len(self.train), num_train_cams
)
)
print(
' query | {:5d} | {:8d} | {:9d}'.format(
num_query_pids, len(self.query), num_query_cams
)
)
print(
' gallery | {:5d} | {:8d} | {:9d}'.format(
num_gallery_pids, len(self.gallery), num_gallery_cams
)
)
print(' ----------------------------------------')
[docs]class VideoDataset(Dataset):
"""A base class representing VideoDataset.
All other video datasets should subclass it.
``__getitem__`` returns an image given index.
It will return ``imgs``, ``pid`` and ``camid``
where ``imgs`` has shape (seq_len, channel, height, width). As a result,
data in each batch has shape (batch_size, seq_len, channel, height, width).
"""
def __init__(
self,
train,
query,
gallery,
seq_len=15,
sample_method='evenly',
**kwargs
):
super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
self.seq_len = seq_len
self.sample_method = sample_method
if self.transform is None:
raise RuntimeError('transform must not be None')
def __getitem__(self, index):
img_paths, pid, camid, dsetid = self.data[index]
num_imgs = len(img_paths)
if self.sample_method == 'random':
# Randomly samples seq_len images from a tracklet of length num_imgs,
# if num_imgs is smaller than seq_len, then replicates images
indices = np.arange(num_imgs)
replace = False if num_imgs >= self.seq_len else True
indices = np.random.choice(
indices, size=self.seq_len, replace=replace
)
# sort indices to keep temporal order (comment it to be order-agnostic)
indices = np.sort(indices)
elif self.sample_method == 'evenly':
# Evenly samples seq_len images from a tracklet
if num_imgs >= self.seq_len:
num_imgs -= num_imgs % self.seq_len
indices = np.arange(0, num_imgs, num_imgs / self.seq_len)
else:
# if num_imgs is smaller than seq_len, simply replicate the last image
# until the seq_len requirement is satisfied
indices = np.arange(0, num_imgs)
num_pads = self.seq_len - num_imgs
indices = np.concatenate(
[
indices,
np.ones(num_pads).astype(np.int32) * (num_imgs-1)
]
)
assert len(indices) == self.seq_len
elif self.sample_method == 'all':
# Samples all images in a tracklet. batch_size must be set to 1
indices = np.arange(num_imgs)
else:
raise ValueError(
'Unknown sample method: {}'.format(self.sample_method)
)
imgs = []
for index in indices:
img_path = img_paths[int(index)]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0) # img must be torch.Tensor
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
item = {'img': imgs, 'pid': pid, 'camid': camid, 'dsetid': dsetid}
return item
[docs] def show_summary(self):
num_train_pids = self.get_num_pids(self.train)
num_train_cams = self.get_num_cams(self.train)
num_query_pids = self.get_num_pids(self.query)
num_query_cams = self.get_num_cams(self.query)
num_gallery_pids = self.get_num_pids(self.gallery)
num_gallery_cams = self.get_num_cams(self.gallery)
print('=> Loaded {}'.format(self.__class__.__name__))
print(' -------------------------------------------')
print(' subset | # ids | # tracklets | # cameras')
print(' -------------------------------------------')
print(
' train | {:5d} | {:11d} | {:9d}'.format(
num_train_pids, len(self.train), num_train_cams
)
)
print(
' query | {:5d} | {:11d} | {:9d}'.format(
num_query_pids, len(self.query), num_query_cams
)
)
print(
' gallery | {:5d} | {:11d} | {:9d}'.format(
num_gallery_pids, len(self.gallery), num_gallery_cams
)
)
print(' -------------------------------------------')