目录
- 前言
- [style transfer原理](#style transfer原理)
- [style transfer代码](#style transfer代码)
- [fast style transfer 代码](#fast style transfer 代码)
前言
本篇来带大家看看VGG的实战篇,这次来带大家看看计算机视觉中一个有趣的小任务,图像风格迁移。
可运行代码位于: Style_transfer (可直接下载运行)
我们所进行的图像风格迁移是基于VGG实现的哦,如果对VGG网络还不太了解的,可以先去看看这边文章:【深度学习基础】:VGG原理篇 - carpell - 博客园
style transfer原理
原理解析
图像风格迁移是一种通过将一张图像的风格应用到另一张图像的内容上,从而生成一张新图像的技术。这种技术通常基于深度学习模型,特别是卷积神经网络(CNN)。
实现步骤:
- 首先,我们需要两张图像:内容图像和风格图像。内容图像是我们想要保持内容的图像,而风格图像是我们想要应用到内容图像上的风格的图像。接下来,我们通过使用预训练的卷积神经网络(如VGG19)来提取图像的特征。
- 在提取特征之后,我们计算内容损失和风格损失。内容损失度量生成图像与内容图像在内容上的差异,通常通过比较两个图像在浅层特征上的差异来计算。风格损失度量生成图像与风格图像在风格上的差异,风格通常通过计算图像特征之间的相关性来表示,即风格图像的特征图之间的gram矩阵。(这个等会会详细讲述)
其实图像风格迁移的核心是一个优化问题,目标是生成一张图像,使得内容损失和风格损失都最小化。我们初始化一张随机生成的图像。然后,我们使用卷积神经网络提取这张图像的特征,并计算内容损失和风格损失。接着,我们计算总损失,并进行反向传播和优化,以更新生成图像的像素值。这个过程会重复进行多次,直到生成图像在内容和风格上都接近目标图像。
损失函数
好,其实大概的原理基本上都清楚了,但是我们来看,我们是该如何去保证图片的内容是我们所选的内容图片,但是风格确实我们所选的风格图片呢?
首先来看内容函数的损失函数,其是这个还是很好理解的,我们所使用的就是平方差损失函数,我们希望我们的生成图片能够和内容图片具有相同的内容,所以最小化两者的差距即可。但是这里有个细节,我们该选择那个层的输出作为我们的标准呢?其实看最上面的图也能看出,我们的网络越深,提取到的内容图片损失的细节就会越多,所以我们会更多的选择浅层的内容图片特征作为标准。

python
def calculate_content_loss(original_feat, generated_feat) -> torch.Tensor:
"""计算内容损失,即生成特征图与标准特征图的规范化误差平方和"""
b, c, h, w = original_feat.shape
x = 2. * c * h * w # 规范化系数
return torch.sum((generated_feat - original_feat) ** 2) / x
然后我们在来看风格损失函数,其实风格损失函数的关键在于gram矩阵。那么我们回到最初,我们为什么能够提取出来图像的风格特征呢?其实就是依靠Gram矩阵了,其用于捕捉图像中不同通道之间的相关性,这些相关性可以反映出图像的纹理、颜色分布等风格特征。通俗说就是出现尖尖的形状的时候边上应该出现些什么。

然后风格损失函数如下,同样的我们会选择深层的风格图像特征作为我们的标准,因为深层网络提取出来的更多的是抽象的语义信息,能够代表其本质的特征,抽象的特征。

python
def gram_matrix(feature):
"""
计算特征图的格莱姆矩阵,作为风格特征表示
:param feature: 输入特征图
:return: 格莱姆矩阵
"""
b, c, h, w = feature.size()
feature = feature.view(b, c, h * w) # 拉平空间维度
G = torch.bmm(feature, feature.transpose(1, 2))
return G
def calculate_style_loss(style_feat, generated_feat) -> torch.Tensor:
"""计算风格损失,即生成特征图与标准特征图的格拉姆矩阵的规范化误差平方和"""
b, c, h, w = style_feat.shape
G = gram_matrix(generated_feat)
A = gram_matrix(style_feat)
x = 4. * ((h * w) ** 2) * (c ** 2) # 规范化系数
return torch.sum((G - A) ** 2) / x
所以最终的损失函数就是风格损失和内容损失的和,其中会乘上不同的系数。

style transfer代码
这里只展现了部分的核心代码,所有代码位于Style_transfer ,下载可直接使用。
我们通过处理目标图片,其一开始就是一张白噪声图,然后通过网络损失不断修改其像素,使得目标图片内容与内容图片相近,风格与风格图片相近。

python
import argparse
import os
import torch
import torch.optim as optim
from tqdm import tqdm
from models import VGG
from utils import make_transform, load_image, save_image, calculate_style_loss, calculate_content_loss
def parse_args():
parser = argparse.ArgumentParser(description='Style Transfer=')
parser.add_argument('--content_image', type=str, default='./data/content1.jpg', help='Path to the content image')
parser.add_argument('--style_image', type=str, default='./data/style2.jpg', help='Path to the style image')
parser.add_argument('--output_dir', type=str, default='./output/iterative_style_transfer', help='Output directory')
parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='Image size (height, width)')
parser.add_argument('--content_weight', type=float, default=1, help='Content weight')
parser.add_argument('--style_weight', type=float, default=15, help='Style weight')
parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
parser.add_argument('--steps_per_epoch', type=int, default=100, help='Number of steps per epoch')
parser.add_argument('--learning_rate', type=float, default=0.03, help='Learning rate')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# 检查文件路径
assert os.path.exists(args.content_image), f"content image is not exist: {args.content_image}"
assert os.path.exists(args.style_image), f"style image is not exist: {args.style_image}"
# 内容特征层及loss加权系数
content_layers = {'5': 0.5, '10': 0.5}
# 风格特征层及loss加权系数
style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2}
# ----------------训练即推理过程----------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = make_transform(args.image_size, normalize=True)
content_img = load_image(args.content_image, transform).to(device)
style_img = load_image(args.style_image, transform).to(device)
# 用小随机噪声初始化图像(避免标准正态分布过大震荡)
generated_img = torch.rand_like(content_img).mul(0.1).requires_grad_().to(device)
save_image(generated_img, args.output_dir, 'noise_init.jpg')
vgg_model = VGG(content_layers, style_layers).to(device).eval()
# 关闭梯度追踪,节省显存
with torch.no_grad():
content_features, _ = vgg_model(content_img)
_, style_features = vgg_model(style_img)
optimizer = optim.Adam([generated_img], lr=args.learning_rate)
for epoch in range(args.epochs):
p_bar = tqdm(range(args.steps_per_epoch), desc=f'Epoch {epoch + 1}/{args.epochs}')
for step in p_bar:
generated_content, generated_style = vgg_model(generated_img)
content_loss = sum(
args.content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content)
for name, gen_content in generated_content.items()
)
style_loss = sum(
args.style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style)
for name, gen_style in generated_style.items()
)
total_loss = style_loss + content_loss
optimizer.zero_grad()
total_loss.backward()
# 梯度裁剪(可选但稳健)
torch.nn.utils.clip_grad_norm_([generated_img], max_norm=1.0)
optimizer.step()
p_bar.set_postfix(style_loss=style_loss.item(), content_loss=content_loss.item(), total_loss=total_loss.item())
# 保存中间结果
save_image(generated_img, args.output_dir, f'generated_epoch_{epoch + 1}.jpg')
效果图

