《动手学深度学习》-49Style_Transfer实现

复制代码
风格迁移:将图A的图片套上图B的风格,有点类似于滤镜。

代码关键

1.找到一个预训练大模型,参数不更新

2.选内容层和风格层,内容层数值比较大,风格层从低到高都要有

3.内容图经过内容层,得到目标内容,训练时让预训练的图片接近目标内容;风格图经过风格层,得到目标风格,然后再用gram得到最终目标风格,训练时让预训练的图片接近最终目标风格

4.让内容图作为与训练图,更新内容图的参数

5.根据内容、目标以及降噪损失三个loss,加权求和得到l,并对l做梯度,更新图像像素值。

复制代码
import torch
import torchvision
from torch import nn
from PIL import Image
from torchvision import transforms
import d2l
#读取图片
d2l.set_figsize()
content_img=Image.Open('')
d2l.plt.imshow(content_img)
style_img=Image.Open('')
d2l.plt.imshow(style_img)
#预处理和后处理,rgb图-tensor转化
rgb_mean=torch.tensor([0.485, 0.456, 0.406])
rgb_std=torch.tensor([0.229, 0.224, 0.225])
def prepocess(img,img_shape):
    transform = transforms.Compose([transforms.Resize(img_shape),
                                    transforms.ToTensor(),
                                    transforms.Normalize(rgb_mean,rgb_std)])
    return transform(img).unsqueeze(0)
def postpocess(img):
    img=img[0].to(rgb_std.device)
    img=torch.clamp(img.permute(1,2,0)*rgb_std+rgb_mean,0,1)
    return torchvision.transforms.ToPILImage()(img.permute(2,0,1))
#模型
#vgg对于抽取特征比较好
pretrained_net=torch.models.vgg19(pretrained=True)
style_layer,content_layer=[0,5,10,19,28],[25]
net=nn.Sequential(*[pretrained_net[i] for i in range(max(content_layer+style_layer+1))])
def extract_features(X,content_layer,style_layer):
    contents=[]
    styles=[]
    for i in range(len(net)):
        X=net[i](X)
        if i in style_layer:
            styles.append(X)
        elif i in content_layer:
            contents.append(X)
    return contents,styles
def get_contents(imge_shape,device):
    content_X=prepocess(content_img,imge_shape).to(device)
    contents_Y,_=extract_features(content_img,content_layer,style_layer)
    return content_X,contents_Y
def get_styles(img_shape,device):
    style_X=prepocess(style_img,img_shape).to(device)
    _,styles_Y=extract_features(style_img,content_layer,style_layer)
    return style_X,styles_Y
#定义损失
def content_loss(Y_hat,Y):
    return torch.square(Y_hat-Y.detach()).mean()
def gram(X):
    num_channels,n=X.shape[1],X.numel()//X.shape[1]
    X=X.reshape((num_channels,n))
    return torch.matmul(X.T,X)/(num_channels*n)
def style_loss(Y_hat,gram_Y):
    return torch.square(gram(Y_hat)-gram_Y.detach()).mean()
def tv_loss(Y_hat):#降噪损失
    return 0.5*(torch.abs(Y_hat[:,:,1:,:]-Y_hat[:,:,:-1,:]).mean()+torch.abs(Y_hat[:,:,:,1:]-Y_hat[:,:,:,:-1]).mean())
content_weight,style_weight,tv_weight=1,1e4,10
def compute_loss(X,contents_Y_hat,style_Y_hat,contents_Y,style_Y):
    """

    :rtype: tuple[list[Tensor], list[Tensor], Tensor, Union[int, Tensor, Literal[0], list[Tensor]]]
    """
    contents_l=[content_loss(Y_hat,Y)*content_weight for Y_hat,Y in zip(contents_Y_hat,contents_Y)]
    style_l=[style_loss(Y_hat,Y)*style_weight for Y_hat,Y in zip(style_Y_hat,style_Y)]
    tv_l=tv_loss(X)*tv_weight
    l=sum(contents_l,style_l,[tv_l])
    return contents_l,style_l,tv_l,l
# 初始化
class SynthesizedImage(nn.Module):
    def __init__(self,img_shape,**kwargs):
        super(SynthesizedImage,self).__init__(**kwargs)
        self.weight=nn.parameter(torch.random(*img_shape.shape))
    def forward(self):#模型没有输入,输出为可训练的图片参数
        return self.weight
def get_inits(X,device,lr,styles_Y):
    gen_img=SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer=torch.optim.Adam(gen_img.parameters(),lr=lr)
    styles_Y_gram=[gram(Y) for Y in styles_Y]#提前算出,省计算
    return gen_img,styles_Y_gram,trainer
def train(X,contents_Y,styles_Y,device,lr,num_epochs,lr_decay_epoch):
    X,styles_Y_gram,trainer=get_inits(X,device,lr,styles_Y)
    scheduler=torch.optim.lr_scheduler.StepLR(trainer,lr_decay_epoch,0.8)
    animator=d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[10,num_epochs],legend=['content','style','tv'],ncols=2,figsize=[7,2.5])
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat,style_Y_hat=extract_features(X,content_layer,style_layer)
        contents_l, style_l, tv_l, l=compute_loss(X,contents_Y_hat,style_Y_hat,contents_Y,styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch+1)%10==0:
            animator.axes[1].imshow(postpocess(X))
            animator.add(epoch+1,[float(sum(contents_l)),float(sum(style_l)),float(sum(tv_l))])
    return X
device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
相关推荐
闻缺陷则喜何志丹1 小时前
【超音速专利 CN117058421A】基于多头模型的图像检测关键点方法、系统、平台及介质
人工智能·深度学习·计算机视觉·关键点·专栏·超音速
Hcoco_me2 小时前
大模型面试题91:合并访存是什么?原理是什么?
人工智能·深度学习·算法·机器学习·vllm
充值修改昵称3 小时前
数据结构基础:B树磁盘IO优化的数据结构艺术
数据结构·b树·python·算法
C系语言4 小时前
python用pip生成requirements.txt
开发语言·python·pip
william_djj4 小时前
python3.8 提取xlsx表格内容填入单个文件
windows·python·xlsx
kszlgy8 小时前
Day 52 神经网络调参指南
python
wrj的博客10 小时前
python环境安装
python·学习·环境配置
Pyeako10 小时前
深度学习--BP神经网络&梯度下降&损失函数
人工智能·python·深度学习·bp神经网络·损失函数·梯度下降·正则化惩罚
哥布林学者11 小时前
吴恩达深度学习课程五:自然语言处理 第二周:词嵌入(四)分层 softmax 和负采样
深度学习·ai