文章目录
- model.py
- main.py
- 参数设置
- 注意事项
- 初始化权重
- 如果发现loss和acc不变
- 关于数据下载
- 关于输出格式
- 运行图
model.py
import torch.nn as nnimport torch.nn.functional as Fimport torch.nn.init as initclass MLP_cls(nn.Module):def __init__(self,in_dim=28*28):super(MLP_cls,self).__init__()self.lin1 = nn.Linear(in_dim,128)self.lin2 = nn.Linear(128,64)self.lin3 = nn.Linear(64,10)self.relu = nn.ReLU()init.xavier_uniform_(self.lin1.weight)init.xavier_uniform_(self.lin2.weight)init.xavier_uniform_(self.lin3.weight)def forward(self,x):x = x.view(-1,28*28)x = self.lin1(x)x = self.relu(x)x = self.lin2(x)x = self.relu(x)x = self.lin3(x)x = self.relu(x)return x
main.py
import torchimport torch.nn as nnimport torchvisionfrom torch.utils.data import DataLoaderimport torch.optim as optimfrom model import MLP_clsseed = 42torch.manual_seed(seed)batch_size_train = 64batch_size_test= 64epochs = 10learning_rate = 0.01momentum = 0.5mlp_net = MLP_cls()train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.5,), (0.5,)) ])),batch_size=batch_size_train, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.5,), (0.5,)) ])),batch_size=batch_size_test, shuffle=True)optimizer = optim.SGD(mlp_net.parameters(), lr=learning_rate,momentum=momentum)criterion = nn.CrossEntropyLoss()print("****************Begin Training****************")mlp_net.train()for epoch in range(epochs):run_loss = 0correct_num = 0for batch_idx, (data, target) in enumerate(train_loader):out = mlp_net(data)_,pred = torch.max(out,dim=1)optimizer.zero_grad()loss = criterion(out,target)loss.backward()run_loss += lossoptimizer.step()correct_num+= torch.sum(pred==target)print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))print("****************Begin Testing****************")mlp_net.eval()test_loss = 0test_correct_num = 0for batch_idx, (data, target) in enumerate(test_loader):out = mlp_net(data)_,pred = torch.max(out,dim=1)test_loss += criterion(out,target)test_correct_num+= torch.sum(pred==target)print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))
参数设置
'./data/' #数据保存路径seed = 42 #随机种子batch_size_train = 64batch_size_test= 64epochs = 10optim --> SGDlearning_rate = 0.01momentum = 0.5
注意事项
初始化权重
这里使用这种方式
init.xavier_uniform_(self.lin1.weight)init.xavier_uniform_(self.lin2.weight)init.xavier_uniform_(self.lin3.weight)
如果发现loss和acc不变
检查一下是不是忘记写optimizer.step()了
关于数据下载
数据在download=True时,会下载在./data文件夹下
关于输出格式
这里用‘xxx {:.2f}’.format(xxx),保留两位小数。注意中间的空格,区分:.2f和%2f