风格迁移生成图片

前言

本文详细介绍了风格迁移图像生成的全部过程,模型方面主要是使用了现成的 VGG19 模型,效果符合预期。

数据处理

我们使用的基础图片如下

我们使用的风格图片如下

实现了两个函数 preprocess_image 和 deprocess_image 用来处理图像。

  1. preprocess_image: 这个函数负责对图像进行预处理。它接受一个图像路径作为输入,并返回一个预处理后的图像张量。具体步骤如下:

    • 使用 keras.utils.load_img 加载图像,并将其调整为指定的目标大小 (img_nrows, img_ncols)
    • 使用 keras.utils.img_to_array 将图像转换为 NumPy 数组。
    • 使用 np.expand_dims 在数组的第一个维度上添加一个维度,将其变成形状为 (1, img_nrows, img_ncols, 3) 的张量,其中 3 表示图像通道数。
    • 使用 vgg19.preprocess_input 函数对图像进行 VGG19 模型的预处理,例如将像素值归一化到特定的范围。
  2. deprocess_image: 这个函数用于将经过处理的图像张量恢复为可视化图像。它接受一个图像张量作为输入,并返回一个经过反向处理的图像数组。具体步骤如下:

    • 将图像张量 x 重塑为形状为 (img_nrows, img_ncols, 3) 的张量。
    • 执行逆操作来撤消预处理步骤,例如对图像进行的平均值偏移操作。这些值在训练 VGG 模型时用于归一化图像。
    • 将图像从 BGR 格式转换为 RGB 格式,这通常是由于不同库或模型使用不同的通道顺序。
    • 使用 np.clip 将图像像素值限制在 0 到 255 之间,并将其转换为 uint8 类型,以确保图像值在正确的范围内。
ini 复制代码
def preprocess_image(image_path):
    img = keras.utils.load_img(image_path, target_size=(img_nrows, img_ncols))
    img = keras.utils.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = vgg19.preprocess_input(img)
    return tf.convert_to_tensor(img)
def deprocess_image(x):
    x = x.reshape((img_nrows, img_ncols, 3))
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype("uint8")
    return x

损失计算

我们定义了几个函数来计算模型的损失值,主要有以下几个部分:

  • gram_matrix :用于计算给定特征张量 x 的格拉姆矩阵,格拉姆矩阵是用来捕捉特征之间相关性的一种方式
  • style_loss :使用风格图像的格拉姆矩阵 S 和合成图像的格拉姆矩阵 C,计算两个格拉姆矩阵之间的均方差,作为风格损失的度量,将最后的结果通过归一化系数对损失进行缩放。
  • content_loss :计算合成图像与内容图像之间的像素级差异(均方差),这种损失直接反映了合成图像与内容图像之间的内容差异。
  • total_variation_loss :计算图像 x 在水平和垂直方向上相邻像素之间的差异的平方,对这些差异进行幂次处理以增强平滑度,对增强后的差异进行求和,作为总变差损失的度量。
  • compute_loss:用于计算图像风格迁移中的总损失,这个总损失由内容损失、风格损失和总变差损失三部分组成。
ini 复制代码
def gram_matrix(x):  # [400, 225, 64]
    x = tf.transpose(x, (2, 0, 1))  # [64, 400, 225,]
    features = tf.reshape(x, (tf.shape(x)[0], -1))  # [64, 90000]
    gram = tf.matmul(features, tf.transpose(features))  # [64, 64]
    return gram

def style_loss(style, combination):
    S = gram_matrix(style)
    C = gram_matrix(combination)
    channels = 3
    size = img_nrows * img_ncols
    return tf.reduce_sum(tf.square(S - C)) / (4. * (channels ** 2) * (size ** 2))

def content_loss(base, combination):
    return tf.reduce_sum(tf.square(combination - base))

def total_variation_loss(x):  # [1,400,225,3]
    a = tf.square(x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :])  # [1,399,224,3]
    b = tf.square(x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :])  # [1,399,224,3]
    return tf.reduce_sum(tf.pow(a + b, 1.25))

