风格迁移:将图A的图片套上图B的风格,有点类似于滤镜。
代码关键
1.找到一个预训练大模型,参数不更新
2.选内容层和风格层,内容层数值比较大,风格层从低到高都要有
3.内容图经过内容层,得到目标内容,训练时让预训练的图片接近目标内容;风格图经过风格层,得到目标风格,然后再用gram得到最终目标风格,训练时让预训练的图片接近最终目标风格
4.让内容图作为与训练图,更新内容图的参数
5.根据内容、目标以及降噪损失三个loss,加权求和得到l,并对l做梯度,更新图像像素值。
import torch
import torchvision
from torch import nn
from PIL import Image
from torchvision import transforms
import d2l
#读取图片
d2l.set_figsize()
content_img=Image.Open('')
d2l.plt.imshow(content_img)
style_img=Image.Open('')
d2l.plt.imshow(style_img)
#预处理和后处理,rgb图-tensor转化
rgb_mean=torch.tensor([0.485, 0.456, 0.406])
rgb_std=torch.tensor([0.229, 0.224, 0.225])
def prepocess(img,img_shape):
transform = transforms.Compose([transforms.Resize(img_shape),
transforms.ToTensor(),
transforms.Normalize(rgb_mean,rgb_std)])
return transform(img).unsqueeze(0)
def postpocess(img):
img=img[0].to(rgb_std.device)
img=torch.clamp(img.permute(1,2,0)*rgb_std+rgb_mean,0,1)
return torchvision.transforms.ToPILImage()(img.permute(2,0,1))
#模型
#vgg对于抽取特征比较好
pretrained_net=torch.models.vgg19(pretrained=True)
style_layer,content_layer=[0,5,10,19,28],[25]
net=nn.Sequential(*[pretrained_net[i] for i in range(max(content_layer+style_layer+1))])
def extract_features(X,content_layer,style_layer):
contents=[]
styles=[]
for i in range(len(net)):
X=net[i](X)
if i in style_layer:
styles.append(X)
elif i in content_layer:
contents.append(X)
return contents,styles
def get_contents(imge_shape,device):
content_X=prepocess(content_img,imge_shape).to(device)
contents_Y,_=extract_features(content_img,content_layer,style_layer)
return content_X,contents_Y
def get_styles(img_shape,device):
style_X=prepocess(style_img,img_shape).to(device)
_,styles_Y=extract_features(style_img,content_layer,style_layer)
return style_X,styles_Y
#定义损失
def content_loss(Y_hat,Y):
return torch.square(Y_hat-Y.detach()).mean()
def gram(X):
num_channels,n=X.shape[1],X.numel()//X.shape[1]
X=X.reshape((num_channels,n))
return torch.matmul(X.T,X)/(num_channels*n)
def style_loss(Y_hat,gram_Y):
return torch.square(gram(Y_hat)-gram_Y.detach()).mean()
def tv_loss(Y_hat):#降噪损失
return 0.5*(torch.abs(Y_hat[:,:,1:,:]-Y_hat[:,:,:-1,:]).mean()+torch.abs(Y_hat[:,:,:,1:]-Y_hat[:,:,:,:-1]).mean())
content_weight,style_weight,tv_weight=1,1e4,10
def compute_loss(X,contents_Y_hat,style_Y_hat,contents_Y,style_Y):
"""
:rtype: tuple[list[Tensor], list[Tensor], Tensor, Union[int, Tensor, Literal[0], list[Tensor]]]
"""
contents_l=[content_loss(Y_hat,Y)*content_weight for Y_hat,Y in zip(contents_Y_hat,contents_Y)]
style_l=[style_loss(Y_hat,Y)*style_weight for Y_hat,Y in zip(style_Y_hat,style_Y)]
tv_l=tv_loss(X)*tv_weight
l=sum(contents_l,style_l,[tv_l])
return contents_l,style_l,tv_l,l
# 初始化
class SynthesizedImage(nn.Module):
def __init__(self,img_shape,**kwargs):
super(SynthesizedImage,self).__init__(**kwargs)
self.weight=nn.parameter(torch.random(*img_shape.shape))
def forward(self):#模型没有输入,输出为可训练的图片参数
return self.weight
def get_inits(X,device,lr,styles_Y):
gen_img=SynthesizedImage(X.shape).to(device)
gen_img.weight.data.copy_(X.data)
trainer=torch.optim.Adam(gen_img.parameters(),lr=lr)
styles_Y_gram=[gram(Y) for Y in styles_Y]#提前算出,省计算
return gen_img,styles_Y_gram,trainer
def train(X,contents_Y,styles_Y,device,lr,num_epochs,lr_decay_epoch):
X,styles_Y_gram,trainer=get_inits(X,device,lr,styles_Y)
scheduler=torch.optim.lr_scheduler.StepLR(trainer,lr_decay_epoch,0.8)
animator=d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[10,num_epochs],legend=['content','style','tv'],ncols=2,figsize=[7,2.5])
for epoch in range(num_epochs):
trainer.zero_grad()
contents_Y_hat,style_Y_hat=extract_features(X,content_layer,style_layer)
contents_l, style_l, tv_l, l=compute_loss(X,contents_Y_hat,style_Y_hat,contents_Y,styles_Y_gram)
l.backward()
trainer.step()
scheduler.step()
if (epoch+1)%10==0:
animator.axes[1].imshow(postpocess(X))
animator.add(epoch+1,[float(sum(contents_l)),float(sum(style_l)),float(sum(tv_l))])
return X
device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)