文章目录
- 一、概述
- 二、代码编写
- 1. 数据处理
- 2. 准备配置文件
- 3. 自定义DataSet和DataLoader
- 4. 构建模型
- 5. 训练模型
- 6. 编写预测模块
- 三、效果展示
- 四、源码地址
一、概述
本项目使用Pytroch
,并基于ResNet50
模型,实现了对天气图片的识别,过程详细,十分适合基础阶段的同学阅读。
项目目录结构:
核心步骤:
- 数据处理
- 准备配置文件
- 构建自定义
DataSet
及Dataloader
- 构建模型
- 训练模型
- 编写预测模块
- 效果展示
二、代码编写
1. 数据处理
本项目数据来源:
https://www.heywhale.com/mw/dataset/60d9bd7c056f570017c305ee/file
http://vcc.szu.edu.cn/research/2017/RSCM.html
由于数据是直接下载,且目录分的很规整,本项目的数据处理部分较为简单,直接手动复制,合并两个数据集即可。
数据概览:
总数据量约7万张。
2. 准备配置文件
配置文件的主要存储一些各个模块通用的一些全局变量,如各种文件的存放位置等等(本人Java程序员出身,一些Python的代码规范不太熟悉,望见谅)。
config.py
:
import timeimport torch# 项目配置文件class Common:'''通用配置'''basePath = "D:/Data/weather/source/all/"# 图片文件基本路径device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 设备配置imageSize = (224,224) # 图片大小labels = ["cloudy","haze","rainy","shine","snow","sunny","sunrise","thunder"] # 标签名称/文件夹名称class Train:'''训练相关配置'''batch_size = 128num_workers = 0# 对于Windows用户,这里应设置为0,否则会出现多线程错误lr = 0.001epochs = 40logDir = "./log/" + time.strftime('%Y-%m-%d-%H-%M-%S',time.gmtime()) # 日志存放位置modelDir = "./model/" # 模型存放位置
3. 自定义DataSet和DataLoader
dada_loader.py
# 自定义数据加载器import torchfrom torch import nnfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsfrom config import Commonfrom config import Trainimport osfrom PIL import Imageimport torch.utils.data as Dataimport numpy# 定义数据处理transformtransform = transforms.Compose([transforms.Resize(Common.imageSize),transforms.ToTensor()])def loadDataFromDir():'''从文件夹中获取数据'''images = []labels = []# 1. 获取根文件夹下所有分类文件夹for d in os.listdir(Common.basePath):for imagePath in os.listdir(Common.basePath + d):# 2. 获取某一类型下所有的图片名称# 3. 读取文件image = Image.open(Common.basePath + d + "/" + imagePath).convert('RGB')print("加载数据" + str(len(images)) + "条")# 4. 添加到图片列表中images.append(transform(image))# 5. 构造labelcategoryIndex = Common.labels.index(d)# 获取分类下标label = [0] * 8# 初始化labellabel[categoryIndex] = 1# 根据下标确定目标值label = torch.tensor(label,dtype=torch.float)# 转为tensor张量# 6. 添加到目标值列表labels.append(label)# 7. 关闭资源image.close()# 返回图片列表和目标值列表return images, labelsclass WeatherDataSet(Dataset):'''自定义DataSet'''def __init__(self):'''初始化DataSet:param transform: 自定义转换器'''images, labels = loadDataFromDir()# 在文件夹中加载图片self.images = imagesself.labels = labelsdef __len__(self):'''返回数据总长度:return:'''return len(self.images)def __getitem__(self, idx):image = self.images[idx]label = self.labels[idx]return image, labeldef splitData(dataset):'''分割数据集:param dataset::return:'''# 求解一下数据的总量total_length = len(dataset)# 确认一下将80%的数据作为训练集, 剩下的20%的数据作为测试集train_length = int(total_length * 0.8)validation_length = total_length - train_length# 利用Data.random_split()直接切分数据集, 按照80%, 20%的比例进行切分train_dataset,validation_dataset = Data.random_split(dataset=dataset, lengths=[train_length, validation_length])return train_dataset, validation_dataset# 1. 分割数据集train_dataset, validation_dataset = splitData(WeatherDataSet())# 2. 训练数据集加载器trainLoader = DataLoader(train_dataset, batch_size=Train.batch_size, shuffle=True, num_workers=Train.num_workers)# 3. 验证集数据加载器valLoader = DataLoader(validation_dataset, batch_size=Train.batch_size, shuffle=False, num_workers=Train.num_workers)
主要步骤:
- 读取图片使用的是Python自带的
PIL
库
PIL
教程:https://blog.csdn.net/weixin_43790276/article/details/108478270
- 由于使用的是残差网络,其图片尺寸必须是
3*224*224
,故需要使用Pytroch的transforms
工具进行处理
transforms
教程:https://blog.csdn.net/qq_38410428/article/details/94719553
- 自定义
DataSet
(继承DataSet类,并实现重写三个核心方法) - 分割数据
- 创建验证集和训练集各自的加载器
4. 构建模型
model.py
import torchfrom torch import nnimport torchvision.models as modelsfrom config import Common, Train# 引入rest50模型net = models.resnet50()net.load_state_dict(torch.load("./model/resnet50-11ad3fa6.pth"))class WeatherModel(nn.Module):def __init__(self, net):super(WeatherModel, self).__init__()# resnet50self.net = netself.relu = nn.ReLU()self.dropout = nn.Dropout(0.1)self.fc = nn.Linear(1000, 8)self.output = nn.Softmax(dim=1)def forward(self, x):x = self.net(x)x = self.relu(x)x = self.dropout(x)x = self.fc(x)x = self.output(x)return xmodel = WeatherModel(net)
主要步骤:
- 引入
Pytorch
官方的残差网络预训练模型
关于新版本的引入方法:https://blog.csdn.net/Sihang_Xie/article/details/125646287
- 添加自己的全连接输出层
- 创建模型
5. 训练模型
train.py
# 训练部分import timeimport torchfrom torch import nnimport matplotlib.pyplot as pltfrom torch.utils.tensorboard import SummaryWriterfrom config import Common, Trainfrom model import model as weatherModelfrom data_loader import trainLoader, valLoaderfrom torch import optim# 1. 获取模型model = weatherModelmodel.to(Common.device)# 2. 定义损失函数criterion = nn.CrossEntropyLoss()# 3. 定义优化器optimizer = optim.Adam(model.parameters(), lr=0.001)# 4. 创建writerwriter = SummaryWriter(log_dir=Train.logDir, flush_secs=500)def train(epoch):'''训练函数'''# 1. 获取dataLoaderloader = trainLoader# 2. 调整为训练状态model.train()print()print('========== Train Epoch:{} Start =========='.format(epoch))epochLoss = 0# 每个epoch的损失epochAcc = 0# 每个epoch的准确率correctNum = 0# 正确预测的数量for data, label in loader:data, label = data.to(Common.device), label.to(Common.device)# 加载到对应设备batchAcc = 0# 单批次正确率batchCorrectNum = 0# 单批次正确个数optimizer.zero_grad()# 清空梯度output = model(data)# 获取模型输出loss = criterion(output, label)# 计算损失loss.backward()# 反向传播梯度optimizer.step()# 更新参数epochLoss += loss.item() * data.size(0)# 计算损失之和# 计算正确预测的个数labels = torch.argmax(label, dim=1)outputs = torch.argmax(output, dim=1)for i in range(0, len(labels)):if labels[i] == outputs[i]:correctNum += 1batchCorrectNum += 1batchAcc = batchCorrectNum / data.size(0)print("Epoch:{}\t TrainBatchAcc:{}".format(epoch, batchAcc))epochLoss = epochLoss / len(trainLoader.dataset)# 平均损失epochAcc = correctNum / len(trainLoader.dataset)# 正确率print("Epoch:{}\t Loss:{} \t Acc:{}".format(epoch, epochLoss, epochAcc))writer.add_scalar("train_loss", epochLoss, epoch)# 写入日志writer.add_scalar("train_acc", epochAcc, epoch)# 写入日志return epochAccdef val(epoch):'''验证函数:param epoch: 轮次:return:'''# 1. 获取dataLoaderloader = valLoader# 2. 初始化损失、准确率列表valLoss = []valAcc = []# 3. 调整为验证状态model.eval()print()print('========== Val Epoch:{} Start =========='.format(epoch))epochLoss = 0# 每个epoch的损失epochAcc = 0# 每个epoch的准确率correctNum = 0# 正确预测的数量with torch.no_grad():for data, label in loader:data, label = data.to(Common.device), label.to(Common.device)# 加载到对应设备batchAcc = 0# 单批次正确率batchCorrectNum = 0# 单批次正确个数output = model(data)# 获取模型输出loss = criterion(output, label)# 计算损失epochLoss += loss.item() * data.size(0)# 计算损失之和# 计算正确预测的个数labels = torch.argmax(label, dim=1)outputs = torch.argmax(output, dim=1)for i in range(0, len(labels)):if labels[i] == outputs[i]:correctNum += 1batchCorrectNum += 1batchAcc = batchCorrectNum / data.size(0)print("Epoch:{}\t ValBatchAcc:{}".format(epoch, batchAcc))epochLoss = epochLoss / len(valLoader.dataset)# 平均损失epochAcc = correctNum / len(valLoader.dataset)# 正确率print("Epoch:{}\t Loss:{} \t Acc:{}".format(epoch, epochLoss, epochAcc))writer.add_scalar("val_loss", epochLoss, epoch)# 写入日志writer.add_scalar("val_acc", epochAcc, epoch)# 写入日志return epochAccif __name__ == '__main__':maxAcc = 0.75for epoch in range(1,Train.epochs + 1):trainAcc = train(epoch)valAcc = val(epoch)if valAcc > maxAcc:maxAcc = valAcc# 保存最大模型torch.save(model, Train.modelDir + "weather-" + time.strftime('%Y-%m-%d-%H-%M-%S', time.gmtime()) + ".pth")# 保存模型torch.save(model,Train.modelDir+"weather-"+time.strftime('%Y-%m-%d-%H-%M-%S',time.gmtime())+".pth")
主要步骤:
- 加载模型
- 准备损失函数及优化器
- 创建
tensorboard
的writer
关于
tensorboard
的使用:https://blog.csdn.net/weixin_43637851/article/details/116003280
- 编写训练函数及验证函数,同时记录损失和正确率
验证函数和训练函数的区别就是是否需要更新参数
- 循环训练
epochs
次,不断保存正确率最大的模型,以及最后一次的训练模型 - 开始训练
- 不断调参(我就只训了3次),知道有一个比较满意的效果
训练过程中电脑的状态:
查看训练日志(tensorboard):
保存的模型:
6. 编写预测模块
pridect.py
import torchimport torchvision.transforms as transformsfrom PIL import Imagefrom config import Commondef pridect(imagePath, modelPath):'''预测函数:param imagePath: 图片路径:param modelPath: 模型路径:return:'''# 1. 读取图片image = Image.open(imagePath)# 2. 进行缩放image = image.resize(Common.imageSize)image.show()# 3. 加载模型model = torch.load(modelPath)model = model.to(Common.device)# 4. 转为tensor张量transform = transforms.ToTensor()x = transform(image)x = torch.unsqueeze(x, 0)# 升维x = x.to(Common.device)# 5. 传入模型output = model(x)# 6. 使用argmax选出最有可能的结果output = torch.argmax(output)print("预测结果:",Common.labels[output.item()])if __name__ == '__main__':pridect("D:/Download/76ee4c5e833499949eac41561dcb487d.jpeg","./model/weather-2022-10-14-07-36-57.pth")
三、效果展示
去网上随便找的图片:
四、源码地址
https://github.com/mengxianglong123/weather-recognition
欢迎交流学习