from __future__ import division, print_function, absolute_import
import torch
from torchreid.data.sampler import build_train_sampler
from torchreid.data.datasets import init_image_dataset, init_video_dataset
from torchreid.data.transforms import build_transforms
[docs]class DataManager(object):
r"""Base data manager.
Args:
sources (str or list): source dataset(s).
targets (str or list, optional): target dataset(s). If not given,
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
norm_std (list or None, optional): data std. Default is None (use imagenet std).
use_gpu (bool, optional): use gpu. Default is True.
"""
def __init__(
self,
sources=None,
targets=None,
height=256,
width=128,
transforms='random_flip',
norm_mean=None,
norm_std=None,
use_gpu=False
):
self.sources = sources
self.targets = targets
self.height = height
self.width = width
if self.sources is None:
raise ValueError('sources must not be None')
if isinstance(self.sources, str):
self.sources = [self.sources]
if self.targets is None:
self.targets = self.sources
if isinstance(self.targets, str):
self.targets = [self.targets]
self.transform_tr, self.transform_te = build_transforms(
self.height,
self.width,
transforms=transforms,
norm_mean=norm_mean,
norm_std=norm_std
)
self.use_gpu = (torch.cuda.is_available() and use_gpu)
@property
def num_train_pids(self):
"""Returns the number of training person identities."""
return self._num_train_pids
@property
def num_train_cams(self):
"""Returns the number of training cameras."""
return self._num_train_cams
[docs] def fetch_test_loaders(self, name):
"""Returns query and gallery of a test dataset, each containing
tuples of (img_path(s), pid, camid).
Args:
name (str): dataset name.
"""
query_loader = self.test_dataset[name]['query']
gallery_loader = self.test_dataset[name]['gallery']
return query_loader, gallery_loader
[docs] def preprocess_pil_img(self, img):
"""Transforms a PIL image to torch tensor for testing."""
return self.transform_te(img)
[docs]class ImageDataManager(DataManager):
r"""Image data manager.
Args:
root (str): root path to datasets.
sources (str or list): source dataset(s).
targets (str or list, optional): target dataset(s). If not given,
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
k_tfm (int): number of times to apply augmentation to an image
independently. If k_tfm > 1, the transform function will be
applied k_tfm times to an image. This variable will only be
useful for training and is currently valid for image datasets only.
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
norm_std (list or None, optional): data std. Default is None (use imagenet std).
use_gpu (bool, optional): use gpu. Default is True.
split_id (int, optional): split id (*0-based*). Default is 0.
combineall (bool, optional): combine train, query and gallery in a dataset for
training. Default is False.
load_train_targets (bool, optional): construct train-loader for target datasets.
Default is False. This is useful for domain adaptation research.
batch_size_train (int, optional): number of images in a training batch. Default is 32.
batch_size_test (int, optional): number of images in a test batch. Default is 32.
workers (int, optional): number of workers. Default is 4.
num_instances (int, optional): number of instances per identity in a batch.
Default is 4.
num_cams (int, optional): number of cameras to sample in a batch (when using
``RandomDomainSampler``). Default is 1.
num_datasets (int, optional): number of datasets to sample in a batch (when
using ``RandomDatasetSampler``). Default is 1.
train_sampler (str, optional): sampler. Default is RandomSampler.
train_sampler_t (str, optional): sampler for target train loader. Default is RandomSampler.
cuhk03_labeled (bool, optional): use cuhk03 labeled images.
Default is False (defaul is to use detected images).
cuhk03_classic_split (bool, optional): use the classic split in cuhk03.
Default is False.
market1501_500k (bool, optional): add 500K distractors to the gallery
set in market1501. Default is False.
Examples::
datamanager = torchreid.data.ImageDataManager(
root='path/to/reid-data',
sources='market1501',
height=256,
width=128,
batch_size_train=32,
batch_size_test=100
)
# return train loader of source data
train_loader = datamanager.train_loader
# return test loader of target data
test_loader = datamanager.test_loader
# return train loader of target data
train_loader_t = datamanager.train_loader_t
"""
data_type = 'image'
def __init__(
self,
root='',
sources=None,
targets=None,
height=256,
width=128,
transforms='random_flip',
k_tfm=1,
norm_mean=None,
norm_std=None,
use_gpu=True,
split_id=0,
combineall=False,
load_train_targets=False,
batch_size_train=32,
batch_size_test=32,
workers=4,
num_instances=4,
num_cams=1,
num_datasets=1,
train_sampler='RandomSampler',
train_sampler_t='RandomSampler',
cuhk03_labeled=False,
cuhk03_classic_split=False,
market1501_500k=False
):
super(ImageDataManager, self).__init__(
sources=sources,
targets=targets,
height=height,
width=width,
transforms=transforms,
norm_mean=norm_mean,
norm_std=norm_std,
use_gpu=use_gpu
)
print('=> Loading train (source) dataset')
trainset = []
for name in self.sources:
trainset_ = init_image_dataset(
name,
transform=self.transform_tr,
k_tfm=k_tfm,
mode='train',
combineall=combineall,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
trainset.append(trainset_)
trainset = sum(trainset)
self._num_train_pids = trainset.num_train_pids
self._num_train_cams = trainset.num_train_cams
self.train_loader = torch.utils.data.DataLoader(
trainset,
sampler=build_train_sampler(
trainset.train,
train_sampler,
batch_size=batch_size_train,
num_instances=num_instances,
num_cams=num_cams,
num_datasets=num_datasets
),
batch_size=batch_size_train,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=True
)
self.train_loader_t = None
if load_train_targets:
# check if sources and targets are identical
assert len(set(self.sources) & set(self.targets)) == 0, \
'sources={} and targets={} must not have overlap'.format(self.sources, self.targets)
print('=> Loading train (target) dataset')
trainset_t = []
for name in self.targets:
trainset_t_ = init_image_dataset(
name,
transform=self.transform_tr,
k_tfm=k_tfm,
mode='train',
combineall=False, # only use the training data
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
trainset_t.append(trainset_t_)
trainset_t = sum(trainset_t)
self.train_loader_t = torch.utils.data.DataLoader(
trainset_t,
sampler=build_train_sampler(
trainset_t.train,
train_sampler_t,
batch_size=batch_size_train,
num_instances=num_instances,
num_cams=num_cams,
num_datasets=num_datasets
),
batch_size=batch_size_train,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=True
)
print('=> Loading test (target) dataset')
self.test_loader = {
name: {
'query': None,
'gallery': None
}
for name in self.targets
}
self.test_dataset = {
name: {
'query': None,
'gallery': None
}
for name in self.targets
}
for name in self.targets:
# build query loader
queryset = init_image_dataset(
name,
transform=self.transform_te,
mode='query',
combineall=combineall,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
self.test_loader[name]['query'] = torch.utils.data.DataLoader(
queryset,
batch_size=batch_size_test,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
# build gallery loader
galleryset = init_image_dataset(
name,
transform=self.transform_te,
mode='gallery',
combineall=combineall,
verbose=False,
root=root,
split_id=split_id,
cuhk03_labeled=cuhk03_labeled,
cuhk03_classic_split=cuhk03_classic_split,
market1501_500k=market1501_500k
)
self.test_loader[name]['gallery'] = torch.utils.data.DataLoader(
galleryset,
batch_size=batch_size_test,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
self.test_dataset[name]['query'] = queryset.query
self.test_dataset[name]['gallery'] = galleryset.gallery
print('\n')
print(' **************** Summary ****************')
print(' source : {}'.format(self.sources))
print(' # source datasets : {}'.format(len(self.sources)))
print(' # source ids : {}'.format(self.num_train_pids))
print(' # source images : {}'.format(len(trainset)))
print(' # source cameras : {}'.format(self.num_train_cams))
if load_train_targets:
print(
' # target images : {} (unlabeled)'.format(len(trainset_t))
)
print(' target : {}'.format(self.targets))
print(' *****************************************')
print('\n')
[docs]class VideoDataManager(DataManager):
r"""Video data manager.
Args:
root (str): root path to datasets.
sources (str or list): source dataset(s).
targets (str or list, optional): target dataset(s). If not given,
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
norm_std (list or None, optional): data std. Default is None (use imagenet std).
use_gpu (bool, optional): use gpu. Default is True.
split_id (int, optional): split id (*0-based*). Default is 0.
combineall (bool, optional): combine train, query and gallery in a dataset for
training. Default is False.
batch_size_train (int, optional): number of tracklets in a training batch. Default is 3.
batch_size_test (int, optional): number of tracklets in a test batch. Default is 3.
workers (int, optional): number of workers. Default is 4.
num_instances (int, optional): number of instances per identity in a batch.
Default is 4.
num_cams (int, optional): number of cameras to sample in a batch (when using
``RandomDomainSampler``). Default is 1.
num_datasets (int, optional): number of datasets to sample in a batch (when
using ``RandomDatasetSampler``). Default is 1.
train_sampler (str, optional): sampler. Default is RandomSampler.
seq_len (int, optional): how many images to sample in a tracklet. Default is 15.
sample_method (str, optional): how to sample images in a tracklet. Default is "evenly".
Choices are ["evenly", "random", "all"]. "evenly" and "random" will sample ``seq_len``
images in a tracklet while "all" samples all images in a tracklet, where the batch size
needs to be set to 1.
Examples::
datamanager = torchreid.data.VideoDataManager(
root='path/to/reid-data',
sources='mars',
height=256,
width=128,
batch_size_train=3,
batch_size_test=3,
seq_len=15,
sample_method='evenly'
)
# return train loader of source data
train_loader = datamanager.train_loader
# return test loader of target data
test_loader = datamanager.test_loader
.. note::
The current implementation only supports image-like training. Therefore, each image in a
sampled tracklet will undergo independent transformation functions. To achieve tracklet-aware
training, you need to modify the transformation functions for video reid such that each function
applies the same operation to all images in a tracklet to keep consistency.
"""
data_type = 'video'
def __init__(
self,
root='',
sources=None,
targets=None,
height=256,
width=128,
transforms='random_flip',
norm_mean=None,
norm_std=None,
use_gpu=True,
split_id=0,
combineall=False,
batch_size_train=3,
batch_size_test=3,
workers=4,
num_instances=4,
num_cams=1,
num_datasets=1,
train_sampler='RandomSampler',
seq_len=15,
sample_method='evenly'
):
super(VideoDataManager, self).__init__(
sources=sources,
targets=targets,
height=height,
width=width,
transforms=transforms,
norm_mean=norm_mean,
norm_std=norm_std,
use_gpu=use_gpu
)
print('=> Loading train (source) dataset')
trainset = []
for name in self.sources:
trainset_ = init_video_dataset(
name,
transform=self.transform_tr,
mode='train',
combineall=combineall,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
trainset.append(trainset_)
trainset = sum(trainset)
self._num_train_pids = trainset.num_train_pids
self._num_train_cams = trainset.num_train_cams
train_sampler = build_train_sampler(
trainset.train,
train_sampler,
batch_size=batch_size_train,
num_instances=num_instances,
num_cams=num_cams,
num_datasets=num_datasets
)
self.train_loader = torch.utils.data.DataLoader(
trainset,
sampler=train_sampler,
batch_size=batch_size_train,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=True
)
print('=> Loading test (target) dataset')
self.test_loader = {
name: {
'query': None,
'gallery': None
}
for name in self.targets
}
self.test_dataset = {
name: {
'query': None,
'gallery': None
}
for name in self.targets
}
for name in self.targets:
# build query loader
queryset = init_video_dataset(
name,
transform=self.transform_te,
mode='query',
combineall=combineall,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
self.test_loader[name]['query'] = torch.utils.data.DataLoader(
queryset,
batch_size=batch_size_test,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
# build gallery loader
galleryset = init_video_dataset(
name,
transform=self.transform_te,
mode='gallery',
combineall=combineall,
verbose=False,
root=root,
split_id=split_id,
seq_len=seq_len,
sample_method=sample_method
)
self.test_loader[name]['gallery'] = torch.utils.data.DataLoader(
galleryset,
batch_size=batch_size_test,
shuffle=False,
num_workers=workers,
pin_memory=self.use_gpu,
drop_last=False
)
self.test_dataset[name]['query'] = queryset.query
self.test_dataset[name]['gallery'] = galleryset.gallery
print('\n')
print(' **************** Summary ****************')
print(' source : {}'.format(self.sources))
print(' # source datasets : {}'.format(len(self.sources)))
print(' # source ids : {}'.format(self.num_train_pids))
print(' # source tracklets : {}'.format(len(trainset)))
print(' # source cameras : {}'.format(self.num_train_cams))
print(' target : {}'.format(self.targets))
print(' *****************************************')
print('\n')