fast style transfer 代码
刚才我们看了基于迭代的图像风格迁移,那么能不能有个方法,我们训练一个专用于一种风格的图像迁移的网络,其网络训练好之后,我们就不再需要老是输入风格图像,这个训练好的网络可以处理所有输入的图像输出其已经训练好的风格。所以fast style transfer 来了。其不仅可以处理图片,同样可以处理视频哦!

这里只展现了部分的核心代码,所有代码位于Style_transfer ,下载可直接使用。
python
import argparse
import os
from datetime import datetime, timedelta
from typing import List, Iterable
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import cv2
import numpy as np
from models import VGG, TransNet
from datasets import ImageDataset
from utils import load_image, save_image, make_transform, save_model, calculate_style_loss, calculate_content_loss, \
denormalize
def parse_args():
parser = argparse.ArgumentParser(description='Fast Style Transfer')
parser.add_argument('--mode', type=str, choices=['train', 'image', 'video'], default='train',
help='Operation mode: train, image, video')
parser.add_argument('--image_size', type=int, nargs=2, default=[300, 450], help='Image size (height, width)')
parser.add_argument('--style_image', type=str, default='./data/style3.jpg', help='Style image path (training mode only)')
# The directory data/train2017/default_class contains several images. Due to the format required by ImageFolder, we need to use a class folder to contain the images, even though the class label is not used.
parser.add_argument('--content_dataset', type=str, default='data/train2014', help='Content image dataset path (training mode only)')
parser.add_argument('--content_weight', type=float, default=1., help='Content weight (training mode only)')
parser.add_argument('--style_weight', type=float, default=15., help='Style weight (training mode only)')
parser.add_argument('--model_save_path', type=str, default=f'./checkpoint/style2.pth', help='Model save path (training mode)')
parser.add_argument('--pretrained_model_path', type=str, help='Pretrained model load path (training mode only)')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs (training mode)')
parser.add_argument('--save_interval', type=int, default=120, help='Save interval (seconds) (training mode)')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate (training mode)')
parser.add_argument('--output_dir', type=str, default='./output/realtime_transfer', help='Output directory (training mode)')
parser.add_argument('--model_path', type=str, default='./checkpoint/style1.pth', help='Model load path (image and video mode)')
parser.add_argument('--input_images_dir', type=str, help='Input images root directory (image mode only)')
parser.add_argument('--output_images_dir', type=str, default='./output/fast_style_transfer/image_generated.jpg',
help='Output images root directory (image mode only)')
parser.add_argument('--video_input', type=str, default='data/maigua.mp4', help='Input video path (video mode only)')
parser.add_argument('--video_output', type=str, default='output/videos/maigua.mp4',
help='Output video path (video mode only)')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
return parser.parse_args()
def train(model,
vgg,
lr, epochs, batch_size, style_weight, content_weight, style_layers, content_layers,
device, transform,
image_style,
content_dataset_root,
save_path, output_dir,
save_interval=timedelta(seconds=120)):
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 对目标网络进行优化
dataset = ImageDataset(content_dataset_root, transform=transform)
dataloader = DataLoader(dataset, batch_size, shuffle=True)
_, style_features = vgg(image_style)
p_bar = tqdm(range(epochs))
last_save_time = datetime.now() - save_interval
for epoch in p_bar:
running_content_loss, running_style_loss = 0.0, 0.0
for i, content_img in enumerate(dataloader):
content_img = content_img.to(device)
image_generated = model(content_img) # 只使用内容图像进行风格迁移
generated_content, generated_style = vgg(image_generated)
style_loss = sum(
style_weight * style_layers[name] * calculate_style_loss(style_features[name], gen_style) for
name, gen_style in generated_style.items())
content_features, _ = vgg(content_img) # 计算内容图的内容特征
content_loss = sum(
content_weight * content_layers[name] * calculate_content_loss(content_features[name], gen_content) for
name, gen_content in generated_content.items())
total_loss = style_loss + content_loss
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) # 梯度裁剪
optimizer.step()
running_content_loss += content_loss.item()
running_style_loss += style_loss.item()
p_bar.set_postfix(progress=f'{(i + 1) / len(dataloader) * 100:.3f}%',
style_loss=f"{style_loss.item():.3f}",
content_loss=f"{content_loss.item():.3f}",
last_save_time=last_save_time)
if datetime.now() - last_save_time > save_interval:
last_save_time = datetime.now()
#writer.add_images('image_generated', denormalize(image_generated), epoch * len(dataloader) + i)
save_model(model, save_path) # 'fast_style_transfer.pth'
save_image(torch.cat((image_generated, content_img), 3), output_dir, f'{epoch}_{i}.jpg')
def process_images(images: Iterable[Image.Image], transform, model, device) -> List[Image.Image]:
images = torch.stack([transform(image) for image in images]).to(device)
model.to(device)
batch_generated = model(images)
batch_generated = denormalize(batch_generated).detach().cpu()
batch_generated = [transforms.ToPILImage()(image) for image in batch_generated]
return batch_generated
def process_video(video_path, output_path, transform, model, device, batch_size=4):
# 打开视频文件
cap = cv2.VideoCapture(video_path)
output_dir, filename = os.path.split(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# 获取视频属性
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 定义视频编码器和创建 VideoWriter 对象
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
# 初始化 tqdm 进度条
pbar = tqdm(total=total_frames, desc="Processing Video")
# 读取视频并批量处理
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
if len(frames) == batch_size:
batch_generated = process_images(frames, transform, model, device)
for gen_frame in batch_generated:
gen = cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR)
cv2.imshow("output", gen)
gen = cv2.resize(gen, (frame_width, frame_height))
out.write(gen)
if cv2.waitKey(1) & 0xFF == ord('s'):
break
frames.clear()
pbar.update(batch_size)
if frames:
batch_generated = process_images(frames, transform, model, device)
for gen_frame in batch_generated:
out.write(cv2.cvtColor(np.array(gen_frame), cv2.COLOR_RGB2BGR))
pbar.update(len(frames))
print(f'video successfully saved to: {output_path}')
pbar.close()
cap.release()
out.release()
if __name__ == '__main__':
args = parse_args()
# print(args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ----------------路径参数----------------
# 内容特征层及loss加权系数
content_layers = {'5': 0.5, '10': 0.5} # 使用vgg的较浅层特征作为内容特征,保证生成图片内容结构相似性
# 风格特征层及loss加权系数
style_layers = {'0': 0.2, '5': 0.2, '10': 0.2, '19': 0.2, '28': 0.2} # 使用vgg不同深度的风格特征,生成风格更加层次丰富
transform = make_transform(size=args.image_size, normalize=True) # 图像变换
image_style = load_image(args.style_image, transform=transform).to(device) # 风格图像
vgg = VGG(content_layers, style_layers).to(device) # 特征提取网络,只用来提取特征,不进行训练
model = TransNet(input_size=args.image_size).to(device) # 内容生成网络,用于生成风格图片,进行训练
if args.mode != 'train' and getattr(args, 'model_path'):
if not os.path.exists(args.model_path):
raise FileNotFoundError(f'{args.model_path}不存在!')
model.load_state_dict(torch.load(args.model_path))
elif args.mode == 'train' and getattr(args, 'pretrained_model_path') and os.path.exists(args.pretrained_model_path):
if not os.path.exists(args.pretrained_model_path):
raise FileNotFoundError(f'{args.pretrained_model_path}不存在!')
model.load_state_dict(torch.load(args.pretrained_model_path))
if args.mode == 'train':
# 训练模式
# 使用大规模内容图像数据训练快速图像风格迁移网络,比如COCO2017数据集
train(model, vgg, args.learning_rate, args.epochs, args.batch_size, args.style_weight, args.content_weight,
style_layers,
content_layers, device, transform, image_style, args.content_dataset, args.model_save_path,
args.output_dir,
save_interval=timedelta(seconds=args.save_interval))
elif args.mode == 'image':
# 使用训练好的风格迁移模型演示批量处理图片
if not os.path.exists(args.output_images_dir):
os.makedirs(args.output_images_dir)
for filename in tqdm(os.listdir(args.input_images_dir), desc='Processing Images'):
try:
filepath = os.path.join(args.input_images_dir, filename)
images_generated = process_images([Image.open(filepath)], transform, model, device)
images_generated[0].save(os.path.join(args.output_images_dir, filename))
except Exception as e:
pass
elif args.mode == 'video':
# 视频处理模式
process_video(args.video_input, args.video_output, transform, model, device,
batch_size=args.batch_size)
else:
raise ValueError("未知的运行模式")
效果图
