《动手学深度学习》-49Style_Transfer实现

复制代码
风格迁移:将图A的图片套上图B的风格,有点类似于滤镜。

代码关键

1.找到一个预训练大模型,参数不更新

2.选内容层和风格层,内容层数值比较大,风格层从低到高都要有

3.内容图经过内容层,得到目标内容,训练时让预训练的图片接近目标内容;风格图经过风格层,得到目标风格,然后再用gram得到最终目标风格,训练时让预训练的图片接近最终目标风格

4.让内容图作为与训练图,更新内容图的参数

5.根据内容、目标以及降噪损失三个loss,加权求和得到l,并对l做梯度,更新图像像素值。

复制代码
import torch
import torchvision
from torch import nn
from PIL import Image
from torchvision import transforms
import d2l
#读取图片
d2l.set_figsize()
content_img=Image.Open('')
d2l.plt.imshow(content_img)
style_img=Image.Open('')
d2l.plt.imshow(style_img)
#预处理和后处理,rgb图-tensor转化
rgb_mean=torch.tensor([0.485, 0.456, 0.406])
rgb_std=torch.tensor([0.229, 0.224, 0.225])
def prepocess(img,img_shape):
    transform = transforms.Compose([transforms.Resize(img_shape),
                                    transforms.ToTensor(),
                                    transforms.Normalize(rgb_mean,rgb_std)])
    return transform(img).unsqueeze(0)
def postpocess(img):
    img=img[0].to(rgb_std.device)
    img=torch.clamp(img.permute(1,2,0)*rgb_std+rgb_mean,0,1)
    return torchvision.transforms.ToPILImage()(img.permute(2,0,1))
#模型
#vgg对于抽取特征比较好
pretrained_net=torch.models.vgg19(pretrained=True)
style_layer,content_layer=[0,5,10,19,28],[25]
net=nn.Sequential(*[pretrained_net[i] for i in range(max(content_layer+style_layer+1))])
def extract_features(X,content_layer,style_layer):
    contents=[]
    styles=[]
    for i in range(len(net)):
        X=net[i](X)
        if i in style_layer:
            styles.append(X)
        elif i in content_layer:
            contents.append(X)
    return contents,styles
def get_contents(imge_shape,device):
    content_X=prepocess(content_img,imge_shape).to(device)
    contents_Y,_=extract_features(content_img,content_layer,style_layer)
    return content_X,contents_Y
def get_styles(img_shape,device):
    style_X=prepocess(style_img,img_shape).to(device)
    _,styles_Y=extract_features(style_img,content_layer,style_layer)
    return style_X,styles_Y
#定义损失
def content_loss(Y_hat,Y):
    return torch.square(Y_hat-Y.detach()).mean()
def gram(X):
    num_channels,n=X.shape[1],X.numel()//X.shape[1]
    X=X.reshape((num_channels,n))
    return torch.matmul(X.T,X)/(num_channels*n)
def style_loss(Y_hat,gram_Y):
    return torch.square(gram(Y_hat)-gram_Y.detach()).mean()
def tv_loss(Y_hat):#降噪损失
    return 0.5*(torch.abs(Y_hat[:,:,1:,:]-Y_hat[:,:,:-1,:]).mean()+torch.abs(Y_hat[:,:,:,1:]-Y_hat[:,:,:,:-1]).mean())
content_weight,style_weight,tv_weight=1,1e4,10
def compute_loss(X,contents_Y_hat,style_Y_hat,contents_Y,style_Y):
    """

    :rtype: tuple[list[Tensor], list[Tensor], Tensor, Union[int, Tensor, Literal[0], list[Tensor]]]
    """
    contents_l=[content_loss(Y_hat,Y)*content_weight for Y_hat,Y in zip(contents_Y_hat,contents_Y)]
    style_l=[style_loss(Y_hat,Y)*style_weight for Y_hat,Y in zip(style_Y_hat,style_Y)]
    tv_l=tv_loss(X)*tv_weight
    l=sum(contents_l,style_l,[tv_l])
    return contents_l,style_l,tv_l,l
# 初始化
class SynthesizedImage(nn.Module):
    def __init__(self,img_shape,**kwargs):
        super(SynthesizedImage,self).__init__(**kwargs)
        self.weight=nn.parameter(torch.random(*img_shape.shape))
    def forward(self):#模型没有输入,输出为可训练的图片参数
        return self.weight
def get_inits(X,device,lr,styles_Y):
    gen_img=SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer=torch.optim.Adam(gen_img.parameters(),lr=lr)
    styles_Y_gram=[gram(Y) for Y in styles_Y]#提前算出,省计算
    return gen_img,styles_Y_gram,trainer
def train(X,contents_Y,styles_Y,device,lr,num_epochs,lr_decay_epoch):
    X,styles_Y_gram,trainer=get_inits(X,device,lr,styles_Y)
    scheduler=torch.optim.lr_scheduler.StepLR(trainer,lr_decay_epoch,0.8)
    animator=d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[10,num_epochs],legend=['content','style','tv'],ncols=2,figsize=[7,2.5])
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat,style_Y_hat=extract_features(X,content_layer,style_layer)
        contents_l, style_l, tv_l, l=compute_loss(X,contents_Y_hat,style_Y_hat,contents_Y,styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch+1)%10==0:
            animator.axes[1].imshow(postpocess(X))
            animator.add(epoch+1,[float(sum(contents_l)),float(sum(style_l)),float(sum(tv_l))])
    return X
device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
相关推荐
大江东去浪淘尽千古风流人物10 分钟前
【VLN】VLN(Vision-and-Language Navigation视觉语言导航)算法本质,范式难点及解决方向(1)
人工智能·python·算法
Swift社区12 分钟前
Gunicorn 与 Uvicorn 部署 Python 后端详解
开发语言·python·gunicorn
饭饭大王66614 分钟前
CANN 生态中的轻量化部署利器:`lite-inference` 项目实战解析
深度学习
Coinsheep16 分钟前
SSTI-flask靶场搭建及通关
python·flask·ssti
IT实战课堂小元酱16 分钟前
大数据深度学习|计算机毕设项目|计算机毕设答辩|flask露天矿爆破效果分析系统开发及应用
人工智能·python·flask
码农阿豪17 分钟前
Flask应用上下文问题解析与解决方案:从错误日志到完美修复
后端·python·flask
wqq631085520 分钟前
Python基于Vue的实验室管理系统 django flask pycharm
vue.js·python·django
Q_Q196328847521 分钟前
python大学生爱心校园互助代购网站_nyvlx_django Flask vue pycharm项目
python·django·flask
码农阿豪24 分钟前
Python Flask应用中文件处理与异常处理的实践指南
开发语言·python·flask
xcLeigh25 分钟前
Python 项目实战:用 Flask 实现 MySQL 数据库增删改查 API
数据库·python·mysql·flask·教程·python3