简单的图像神经风格迁移(Pytorch)

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import copy
from torchvision.models import vgg19, VGG19_Weights

Checking for GPU Availaibility

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
device
device(type='cuda')
content_image_directory = 'content_image.jpg'
style_image_directory = 'style_image.jpg'

Initial DataPreprocessing and basic functions

image_size = 512 if torch.cuda.is_available() else 128
transform = transforms.Compose([
                  transforms.Resize(image_size),
                  transforms.ToTensor()])
def image_loader(image_directory):
  image = Image.open(image_directory)
  image = transform(image).unsqueeze(0)
  return image.to(device, torch.float)
content_img = image_loader(content_image_directory)
style_img = image_loader(style_image_directory)
assert content_img.size() == style_img.size()
reform = transforms.ToPILImage()


def image_unloader(tensor_input,title=None):
  image = tensor_input.cpu().clone()
  image = image.squeeze(0)
  image = reform(image)
  plt.imshow(image)
  plt.title(title)
  plt.show()
  plt.close()
plt.figure()
image_unloader(content_img, title = 'Content Image')
plt.figure()
image_unloader(style_img, title = 'Style Image')

Loss Functions for Custom Model

class Content_Loss(nn.Module):
  def __init__(self, target):
    super(Content_Loss, self).__init__()
    self.target = target.detach()
  def forward(self, input):
    self.loss = F.mse_loss(input, self.target)
    return input
def gram_matrix(matrix):
  a,b,c,d = matrix.size()
  features = matrix.view(a*b, c*d)
  G = torch.mm(features, features.t())
  return G.div(a*b*c*d)
class Style_Loss(nn.Module):
  def __init__(self, target):
    super(Style_Loss,self).__init__()
    self.target = gram_matrix(target).detach()
  def forward(self, input):
    G_input = gram_matrix(input)
    self.loss = F.mse_loss(G_input, self.target)
    return input

Pretrained VGG19

