记.backward()报错

最近我在模型训练损失里加入了LPIPS深度感知损失,训练的时候就出现了如上的报错,具体解释为:调用梯度反向传播loss.backward()时,我们计算梯度,需要一个标量的loss(即该loss张量的维度为1,只包含一个元素);而LPIPS的输出的loss为一个[4,1,1,1]的4维张量(batch_size,c,h,w),因此报错。

修正:

python 复制代码
def lpips_loss(img1, img2):
    # loss_fn_alex = lpips.LPIPS(net='alex')  # best forward scores
    loss_fn_vgg = lpips.LPIPS(net='vgg')  # closer to "traditional" perceptual loss, when used for optimization
    loss_fn_vgg.cuda()
    loss = loss_fn_vgg.forward(img1, img2)
    loss = torch.mean(loss)
    return loss

参考:

grad can be implicitly created only for scalar outputs-CSDN博客https://blog.csdn.net/qq_39208832/article/details/117415229
lpips · PyPIhttps://pypi.org/project/lpips/

相关推荐
蒙双眼看世界3 分钟前
AI应用实战:Excel表的操作工具
人工智能
jndingxin17 分钟前
OpenCV 图形API(64)图像结构分析和形状描述符------在图像中查找轮廓函数findContours()
人工智能·opencv
唯创电子19 分钟前
芯资讯|WTR096-16S录音语音芯片:重塑智能家居的情感连接与安全守护
人工智能·智能家居·语音识别·语音芯片·录音芯片
开发小能手-roy25 分钟前
使用PyTorch实现简单图像识别(基于MNIST手写数字数据集)的完整代码示例,包含数据加载、模型定义、训练和预测全流程
人工智能·pytorch·python
嗨,紫玉灵神熊38 分钟前
使用 OpenCV 实现图像中心旋转
图像处理·人工智能·opencv·计算机视觉
cmoaciopm43 分钟前
FastGPT部署的一些问题整理
人工智能
odoo中国1 小时前
机器学习实操 第一部分 机器学习基础 第6章 决策树
人工智能·决策树·机器学习
giszz1 小时前
DeepSeek提示词技巧
人工智能
AI技术学长1 小时前
训练神经网络的批量标准化(使用 PyTorch)
人工智能·pytorch·神经网络·数据科学·计算机技术·批量标准化
ccLianLian1 小时前
深度学习·经典模型·Transformer
人工智能·深度学习·transformer