风格迁移项目一:如何使用

前言

由于我不太会pr,所以直接新建的项目,

原项目地址:https://github.com/Optimistism/Style-transfer

原项目代码的讲解地址:https://www.bilibili.com/video/BV1yY4y1c7Cz/

本项目是对原项目的一点点完善。

项目地址:https://github.com/Knighthood2001/Style-transfer

更新

  • 2024-08-03 更新:这个项目,是不能保存模型参数的,也就是说,你只能通过训练,得出图片,而不能直接进行推理。

项目如何运行

  1. 下载项目
    这个比较简单,直接在github上下载zip包,然后解压即可。或者使用git clone命令克隆项目。
  2. 安装依赖
  • 这里由于我没有用到虚拟环境,无法准确得出需要用的包,大致就是torch和我requirements.txt中的包。
  • 安装命令
    安装torch的话,会比较慢
shell 复制代码
pip install -r requirements.txt
  • 这里需要注意,需要安装tensorboard,作者给出的并没有这个,但是如果没有这个,你是查看不了运行结果的。
  1. 运行项目
  • 首先就是运行代码
  • 然后的话,它会去https://download.pytorch.org/models/vgg19-dcbb9e9d.pth 去下载vgg19的预训练权重,这个文件大概548mb。
    由于我这里不知道pytorch如何加载这个文件,因此,大家可以自己下载这个文件,或者将我这个文件,放到你代码运行后,它给出的那个路径下。可以看下面我更改部分第一点的第二张图片,运行代码会给出最终下载的路径的。
  • 然后想要看到tensorboard中的数据,就在终端输入
shell 复制代码
tensorboard --logdir=runs

然后点击浏览器,输入localhost:6006
点击这里也可以

后续你跑自己的图片,只需要在main.py中修改content_imgstyle_img即可。

更改

  1. 更新了vgg19传参
    最开始会报这个警告

使用这行代码后

shell 复制代码
# self.vgg = models.vgg19(pretrained=True).features  # .features用于提取卷积层
self.vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features

这行代码,会去自动下载官网的vgg19预训练权重,548mb,下载还是很慢的。我下载这个花了30分钟。

我这里由于文件太大,无法上传到github,因此,大家可以自己下载这个文件。

不会警告了

  1. 添加了content_weight参数,更加符合其公式要求
shell 复制代码
loss = content_weight * content_loss + style_weight * style_loss
  1. 运行代码后提示报错
shell 复制代码
RuntimeError: Input type (torch.FloatTensor) and weight type
 (torch.cuda.FloatTensor) should be the same or input should 
 be a MKLDNN tensor and weight is a dense tensor

当模型的权重(weights)和输入数据(inputs)不在同一个设备上时。

这个命令报错表示你的模型权重被移动到了CUDA设备上(即GPU),但是你的输入数据还在CPU上。

因此,具体做法是,将输入数据也移动到device中,它将根据你代码中的变换,自动选择设备。

shell 复制代码
# 图片输入到gpu,否则就会报错
content_img = content_img.to(device)
style_img = style_img.to(device)
  1. 添加了日志记录

main.py的最开始,添加了一些代码,主要就是设置了日志,并且以时间命名,方便后续将结果保存到日志中,方便查看。免得你自己去创建txt进行cv。

可以看到,这里也会有时间,你可以通过看日志,计算训练的时间。

实际体验

我本机是1650的显卡,跑3000轮,花了20分钟。

原图

结果图

梵高风的狗头自拍图

2999轮的结果

6000轮的结果

效果其实还不错

相关推荐
埃菲尔铁塔_CV算法10 分钟前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR11 分钟前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️17 分钟前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
好喜欢吃红柚子34 分钟前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python38 分钟前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯1 小时前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠1 小时前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost
Debroon1 小时前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
羊小猪~~1 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
AI小杨1 小时前
【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程
人工智能·opencv·计算机视觉·霍夫变换·车道线检测