构造函数__init__
def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
- 初始化函数定义了网络的主要结构和参数。
channel
: 输入特征的通道数。dim
: Transformer部分的特征维度。depth
: Transformer的层数。kernel_size
: 卷积层的核大小。patch_size
: 将图像分割为patches的尺寸。mlp_dim
: Transformer中前馈网络的维度。dropout
: Dropout比率,用于正则化。
网络层的定义
self.mv01 = IRBlock(channel, channel)self.conv1 = conv_nxn_bn(channel, channel, kernel_size)self.conv3 = conv_1x1_bn(dim, channel)self.conv2 = conv_1x1_bn(channel, dim)self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
IRBlock
和conv_nxn_bn
,conv_1x1_bn
用于特征提取和维度变换。UserDefined
是之前提到的基于Transformer的结构,用于处理序列数据。- 这些层的组合利用了CNN的空间特征提取能力和Transformer的序列处理能力。
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)
前向传播 forward
def forward(self, x):y = x.clone()x = self.conv1(x)x = self.conv2(x)z = x.clone()_, _, h, w = x.shapex = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)x = self.transformer(x)x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)x = self.conv3(x)x = torch.cat((x, z), 1)x = self.conv4(x)x = x + yx = self.mv01(x)return x
forward
方法定义了数据通过网络的流程。- 输入
x
首先经过几个卷积层进行特征提取和维度变换。 - 输入被重组(
rearrange
),准备送入Transformer结构。 - Transformer处理重组后的数据再被重组回原来的形状。
- 经过进一步的卷积处理后,使用残差连接,并通过另一个
IRBlock
。
完整代码:
class MobileViTBv3(nn.Module):def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):super().__init__()self.ph, self.pw = patch_sizeself.mv01 = IRBlock(channel, channel) self.conv1 = conv_nxn_bn(channel, channel, kernel_size)self.conv3 = conv_1x1_bn(dim, channel)self.conv2 = conv_1x1_bn(channel, dim)self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)def forward(self, x):y = x.clone()x = self.conv1(x)x = self.conv2(x)z = x.clone()_, _, h, w = x.shapex = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)x = self.transformer(x)x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)x = self.conv3(x)x = torch.cat((x, z), 1)x = self.conv4(x)x = x + yx = self.mv01(x)return x