记.backward()报错

最近我在模型训练损失里加入了LPIPS深度感知损失,训练的时候就出现了如上的报错,具体解释为:调用梯度反向传播loss.backward()时,我们计算梯度,需要一个标量的loss(即该loss张量的维度为1,只包含一个元素);而LPIPS的输出的loss为一个[4,1,1,1]的4维张量(batch_size,c,h,w),因此报错。

修正:

python 复制代码
def lpips_loss(img1, img2):
    # loss_fn_alex = lpips.LPIPS(net='alex')  # best forward scores
    loss_fn_vgg = lpips.LPIPS(net='vgg')  # closer to "traditional" perceptual loss, when used for optimization
    loss_fn_vgg.cuda()
    loss = loss_fn_vgg.forward(img1, img2)
    loss = torch.mean(loss)
    return loss

参考:

grad can be implicitly created only for scalar outputs-CSDN博客https://blog.csdn.net/qq_39208832/article/details/117415229
lpips · PyPIhttps://pypi.org/project/lpips/

相关推荐
Lucifer__hell5 分钟前
【测试】Axure原型的AI测试用例生成方案
人工智能·测试用例·axure
跨境卫士苏苏10 分钟前
清关链路更透明以后跨境卖家如何减少资料反复修改
大数据·人工智能·安全·跨境电商·亚马逊
easy_coder12 分钟前
ReAct 进入死循环?用 Harness 把它拉回来
人工智能·架构·云计算
我是无敌小恐龙22 分钟前
Java SE 零基础入门Day06 方法重载+Debug调试+String字符串全套API详解(超全干货)
java·开发语言·人工智能·python·transformer·无人机·量子计算
aidesignplus23 分钟前
从平方到线性:Mamba如何挑战Transformer的长序列效率瓶颈?
人工智能·python·深度学习·vim·transformer
三维频道25 分钟前
工业级三维扫描实测:汽车灯具复杂结构件的全尺寸 3D 测量方案分析
java·人工智能·python·数码相机·3d·汽车·汽车轻量化制造
人工智能AI技术25 分钟前
过拟合与欠拟合:机器学习最基础核心问题
人工智能
码农飞哥31 分钟前
从Java后端到AI应用开发,我这两年做了什么
java·开发语言·人工智能
大龄码农-涵哥33 分钟前
Spring Boot项目集成AI对话:使用Spring AI打造智能客服
人工智能·spring boot·spring
Jmayday42 分钟前
Pytorch:神经网络基础
人工智能·pytorch·神经网络