from __future__ import absolute_import
import os
import sys
import os.path as osp
from .tools import mkdir_if_missing
__all__ = ['Logger', 'RankLogger']
[docs]class Logger(object):
"""Writes console output to external text file.
Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py>`_
Args:
fpath (str): directory to save logging file.
Examples::
>>> import sys
>>> import os
>>> import os.path as osp
>>> from torchreid.utils import Logger
>>> save_dir = 'log/resnet50-softmax-market1501'
>>> log_name = 'train.log'
>>> sys.stdout = Logger(osp.join(args.save_dir, log_name))
"""
def __init__(self, fpath=None):
self.console = sys.stdout
self.file = None
if fpath is not None:
mkdir_if_missing(osp.dirname(fpath))
self.file = open(fpath, 'w')
def __del__(self):
self.close()
def __enter__(self):
pass
def __exit__(self, *args):
self.close()
def write(self, msg):
self.console.write(msg)
if self.file is not None:
self.file.write(msg)
def flush(self):
self.console.flush()
if self.file is not None:
self.file.flush()
os.fsync(self.file.fileno())
def close(self):
self.console.close()
if self.file is not None:
self.file.close()
[docs]class RankLogger(object):
"""Records the rank1 matching accuracy obtained for each
test dataset at specified evaluation steps and provides a function
to show the summarized results, which are convenient for analysis.
Args:
sources (str or list): source dataset name(s).
targets (str or list): target dataset name(s).
Examples::
>>> from torchreid.utils import RankLogger
>>> s = 'market1501'
>>> t = 'market1501'
>>> ranklogger = RankLogger(s, t)
>>> ranklogger.write(t, 10, 0.5)
>>> ranklogger.write(t, 20, 0.7)
>>> ranklogger.write(t, 30, 0.9)
>>> ranklogger.show_summary()
>>> # You will see:
>>> # => Show performance summary
>>> # market1501 (source)
>>> # - epoch 10 rank1 50.0%
>>> # - epoch 20 rank1 70.0%
>>> # - epoch 30 rank1 90.0%
>>> # If there are multiple test datasets
>>> t = ['market1501', 'dukemtmcreid']
>>> ranklogger = RankLogger(s, t)
>>> ranklogger.write(t[0], 10, 0.5)
>>> ranklogger.write(t[0], 20, 0.7)
>>> ranklogger.write(t[0], 30, 0.9)
>>> ranklogger.write(t[1], 10, 0.1)
>>> ranklogger.write(t[1], 20, 0.2)
>>> ranklogger.write(t[1], 30, 0.3)
>>> ranklogger.show_summary()
>>> # You can see:
>>> # => Show performance summary
>>> # market1501 (source)
>>> # - epoch 10 rank1 50.0%
>>> # - epoch 20 rank1 70.0%
>>> # - epoch 30 rank1 90.0%
>>> # dukemtmcreid (target)
>>> # - epoch 10 rank1 10.0%
>>> # - epoch 20 rank1 20.0%
>>> # - epoch 30 rank1 30.0%
"""
def __init__(self, sources, targets):
self.sources = sources
self.targets = targets
if isinstance(self.sources, str):
self.sources = [self.sources]
if isinstance(self.targets, str):
self.targets = [self.targets]
self.logger = {
name: {
'epoch': [],
'rank1': []
}
for name in self.targets
}
[docs] def write(self, name, epoch, rank1):
"""Writes result.
Args:
name (str): dataset name.
epoch (int): current epoch.
rank1 (float): rank1 result.
"""
self.logger[name]['epoch'].append(epoch)
self.logger[name]['rank1'].append(rank1)
[docs] def show_summary(self):
"""Shows saved results."""
print('=> Show performance summary')
for name in self.targets:
from_where = 'source' if name in self.sources else 'target'
print('{} ({})'.format(name, from_where))
for epoch, rank1 in zip(
self.logger[name]['epoch'], self.logger[name]['rank1']
):
print('- epoch {}\t rank1 {:.1%}'.format(epoch, rank1))