YOLOV7改进–添加CBAM注意力机制
- CBAM注意力机制
- 代码
- 在commen.py中添加CBAM模块
- 在yolo.py中添加CBAM模块名
- 在cfg文件中添加CBAM信息
因为项目需要,尝试在yolov7上加入CBAM注意力机制,看看能不能提升点性能。之前有在yolov5上添加CBAM的经验,所以直接把yolov5中的CBAM搬过来,废话不多说,直接看代码吧!
CBAM注意力机制
首先,介绍一下CBAM注意力机制:
论文来源:https://arxiv.org/pdf/1807.06521.pdf
Convolutional Block Attention Module (CBAM)由两个模块构成,分别为通道注意力(CAM)和空间注意力模块(SAM),CAM可以使网络关注图像的前景,使网络更加关注有意义的gt区域,而SAM可以让网络关注到整张图片中富含上下文信息的位置。这两个模块即插即用,建议串行加入到网络中(论文里面是串行比并行好,在博主的数据集下,并行和串行效果不明显,博主认为特征融合没有苛刻的要求,视使用的数据集而定,怎么连效果好就怎么连),下面的展示的代码是串行方法。
代码
在commen.py中添加CBAM模块
这部分代码同yolov5的一样,直接拿来用!
class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) self.relu = nn.ReLU() self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.f2(self.relu(self.f1(self.avg_pool(x)))) max_out = self.f2(self.relu(self.f1(self.max_pool(x)))) out = self.sigmoid(avg_out + max_out) return outclass SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv(x) return self.sigmoid(x) class CBAM(nn.Module): # Standard convolution def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super(CBAM, self).__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.Hardswish() if act else nn.Identity() self.ca = ChannelAttention(c2) self.sa = SpatialAttention() def forward(self, x): x = self.act(self.bn(self.conv(x))) x = self.ca(x) * x x = self.sa(x) * x return x def fuseforward(self, x): return self.act(self.conv(x))
在yolo.py中添加CBAM模块名
找到yolo.py第459行,加入CBAM模块名。
if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC, SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv, Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC, Res, ResCSPA, ResCSPB, ResCSPC, RepRes, RepResCSPA, RepResCSPB, RepResCSPC, ResX, ResXCSPA, ResXCSPB, ResXCSPC, RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC, Ghost, GhostCSPA, GhostCSPB, GhostCSPC, SwinTransformerBlock, STCSPA, STCSPB, STCSPC, SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC, CBAM]: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) args = [c1, c2, *args[1:]] if m in [DownC, SPPCSPC, GhostSPPCSPC, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC, ResCSPA, ResCSPB, ResCSPC, RepResCSPA, RepResCSPB, RepResCSPC, ResXCSPA, ResXCSPB, ResXCSPC, RepResXCSPA, RepResXCSPB, RepResXCSPC, GhostCSPA, GhostCSPB, GhostCSPC, STCSPA, STCSPB, STCSPC, ST2CSPA, ST2CSPB, ST2CSPC]: args.insert(2, n) # number of repeats n = 1
在cfg文件中添加CBAM信息
这里以添加到backbone为例,将Conv替换成CBAM即可,同样也可在FPN里替换。
# parametersnc: 80 # number of classesdepth_multiple: 1.0 # model depth multiplewidth_multiple: 1.0 # layer channel multiple# anchorsanchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 - [116,90, 156,198, 373,326] # P5/32backbone: # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True # [[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 0-P1/2 [[-1, 1, CBAM, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 0-P1/2 # [-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 1-P2/4 [-1, 1, CBAM, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 1-P2/4 [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 7 [-1, 1, MP, []], # 8-P3/8 [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 14 [-1, 1, MP, []], # 15-P4/16 [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 21 [-1, 1, MP, []], # 22-P5/32 [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 28 ]# yolov7-tiny headhead: [[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, SP, [5]], [-2, 1, SP, [9]], [-3, 1, SP, [13]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -7], 1, Concat, [1]], [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 37 [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4 [[-1, -2], 1, Concat, [1]], [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 47 [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3 [[-1, -2], 1, Concat, [1]], [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 57 [-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]], [[-1, 47], 1, Concat, [1]], [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 65 [-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]], [[-1, 37], 1, Concat, [1]], [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 73 [57, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [65, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [73, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[74,75,76], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ]