PyTorch深度学习实战(37)——CycleGAN详解与实现
- 0. 前言
- 1. CycleGAN 基本原理
- 2. CycleGAN 模型分析
- 3. 实现 CycleGAN
- 小结
- 系列链接
0. 前言
CycleGAN
是一种用于图像转换的生成对抗网络(Generative Adversarial Network, GAN),可以在不需要配对数据的情况下将一种风格的图像转换成另一种风格,而无需为每一对输入-输出图像配对训练数据。CycleGAN
的核心思想是利用两个生成器和两个判别器,它们共同学习两个域之间的映射关系。例如,将马的图像转换成斑马的图像,或者将夏天的风景转换成冬天的风景。在本节中,我们将学习 CycleGAN
的基本原理,并实现该模型用于将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像。
1. CycleGAN 基本原理
CycleGAN
是一种无需配对的图像转换技术,它可以将一个图像域中的图像转换为另一个图像域中的图像,而不需要匹配这两个域中的图像。它使用两个生成器和两个判别器,其中一个生成器将一个域中的图像转换为另一个域中的图像,而第二个生成器将其转换回来。这个过程被称为循环一致性,转换过程是可逆的。
CycleGAN
可以用于执行从一个类别到另一个类别的图像转换,而无需提供相匹配的输入-输出图像对来训练模型,只需要在两个不同的文件夹中提供这两个类别的图像。在本节中,我们将学习如何训练 CycleGAN
将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像,CycleGAN
中的 Cycle
是指将图像从一个类别转换到另一个类别,然后再转换回原始类别的过程。
在 CycleGAN
中,需要使用三种不同的损失值:
- 鉴别器损失:用于区分真实图像和伪造图像
- 循环一致性损失:由于
CycleGAN
使用了两个生成器,因此需要确保转换是可逆的,循环一致性损失通过将转换过的图像再次传递到原始的生成器中,并将生成的图像与原始图像进行比较来实现 - 恒等损失 (
Identity loss
):确保生成器在不进行转换的情况下仍然能够生成与原始图像相似的图像,通过将原始图像传递到生成器中,并计算生成图像与原始图像之间的差异
2. CycleGAN 模型分析
CycleGAN
模型构建策略如下:
- 导入数据集并进行预处理
- 定义
UNet
架构用于构建生成器和判别器网络 - 定义两个生成器:
G_AB
:将类别A
图像转换为类别B
图像的生成器G_BA
:将类别B
图像转换为类别A
图像的生成器
- 定义恒等损失:
- 如果将一张橘子的图像输入到橙子生成器,理想情况下,如果生成器完全理解橙子的所有信息,它不应该改变图像,而应该“生成”完全相同的图像,据此,我们可以创建一个恒等变换
- 当类别
A
(real_A
) 的图像通过G_BA
并与real_A
进行比较时,恒等损失应该是最小的 - 当类别
B
(real_B
) 的图像通过G_AB
并与real_B
进行比较时,恒等损失应该是最小的
- 定义GAN损失:
real_A
和fake_A
的判别器和生成器损失(当real_B
图像通过G_BA
时得到fake_A
)real_B
和fake_B
的判别器和生成器损失(当real_A
图像通过G_AB
时得到fake_B
)
- 定义循环一致性损失:
- 一张苹果图像需要通过橙子生成网络进行转换,生成伪造的橘子图像,然后再通过苹果生成网络将伪造的橙子图像转换回苹果图像
fake_B
是real_A
通过G_AB
时的输出,当fake_B
通过G_BA
时应该重新生成real_A
fake_A
是real_B
通过G_BA
时的输出,当fake_A
通过G_AB
时应该重新生成real_B
- 优化三个损失函数的加权和
3. 实现 CycleGAN
接下来,我们使用 PyTorch
实现 CycleGAN
模型,用以将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像。
(1) 导入相关数据集和库。
首先下载并解压数据集,可以自行构建数据集,也可以下载本文所用数据集,下载地址:https://pan.baidu.com/s/1iTOt2NsUQ1a3taUHjvkjfA,提取码:iuqf
。
可视化示例图像如下:
与 Pix2Pix 训练数据集不同,苹果和橙色图像之间不存在一一对应的关系。
导入所需的库:
import torchfrom torch import nnfrom torch import optimfrom matplotlib import pyplot as pltimport numpy as npfrom torchvision.utils import make_gridfrom torch.utils.data import DataLoader, Datasetimport cv2import randomfrom glob import globfrom PIL import Imageimport itertoolsfrom torchvision import transformsdevice = "cuda" if torch.cuda.is_available() else "cpu"
(2) 定义图像转换管道 transform
:
IMAGE_SIZE = 256device = 'cuda' if torch.cuda.is_available() else 'cpu'transform = transforms.Compose([transforms.Resize(int(IMAGE_SIZE*1.33)),transforms.RandomCrop((IMAGE_SIZE,IMAGE_SIZE)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
(3) 定义数据集类 CycleGANDataset
,以苹果图像 apple
和橙子图像 orange
文件夹为输入,提供批数据:
class CycleGANDataset(Dataset):def __init__(self, apples, oranges):self.apples = glob(apples)self.oranges = glob(oranges)def __getitem__(self, ix):apple = self.apples[ix % len(self.apples)]orange = random.choice(self.oranges)apple = Image.open(apple).convert('RGB')orange = Image.open(orange).convert('RGB')return apple, orangedef __len__(self):return max(len(self.apples), len(self.oranges))def choose(self):return self[random.randint(len(self))]def collate_fn(self, batch):srcs, trgs = list(zip(*batch))srcs = torch.cat([transform(img)[None] for img in srcs], 0).to(device).float()trgs = torch.cat([transform(img)[None] for img in trgs], 0).to(device).float()return srcs.to(device), trgs.to(device)
(4) 定义训练、验证数据集和数据加载器:
trn_ds = CycleGANDataset('apples_oranges/apples_train/*.jpg', 'apples_oranges/oranges_train/*.jpg')val_ds = CycleGANDataset('apples_oranges/apples_test/*.jpg', 'apples_oranges/oranges_test/*.jpg')trn_dl = DataLoader(trn_ds, batch_size=1, shuffle=True, collate_fn=trn_ds.collate_fn)val_dl = DataLoader(val_ds, batch_size=5, shuffle=True, collate_fn=val_ds.collate_fn)
(5) 定义网络的权重初始化方法 weights_init_normal
:
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)if hasattr(m, "bias") and m.bias is not None:torch.nn.init.constant_(m.bias.data, 0.0)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)
(6) 定义残差块 ResidualBlock
:
class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()self.block = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),nn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),)def forward(self, x):return x + self.block(x)
(7) 定义生成器 GeneratorResNet
:
class GeneratorResNet(nn.Module):def __init__(self, num_residual_blocks=9):super(GeneratorResNet, self).__init__()out_features = 64channels = 3model = [nn.ReflectionPad2d(3),nn.Conv2d(channels, out_features, 7),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Downsamplingfor _ in range(2):out_features *= 2model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Residual blocksfor _ in range(num_residual_blocks):model += [ResidualBlock(out_features)]# Upsamplingfor _ in range(2):out_features //= 2model += [nn.Upsample(scale_factor=2),nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Output layermodel += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]self.model = nn.Sequential(*model)self.apply(weights_init_normal)def forward(self, x):return self.model(x)
(8) 定义判别器 Discriminator
:
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()channels, height, width = 3, IMAGE_SIZE, IMAGE_SIZEdef discriminator_block(in_filters, out_filters, normalize=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]if normalize:layers.append(nn.InstanceNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(channels, 64, normalize=False),*discriminator_block(64, 128),*discriminator_block(128, 256),*discriminator_block(256, 512),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(512, 1, 4, padding=1))self.apply(weights_init_normal)def forward(self, img):return self.model(img)
(9) 定义图像样本生成函数 generate_sample
:
@torch.no_grad()def generate_sample(G_AB, G_BA):data = next(iter(val_dl))G_AB.eval()G_BA.eval()real_A, real_B = datafake_B = G_AB(real_A)fake_A = G_BA(real_B)# Arange images along x-axisreal_A = make_grid(real_A, nrow=5, normalize=True)real_B = make_grid(real_B, nrow=5, normalize=True)fake_A = make_grid(fake_A, nrow=5, normalize=True)fake_B = make_grid(fake_B, nrow=5, normalize=True)# Arange images along y-axisimage_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)plt.imshow(image_grid.detach().cpu().permute(1,2,0).numpy())plt.show()
(10) 定义生成器训练函数 generator_train_step
。
该函数将两个生成器( G_AB
和 G_BA
)、优化器和两个类别的真实图像( real_A
和 real_B
)作为输入:
def generator_train_step(Gs, optimizer, real_A, real_B, D_A, D_B, criterion_identity, criterion_cycle, criterion_GAN, lambda_cyc, lambda_id):
指定生成器:
G_AB, G_BA = Gs
将优化器的梯度设置为零:
optimizer.zero_grad()
计算类别 A
(苹果)和类别 B
(橙子)图像的恒等损失 (loss_identity
):
loss_id_A = criterion_identity(G_BA(real_A), real_A)loss_id_B = criterion_identity(G_AB(real_B), real_B)loss_identity = (loss_id_A + loss_id_B) / 2
计算图像通过生成器时的 GAN
损失,此时生成的图像应尽可能接近另一类别,使用 np.ones
作为训练生成器的判别网络目标输出,因为我们将生成的伪造图像传递给相同类别的判别器:
fake_B = G_AB(real_A)loss_GAN_AB = criterion_GAN(D_B(fake_B), torch.Tensor(np.ones((len(real_A), 1, 16, 16))).to(device))fake_A = G_BA(real_B)loss_GAN_BA = criterion_GAN(D_A(fake_A), torch.Tensor(np.ones((len(real_A), 1, 16, 16))).to(device))loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
计算循环一致性损失。假设,一张苹果图像被橙子生成器转换为一张伪造橙子图像,然后伪造橙子图像通过苹果生成器转换回一张苹果图像,理想情况下,经过该过程后应该返回原始图像,即循环一致性损失应该为零:
recov_A = G_BA(fake_B)loss_cycle_A = criterion_cycle(recov_A, real_A)recov_B = G_AB(fake_A)loss_cycle_B = criterion_cycle(recov_B, real_B)loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
计算总损失并执行反向传播:
loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identityloss_G.backward()optimizer.step()return loss_G, loss_identity, loss_GAN, loss_cycle, loss_G, fake_A, fake_B
(11) 定义判别器训练函数 discriminator_train_step
:
def discriminator_train_step(D, real_data, fake_data, optimizer, criterion_GAN):optimizer.zero_grad()loss_real = criterion_GAN(D(real_data), torch.Tensor(np.ones((len(real_data), 1, 16, 16))).to(device))loss_fake = criterion_GAN(D(fake_data.detach()), torch.Tensor(np.zeros((len(real_data), 1, 16, 16))).to(device))loss_D = (loss_real + loss_fake) / 2loss_D.backward()optimizer.step()return loss_D
(12) 定义生成器、判别器对象、优化器和损失函数:
G_AB = GeneratorResNet().to(device)G_BA = GeneratorResNet().to(device)D_A = Discriminator().to(device)D_B = Discriminator().to(device)criterion_GAN = torch.nn.MSELoss()criterion_cycle = torch.nn.L1Loss()criterion_identity = torch.nn.L1Loss()optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999))optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))lambda_cyc, lambda_id = 10.0, 5.0
(13) 训练网络:
n_epochs = 50# log = Report(n_epochs)loss_D_epochs = []loss_G_epochs = []loss_GAN_epochs = []loss_cycle_epochs = []loss_identity_epochs = []for epoch in range(n_epochs):N = len(trn_dl)loss_D_items = []loss_G_items = []loss_GAN_items = []loss_cycle_items = []loss_identity_items = []for bx, batch in enumerate(trn_dl):real_A, real_B = batchloss_G, loss_identity, loss_GAN, loss_cycle, loss_G, fake_A, fake_B = generator_train_step((G_AB,G_BA), optimizer_G, real_A, real_B, D_A, D_B, criterion_identity, criterion_cycle, criterion_GAN, lambda_cyc, lambda_id)loss_D_A = discriminator_train_step(D_A, real_A, fake_A, optimizer_D_A, criterion_GAN)loss_D_B = discriminator_train_step(D_B, real_B, fake_B, optimizer_D_B, criterion_GAN)loss_D = (loss_D_A + loss_D_B) / 2loss_D_items.append(loss_D.item())loss_G_items.append(loss_G.item())loss_GAN_items.append(loss_GAN.item())loss_cycle_items.append(loss_cycle.item())loss_identity_items.append(loss_identity.item())loss_D_epochs.append(np.average(loss_D_items))loss_G_epochs.append(np.average(loss_G_items))loss_GAN_epochs.append(np.average(loss_GAN_items))loss_cycle_epochs.append(np.average(loss_cycle_items))loss_identity_epochs.append(np.average(loss_cycle_items))
(14) 训练模型后,测试模型生成图像:
generate_sample(G_AB, G_BA)
从上图可以看出,CycleGAN
可以成功地将苹果转换为橙子(前两行),将橙子转换为苹果(后两行)。
小结
CycleGAN
是一种用于无监督图像转换的深度学习模型,它通过两个生成器和两个判别器的组合来学习两个不同域之间的映射关系。生成器负责将一个域的图像转换成另一个域的图像,而判别器则用于区分生成的图像和真实的图像。CycleGAN
引入循环一致性损失,确保图像转换是可逆的,从而提高生成图像的质量。通过对抗训练和循环一致性损失,CycleGAN
可以实现在没有配对标签的情况下进行图像域转换。
系列链接
PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——从零开始实现SSD目标检测
PyTorch深度学习实战(24)——使用U-Net架构进行图像分割
PyTorch深度学习实战(25)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(26)——多对象实例分割
PyTorch深度学习实战(27)——自编码器(Autoencoder)
PyTorch深度学习实战(28)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(29)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(30)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(31)——神经风格迁移
PyTorch深度学习实战(32)——Deepfakes
PyTorch深度学习实战(33)——生成对抗网络(Generative Adversarial Network, GAN)
PyTorch深度学习实战(34)——DCGAN详解与实现
PyTorch深度学习实战(35)——条件生成对抗网络(Conditional Generative Adversarial Network, CGAN)
PyTorch深度学习实战(36)——Pix2Pix详解与实现