【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存

一、Vision Transformer介绍

Transformer的核心是 “自注意力” 机制。

论文地址:https://arxiv.org/pdf/2010.11929.pdf

自注意力(self-attention)相比 卷积神经网络循环神经网络 同时具有并行计算和最短的最大路径⻓度这两个优势。因此,使用自注意力来设计深度架构是很有吸引力的。对比之前仍然依赖循环神经网络实现输入表示的自注意力模型 [Cheng et al., 2016,Lin et al., 2017b, Paulus et al., 2017],transformer模型完全基于注意力机制,没有任何卷积层或循环神经网络层 [Vaswani et al., 2017]。尽管transformer最初是应用于在文本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语言、视觉、语音和强化学习领域。

图片[1] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

17年发布时主要应用于不同语言之间翻译功能的实现。而在后来,有关研究发现Transformer应用于计算机视觉CV方面有着不输于卷积神经网络的强劲性能,一定程度上甚至比卷积神经网络更强。于是,初代Vision Transformer诞生了, 简称Vit。

图片[2] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

Vision Transformer和Transformer区别是什么?用最最最简单的理解方式来看,Transformer的工作就是把一句话从一种语言翻译成另一种语言。主要是通过是将待翻译的一句话拆分为 多个单词 或者 多个模块,进行编码和解码训练,再评估那个单词对应的意思得分高就是相应的翻译结果。

而Vision Transformer则是将一个图片抽象地看做翻译中一个句子,通过图像分割将其拆分为多个模块,再进行编码和解码训练,评估中得分高的选项便是预测的结果。(纯属个人理解,如有错误,欢迎批评指正)

图片[3] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

二、数据集

我的数据集为植物叶片病害的无标注数据集,共有三种类型。

{"0": "Huanglong_disease","1": "Magnesium_deficiency","2": "Normal"}

其中train : val : test = 8 : 1 : 1,种类都是三种,只是数量不一样。

train├── Huanglong_disease│├── 000000.jpg│├── 000001.jpg│├── 000002.jpg│├── .............│├── 000607.jpg├── Magnesium_deficiency└── Normal

大概长这样:

图片[4] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

三、实战代码

1.vit_model.py