vgg = vgg19(weights=VGG19_Weights.DEFAULT).features.eval()
vgg
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (20): ReLU(inplace=True)
  (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (22): ReLU(inplace=True)
  (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (24): ReLU(inplace=True)
  (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (26): ReLU(inplace=True)
  (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (29): ReLU(inplace=True)
  (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (31): ReLU(inplace=True)
  (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (33): ReLU(inplace=True)
  (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (35): ReLU(inplace=True)
  (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)vgg_normalization_mean = torch.tensor([0.485, 0.456, 0.406])

vgg_normalization_std = torch.tensor([0.229, 0.224, 0.225])
class Normalization(nn.Module):

  def __init__(self, mean, std):

    super(Normalization, self).__init__()

    self.mean = torch.tensor(mean).view(-1, 1, 1)

    self.std = torch.tensor(std).view(-1, 1, 1)




  def forward(self, img):

    return (img-self.mean)/self.stdContent Layers and Style Layers
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
for i in vgg.children():
  print(i)
  break
Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

Function for getting style and content losses

def style_and_content_losses(pre_model = vgg, normalization_mean=vgg_normalization_mean,
                             normalization_std=vgg_normalization_std,
                                content_layers=content_layers, style_layers=style_layers,
                                content_img=content_img, style_img=style_img):
  normalization = Normalization(vgg_normalization_mean, vgg_normalization_std)


  content_losses = []
  style_losses = []


  model = nn.Sequential(normalization)


  i = 0
  for layer in vgg.children():
    if isinstance(layer, nn.Conv2d):
      i+=1
      name = 'conv_{}'.format(i)
    elif isinstance(layer, nn.ReLU):
      name = 'ReLU_{}'.format(i)
      layer = nn.ReLU(inplace=False)
    elif isinstance(layer, nn.MaxPool2d):
      name = 'pool_{}'.format(i)
    elif isinstance(layer, nn.BatchNorm2d):
      name = 'batch_norm_{}'.format(i)
    else:
      raise RuntimeError('Unrecongnized_Layer:{}'.format(layer.__class__.__name__))
    model.add_module(name,layer)


    if name in content_layers:
      target = model(content_img).detach()
      content_loss = Content_Loss(target)
      model.add_module('content_loss_{}'.format(i), content_loss)
      content_losses.append(content_loss)
    if name in style_layers:
      target_feature = model(style_img).detach()
      style_loss = Style_Loss(target=target_feature)
      model.add_module('style_loss_{}'.format(i), style_loss)
      style_losses.append(style_loss)


  for i in range(len(model)-1, -1, -1):
    if isinstance(model[i], Content_Loss) or isinstance(model[i], Style_Loss):
      break


  model = model[:(i+1)]
  return model, style_losses, content_losses

Input Image same as content Image

input_img = content_img.clone()
plt.figure()
image_unloader(input_img, title = 'Input_Image')
def get_optimizer(input_img):
  optimizer = torch.optim.LBFGS([input_img])
  return optimizer

Final Function for Neural Transfer

def style_transfer(pre_model=vgg, normalization_mean=vgg_normalization_mean,
                   normalization_std=vgg_normalization_std,
                   content_img=content_img, style_img=style_img,
                   input_img = input_img, num_epochs=400,
                   style_weight=1000000, content_weight=1):
  print('Building the style transfer model')
  model, style_losses, content_losses = style_and_content_losses()
  input_img.requires_grad_(True)
  model.eval()
  model.requires_grad_(False)
  optimizer= get_optimizer(input_img)


  print('Optimizing...')
  run = [0]
  while run[0]<=num_epochs:
    def closure():
      with torch.no_grad():
        input_img.clamp_(0,1)
      optimizer.zero_grad()
      model(input_img)
      style_score=0
      content_score = 0
      for s_loss in style_losses:
        style_score += s_loss.loss
      for c_loss in content_losses:
        content_score += c_loss.loss


      style_score *= style_weight
      content_score *= content_weight


      loss = style_score + content_score
      loss.backward()


      run[0] += 1
      if run[0] % 50 == 0:
        print('run {}:'.format(run))
        print('Style_loss : {:4f} & Content_loss : {:4f}'.format(
            style_score.item(), content_score.item()
        ))
        print()
      return style_score + content_score
    optimizer.step(closure)


  with torch.no_grad():
        input_img.clamp_(0, 1)


  return input_img

Getting Output

output = style_transfer()
Optimizing...
run [50]:
Style_loss : 4.008237 & Content_loss : 4.146012

run [100]:
Style_loss : 1.135739 & Content_loss : 3.031179

run [150]:
Style_loss : 0.714863 & Content_loss : 2.651623

run [200]:
Style_loss : 0.481216 & Content_loss : 2.490812

run [250]:
Style_loss : 0.349222 & Content_loss : 2.403970

run [300]:
Style_loss : 0.264658 & Content_loss : 2.349341

run [350]:
Style_loss : 0.214068 & Content_loss : 2.314269

run [400]:
Style_loss : 0.184173 & Content_loss : 2.288790

The Output

plt.figure()
image_unloader(output, title='Output Image')


plt.ioff()
plt.show()

Initialization with a white noise image

input_img2 = torch.randn(content_img.data.size())
plt.figure()
image_unloader(input_img2, title = 'white noise image')
output2 = style_transfer(input_img=input_img2,num_epochs=700)
Building the style transfer model
Optimizing...
run [50]:
Style_loss : 83.288612 & Content_loss : 11.821552

run [100]:
Style_loss : 23.298397 & Content_loss : 9.524027

run [150]:
Style_loss : 5.048073 & Content_loss : 7.467306

run [200]:
Style_loss : 1.898000 & Content_loss : 5.806565

run [250]:
Style_loss : 1.291034 & Content_loss : 4.793440

run [300]:
Style_loss : 0.958348 & Content_loss : 4.172914

run [350]:
Style_loss : 0.734883 & Content_loss : 3.720185

run [400]:
Style_loss : 0.569834 & Content_loss : 3.423393

run [450]:
Style_loss : 0.448026 & Content_loss : 3.190432

run [500]:
Style_loss : 0.360451 & Content_loss : 3.018604

run [550]:
Style_loss : 0.303826 & Content_loss : 2.879138

run [600]:
Style_loss : 0.262109 & Content_loss : 2.767931

run [650]:
Style_loss : 0.232545 & Content_loss : 2.680881

run [700]:
Style_loss : 0.209969 & Content_loss : 2.609402
plt.figure()
image_unloader(output2, title='Output2 Image')
知乎学术咨询:https://www.zhihu.com/consult/people/792359672131756032?isMe=1

担任《Mechanical System and Signal Processing》等审稿专家,擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

相关推荐
m0_609000427 分钟前
向日葵好用吗?4款稳定的远程控制软件推荐。
运维·服务器·网络·人工智能·远程工作
开MINI的工科男1 小时前
深蓝学院-- 量产自动驾驶中的规划控制算法 小鹏
人工智能·机器学习·自动驾驶
waterHBO1 小时前
python 爬虫 selenium 笔记
爬虫·python·selenium
limingade2 小时前
手机实时提取SIM卡打电话的信令和声音-新的篇章(一、可行的方案探讨)
物联网·算法·智能手机·数据分析·信息与通信
编程零零七2 小时前
Python数据分析工具(三):pymssql的用法
开发语言·前端·数据库·python·oracle·数据分析·pymssql
AI大模型知识分享2 小时前
Prompt最佳实践|如何用参考文本让ChatGPT答案更精准?
人工智能·深度学习·机器学习·chatgpt·prompt·gpt-3
张人玉4 小时前
人工智能——猴子摘香蕉问题
人工智能
草莓屁屁我不吃4 小时前
Siri因ChatGPT-4o升级:我们的个人信息还安全吗?
人工智能·安全·chatgpt·chatgpt-4o
AIAdvocate4 小时前
Pandas_数据结构详解
数据结构·python·pandas
小言从不摸鱼4 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt