最近对对抗生成网络GAN比较感兴趣,相关知识点文章还在编辑中,以下这个是一个练手的小项目~

(在原模型上做了,为了减少计算量让其好训练一些。)

一、导入工具包

import tensorflow as tffrom tensorflow.keras import layersimport numpy as npimport osimport timeimport globimport matplotlib.pyplot as pltfrom IPython.display import clear_outputfrom IPython import display

1.1 设置GPU

gpus = tf.config.list_physical_devices("GPU")if gpus:    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用    tf.config.set_visible_devices([gpu0],"GPU")gpus 
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

二、导入训练数据

链接:点这里

fileList = glob.glob('./ani_face/*.jpg')len(fileList)
41621

2.1 数据可视化

# 随机显示几张图for index,i in enumerate(fileList[:3]):    display.display(display.Image(fileList[index]))

2.2 数据预处理

# 文件名列表path_ds = tf.data.Dataset.from_tensor_slices(fileList)# 预处理,归一化,缩放def load_and_preprocess_image(path):    image = tf.io.read_file(path)    image = tf.image.decode_jpeg(image, channels=3)    image = tf.image.resize(image, [64, 64])    image /= 255.0  # normalize to [0,1] range    image = tf.reshape(image, [1, 64,64,3])    return imageimage_ds = path_ds.map(load_and_preprocess_image)image_ds
# 查看一张图片for x in image_ds:    plt.axis("off")    plt.imshow((x.numpy() * 255).astype("int32")[0])    break

三、网络构建

3.1 D网络

discriminator = keras.Sequential(    [        keras.Input(shape=(64, 64, 3)),        layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),        layers.LeakyReLU(alpha=0.2),        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),        layers.LeakyReLU(alpha=0.2),        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),        layers.LeakyReLU(alpha=0.2),        layers.Flatten(),        layers.Dropout(0.2),        layers.Dense(1, activation="sigmoid"),    ],    name="discriminator",)discriminator.summary()
Model: "discriminator"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================conv2d (Conv2D)              (None, 32, 32, 64)        3136      _________________________________________________________________leaky_re_lu (LeakyReLU)      (None, 32, 32, 64)        0         _________________________________________________________________conv2d_1 (Conv2D)            (None, 16, 16, 128)       131200    _________________________________________________________________leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 128)       0         _________________________________________________________________conv2d_2 (Conv2D)            (None, 8, 8, 128)         262272    _________________________________________________________________leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 128)         0         _________________________________________________________________flatten (Flatten)            (None, 8192)              0         _________________________________________________________________dropout (Dropout)            (None, 8192)              0         _________________________________________________________________dense (Dense)                (None, 1)                 8193      =================================================================Total params: 404,801Trainable params: 404,801Non-trainable params: 0

3.2 G网络

latent_dim = 128generator = keras.Sequential(    [        keras.Input(shape=(latent_dim,)),        layers.Dense(8 * 8 * 128),        layers.Reshape((8, 8, 128)),        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),        layers.LeakyReLU(alpha=0.2),        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),        layers.LeakyReLU(alpha=0.2),        layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),        layers.LeakyReLU(alpha=0.2),        layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),    ],    name="generator",)generator.summary()

3.3重写train_step

class GAN(keras.Model):    def __init__(self, discriminator, generator, latent_dim):        super(GAN, self).__init__()        self.discriminator = discriminator        self.generator = generator        self.latent_dim = latent_dim    def compile(self, d_optimizer, g_optimizer, loss_fn):        super(GAN, self).compile()        self.d_optimizer = d_optimizer        self.g_optimizer = g_optimizer        self.loss_fn = loss_fn        self.d_loss_metric = keras.metrics.Mean(name="d_loss")        self.g_loss_metric = keras.metrics.Mean(name="g_loss")    @property    def metrics(self):        return [self.d_loss_metric, self.g_loss_metric]    def train_step(self, real_images):        # 生成噪音        batch_size = tf.shape(real_images)[0]        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))        # 生成的图片        generated_images = self.generator(random_latent_vectors)        # Combine them with real images        combined_images = tf.concat([generated_images, real_images], axis=0)        # Assemble labels discriminating real from fake images        labels = tf.concat(            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0        )        # Add random noise to the labels - important trick!        labels += 0.05 * tf.random.uniform(tf.shape(labels))        # 训练判别器,生成的当成0,真实的当成1         with tf.GradientTape() as tape:            predictions = self.discriminator(combined_images)            d_loss = self.loss_fn(labels, predictions)        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)        self.d_optimizer.apply_gradients(            zip(grads, self.discriminator.trainable_weights)        )        # Sample random points in the latent space        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))        # Assemble labels that say "all real images"        misleading_labels = tf.zeros((batch_size, 1))        # Train the generator (note that we should *not* update the weights        # of the discriminator)!        with tf.GradientTape() as tape:            predictions = self.discriminator(self.generator(random_latent_vectors))            g_loss = self.loss_fn(misleading_labels, predictions)        grads = tape.gradient(g_loss, self.generator.trainable_weights)        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))        # Update metrics        self.d_loss_metric.update_state(d_loss)        self.g_loss_metric.update_state(g_loss)        return {            "d_loss": self.d_loss_metric.result(),            "g_loss": self.g_loss_metric.result(),        }

3.4 设置回调函数

class GANMonitor(keras.callbacks.Callback):    def __init__(self, num_img=3, latent_dim=128):        self.num_img = num_img        self.latent_dim = latent_dim    def on_epoch_end(self, epoch, logs=None):        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))        generated_images = self.model.generator(random_latent_vectors)        generated_images *= 255        generated_images.numpy()        for i in range(self.num_img):            img = keras.preprocessing.image.array_to_img(generated_images[i])            display.display(img)            img.save("gen_ani/generated_img_%03d_%d.png" % (epoch, i))

四、训练模型

epochs = 100  # In practice, use ~100 epochsgan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)gan.compile(    d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),    loss_fn=keras.losses.BinaryCrossentropy(),)gan.fit(    image_ds, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)])

五、保存模型

#保存模型gan.generator.save('./data/ani_G_model')

生成模型文件:点这里

六、生成漫画脸

G_model =  tf.keras.models.load_model('./data/ani_G_model/',compile=False)def randomGenerate():    noise_seed = tf.random.normal([16, 128])    predictions = G_model(noise_seed, training=False)    fig = plt.figure(figsize=(8, 8))    for i in range(predictions.shape[0]):        plt.subplot(4, 4, i+1)        img = (predictions[i].numpy() * 255 ).astype('int')        plt.imshow(img )        plt.axis('off')    plt.show()
count = 0while True:    randomGenerate()    clear_output(wait=True)    time.sleep(0.1)    if count > 100:        break    count+=1