Source code for torchreid.models.xception

from __future__ import division, absolute_import
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo

__all__ = ['xception']

pretrained_settings = {
    'xception': {
        'imagenet': {
            'url':
            'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
            'input_space': 'RGB',
            'input_size': [3, 299, 299],
            'input_range': [0, 1],
            'mean': [0.5, 0.5, 0.5],
            'std': [0.5, 0.5, 0.5],
            'num_classes': 1000,
            'scale':
            0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
        }
    }
}


class SeparableConv2d(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=1,
        stride=1,
        padding=0,
        dilation=1,
        bias=False
    ):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups=in_channels,
            bias=bias
        )
        self.pointwise = nn.Conv2d(
            in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):

    def __init__(
        self,
        in_filters,
        out_filters,
        reps,
        strides=1,
        start_with_relu=True,
        grow_first=True
    ):
        super(Block, self).__init__()

        if out_filters != in_filters or strides != 1:
            self.skip = nn.Conv2d(
                in_filters, out_filters, 1, stride=strides, bias=False
            )
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip = None

        self.relu = nn.ReLU(inplace=True)
        rep = []

        filters = in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(
                SeparableConv2d(
                    in_filters,
                    out_filters,
                    3,
                    stride=1,
                    padding=1,
                    bias=False
                )
            )
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps - 1):
            rep.append(self.relu)
            rep.append(
                SeparableConv2d(
                    filters, filters, 3, stride=1, padding=1, bias=False
                )
            )
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(
                SeparableConv2d(
                    in_filters,
                    out_filters,
                    3,
                    stride=1,
                    padding=1,
                    bias=False
                )
            )
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3, strides, 1))
        self.rep = nn.Sequential(*rep)

    def forward(self, inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x += skip
        return x


[docs]class Xception(nn.Module): """Xception. Reference: Chollet. Xception: Deep Learning with Depthwise Separable Convolutions. CVPR 2017. Public keys: - ``xception``: Xception. """ def __init__( self, num_classes, loss, fc_dims=None, dropout_p=None, **kwargs ): super(Xception, self).__init__() self.loss = loss self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) self.conv2 = nn.Conv2d(32, 64, 3, bias=False) self.bn2 = nn.BatchNorm2d(64) self.block1 = Block( 64, 128, 2, 2, start_with_relu=False, grow_first=True ) self.block2 = Block( 128, 256, 2, 2, start_with_relu=True, grow_first=True ) self.block3 = Block( 256, 728, 2, 2, start_with_relu=True, grow_first=True ) self.block4 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block5 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block6 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block7 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block8 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block9 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block10 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block11 = Block( 728, 728, 3, 1, start_with_relu=True, grow_first=True ) self.block12 = Block( 728, 1024, 2, 2, start_with_relu=True, grow_first=False ) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) self.bn3 = nn.BatchNorm2d(1536) self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) self.bn4 = nn.BatchNorm2d(2048) self.global_avgpool = nn.AdaptiveAvgPool2d(1) self.feature_dim = 2048 self.fc = self._construct_fc_layer(fc_dims, 2048, dropout_p) self.classifier = nn.Linear(self.feature_dim, num_classes) self._init_params() def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): """Constructs fully connected layer. Args: fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed input_dim (int): input dimension dropout_p (float): dropout probability, if None, dropout is unused """ if fc_dims is None: self.feature_dim = input_dim return None assert isinstance( fc_dims, (list, tuple) ), 'fc_dims must be either list or tuple, but got {}'.format( type(fc_dims) ) layers = [] for dim in fc_dims: layers.append(nn.Linear(input_dim, dim)) layers.append(nn.BatchNorm1d(dim)) layers.append(nn.ReLU(inplace=True)) if dropout_p is not None: layers.append(nn.Dropout(p=dropout_p)) input_dim = dim self.feature_dim = fc_dims[-1] return nn.Sequential(*layers) def _init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu' ) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def featuremaps(self, input): x = self.conv1(input) x = self.bn1(x) x = F.relu(x, inplace=True) x = self.conv2(x) x = self.bn2(x) x = F.relu(x, inplace=True) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) x = self.block12(x) x = self.conv3(x) x = self.bn3(x) x = F.relu(x, inplace=True) x = self.conv4(x) x = self.bn4(x) x = F.relu(x, inplace=True) return x def forward(self, x): f = self.featuremaps(x) v = self.global_avgpool(f) v = v.view(v.size(0), -1) if self.fc is not None: v = self.fc(v) if not self.training: return v y = self.classifier(v) if self.loss == 'softmax': return y elif self.loss == 'triplet': return y, v else: raise KeyError('Unsupported loss: {}'.format(self.loss))
def init_pretrained_weights(model, model_url): """Initialize models with pretrained weights. Layers that don't match with pretrained layers in name or size are kept unchanged. """ pretrain_dict = model_zoo.load_url(model_url) model_dict = model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size() } model_dict.update(pretrain_dict) model.load_state_dict(model_dict) def xception(num_classes, loss='softmax', pretrained=True, **kwargs): model = Xception(num_classes, loss, fc_dims=None, dropout_p=None, **kwargs) if pretrained: model_url = pretrained_settings['xception']['imagenet']['url'] init_pretrained_weights(model, model_url) return model