Source code for torchreid.models.hacnn

from __future__ import division, absolute_import
import torch
from torch import nn
from torch.nn import functional as F

__all__ = ['HACNN']


class ConvBlock(nn.Module):
    """Basic convolutional block.
    
    convolution + batch normalization + relu.

    Args:
        in_c (int): number of input channels.
        out_c (int): number of output channels.
        k (int or tuple): kernel size.
        s (int or tuple): stride.
        p (int or tuple): padding.
    """

    def __init__(self, in_c, out_c, k, s=1, p=0):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
        self.bn = nn.BatchNorm2d(out_c)

    def forward(self, x):
        return F.relu(self.bn(self.conv(x)))


class InceptionA(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(InceptionA, self).__init__()
        mid_channels = out_channels // 4

        self.stream1 = nn.Sequential(
            ConvBlock(in_channels, mid_channels, 1),
            ConvBlock(mid_channels, mid_channels, 3, p=1),
        )
        self.stream2 = nn.Sequential(
            ConvBlock(in_channels, mid_channels, 1),
            ConvBlock(mid_channels, mid_channels, 3, p=1),
        )
        self.stream3 = nn.Sequential(
            ConvBlock(in_channels, mid_channels, 1),
            ConvBlock(mid_channels, mid_channels, 3, p=1),
        )
        self.stream4 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1),
            ConvBlock(in_channels, mid_channels, 1),
        )

    def forward(self, x):
        s1 = self.stream1(x)
        s2 = self.stream2(x)
        s3 = self.stream3(x)
        s4 = self.stream4(x)
        y = torch.cat([s1, s2, s3, s4], dim=1)
        return y


class InceptionB(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(InceptionB, self).__init__()
        mid_channels = out_channels // 4

        self.stream1 = nn.Sequential(
            ConvBlock(in_channels, mid_channels, 1),
            ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
        )
        self.stream2 = nn.Sequential(
            ConvBlock(in_channels, mid_channels, 1),
            ConvBlock(mid_channels, mid_channels, 3, p=1),
            ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
        )
        self.stream3 = nn.Sequential(
            nn.MaxPool2d(3, stride=2, padding=1),
            ConvBlock(in_channels, mid_channels * 2, 1),
        )

    def forward(self, x):
        s1 = self.stream1(x)
        s2 = self.stream2(x)
        s3 = self.stream3(x)
        y = torch.cat([s1, s2, s3], dim=1)
        return y


class SpatialAttn(nn.Module):
    """Spatial Attention (Sec. 3.1.I.1)"""

    def __init__(self):
        super(SpatialAttn, self).__init__()
        self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
        self.conv2 = ConvBlock(1, 1, 1)

    def forward(self, x):
        # global cross-channel averaging
        x = x.mean(1, keepdim=True)
        # 3-by-3 conv
        x = self.conv1(x)
        # bilinear resizing
        x = F.upsample(
            x, (x.size(2) * 2, x.size(3) * 2),
            mode='bilinear',
            align_corners=True
        )
        # scaling conv
        x = self.conv2(x)
        return x


class ChannelAttn(nn.Module):
    """Channel Attention (Sec. 3.1.I.2)"""

    def __init__(self, in_channels, reduction_rate=16):
        super(ChannelAttn, self).__init__()
        assert in_channels % reduction_rate == 0
        self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
        self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)

    def forward(self, x):
        # squeeze operation (global average pooling)
        x = F.avg_pool2d(x, x.size()[2:])
        # excitation operation (2 conv layers)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class SoftAttn(nn.Module):
    """Soft Attention (Sec. 3.1.I)
    
    Aim: Spatial Attention + Channel Attention
    
    Output: attention maps with shape identical to input.
    """

    def __init__(self, in_channels):
        super(SoftAttn, self).__init__()
        self.spatial_attn = SpatialAttn()
        self.channel_attn = ChannelAttn(in_channels)
        self.conv = ConvBlock(in_channels, in_channels, 1)

    def forward(self, x):
        y_spatial = self.spatial_attn(x)
        y_channel = self.channel_attn(x)
        y = y_spatial * y_channel
        y = torch.sigmoid(self.conv(y))
        return y