"""original code from rwightman:https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py"""from functools import partialfrom collections import OrderedDictimport torchimport torch.nn as nndef drop_path(x, drop_prob: float = 0., training: bool = False):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted forchanging the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use'survival rate' as the argument."""if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)# work with diff dim tensors, not just 2D ConvNetsrandom_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_()# binarizeoutput = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Module):"""Drop paths (Stochastic Depth) per sample(when applied in main path of residual blocks)."""def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)class PatchEmbed(nn.Module):"""2D Image to Patch Embedding"""def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):super().__init__()img_size = (img_size, img_size)patch_size = (patch_size, patch_size)self.img_size = img_sizeself.patch_size = patch_sizeself.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])self.num_patches = self.grid_size[0] * self.grid_size[1]self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):B, C, H, W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."# flatten: [B, C, H, W] -> [B, C, HW]# transpose: [B, C, HW] -> [B, HW, C]x = self.proj(x).flatten(2).transpose(1, 2)x = self.norm(x)return xclass Attention(nn.Module):def __init__(self, dim, # 输入token的dim num_heads=8, qkv_bias=False, qk_scale=None, attn_drop_ratio=0., proj_drop_ratio=0.):super(Attention, self).__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop_ratio)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop_ratio)def forward(self, x):# [batch_size, num_patches + 1, total_embed_dim]B, N, C = x.shape# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)# [batch_size, num_heads, num_patches + 1, embed_dim_per_head]q, k, v = qkv[0], qkv[1], qkv[2]# make torchscript happy (cannot use tensor as tuple)# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]# reshape: -> [batch_size, num_patches + 1, total_embed_dim]x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass Mlp(nn.Module):"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super(Block, self).__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout hereself.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None):"""Args:img_size (int, tuple): input image sizepatch_size (int, tuple): patch sizein_c (int): number of input channelsnum_classes (int): number of classes for classification headembed_dim (int): embedding dimensiondepth (int): depth of transformernum_heads (int): number of attention headsmlp_ratio (int): ratio of mlp hidden dim to embedding dimqkv_bias (bool): enable bias for qkv if Trueqk_scale (float): override default qk scale of head_dim ** -0.5 if setrepresentation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if setdistilled (bool): model includes a distillation token and head as in DeiT modelsdrop_ratio (float): dropout rateattn_drop_ratio (float): attention dropout ratedrop_path_ratio (float): stochastic depth rateembed_layer (nn.Module): patch embedding layernorm_layer: (nn.Module): normalization layer"""super(VisionTransformer, self).__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dim# num_features for consistency with other modelsself.num_tokens = 2 if distilled else 1norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)act_layer = act_layer or nn.GELUself.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)num_patches = self.patch_embed.num_patchesself.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else Noneself.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))self.pos_drop = nn.Dropout(p=drop_ratio)dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]# stochastic depth decay ruleself.blocks = nn.Sequential(*[Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],norm_layer=norm_layer, act_layer=act_layer)for i in range(depth)])self.norm = norm_layer(embed_dim)# Representation layerif representation_size and not distilled:self.has_logits = Trueself.num_features = representation_sizeself.pre_logits = nn.Sequential(OrderedDict([("fc", nn.Linear(embed_dim, representation_size)),("act", nn.Tanh())]))else:self.has_logits = Falseself.pre_logits = nn.Identity()# Classifier head(s)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self.head_dist = Noneif distilled:self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()# Weight initnn.init.trunc_normal_(self.pos_embed, std=0.02)if self.dist_token is not None:nn.init.trunc_normal_(self.dist_token, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)self.apply(_init_vit_weights)def forward_features(self, x):# [B, C, H, W] -> [B, num_patches, embed_dim]x = self.patch_embed(x)# [B, 196, 768]# [1, 1, 768] -> [B, 1, 768]cls_token = self.cls_token.expand(x.shape[0], -1, -1)if self.dist_token is None:x = torch.cat((cls_token, x), dim=1)# [B, 197, 768]else:x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)x = self.pos_drop(x + self.pos_embed)x = self.blocks(x)x = self.norm(x)if self.dist_token is None:return self.pre_logits(x[:, 0])else:return x[:, 0], x[:, 1]def forward(self, x):x = self.forward_features(x)if self.head_dist is not None:x, x_dist = self.head(x[0]), self.head_dist(x[1])if self.training and not torch.jit.is_scripting():# during inference, return the average of both classifier predictionsreturn x, x_distelse:return (x + x_dist) / 2else:x = self.head(x)return xdef _init_vit_weights(m):"""ViT weight initialization:param m: module"""if isinstance(m, nn.Linear):nn.init.trunc_normal_(m.weight, std=.01)if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out")if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.LayerNorm):nn.init.zeros_(m.bias)nn.init.ones_(m.weight)def vit_base_patch16_224(num_classes: int = 1000):"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA密码: eu9f"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=768,depth=12,num_heads=12,representation_size=None,num_classes=num_classes)return modeldef vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=768,depth=12,num_heads=12,representation_size=768 if has_logits else None,num_classes=num_classes)return modeldef vit_base_patch32_224(num_classes: int = 1000):"""ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg密码: s5hl"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=768,depth=12,num_heads=12,representation_size=None,num_classes=num_classes)return modeldef vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=768,depth=12,num_heads=12,representation_size=768 if has_logits else None,num_classes=num_classes)return modeldef vit_large_patch16_224(num_classes: int = 1000):"""ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ密码: qqt8"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=1024,depth=24,num_heads=16,representation_size=None,num_classes=num_classes)return modeldef vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth"""model = VisionTransformer(img_size=224,patch_size=16,embed_dim=1024,depth=24,num_heads=16,representation_size=1024 if has_logits else None,num_classes=num_classes)return modeldef vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.weights ported from official Google JAX impl:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth"""model = VisionTransformer(img_size=224,patch_size=32,embed_dim=1024,depth=24,num_heads=16,representation_size=1024 if has_logits else None,num_classes=num_classes)return modeldef vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):"""ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.NOTE: converted weights not currently available, too large for github release hosting."""model = VisionTransformer(img_size=224,patch_size=14,embed_dim=1280,depth=32,num_heads=16,representation_size=1280 if has_logits else None,num_classes=num_classes)return model

2.utils.py

import osimport sysimport jsonimport pickleimport randomimport torchfrom tqdm import tqdmimport matplotlib.pyplot as pltdef read_split_data(root: str, val_rate: float = 0.2):random.seed(0)# 保证随机结果可复现assert os.path.exists(root), "dataset root: {} does not exist.".format(root)# 遍历文件夹,一个文件夹对应一个类别flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]# 排序,保证顺序一致flower_class.sort()# 生成类别名称以及对应的数字索引class_indices = dict((k, v) for v, k in enumerate(flower_class))json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)train_images_path = []# 存储训练集的所有图片路径train_images_label = []# 存储训练集图片对应索引信息val_images_path = []# 存储验证集的所有图片路径val_images_label = []# 存储验证集图片对应索引信息every_class_num = []# 存储每个类别的样本总数supported = [".jpg", ".JPG", ".png", ".PNG"]# 支持的文件后缀类型# 遍历每个文件夹下的文件for cla in flower_class:cla_path = os.path.join(root, cla)# 遍历获取supported支持的所有文件路径images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)if os.path.splitext(i)[-1] in supported]# 获取该类别对应的索引image_class = class_indices[cla]# 记录该类别的样本数量every_class_num.append(len(images))# 按比例随机采样验证样本val_path = random.sample(images, k=int(len(images) * val_rate))for img_path in images:if img_path in val_path:# 如果该路径在采样的验证集样本中则存入验证集val_images_path.append(img_path)val_images_label.append(image_class)else:# 否则存入训练集train_images_path.append(img_path)train_images_label.append(image_class)print("{} images were found in the dataset.".format(sum(every_class_num)))print("{} images for training.".format(len(train_images_path)))print("{} images for validation.".format(len(val_images_path)))plot_image = Falseif plot_image:# 绘制每种类别个数柱状图plt.bar(range(len(flower_class)), every_class_num, align='center')# 将横坐标0,1,2,3,4替换为相应的类别名称plt.xticks(range(len(flower_class)), flower_class)# 在柱状图上添加数值标签for i, v in enumerate(every_class_num):plt.text(x=i, y=v + 5, s=str(v), ha='center')# 设置x坐标plt.xlabel('image class')# 设置y坐标plt.ylabel('number of images')# 设置柱状图的标题plt.title('flower class distribution')plt.show()return train_images_path, train_images_label, val_images_path, val_images_labeldef plot_data_loader_image(data_loader):batch_size = data_loader.batch_sizeplot_num = min(batch_size, 4)json_path = './class_indices.json'assert os.path.exists(json_path), json_path + " does not exist."json_file = open(json_path, 'r')class_indices = json.load(json_file)for data in data_loader:images, labels = datafor i in range(plot_num):# [C, H, W] -> [H, W, C]img = images[i].numpy().transpose(1, 2, 0)# 反Normalize操作img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255label = labels[i].item()plt.subplot(1, plot_num, i+1)plt.xlabel(class_indices[str(label)])plt.xticks([])# 去掉x轴的刻度plt.yticks([])# 去掉y轴的刻度plt.imshow(img.astype('uint8'))plt.show()def write_pickle(list_info: list, file_name: str):with open(file_name, 'wb') as f:pickle.dump(list_info, f)def read_pickle(file_name: str) -> list:with open(file_name, 'rb') as f:info_list = pickle.load(f)return info_listdef train_one_epoch(model, optimizer, data_loader, device, epoch):model.train()loss_function = torch.nn.CrossEntropyLoss()accu_loss = torch.zeros(1).to(device)# 累计损失accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数optimizer.zero_grad()sample_num = 0data_loader = tqdm(data_loader, file=sys.stdout)for step, data in enumerate(data_loader):images, labels = datasample_num += images.shape[0]pred = model(images.to(device))pred_classes = torch.max(pred, dim=1)[1]accu_num += torch.eq(pred_classes, labels.to(device)).sum()loss = loss_function(pred, labels.to(device))loss.backward()accu_loss += loss.detach()data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, accu_loss.item() / (step + 1), accu_num.item() / sample_num)if not torch.isfinite(loss):print('WARNING: non-finite loss, ending training ', loss)sys.exit(1)optimizer.step()optimizer.zero_grad()return accu_loss.item() / (step + 1), accu_num.item() / sample_num@torch.no_grad()def evaluate(model, data_loader, device, epoch):loss_function = torch.nn.CrossEntropyLoss()model.eval()accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数accu_loss = torch.zeros(1).to(device)# 累计损失sample_num = 0data_loader = tqdm(data_loader, file=sys.stdout)for step, data in enumerate(data_loader):images, labels = datasample_num += images.shape[0]pred = model(images.to(device))pred_classes = torch.max(pred, dim=1)[1]accu_num += torch.eq(pred_classes, labels.to(device)).sum()loss = loss_function(pred, labels.to(device))accu_loss += lossdata_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, accu_loss.item() / (step + 1), accu_num.item() / sample_num)return accu_loss.item() / (step + 1), accu_num.item() / sample_num

3.my_dataset.py

from PIL import Imageimport torchfrom torch.utils.data import Datasetclass MyDataSet(Dataset):"""自定义数据集"""def __init__(self, images_path: list, images_class: list, transform=None):self.images_path = images_pathself.images_class = images_classself.transform = transformdef __len__(self):return len(self.images_path)def __getitem__(self, item):img = Image.open(self.images_path[item])# RGB为彩色图片,L为灰度图片if img.mode != 'RGB':raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))label = self.images_class[item]if self.transform is not None:img = self.transform(img)return img, label@staticmethoddef collate_fn(batch):# 官方实现的default_collate可以参考# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.pyimages, labels = tuple(zip(*batch))images = torch.stack(images, dim=0)labels = torch.as_tensor(labels)return images, labels

4.train.py

其中若使用预训练模型需要提前下载,下载地址在 utils.py 处有标明,代码默认是使用预训练模型的。下载后,预训练模型放入项目的根目录即可。我训练的数据集种类有三种,于是我将网络的全连接层的输出改成了 3 ,各位需要依据自己数据集不同来进行调整。

若下载不方便,也可以下载我上传的资源:

vit_base_patch16_224_in21k.zip-深度学习文档类资源-CSDN下载

import osimport mathimport argparseimport torchimport torch.optim as optimimport torch.optim.lr_scheduler as lr_schedulerfrom torch.utils.tensorboard import SummaryWriterfrom torchvision import transformsfrom my_dataset import MyDataSetfrom vit_model import vit_base_patch16_224_in21k as create_modelfrom utils import read_split_data, train_one_epoch, evaluateimport xlwtbook = xlwt.Workbook(encoding='utf-8') #创建Workbook,相当于创建Excel# 创建sheet,Sheet1为表的名字,cell_overwrite_ok为是否覆盖单元格sheet1 = book.add_sheet(u'Train_data', cell_overwrite_ok=True)# 向表中添加数据sheet1.write(0, 0, 'epoch')sheet1.write(0, 1, 'Train_Loss')sheet1.write(0, 2, 'Train_Acc')sheet1.write(0, 3, 'Val_Loss')sheet1.write(0, 4, 'Val_Acc')sheet1.write(0, 5, 'lr')sheet1.write(0, 6, 'Best val Acc')def main(args):best_acc = 0device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("./weights") is False:os.makedirs("./weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),"val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])# number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=nw, collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=nw, collate_fn=val_dataset.collate_fn)model = create_model(num_classes=3, has_logits=False).to(device)images = torch.zeros(1, 3, 224, 224).to(device)#要求大小与输入图片的大小一致tb_writer.add_graph(model, images, verbose=False)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)# 删除不需要的权重del_keys = ['head.weight', 'head.bias'] if model.has_logits \else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']for k in del_keys:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head, pre_logits外,其他权重全部冻结if "head" not in name and "pre_logits" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf# cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):sheet1.write(epoch+1, 0, epoch+1)sheet1.write(epoch + 1, 5, str(optimizer.state_dict()['param_groups'][0]['lr']))# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()sheet1.write(epoch + 1, 1, str(train_loss))sheet1.write(epoch + 1, 2, str(train_acc))# validateval_loss, val_acc = evaluate(model=model, data_loader=val_loader, device=device, epoch=epoch)sheet1.write(epoch + 1, 3, str(val_loss))sheet1.write(epoch + 1, 4, str(val_acc))tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)if val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), "./weights/best_model.pth")#torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))sheet1.write(1, 6, str(best_acc))book.save('.\Train_data.xlsx')print("The Best Acc = : {:.4f}".format(best_acc))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=3)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=8)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录parser.add_argument('--data-path', type=str,default=r"D:\pyCharmdata\resnet50_plant_3\datasets\train")parser.add_argument('--model-name', default='', help='create model name')# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='./vit_base_patch16_224_in21k.pth',help='initial weights path')# 是否冻结权重parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

5.predict.py

可以实现单张图片的种类预测,得分最高的便是模型预测种类。

import osimport jsonimport torchfrom PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltfrom vit_model import vit_base_patch16_224_in21k as create_modeldef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])# load imageimg_path = r"D:\pyCharmdata\resnet50_plant_3\datasets\test\Huanglong_disease\000000.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = create_model(num_classes=3, has_logits=False).to(device)# load model weightsmodel_weight_path = "./weights/best_model.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

预测结果展示:

图片[5] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

四、训练数据

在配置好环境和数据集、预训练模型的路径后,即可运行 train.py 开始训练,默认是训练100轮。

训练使用的是SGDM优化器,初始学习率为0.001,使用LambdaLR自定义学习率调整策略,导入预训练模型但不冻结网络层和参数。

图片[6] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

训练过程中可以在项目路径下的终端 输入:

tensorboard --logdir=runs/

进行实时监控训练进程,也可以查看 Vision Transformer 的网络可视化结构。

图片[7] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

Vision Transformer 的网络可视化 :图片[8] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

我简单训练了100轮后,最高 val_acc 准确率为 0.9976。

图片[9] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

训练结束后,会在项目根目录生成一个Excel文件,里面记载了训练全过程的数据,你也可以在通过 Matlab 来获得高度自定义化的可视化对比图片,堪称 论文人 的福音。

我这里只展示前10轮的训练数据。

图片[10] - 【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 - MaxSSL

我的完整项目框架,有需要的自取:

Vit_myself.zip-深度学习文档类资源-CSDN下载

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

如果本文对你有帮助,欢迎一键三连!!!

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享