def compute_loss(combination_image, base_image, style_reference_image):
    input_tensor = tf.concat([base_image, style_reference_image, combination_image], axis=0)
    features = feature_extractor(input_tensor)
    loss = tf.zeros(shape=())
    layer_features = features[content_layer_name]
    base_image_features = layer_features[0, :, :, :]
    combination_features = layer_features[2, :, :, :]
    loss = loss + content_weight * content_loss(base_image_features, combination_features)
    for layer_name in style_layer_names:
        layer_features = features[layer_name]  # [3, 400, 255 ,64]
        style_reference_features = layer_features[1, :, :, :]  # [400, 255 ,64]
        combination_features = layer_features[2, :, :, :]  # [400, 255 ,64]
        sl = style_loss(style_reference_features, combination_features)
        loss += (style_weight / len(style_layer_names)) * sl
    loss += total_variation_weight * total_variation_loss(combination_image)
    return loss

模型训练

  1. 模型加载与特征提取器创建 : 使用 VGG19 模型作为特征提取器,通过遍历模型的层,创建一个字典,其中键是层的名称,值是层的输出张量,使用 keras.Model 创建一个新的模型,该模型接受原始 VGG19 模型的输入,并输出特定层的特征张量。

  2. 损失函数与梯度计算函数compute_loss_and_grads 函数计算给定组合图像、基础图像和风格参考图像的总损失,并计算相对于组合图像的梯度。并采用了 SGD 优化器,对由损失计算出的梯度进行优化。

  3. 迭代优化: 对基础图像和风格参考图像进行预处理,以适应 VGG19 模型的输入要求,将基础图像作为初始的组合图像。通过 10000 次迭代优化过程,不断更新组合图像以最小化损失,在每次迭代中计算损失和梯度,并应用梯度更新组合图像,每 100 次迭代输出一次损失值,并保存当前的组合图像。

ini 复制代码
model = vgg19.VGG19(weights="imagenet", include_top=False)
output_dict = dict([(layer.name, layer.output) for layer in model.layers])
feature_extractor = keras.Model(inputs=model.inputs, outputs=output_dict)
style_layer_names = ["block1_conv1", "block2_conv1", "block3_conv1", "block4_conv1", "block5_conv1", ]
content_layer_name = "block5_conv2"
def compute_loss_and_grads(combination_image, base_image, style_reference_image):
    with tf.GradientTape() as tape:
        loss = compute_loss(combination_image, base_image, style_reference_image)
    grads = tape.gradient(loss, combination_image)
    return loss, grads
optimizer = keras.optimizers.SGD(keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96))

base_image = preprocess_image(base_image_path)
style_reference_image = preprocess_image(style_reference_image_path)
combination_image = tf.Variable(preprocess_image(base_image_path))
iterations = 10000
for i in range(1, iterations + 1):
    loss, grads = compute_loss_and_grads(combination_image, base_image, style_reference_image)
    optimizer.apply_gradients([(grads, combination_image)])
    if i % 100 == 0:
        print("Iteration %d: loss=%.2f" % (i, loss))
        img = deprocess_image(combination_image.numpy())
        fname = result_prefix + "_at_iteration_%d.png" % i
        keras.utils.save_img(fname, img)

训练过程消耗 9G 左右显存,日志打印:

ini 复制代码
Iteration 100: loss=8640.24
Iteration 200: loss=7211.42
Iteration 300: loss=6196.27
...
Iteration 9800: loss=3071.73
Iteration 9900: loss=3071.40
Iteration 10000: loss=3071.09

结果展示,风格发生了变化,生成了毕加索风格的图片。

参考

github.com/wangdayaya/...

相关推荐
HPC_fac1305206781638 分钟前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
小陈phd3 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao4 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
ZHOU_WUYI8 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1238 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界9 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221519 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2519 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台
浊酒南街10 小时前
Statsmodels之OLS回归
人工智能·数据挖掘·回归
畅联云平台10 小时前
美畅物联丨智能分析,安全管控:视频汇聚平台助力智慧工地建设
人工智能·物联网