bug:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
源代码如下:
if __name__ == "__main__":from torchsummary import summarymodel = UNet()print(model)summary(model, input_size=(1, 480, 480))
在使用torchsummary可视化模型时候报错,报这个错误是因为类型不匹配,根据报错内容可以看出Input type为torch.FloatTensor(CPU数据类型),而weight type(即网络权重参数这些)为torch.cuda.FloatTensor(GPU数据类型)。
我们将model传到GPU上便可。将代码如下修改便可正常运行:
if __name__ == "__main__":from torchsummary import summarydevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = UNet().to(device)# modifyprint(model)summary(model, input_size=(1, 480, 480))