Source code for torchreid.models.shufflenet

from __future__ import division, absolute_import
import torch
import torch.utils.model_zoo as model_zoo
from torch import nn
from torch.nn import functional as F

__all__ = ['shufflenet']

model_urls = {
    # training epoch = 90, top1 = 61.8
    'imagenet':
    'https://mega.nz/#!RDpUlQCY!tr_5xBEkelzDjveIYBBcGcovNCOrgfiJO9kiidz9fZM',
}


class ChannelShuffle(nn.Module):

    def __init__(self, num_groups):
        super(ChannelShuffle, self).__init__()
        self.g = num_groups

    def forward(self, x):
        b, c, h, w = x.size()
        n = c // self.g
        # reshape
        x = x.view(b, self.g, n, h, w)
        # transpose
        x = x.permute(0, 2, 1, 3, 4).contiguous()
        # flatten
        x = x.view(b, c, h, w)
        return x


class Bottleneck(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        num_groups,
        group_conv1x1=True
    ):
        super(Bottleneck, self).__init__()
        assert stride in [1, 2], 'Warning: stride must be either 1 or 2'
        self.stride = stride
        mid_channels = out_channels // 4
        if stride == 2:
            out_channels -= in_channels
        # group conv is not applied to first conv1x1 at stage 2
        num_groups_conv1x1 = num_groups if group_conv1x1 else 1
        self.conv1 = nn.Conv2d(
            in_channels,
            mid_channels,
            1,
            groups=num_groups_conv1x1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.shuffle1 = ChannelShuffle(num_groups)
        self.conv2 = nn.Conv2d(
            mid_channels,
            mid_channels,
            3,
            stride=stride,
            padding=1,
            groups=mid_channels,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(mid_channels)
        self.conv3 = nn.Conv2d(
            mid_channels, out_channels, 1, groups=num_groups, bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channels)
        if stride == 2:
            self.shortcut = nn.AvgPool2d(3, stride=2, padding=1)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.shuffle1(out)
        out = self.bn2(self.conv2(out))
        out = self.bn3(self.conv3(out))
        if self.stride == 2:
            res = self.shortcut(x)
            out = F.relu(torch.cat([res, out], 1))
        else:
            out = F.relu(x + out)
        return out


# configuration of (num_groups: #out_channels) based on Table 1 in the paper
cfg = {
    1: [144, 288, 576],
    2: [200, 400, 800],
    3: [240, 480, 960],
    4: [272, 544, 1088],
    8: [384, 768, 1536],
}


[docs]class ShuffleNet(nn.Module): """ShuffleNet. Reference: Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices. CVPR 2018. Public keys: - ``shufflenet``: ShuffleNet (groups=3). """ def __init__(self, num_classes, loss='softmax', num_groups=3, **kwargs): super(ShuffleNet, self).__init__() self.loss = loss self.conv1 = nn.Sequential( nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(24), nn.ReLU(), nn.MaxPool2d(3, stride=2, padding=1), ) self.stage2 = nn.Sequential( Bottleneck( 24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False ), Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), ) self.stage3 = nn.Sequential( Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), ) self.stage4 = nn.Sequential( Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups), Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), ) self.classifier = nn.Linear(cfg[num_groups][2], num_classes) self.feat_dim = cfg[num_groups][2] def forward(self, x): x = self.conv1(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1) if not self.training: return x y = self.classifier(x) if self.loss == 'softmax': return y elif self.loss == 'triplet': return y, x else: raise KeyError('Unsupported loss: {}'.format(self.loss))
def init_pretrained_weights(model, model_url): """Initializes model with pretrained weights. Layers that don't match with pretrained layers in name or size are kept unchanged. """ pretrain_dict = model_zoo.load_url(model_url) model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) def shufflenet(num_classes, loss='softmax', pretrained=True, **kwargs): model = ShuffleNet(num_classes, loss, **kwargs) if pretrained: # init_pretrained_weights(model, model_urls['imagenet']) import warnings warnings.warn( 'The imagenet pretrained weights need to be manually downloaded from {}' .format(model_urls['imagenet']) ) return model