UNet – 预测数据predict(多个图像的分割)

目录

1. 介绍

2. predict 预测分割图片

3. 结果展示

4. 完整代码


1. 介绍

之前已经将unet的网络模块、dataset数据加载和train训练数据已经解决了,这次要将unet网络去分割图像,下面是之前的链接

unet 网络:UNet – unet网络

dataset 数据处理:UNet – 数据加载 Dataset

train 网络训练:UNet – 训练数据train

待分割的图像如下:

图片[1] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

存放的路径在U-net项目的predict里面

图片[2] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

我们的目标是将predict里面所有的图片分割出来,按照名称顺序保存在result文件夹里面:

图片[3] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

2. predict 预测分割图片

首先定义图片的预处理,按照dataset里面相同的方式进行预处理

图片[4] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

然后是加载网络的模型和网络参数

图片[5] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

然后加载predict里面所有待处理图片的路径

需要注意的是,os.listdir 加载的只是里面每个图片,并不是图片的具体路径。tests_path 里面的内容如下面的注释所示:

图片[6] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

接下来就可以分割图片了

因为tests_path 里面每个文件是 x.png 即文件名+后缀的方式。通过split的 ‘.’ 分割成x和后缀名png的形式,[-2]代表取倒数第二个值,就可以将每个文件名x取出来,然后将路径拼接就可以存放到result里面

open图像的时候,也要注意,test_path 只是遍历tests_path 里面的文件,需要加上之前的predict路径才能正确的读取到每个待分割的图片

因为这里处理图像会改变size成480*480的形式,想要将输出的结果保持不变的话,在网络预测前将图像的大小保存下来就可以了。(注:这里的size和opencv里面的shape返回值是反过来的

这里不清楚的可以通过调试,打印每个变量的内容看一下就可以了

图片[7] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

接下来就是网络预测的部分,这里输出的size是(batch,channel,height,width),因为这里的batch是1,channel 灰度图片因此也是1,这里通过squeeze将1的维度删去,只需要图像的大小

下面是squeeze的用法图片[8] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

然后图像保存的话,要转到cpu上面,这一步不知道为啥,但是不加这一步会报错

图片[9] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

最后就是保存图像了,将网络的结果二值化后,还原图像再保存就可以了

图片[10] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

3. 结果展示

predict里面待预测的图片

图片[11] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

result 里面分割好的图片

图片[12] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

下面是 参考文章 博主的分割结果

图片[13] - UNet – 预测数据predict(多个图像的分割) - MaxSSL

对比发现,有些小的细节会丢失,但是大概的轮廓分割出来了

4. 完整代码

完整的项目可以在 这里 下载

import numpy as npimport torchimport cv2from model import UNetfrom torchvision import transformsfrom PIL import Imageimport os# 预处理transform = transforms.Compose([transforms.Resize((480,480)),# 缩放图像transforms.ToTensor(),])# 加载模型device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = UNet(in_channels=1, num_classes=1)net.load_state_dict(torch.load('Unet.pth', map_location=device))net.to(device)# 测试模式net.eval()# 读取所有图片路径tests_path = os.listdir('./predict/') # 获取 './predict/' 路径下所有文件,这里的路径只是里面文件的路径''''print(tests_path)['0.png', '1.png', '10.png', '11.png', '12.png', '13.png', '14.png', '15.png', '16.png', '17.png', '18.png', '19.png', '2.png', '20.png', '21.png', '22.png', '23.png', '24.png', '25.png', '26.png', '27.png', '28.png', '29.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png', '9.png']'''with torch.no_grad(): # 预测的时候不需要计算梯度for test_path in tests_path:# 遍历每个predict的文件save_pre_path = './result/'+test_path.split('.')[-2] + '_res.png'# 将保存的路径按照原图像的后缀,按照数字排序保存img = Image.open('./predict/' +test_path) # 预测图片的路径width,height = img.size[0],img.size[1]# 保存图像的大小img = transform(img)img = torch.unsqueeze(img,dim = 0)# 扩展图像的维度pred = net(img.to(device))# 网络预测pred = torch.squeeze(pred)# 将(batch、channel)维度去掉pred = np.array(pred.data.cpu())# 保存图片需要转为cpu处理pred[pred >= 0] = 255 # 处理结果二值化pred[pred < 0] = 0pred = np.uint8(pred) # 转为图片的形式pred = cv2.resize(pred,(width,height),cv2.INTER_CUBIC)# 还原图像的sizecv2.imwrite(save_pre_path, pred)# 保存图片

© 版权声明
THE END
喜欢就支持一下吧
点赞0 分享