记.backward()报错

最近我在模型训练损失里加入了LPIPS深度感知损失,训练的时候就出现了如上的报错,具体解释为:调用梯度反向传播loss.backward()时,我们计算梯度,需要一个标量的loss(即该loss张量的维度为1,只包含一个元素);而LPIPS的输出的loss为一个[4,1,1,1]的4维张量(batch_size,c,h,w),因此报错。

修正:

python 复制代码
def lpips_loss(img1, img2):
    # loss_fn_alex = lpips.LPIPS(net='alex')  # best forward scores
    loss_fn_vgg = lpips.LPIPS(net='vgg')  # closer to "traditional" perceptual loss, when used for optimization
    loss_fn_vgg.cuda()
    loss = loss_fn_vgg.forward(img1, img2)
    loss = torch.mean(loss)
    return loss

参考:

grad can be implicitly created only for scalar outputs-CSDN博客https://blog.csdn.net/qq_39208832/article/details/117415229
lpips · PyPIhttps://pypi.org/project/lpips/

相关推荐
andafaAPS15 小时前
安达发|aps软件系统:塑料薄膜业数字化升级,破生产管理难题
人工智能·aps生产排程·安达发aps·计划排产软件·自动排单软件·aps软件系统
前端若水15 小时前
【无标题】
java·人工智能·python·机器学习
数字供应链安全产品选型15 小时前
数字供应链安全治理体系研究:从软件供应链到AI原生安全的演进与实践
人工智能·安全·ai-native
iDao技术魔方15 小时前
GEO 生成式引擎优化完全指南:让你的内容成为 AI 的默认答案
人工智能
HIT_Weston15 小时前
87、【Agent】【OpenCode】read 工具提示词
人工智能·agent·opencode
墨北小七15 小时前
使用火山引擎 HiAgent 构建工业级设备智能运维智能体
运维·人工智能·火山引擎
晚霞的不甘16 小时前
CANN-ATB加速库:Transformer推理性能密码
人工智能·深度学习·transformer
创世宇图16 小时前
【AI入门知识点】Function Calling 是什么?为什么 AI 开始会“调用工具”了?
人工智能·ai·llm·functioncalling
微软技术栈16 小时前
Microsoft AI Genius 4.0 | 使用 GitHub Copilot SDK 升级开发者体验
人工智能·microsoft·github