Source code for torchreid.utils.tools

from __future__ import division, print_function, absolute_import
import os
import sys
import json
import time
import errno
import numpy as np
import random
import os.path as osp
import warnings
import PIL
import torch
from PIL import Image

__all__ = [
    'mkdir_if_missing', 'check_isfile', 'read_json', 'write_json',
    'set_random_seed', 'download_url', 'read_image', 'collect_env_info',
    'listdir_nohidden'
]


[docs]def mkdir_if_missing(dirname): """Creates dirname if it is missing.""" if not osp.exists(dirname): try: os.makedirs(dirname) except OSError as e: if e.errno != errno.EEXIST: raise
[docs]def check_isfile(fpath): """Checks if the given path is a file. Args: fpath (str): file path. Returns: bool """ isfile = osp.isfile(fpath) if not isfile: warnings.warn('No file found at "{}"'.format(fpath)) return isfile
[docs]def read_json(fpath): """Reads json file from a path.""" with open(fpath, 'r') as f: obj = json.load(f) return obj
[docs]def write_json(obj, fpath): """Writes to a json file.""" mkdir_if_missing(osp.dirname(fpath)) with open(fpath, 'w') as f: json.dump(obj, f, indent=4, separators=(',', ': '))
def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed)
[docs]def download_url(url, dst): """Downloads file from a url to a destination. Args: url (str): url to download file. dst (str): destination path. """ from six.moves import urllib print('* url="{}"'.format(url)) print('* destination="{}"'.format(dst)) def _reporthook(count, block_size, total_size): global start_time if count == 0: start_time = time.time() return duration = time.time() - start_time progress_size = int(count * block_size) speed = int(progress_size / (1024*duration)) percent = int(count * block_size * 100 / total_size) sys.stdout.write( '\r...%d%%, %d MB, %d KB/s, %d seconds passed' % (percent, progress_size / (1024*1024), speed, duration) ) sys.stdout.flush() urllib.request.urlretrieve(url, dst, _reporthook) sys.stdout.write('\n')
[docs]def read_image(path): """Reads image from path using ``PIL.Image``. Args: path (str): path to an image. Returns: PIL image """ got_img = False if not osp.exists(path): raise IOError('"{}" does not exist'.format(path)) while not got_img: try: img = Image.open(path).convert('RGB') got_img = True except IOError: print( 'IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.' .format(path) ) return img
[docs]def collect_env_info(): """Returns env info as a string. Code source: github.com/facebookresearch/maskrcnn-benchmark """ from torch.utils.collect_env import get_pretty_env_info env_str = get_pretty_env_info() env_str += '\n Pillow ({})'.format(PIL.__version__) return env_str
[docs]def listdir_nohidden(path, sort=False): """List non-hidden items in a directory. Args: path (str): directory path. sort (bool): sort the items. """ items = [f for f in os.listdir(path) if not f.startswith('.')] if sort: items.sort() return items