Source code for torchreid.data.sampler

from __future__ import division, absolute_import
import copy
import numpy as np
import random
from collections import defaultdict
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler

AVAI_SAMPLERS = [
    'RandomIdentitySampler', 'SequentialSampler', 'RandomSampler',
    'RandomDomainSampler', 'RandomDatasetSampler'
]


[docs]class RandomIdentitySampler(Sampler): """Randomly samples N identities each with K instances. Args: data_source (list): contains tuples of (img_path(s), pid, camid, dsetid). batch_size (int): batch size. num_instances (int): number of instances per identity in a batch. """ def __init__(self, data_source, batch_size, num_instances): if batch_size < num_instances: raise ValueError( 'batch_size={} must be no less ' 'than num_instances={}'.format(batch_size, num_instances) ) self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = self.batch_size // self.num_instances self.index_dic = defaultdict(list) for index, items in enumerate(data_source): pid = items[1] self.index_dic[pid].append(index) self.pids = list(self.index_dic.keys()) assert len(self.pids) >= self.num_pids_per_batch # estimate number of examples in an epoch # TODO: improve precision self.length = 0 for pid in self.pids: idxs = self.index_dic[pid] num = len(idxs) if num < self.num_instances: num = self.num_instances self.length += num - num % self.num_instances def __iter__(self): batch_idxs_dict = defaultdict(list) for pid in self.pids: idxs = copy.deepcopy(self.index_dic[pid]) if len(idxs) < self.num_instances: idxs = np.random.choice( idxs, size=self.num_instances, replace=True ) random.shuffle(idxs) batch_idxs = [] for idx in idxs: batch_idxs.append(idx) if len(batch_idxs) == self.num_instances: batch_idxs_dict[pid].append(batch_idxs) batch_idxs = [] avai_pids = copy.deepcopy(self.pids) final_idxs = [] while len(avai_pids) >= self.num_pids_per_batch: selected_pids = random.sample(avai_pids, self.num_pids_per_batch) for pid in selected_pids: batch_idxs = batch_idxs_dict[pid].pop(0) final_idxs.extend(batch_idxs) if len(batch_idxs_dict[pid]) == 0: avai_pids.remove(pid) return iter(final_idxs) def __len__(self): return self.length
[docs]class RandomDomainSampler(Sampler): """Random domain sampler. We consider each camera as a visual domain. How does the sampling work: 1. Randomly sample N cameras (based on the "camid" label). 2. From each camera, randomly sample K images. Args: data_source (list): contains tuples of (img_path(s), pid, camid, dsetid). batch_size (int): batch size. n_domain (int): number of cameras to sample in a batch. """ def __init__(self, data_source, batch_size, n_domain): self.data_source = data_source # Keep track of image indices for each domain self.domain_dict = defaultdict(list) for i, items in enumerate(data_source): camid = items[2] self.domain_dict[camid].append(i) self.domains = list(self.domain_dict.keys()) # Make sure each domain can be assigned an equal number of images if n_domain is None or n_domain <= 0: n_domain = len(self.domains) assert batch_size % n_domain == 0 self.n_img_per_domain = batch_size // n_domain self.batch_size = batch_size self.n_domain = n_domain self.length = len(list(self.__iter__())) def __iter__(self): domain_dict = copy.deepcopy(self.domain_dict) final_idxs = [] stop_sampling = False while not stop_sampling: selected_domains = random.sample(self.domains, self.n_domain) for domain in selected_domains: idxs = domain_dict[domain] selected_idxs = random.sample(idxs, self.n_img_per_domain) final_idxs.extend(selected_idxs) for idx in selected_idxs: domain_dict[domain].remove(idx) remaining = len(domain_dict[domain]) if remaining < self.n_img_per_domain: stop_sampling = True return iter(final_idxs) def __len__(self): return self.length
[docs]class RandomDatasetSampler(Sampler): """Random dataset sampler. How does the sampling work: 1. Randomly sample N datasets (based on the "dsetid" label). 2. From each dataset, randomly sample K images. Args: data_source (list): contains tuples of (img_path(s), pid, camid, dsetid). batch_size (int): batch size. n_dataset (int): number of datasets to sample in a batch. """ def __init__(self, data_source, batch_size, n_dataset): self.data_source = data_source # Keep track of image indices for each dataset self.dataset_dict = defaultdict(list) for i, items in enumerate(data_source): dsetid = items[3] self.dataset_dict[dsetid].append(i) self.datasets = list(self.dataset_dict.keys()) # Make sure each dataset can be assigned an equal number of images if n_dataset is None or n_dataset <= 0: n_dataset = len(self.datasets) assert batch_size % n_dataset == 0 self.n_img_per_dset = batch_size // n_dataset self.batch_size = batch_size self.n_dataset = n_dataset self.length = len(list(self.__iter__())) def __iter__(self): dataset_dict = copy.deepcopy(self.dataset_dict) final_idxs = [] stop_sampling = False while not stop_sampling: selected_datasets = random.sample(self.datasets, self.n_dataset) for dset in selected_datasets: idxs = dataset_dict[dset] selected_idxs = random.sample(idxs, self.n_img_per_dset) final_idxs.extend(selected_idxs) for idx in selected_idxs: dataset_dict[dset].remove(idx) remaining = len(dataset_dict[dset]) if remaining < self.n_img_per_dset: stop_sampling = True return iter(final_idxs) def __len__(self): return self.length
[docs]def build_train_sampler( data_source, train_sampler, batch_size=32, num_instances=4, num_cams=1, num_datasets=1, **kwargs ): """Builds a training sampler. Args: data_source (list): contains tuples of (img_path(s), pid, camid). train_sampler (str): sampler name (default: ``RandomSampler``). batch_size (int, optional): batch size. Default is 32. num_instances (int, optional): number of instances per identity in a batch (when using ``RandomIdentitySampler``). Default is 4. num_cams (int, optional): number of cameras to sample in a batch (when using ``RandomDomainSampler``). Default is 1. num_datasets (int, optional): number of datasets to sample in a batch (when using ``RandomDatasetSampler``). Default is 1. """ assert train_sampler in AVAI_SAMPLERS, \ 'train_sampler must be one of {}, but got {}'.format(AVAI_SAMPLERS, train_sampler) if train_sampler == 'RandomIdentitySampler': sampler = RandomIdentitySampler(data_source, batch_size, num_instances) elif train_sampler == 'RandomDomainSampler': sampler = RandomDomainSampler(data_source, batch_size, num_cams) elif train_sampler == 'RandomDatasetSampler': sampler = RandomDatasetSampler(data_source, batch_size, num_datasets) elif train_sampler == 'SequentialSampler': sampler = SequentialSampler(data_source) elif train_sampler == 'RandomSampler': sampler = RandomSampler(data_source) return sampler