class HardAttn(nn.Module):
    """Hard Attention (Sec. 3.1.II)"""

    def __init__(self, in_channels):
        super(HardAttn, self).__init__()
        self.fc = nn.Linear(in_channels, 4 * 2)
        self.init_params()

    def init_params(self):
        self.fc.weight.data.zero_()
        self.fc.bias.data.copy_(
            torch.tensor(
                [0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float
            )
        )

    def forward(self, x):
        # squeeze operation (global average pooling)
        x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
        # predict transformation parameters
        theta = torch.tanh(self.fc(x))
        theta = theta.view(-1, 4, 2)
        return theta


class HarmAttn(nn.Module):
    """Harmonious Attention (Sec. 3.1)"""

    def __init__(self, in_channels):
        super(HarmAttn, self).__init__()
        self.soft_attn = SoftAttn(in_channels)
        self.hard_attn = HardAttn(in_channels)

    def forward(self, x):
        y_soft_attn = self.soft_attn(x)
        theta = self.hard_attn(x)
        return y_soft_attn, theta


[docs]class HACNN(nn.Module): """Harmonious Attention Convolutional Neural Network. Reference: Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018. Public keys: - ``hacnn``: HACNN. """ # Args: # num_classes (int): number of classes to predict # nchannels (list): number of channels AFTER concatenation # feat_dim (int): feature dimension for a single stream # learn_region (bool): whether to learn region features (i.e. local branch) def __init__( self, num_classes, loss='softmax', nchannels=[128, 256, 384], feat_dim=512, learn_region=True, use_gpu=True, **kwargs ): super(HACNN, self).__init__() self.loss = loss self.learn_region = learn_region self.use_gpu = use_gpu self.conv = ConvBlock(3, 32, 3, s=2, p=1) # Construct Inception + HarmAttn blocks # ============== Block 1 ============== self.inception1 = nn.Sequential( InceptionA(32, nchannels[0]), InceptionB(nchannels[0], nchannels[0]), ) self.ha1 = HarmAttn(nchannels[0]) # ============== Block 2 ============== self.inception2 = nn.Sequential( InceptionA(nchannels[0], nchannels[1]), InceptionB(nchannels[1], nchannels[1]), ) self.ha2 = HarmAttn(nchannels[1]) # ============== Block 3 ============== self.inception3 = nn.Sequential( InceptionA(nchannels[1], nchannels[2]), InceptionB(nchannels[2], nchannels[2]), ) self.ha3 = HarmAttn(nchannels[2]) self.fc_global = nn.Sequential( nn.Linear(nchannels[2], feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU(), ) self.classifier_global = nn.Linear(feat_dim, num_classes) if self.learn_region: self.init_scale_factors() self.local_conv1 = InceptionB(32, nchannels[0]) self.local_conv2 = InceptionB(nchannels[0], nchannels[1]) self.local_conv3 = InceptionB(nchannels[1], nchannels[2]) self.fc_local = nn.Sequential( nn.Linear(nchannels[2] * 4, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU(), ) self.classifier_local = nn.Linear(feat_dim, num_classes) self.feat_dim = feat_dim * 2 else: self.feat_dim = feat_dim def init_scale_factors(self): # initialize scale factors (s_w, s_h) for four regions self.scale_factors = [] self.scale_factors.append( torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) ) self.scale_factors.append( torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) ) self.scale_factors.append( torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) ) self.scale_factors.append( torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float) ) def stn(self, x, theta): """Performs spatial transform x: (batch, channel, height, width) theta: (batch, 2, 3) """ grid = F.affine_grid(theta, x.size()) x = F.grid_sample(x, grid) return x def transform_theta(self, theta_i, region_idx): """Transforms theta to include (s_w, s_h), resulting in (batch, 2, 3)""" scale_factors = self.scale_factors[region_idx] theta = torch.zeros(theta_i.size(0), 2, 3) theta[:, :, :2] = scale_factors theta[:, :, -1] = theta_i if self.use_gpu: theta = theta.cuda() return theta def forward(self, x): assert x.size(2) == 160 and x.size(3) == 64, \ 'Input size does not match, expected (160, 64) but got ({}, {})'.format(x.size(2), x.size(3)) x = self.conv(x) # ============== Block 1 ============== # global branch x1 = self.inception1(x) x1_attn, x1_theta = self.ha1(x1) x1_out = x1 * x1_attn # local branch if self.learn_region: x1_local_list = [] for region_idx in range(4): x1_theta_i = x1_theta[:, region_idx, :] x1_theta_i = self.transform_theta(x1_theta_i, region_idx) x1_trans_i = self.stn(x, x1_theta_i) x1_trans_i = F.upsample( x1_trans_i, (24, 28), mode='bilinear', align_corners=True ) x1_local_i = self.local_conv1(x1_trans_i) x1_local_list.append(x1_local_i) # ============== Block 2 ============== # Block 2 # global branch x2 = self.inception2(x1_out) x2_attn, x2_theta = self.ha2(x2) x2_out = x2 * x2_attn # local branch if self.learn_region: x2_local_list = [] for region_idx in range(4): x2_theta_i = x2_theta[:, region_idx, :] x2_theta_i = self.transform_theta(x2_theta_i, region_idx) x2_trans_i = self.stn(x1_out, x2_theta_i) x2_trans_i = F.upsample( x2_trans_i, (12, 14), mode='bilinear', align_corners=True ) x2_local_i = x2_trans_i + x1_local_list[region_idx] x2_local_i = self.local_conv2(x2_local_i) x2_local_list.append(x2_local_i) # ============== Block 3 ============== # Block 3 # global branch x3 = self.inception3(x2_out) x3_attn, x3_theta = self.ha3(x3) x3_out = x3 * x3_attn # local branch if self.learn_region: x3_local_list = [] for region_idx in range(4): x3_theta_i = x3_theta[:, region_idx, :] x3_theta_i = self.transform_theta(x3_theta_i, region_idx) x3_trans_i = self.stn(x2_out, x3_theta_i) x3_trans_i = F.upsample( x3_trans_i, (6, 7), mode='bilinear', align_corners=True ) x3_local_i = x3_trans_i + x2_local_list[region_idx] x3_local_i = self.local_conv3(x3_local_i) x3_local_list.append(x3_local_i) # ============== Feature generation ============== # global branch x_global = F.avg_pool2d(x3_out, x3_out.size()[2:] ).view(x3_out.size(0), x3_out.size(1)) x_global = self.fc_global(x_global) # local branch if self.learn_region: x_local_list = [] for region_idx in range(4): x_local_i = x3_local_list[region_idx] x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:] ).view(x_local_i.size(0), -1) x_local_list.append(x_local_i) x_local = torch.cat(x_local_list, 1) x_local = self.fc_local(x_local) if not self.training: # l2 normalization before concatenation if self.learn_region: x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True) x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True) return torch.cat([x_global, x_local], 1) else: return x_global prelogits_global = self.classifier_global(x_global) if self.learn_region: prelogits_local = self.classifier_local(x_local) if self.loss == 'softmax': if self.learn_region: return (prelogits_global, prelogits_local) else: return prelogits_global elif self.loss == 'triplet': if self.learn_region: return (prelogits_global, prelogits_local), (x_global, x_local) else: return prelogits_global, x_global else: raise KeyError("Unsupported loss: {}".format(self.loss))