文章目录
1.
import torch.nn as nnimport torchfrom einops import rearrange, repeatfrom einops.layers.torch import Rearrangeimport torch.nn.functional as Fclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class PPM(nn.Module):def __init__(self, pooling_sizes=(1, 3, 5)):super().__init__()self.layer = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size=(size, size)) for size in pooling_sizes])def forward(self, feat):b, c, h, w = feat.shapeoutput = [layer(feat).view(b, c, -1) for layer in self.layer]output = torch.cat(output, dim=-1)return outputclass ESA_layer(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)self.ppm = PPM(pooling_sizes=(1, 3, 5))self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):b, c, h, w = x.shapeq, k, v = self.to_qkv(x).chunk(3, dim=1)q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads)k, v = self.ppm(k), self.ppm(v)k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads)v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b head n d -> b n (head d)')return self.to_out(out)class ESA_blcok(nn.Module):def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout=0.):super().__init__()self.ESAlayer = ESA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout)self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))def forward(self, x):b, c, h, w = x.shapeout = rearrange(x, 'b c h w -> b (h w) c')out = self.ESAlayer(x) + outout = self.ff(out) + outout = rearrange(out, 'b (h w) c -> b c h w', h=h)return out+xdef MaskAveragePooling(x, mask):mask = torch.sigmoid(mask)b, c, h, w = x.shapeeps = 0.0005x_mask = x * maskh, w = x.shape[2], x.shape[3]area = F.avg_pool2d(mask, (h, w)) * h * w + epsx_feat = F.avg_pool2d(x_mask, (h, w)) * h * w / areax_feat = x_feat.view(b, c, -1)return x_featclass LCA_layer(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head * headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim=-1)self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x, mask):b, c, h, w = x.shapeq, k, v = self.to_qkv(x).chunk(3, dim=1)q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads)k, v = MaskAveragePooling(k, mask), MaskAveragePooling(v, mask)k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads)v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b head n d -> b n (head d)')return self.to_out(out)class LCA_blcok(nn.Module):def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout=0.):super().__init__()self.LCAlayer = LCA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout)self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))def forward(self, x, mask):b, c, h, w = x.shapeout = rearrange(x, 'b c h w -> b (h w) c')out = self.LCAlayer(x, mask) + outout = self.ff(out) + outout = rearrange(out, 'b (h w) c -> b c h w', h=h)return outif __name__ == '__main__':x = torch.rand((4, 3, 320, 320))mask = torch.rand(4, 1, 320, 320)lca = LCA_blcok(dim=3)esa = ESA_blcok(dim=3)print(lca(x, mask).shape)print(esa(x).shape)