paper:GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond
official implementaion:https://github.com/xvjiarui/GCNet
Third party implementation:https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/context_block.py
存在的问题
通过捕获long-range dependency提取全局信息,对各种视觉任务都是很有帮助的。Non-local Network(介绍见https://blog.csdn.net/ooooocj/article/details/124573078)通过自注意力机制来解决这个问题。对于每个查询位置(query position),non-local network首先计算该位置和所有位置之间一个两两成对的关系,得到一个attention map。然后对attention map所有位置的权重加权求和得到汇总特征,每一个查询位置都得到一个汇总特征,将汇总特征与原始特征相加得到最终输出。
对于某个query position,non-local network计算的另一个位置与该位置的关系即一个权重值表示这个位置对query位置的重要程度。本文可视化attention map发现,不同的query位置其对应的attention map几乎一样,如下图所示
non-local block可以表示为下式
其中 \(i\) 是query position的索引,\(j\) 遍历所有位置,\(f(\mathbf{x}_{i},\mathbf{x}_{j})\) 表示位置 \(i\) 和 \(j\) 之间的关系,\(\mathcal{C}(\mathbf{x})\) 是归一化因子,\(W_{z}\) 和 \(W_{v}\) 是线性变换矩阵例如 \(1\times1\) 卷积。non-local block有多种不同的实例化方法,例如Gaussian、Embedded Gaussian、Dot product、Concat,下图(a)是Embedded Gaussian的结构。
由于需要计算每个query位置的attention map,因此non-local block的时间和空间复杂度都是所有位置的平方关系。
下图是作者从COCO数据集中随机挑选的6张图片,并可视化出3个不同的query position即图中的红点与对应的query-specific attention map,可以看出对于不同的查询位置,它们的attention map几乎是相同的。
为了进一步验证这一观察结果,作者又分析了不同的查询位置与全局上下文之间的距离。结果如下表所示。其中计算了三种向量之间的余弦距离cosine distance,分别为non-local block的输入、输出、以及查询的注意力图,对应表中的input、output、att。
从表中可以看出,input列的余弦距离比较大表明non-local的输入特征可以在不同的位置进行区分,但output列的余弦距离非常小,表明non-local block建模的全局上下文特征对于不同的query position几乎是相同的。attention map上的距离非常小,也验证了可视化的观察结果。
尽管non-local block是打算针对每个位置计算全局上下文的,但训练后的全局上下文实际上是独立于查询位置的。因为没有必要为每个查询位置单独计算query-specific全局上下文。
本文的创新点
本文通过观察发现non-local block针对每个query position计算的attention map最终结果是独立于查询位置的,那么就没有必要针对每个查询位置计算了,因此提出计算一个通用的attention map并应用于输入feature map上的所有位置,大大减少了计算量的同时又没有导致性能的降低。此外,结合SE block,设计了一个新的Global Context (GC) block,既轻量又可以有效地建模全局上下文。GC Block结合了Non-local block和SE block的优点,基于GC Block设计的GCNet在多个任务上均超过了NLNet和SENet。
方法介绍
作者舍去了式(1)中的 \(W_{z}\) 即图(3)(a)中的query分支,得到下式
这里采用的是最常用的Embedded Gaussiian的实例化方式。简化后的non-local block如图(3)(b)所示。
为了进一步减少计算量,作者应用分配率将 \(W_{v}\) 移到attention pool的外面,如下
这里简化后的non-local block如图4(b)所示。1×1卷积 \(W_{v}\) 的FLOPs从 \(\mathcal{O}(HWC^{2})\) 减小到 \(\mathcal{O}(C^{2})\)。
到目前为止简化的NL block中参数量最大的部分在transform module即图4(b)中的Transform部分,这里是一个1×1卷积但参数量为 \(C\cdot C\),当把Nl block应用到较深的层例如resnet中的 \(res_{5}\) 时,CxC=2028×2048,占据了整个block大部分的计算量。为了进一步减少计算量,作者借鉴了SE block的思想如图4(c),将图4(b)中的transform module换成了图4(d)中的bottleneck transform module,其中 \(r\) 是reduction ratio,这样参数量就从C·C变成了2·C·C/r,默认情况下r=16,因此参数量就减少为1/8。
实验结果
下表baseline是backbone为ResNet-50的Mask R-CNN在COCO数据集上的目标检测和实例分割的结果。将1个non-local block(NL)、1个simplified non-local block(SNL)、1个global context block(GC)插入到c4的最后一个residual block前,可以看出GC block获得的相似的性能但参数量更小。将GC block添加到所有的residual block中在参数量相似的情况下得到了更高的性能。
代码解析
这里的代码是mmcv中的实现。
# Copyright (c) OpenMMLab. All rights reserved.from typing import Unionimport torchfrom torch import nnfrom ..utils import constant_init, kaiming_initfrom .registry import PLUGIN_LAYERSdef last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:if isinstance(m, nn.Sequential):constant_init(m[-1], val=0)else:constant_init(m, val=0)@PLUGIN_LAYERS.register_module()class ContextBlock(nn.Module):"""ContextBlock module in GCNet.See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'(https://arxiv.org/abs/1904.11492) for details.Args:in_channels (int): Channels of the input feature map.ratio (float): Ratio of channels of transform bottleneckpooling_type (str): Pooling method for context modeling.Options are 'att' and 'avg', stand for attention pooling andaverage pooling respectively. Default: 'att'.fusion_types (Sequence[str]): Fusion method for feature fusion,Options are 'channels_add', 'channel_mul', stand for channelwiseaddition and multiplication respectively. Default: ('channel_add',)"""_abbr_ = 'context_block'def __init__(self, in_channels: int, ratio: float, pooling_type: str = 'att', fusion_types: tuple = ('channel_add', )):super().__init__()assert pooling_type in ['avg', 'att']assert isinstance(fusion_types, (list, tuple))valid_fusion_types = ['channel_add', 'channel_mul']assert all([f in valid_fusion_types for f in fusion_types])assert len(fusion_types) > 0, 'at least one fusion should be used'self.in_channels = in_channelsself.ratio = ratioself.planes = int(in_channels * ratio)self.pooling_type = pooling_typeself.fusion_types = fusion_typesif pooling_type == 'att':self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)self.softmax = nn.Softmax(dim=2)else:self.avg_pool = nn.AdaptiveAvgPool2d(1)if 'channel_add' in fusion_types:self.channel_add_conv = nn.Sequential(nn.Conv2d(self.in_channels, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),# yapf: disablenn.Conv2d(self.planes, self.in_channels, kernel_size=1))else:self.channel_add_conv = Noneif 'channel_mul' in fusion_types:self.channel_mul_conv = nn.Sequential(nn.Conv2d(self.in_channels, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),# yapf: disablenn.Conv2d(self.planes, self.in_channels, kernel_size=1))else:self.channel_mul_conv = Noneself.reset_parameters()def reset_parameters(self):if self.pooling_type == 'att':kaiming_init(self.conv_mask, mode='fan_in')self.conv_mask.inited = Trueif self.channel_add_conv is not None:last_zero_init(self.channel_add_conv)if self.channel_mul_conv is not None:last_zero_init(self.channel_mul_conv)def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:batch, channel, height, width = x.size()if self.pooling_type == 'att':input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]context_mask = self.conv_mask(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = self.softmax(context_mask)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)else:# [N, C, 1, 1]context = self.avg_pool(x)return contextdef forward(self, x: torch.Tensor) -> torch.Tensor:# [N, C, 1, 1]context = self.spatial_pool(x)out = xif self.channel_mul_conv is not None:# [N, C, 1, 1]channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))out = out * channel_mul_termif self.channel_add_conv is not None:# [N, C, 1, 1]channel_add_term = self.channel_add_conv(context)out = out + channel_add_termreturn out