《动手学深度学习》-49Style_Transfer实现

复制代码
风格迁移:将图A的图片套上图B的风格,有点类似于滤镜。

代码关键

1.找到一个预训练大模型,参数不更新

2.选内容层和风格层,内容层数值比较大,风格层从低到高都要有

3.内容图经过内容层,得到目标内容,训练时让预训练的图片接近目标内容;风格图经过风格层,得到目标风格,然后再用gram得到最终目标风格,训练时让预训练的图片接近最终目标风格

4.让内容图作为与训练图,更新内容图的参数

5.根据内容、目标以及降噪损失三个loss,加权求和得到l,并对l做梯度,更新图像像素值。

复制代码
import torch
import torchvision
from torch import nn
from PIL import Image
from torchvision import transforms
import d2l
#读取图片
d2l.set_figsize()
content_img=Image.Open('')
d2l.plt.imshow(content_img)
style_img=Image.Open('')
d2l.plt.imshow(style_img)
#预处理和后处理,rgb图-tensor转化
rgb_mean=torch.tensor([0.485, 0.456, 0.406])
rgb_std=torch.tensor([0.229, 0.224, 0.225])
def prepocess(img,img_shape):
    transform = transforms.Compose([transforms.Resize(img_shape),
                                    transforms.ToTensor(),
                                    transforms.Normalize(rgb_mean,rgb_std)])
    return transform(img).unsqueeze(0)
def postpocess(img):
    img=img[0].to(rgb_std.device)
    img=torch.clamp(img.permute(1,2,0)*rgb_std+rgb_mean,0,1)
    return torchvision.transforms.ToPILImage()(img.permute(2,0,1))
#模型
#vgg对于抽取特征比较好
pretrained_net=torch.models.vgg19(pretrained=True)
style_layer,content_layer=[0,5,10,19,28],[25]
net=nn.Sequential(*[pretrained_net[i] for i in range(max(content_layer+style_layer+1))])
def extract_features(X,content_layer,style_layer):
    contents=[]
    styles=[]
    for i in range(len(net)):
        X=net[i](X)
        if i in style_layer:
            styles.append(X)
        elif i in content_layer:
            contents.append(X)
    return contents,styles
def get_contents(imge_shape,device):
    content_X=prepocess(content_img,imge_shape).to(device)
    contents_Y,_=extract_features(content_img,content_layer,style_layer)
    return content_X,contents_Y
def get_styles(img_shape,device):
    style_X=prepocess(style_img,img_shape).to(device)
    _,styles_Y=extract_features(style_img,content_layer,style_layer)
    return style_X,styles_Y
#定义损失
def content_loss(Y_hat,Y):
    return torch.square(Y_hat-Y.detach()).mean()
def gram(X):
    num_channels,n=X.shape[1],X.numel()//X.shape[1]
    X=X.reshape((num_channels,n))
    return torch.matmul(X.T,X)/(num_channels*n)
def style_loss(Y_hat,gram_Y):
    return torch.square(gram(Y_hat)-gram_Y.detach()).mean()
def tv_loss(Y_hat):#降噪损失
    return 0.5*(torch.abs(Y_hat[:,:,1:,:]-Y_hat[:,:,:-1,:]).mean()+torch.abs(Y_hat[:,:,:,1:]-Y_hat[:,:,:,:-1]).mean())
content_weight,style_weight,tv_weight=1,1e4,10
def compute_loss(X,contents_Y_hat,style_Y_hat,contents_Y,style_Y):
    """

    :rtype: tuple[list[Tensor], list[Tensor], Tensor, Union[int, Tensor, Literal[0], list[Tensor]]]
    """
    contents_l=[content_loss(Y_hat,Y)*content_weight for Y_hat,Y in zip(contents_Y_hat,contents_Y)]
    style_l=[style_loss(Y_hat,Y)*style_weight for Y_hat,Y in zip(style_Y_hat,style_Y)]
    tv_l=tv_loss(X)*tv_weight
    l=sum(contents_l,style_l,[tv_l])
    return contents_l,style_l,tv_l,l
# 初始化
class SynthesizedImage(nn.Module):
    def __init__(self,img_shape,**kwargs):
        super(SynthesizedImage,self).__init__(**kwargs)
        self.weight=nn.parameter(torch.random(*img_shape.shape))
    def forward(self):#模型没有输入,输出为可训练的图片参数
        return self.weight
def get_inits(X,device,lr,styles_Y):
    gen_img=SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer=torch.optim.Adam(gen_img.parameters(),lr=lr)
    styles_Y_gram=[gram(Y) for Y in styles_Y]#提前算出,省计算
    return gen_img,styles_Y_gram,trainer
def train(X,contents_Y,styles_Y,device,lr,num_epochs,lr_decay_epoch):
    X,styles_Y_gram,trainer=get_inits(X,device,lr,styles_Y)
    scheduler=torch.optim.lr_scheduler.StepLR(trainer,lr_decay_epoch,0.8)
    animator=d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[10,num_epochs],legend=['content','style','tv'],ncols=2,figsize=[7,2.5])
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat,style_Y_hat=extract_features(X,content_layer,style_layer)
        contents_l, style_l, tv_l, l=compute_loss(X,contents_Y_hat,style_Y_hat,contents_Y,styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch+1)%10==0:
            animator.axes[1].imshow(postpocess(X))
            animator.add(epoch+1,[float(sum(contents_l)),float(sum(style_l)),float(sum(tv_l))])
    return X
device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
相关推荐
程序媛徐师姐13 小时前
Python基于深度学习的声音识别青少年防沉迷系统【附源码、文档说明】
python·深度学习·声音识别青少年防沉迷系统·python声音识别·python青少年防沉迷系统·python深度学习声音识别·青少年防沉迷系统
月光有害13 小时前
简单理解深度学习中的多种归一化方法
人工智能·深度学习
汤姆yu13 小时前
基于python大数据的天气可视化及预测系统
大数据·开发语言·python
huohuopro13 小时前
只能录入不能粘贴?这里解决
python
代码探秘者13 小时前
【算法篇】2.滑动窗口
java·数据结构·后端·python·算法·spring
进击的雷神13 小时前
展位号后缀清理、详情页JS数据提取、重试机制控制、地址字段重构——美国NPE展爬虫四大技术难关攻克纪实
javascript·爬虫·python·重构
这张生成的图像能检测吗13 小时前
(论文速读)ASFRMT:基于对抗的超特征重构元传递网络弱特征增强与谐波传动故障诊断
人工智能·深度学习·计算机视觉·故障诊断
oem11013 小时前
Django全栈开发入门:构建一个博客系统
jvm·数据库·python
社畜码农且逊13 小时前
安装虚拟环境工具virtualenv,通过virtualenv myenv创建虚拟环境,激活虚拟环境后安装项目依赖
python·virtualenv
前端摸鱼匠13 小时前
面试题7:Encoder-only、Decoder-only、Encoder-Decoder三种架构的差异与适用场景?
人工智能·深度学习·ai·面试·职场和发展·架构·transformer