Source code for torchreid.data.datasets.video.prid2011

from __future__ import division, print_function, absolute_import
import glob
import os.path as osp

from torchreid.utils import read_json

from ..dataset import VideoDataset


[docs]class PRID2011(VideoDataset): """PRID2011. Reference: Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011. URL: `<https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/>`_ Dataset statistics: - identities: 200. - tracklets: 400. - cameras: 2. """ dataset_dir = 'prid2011' dataset_url = None def __init__(self, root='', split_id=0, **kwargs): self.root = osp.abspath(osp.expanduser(root)) self.dataset_dir = osp.join(self.root, self.dataset_dir) self.download_dataset(self.dataset_dir, self.dataset_url) self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json') self.cam_a_dir = osp.join( self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a' ) self.cam_b_dir = osp.join( self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b' ) required_files = [self.dataset_dir, self.cam_a_dir, self.cam_b_dir] self.check_before_run(required_files) splits = read_json(self.split_path) if split_id >= len(splits): raise ValueError( 'split_id exceeds range, received {}, but expected between 0 and {}' .format(split_id, len(splits) - 1) ) split = splits[split_id] train_dirs, test_dirs = split['train'], split['test'] train = self.process_dir(train_dirs, cam1=True, cam2=True) query = self.process_dir(test_dirs, cam1=True, cam2=False) gallery = self.process_dir(test_dirs, cam1=False, cam2=True) super(PRID2011, self).__init__(train, query, gallery, **kwargs) def process_dir(self, dirnames, cam1=True, cam2=True): tracklets = [] dirname2pid = {dirname: i for i, dirname in enumerate(dirnames)} for dirname in dirnames: if cam1: person_dir = osp.join(self.cam_a_dir, dirname) img_names = glob.glob(osp.join(person_dir, '*.png')) assert len(img_names) > 0 img_names = tuple(img_names) pid = dirname2pid[dirname] tracklets.append((img_names, pid, 0)) if cam2: person_dir = osp.join(self.cam_b_dir, dirname) img_names = glob.glob(osp.join(person_dir, '*.png')) assert len(img_names) > 0 img_names = tuple(img_names) pid = dirname2pid[dirname] tracklets.append((img_names, pid, 1)) return tracklets