记.backward()报错

最近我在模型训练损失里加入了LPIPS深度感知损失,训练的时候就出现了如上的报错,具体解释为:调用梯度反向传播loss.backward()时,我们计算梯度,需要一个标量的loss(即该loss张量的维度为1,只包含一个元素);而LPIPS的输出的loss为一个[4,1,1,1]的4维张量(batch_size,c,h,w),因此报错。

修正:

python 复制代码
def lpips_loss(img1, img2):
    # loss_fn_alex = lpips.LPIPS(net='alex')  # best forward scores
    loss_fn_vgg = lpips.LPIPS(net='vgg')  # closer to "traditional" perceptual loss, when used for optimization
    loss_fn_vgg.cuda()
    loss = loss_fn_vgg.forward(img1, img2)
    loss = torch.mean(loss)
    return loss

参考:

grad can be implicitly created only for scalar outputs-CSDN博客https://blog.csdn.net/qq_39208832/article/details/117415229
lpips · PyPIhttps://pypi.org/project/lpips/

相关推荐
康康的AI博客几秒前
多模态大一统:从GPT-4突破到AI领域质的飞跃之路
人工智能·ai
咚咚王者4 分钟前
人工智能之核心基础 机器学习 第十九章 强化学习入门
人工智能·机器学习
flying_13145 分钟前
图神经网络分享系列-GGNN(GATED GRAPH SEQUENCE NEURAL NETWORKS)(一)
人工智能·深度学习·神经网络·图神经网络·ggnn·门控机制·图特征学习
Hcoco_me10 分钟前
大模型面试题89:GPU的内存结构是什么样的?
人工智能·算法·机器学习·chatgpt·机器人
sanggou18 分钟前
Spring Boot 中基于 WebClient 的 SSE 流式接口实战
java·人工智能
DREAM依旧21 分钟前
本地微调的Ollama模型部署到Dify平台上
人工智能·python
辰阳星宇22 分钟前
【工具调用】BFCL榜单数据分析
人工智能·数据挖掘·数据分析
小陈phd22 分钟前
langGraph从入门到精通(九)——基于LangGraph构建具备多工具调用与自动化摘要能力的智能 Agent
人工智能·python·langchain
wishchin23 分钟前
Jetson Orin Trt: No CMAKE_CUDA_COMPILER could be found
linux·运维·深度学习
Das128 分钟前
【机器学习】07_降维与度量学习
人工智能·学习·机器学习