使用 PyTorch 深度学习搭建模型后,如果想查看模型结构,可以直接使用 print(model) 函数打印。但该输出结果不是特别直观,查阅发现有个能输出类似 keras 风格 model.summary() 的模型可视化工具。这里记录一下方便以后查阅。
PyTorch 打印模型结构、输出维度和参数信息(torchsummary)
- 安装 torchsummary
- 输出网络信息
- AttributeError: ‘tuple’ object has no attribute ‘size’
安装 torchsummary
pip install torchsummary
输出网络信息
summary函数介绍
model
:网络模型
input_size
:网络输入图片的shape,这里不用加batch_size进去
batch_size
:batch_size参数,默认是-1
device
:在GPU还是CPU上运行,默认是cuda在GPU上运行,如果想在CPU上执行将参数改为CPU即可
import torchimport torch.nn as nnfrom torchsummary import summaryclass Shallow_ConvNet(nn.Module):def __init__(self, in_channel, conv_channel_temp, kernel_size_temp, conv_channel_spat, kernel_size_spat,pooling_size, pool_stride_size, dropoutRate, n_classes, class_kernel_size) :super(Shallow_ConvNet, self).__init__()self.temp_conv = nn.Conv2d(in_channels=in_channel,out_channels=conv_channel_temp,kernel_size=(1, kernel_size_temp),stride=1,bias=False)self.spat_conv = nn.Conv2d(in_channels=conv_channel_temp,out_channels=conv_channel_spat,kernel_size=(kernel_size_spat, 1),stride=1,bias=False)self.bn = nn.BatchNorm2d(num_features=conv_channel_spat)# slef.act_conv = x*xself.pooling = nn.AvgPool2d(kernel_size=(1, pooling_size), stride=(1, pool_stride_size))# slef.act_pool = log(max(x, eps))self.dropout = nn.Dropout(p=dropoutRate)self.class_conv = nn.Conv2d(in_channels=conv_channel_spat,out_channels=n_classes,kernel_size=(1, class_kernel_size),bias=False)self.softmax = nn.Softmax(dim=1)def safe_log(self, x):""" Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""return torch.log(torch.clamp(x, min=1e-6))def forward(self, x):# input shape (batch_size, C, T)if len(x.shape) is not 4:x = torch.unsqueeze(x, 1)# input shape (batch_size, 1, C, T)x = self.temp_conv(x)x = self.spat_conv(x)x = self.bn(x)x = x*x # conv_activatex = self.pooling(x)x = self.safe_log(x) # pool_activatex = self.dropout(x)x = self.class_conv(x)x= self.softmax(x)out = torch.squeeze(x)return out###============================ Initialization parameters ============================###channels = 44samples = 534in_channel = 1conv_channel_temp = 40kernel_size_temp = 25conv_channel_spat = 40kernel_size_spat = channelspooling_size = 75pool_stride_size = 15dropoutRate = 0.3n_classes = 4class_kernel_size = 30def main():input = torch.randn(32, 1, channels, samples)model = Shallow_ConvNet(in_channel, conv_channel_temp, kernel_size_temp, conv_channel_spat, kernel_size_spat,pooling_size, pool_stride_size, dropoutRate, n_classes, class_kernel_size)out = model(input)print('===============================================================')print('out', out.shape)print('model', model)summary(model=model, input_size=(1,channels,samples), batch_size=32, device="cpu")if __name__ == "__main__":main()
输出:
out torch.Size([32, 4])model Shallow_ConvNet((temp_conv): Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1), bias=False)(spat_conv): Conv2d(40, 40, kernel_size=(44, 1), stride=(1, 1), bias=False)(bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pooling): AvgPool2d(kernel_size=(1, 75), stride=(1, 15), padding=0)(dropout): Dropout(p=0.3, inplace=False)(class_conv): Conv2d(40, 4, kernel_size=(1, 30), stride=(1, 1), bias=False)(softmax): Softmax(dim=1))----------------------------------------------------------------Layer (type) Output Shape Param #================================================================Conv2d-1[32, 40, 44, 510] 1,000Conv2d-2 [32, 40, 1, 510]70,400 BatchNorm2d-3 [32, 40, 1, 510]80 AvgPool2d-4[32, 40, 1, 30] 0 Dropout-5[32, 40, 1, 30] 0Conv2d-6[32, 4, 1, 1] 4,800 Softmax-7[32, 4, 1, 1] 0================================================================Total params: 76,280Trainable params: 76,280Non-trainable params: 0----------------------------------------------------------------Input size (MB): 2.87Forward/backward pass size (MB): 229.69Params size (MB): 0.29Estimated Total Size (MB): 232.85----------------------------------------------------------------
AttributeError: ‘tuple’ object has no attribute ‘size’
旧的summary加入LSTM之类的会报错,需要用新的summarry
pip install torchinfo
from torchinfo import summarydef main():input = torch.randn(32, window_size, channels, samples)model = Cascade_Conv_LSTM(in_channel, out_channel_conv1, out_channel_conv2, out_channel_conv3, kernel_conv123, stride_conv123, padding_conv123,fc1_in, fc1_out, dropoutRate1, lstm1_in, lstm1_hidden, lstm1_layer, lstm2_in, lstm2_hidden, lstm2_layer, fc2_in, fc2_out, dropoutRate2,fc3_in, n_classes)# model = model.to('cuda:1')# input = torch.from_numpy(input).to('cuda:1').to(torch.float32).requires_grad_()out = model(input)print('===============================================================')print('out', out.shape)print('model', model)summary(model=model, input_size=(32,10,channels,samples), device="cpu")if __name__ == "__main__":main()
==========================================================================================Layer (type:depth-idx) Output ShapeParam #==========================================================================================Cascade_Conv_LSTM[32, 4] --├─Sequential: 1-1[320, 32, 10, 11] --│└─Conv2d: 2-1 [320, 32, 10, 11] 288│└─ELU: 2-2[320, 32, 10, 11] --├─Sequential: 1-2[320, 64, 10, 11] --│└─Conv2d: 2-3 [320, 64, 10, 11] 18,432│└─ELU: 2-4[320, 64, 10, 11] --├─Sequential: 1-3[320, 128, 10, 11]--│└─Conv2d: 2-5 [320, 128, 10, 11]73,728│└─ELU: 2-6[320, 128, 10, 11]--├─Sequential: 1-4[320, 1024] --│└─Linear: 2-7 [320, 1024] 14,418,944│└─ELU: 2-8[320, 1024] --├─Dropout: 1-5 [320, 1024] --├─LSTM: 1-6[32, 10, 1024]8,396,800├─LSTM: 1-7[32, 10, 1024]8,396,800├─Sequential: 1-8[32, 1024]--│└─Linear: 2-9 [32, 1024]1,049,600│└─ELU: 2-10 [32, 1024]--├─Dropout: 1-9 [32, 1024]--├─Linear: 1-10 [32, 4] 4,100├─Softmax: 1-11[32, 4] --==========================================================================================Total params: 32,358,692Trainable params: 32,358,692Non-trainable params: 0Total mult-adds (G): 13.28==========================================================================================Input size (MB): 0.14Forward/backward pass size (MB): 71.21Params size (MB): 129.43Estimated Total Size (MB): 200.78==========================================================================================