torchreid.engine¶
Base Engine¶
-
class
torchreid.engine.engine.
Engine
(datamanager, use_gpu=True)[source]¶ A generic base Engine class for both image- and video-reid.
- Parameters
datamanager (DataManager) – an instance of
torchreid.data.ImageDataManager
ortorchreid.data.VideoDataManager
.use_gpu (bool, optional) – use gpu. Default is True.
-
run
(save_dir='log', max_epoch=0, start_epoch=0, print_freq=10, fixbase_epoch=0, open_layers=None, start_eval=0, eval_freq=-1, test_only=False, dist_metric='euclidean', normalize_feature=False, visrank=False, visrank_topk=10, use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False)[source]¶ A unified pipeline for training and evaluating a model.
- Parameters
save_dir (str) – directory to save model.
max_epoch (int) – maximum epoch.
start_epoch (int, optional) – starting epoch. Default is 0.
print_freq (int, optional) – print_frequency. Default is 10.
fixbase_epoch (int, optional) – number of epochs to train
open_layers
(new layers) while keeping base layers frozen. Default is 0.fixbase_epoch
is counted inmax_epoch
.open_layers (str or list, optional) – layers (attribute names) open for training.
start_eval (int, optional) – from which epoch to start evaluation. Default is 0.
eval_freq (int, optional) – evaluation frequency. Default is -1 (meaning evaluation is only performed at the end of training).
test_only (bool, optional) – if True, only runs evaluation on test datasets. Default is False.
dist_metric (str, optional) – distance metric used to compute distance matrix between query and gallery. Default is “euclidean”.
normalize_feature (bool, optional) – performs L2 normalization on feature vectors before computing feature distance. Default is False.
visrank (bool, optional) – visualizes ranked results. Default is False. It is recommended to enable
visrank
whentest_only
is True. The ranked images will be saved to “save_dir/visrank_dataset”, e.g. “save_dir/visrank_market1501”.visrank_topk (int, optional) – top-k ranked images to be visualized. Default is 10.
use_metric_cuhk03 (bool, optional) – use single-gallery-shot setting for cuhk03. Default is False. This should be enabled when using cuhk03 classic split.
ranks (list, optional) – cmc ranks to be computed. Default is [1, 5, 10, 20].
rerank (bool, optional) – uses person re-ranking (by Zhong et al. CVPR’17). Default is False. This is only enabled when test_only=True.
-
test
(dist_metric='euclidean', normalize_feature=False, visrank=False, visrank_topk=10, save_dir='', use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False)[source]¶ Tests model on target datasets.
Note
This function has been called in
run()
.Note
The test pipeline implemented in this function suits both image- and video-reid. In general, a subclass of Engine only needs to re-implement
extract_features()
andparse_data_for_eval()
(most of the time), but not a must. Please refer to the source code for more details.
-
two_stepped_transfer_learning
(epoch, fixbase_epoch, open_layers, model=None)[source]¶ Two-stepped transfer learning.
The idea is to freeze base layers for a certain number of epochs and then open all layers for training.
Reference: https://arxiv.org/abs/1611.05244
Image Engines¶
-
class
torchreid.engine.image.softmax.
ImageSoftmaxEngine
(datamanager, model, optimizer, scheduler=None, use_gpu=True, label_smooth=True)[source]¶ Softmax-loss engine for image-reid.
- Parameters
datamanager (DataManager) – an instance of
torchreid.data.ImageDataManager
ortorchreid.data.VideoDataManager
.model (nn.Module) – model instance.
optimizer (Optimizer) – an Optimizer.
scheduler (LRScheduler, optional) – if None, no learning rate decay will be performed.
use_gpu (bool, optional) – use gpu. Default is True.
label_smooth (bool, optional) – use label smoothing regularizer. Default is True.
Examples:
import torchreid datamanager = torchreid.data.ImageDataManager( root='path/to/reid-data', sources='market1501', height=256, width=128, combineall=False, batch_size=32 ) model = torchreid.models.build_model( name='resnet50', num_classes=datamanager.num_train_pids, loss='softmax' ) model = model.cuda() optimizer = torchreid.optim.build_optimizer( model, optim='adam', lr=0.0003 ) scheduler = torchreid.optim.build_lr_scheduler( optimizer, lr_scheduler='single_step', stepsize=20 ) engine = torchreid.engine.ImageSoftmaxEngine( datamanager, model, optimizer, scheduler=scheduler ) engine.run( max_epoch=60, save_dir='log/resnet50-softmax-market1501', print_freq=10 )
-
class
torchreid.engine.image.triplet.
ImageTripletEngine
(datamanager, model, optimizer, margin=0.3, weight_t=1, weight_x=1, scheduler=None, use_gpu=True, label_smooth=True)[source]¶ Triplet-loss engine for image-reid.
- Parameters
datamanager (DataManager) – an instance of
torchreid.data.ImageDataManager
ortorchreid.data.VideoDataManager
.model (nn.Module) – model instance.
optimizer (Optimizer) – an Optimizer.
margin (float, optional) – margin for triplet loss. Default is 0.3.
weight_t (float, optional) – weight for triplet loss. Default is 1.
weight_x (float, optional) – weight for softmax loss. Default is 1.
scheduler (LRScheduler, optional) – if None, no learning rate decay will be performed.
use_gpu (bool, optional) – use gpu. Default is True.
label_smooth (bool, optional) – use label smoothing regularizer. Default is True.
Examples:
import torchreid datamanager = torchreid.data.ImageDataManager( root='path/to/reid-data', sources='market1501', height=256, width=128, combineall=False, batch_size=32, num_instances=4, train_sampler='RandomIdentitySampler' # this is important ) model = torchreid.models.build_model( name='resnet50', num_classes=datamanager.num_train_pids, loss='triplet' ) model = model.cuda() optimizer = torchreid.optim.build_optimizer( model, optim='adam', lr=0.0003 ) scheduler = torchreid.optim.build_lr_scheduler( optimizer, lr_scheduler='single_step', stepsize=20 ) engine = torchreid.engine.ImageTripletEngine( datamanager, model, optimizer, margin=0.3, weight_t=0.7, weight_x=1, scheduler=scheduler ) engine.run( max_epoch=60, save_dir='log/resnet50-triplet-market1501', print_freq=10 )
Video Engines¶
-
class
torchreid.engine.video.softmax.
VideoSoftmaxEngine
(datamanager, model, optimizer, scheduler=None, use_gpu=True, label_smooth=True, pooling_method='avg')[source]¶ Softmax-loss engine for video-reid.
- Parameters
datamanager (DataManager) – an instance of
torchreid.data.ImageDataManager
ortorchreid.data.VideoDataManager
.model (nn.Module) – model instance.
optimizer (Optimizer) – an Optimizer.
scheduler (LRScheduler, optional) – if None, no learning rate decay will be performed.
use_gpu (bool, optional) – use gpu. Default is True.
label_smooth (bool, optional) – use label smoothing regularizer. Default is True.
pooling_method (str, optional) – how to pool features for a tracklet. Default is “avg” (average). Choices are [“avg”, “max”].
Examples:
import torch import torchreid # Each batch contains batch_size*seq_len images datamanager = torchreid.data.VideoDataManager( root='path/to/reid-data', sources='mars', height=256, width=128, combineall=False, batch_size=8, # number of tracklets seq_len=15 # number of images in each tracklet ) model = torchreid.models.build_model( name='resnet50', num_classes=datamanager.num_train_pids, loss='softmax' ) model = model.cuda() optimizer = torchreid.optim.build_optimizer( model, optim='adam', lr=0.0003 ) scheduler = torchreid.optim.build_lr_scheduler( optimizer, lr_scheduler='single_step', stepsize=20 ) engine = torchreid.engine.VideoSoftmaxEngine( datamanager, model, optimizer, scheduler=scheduler, pooling_method='avg' ) engine.run( max_epoch=60, save_dir='log/resnet50-softmax-mars', print_freq=10 )
-
class
torchreid.engine.video.triplet.
VideoTripletEngine
(datamanager, model, optimizer, margin=0.3, weight_t=1, weight_x=1, scheduler=None, use_gpu=True, label_smooth=True, pooling_method='avg')[source]¶ Triplet-loss engine for video-reid.
- Parameters
datamanager (DataManager) – an instance of
torchreid.data.ImageDataManager
ortorchreid.data.VideoDataManager
.model (nn.Module) – model instance.
optimizer (Optimizer) – an Optimizer.
margin (float, optional) – margin for triplet loss. Default is 0.3.
weight_t (float, optional) – weight for triplet loss. Default is 1.
weight_x (float, optional) – weight for softmax loss. Default is 1.
scheduler (LRScheduler, optional) – if None, no learning rate decay will be performed.
use_gpu (bool, optional) – use gpu. Default is True.
label_smooth (bool, optional) – use label smoothing regularizer. Default is True.
pooling_method (str, optional) – how to pool features for a tracklet. Default is “avg” (average). Choices are [“avg”, “max”].
Examples:
import torch import torchreid # Each batch contains batch_size*seq_len images # Each identity is sampled with num_instances tracklets datamanager = torchreid.data.VideoDataManager( root='path/to/reid-data', sources='mars', height=256, width=128, combineall=False, num_instances=4, train_sampler='RandomIdentitySampler' batch_size=8, # number of tracklets seq_len=15 # number of images in each tracklet ) model = torchreid.models.build_model( name='resnet50', num_classes=datamanager.num_train_pids, loss='triplet' ) model = model.cuda() optimizer = torchreid.optim.build_optimizer( model, optim='adam', lr=0.0003 ) scheduler = torchreid.optim.build_lr_scheduler( optimizer, lr_scheduler='single_step', stepsize=20 ) engine = torchreid.engine.VideoTripletEngine( datamanager, model, optimizer, margin=0.3, weight_t=0.7, weight_x=1, scheduler=scheduler, pooling_method='avg' ) engine.run( max_epoch=60, save_dir='log/resnet50-triplet-mars', print_freq=10 )