[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
百胜软件@百胜软件28 分钟前
黄飞对话阿里云AI专家:当零售中台拥有AI大脑,未来将去向何方?
人工智能·阿里云·零售
数科云8 小时前
AI提示词(Prompt)入门:什么是Prompt?为什么要写好Prompt?
人工智能·aigc·ai写作·ai工具集·最新ai资讯
Devlive 开源社区8 小时前
技术日报|Claude Code超级能力库superpowers登顶日增1538星,自主AI循环ralph爆火登榜第二
人工智能
软件供应链安全指南8 小时前
灵脉 IAST 5.4 升级:双轮驱动 AI 漏洞治理与业务逻辑漏洞精准检测
人工智能·安全
lanmengyiyu8 小时前
单塔和双塔的区别和共同点
人工智能·双塔模型·网络结构·单塔模型
微光闪现8 小时前
AI识别宠物焦虑、紧张和晕车行为,是否已经具备实际可行性?
大数据·人工智能·宠物
技术小黑屋_9 小时前
用好Few-shot Prompting,AI 准确率提升100%
人工智能
中草药z9 小时前
【嵌入模型】概念、应用与两大 AI 开源社区(Hugging Face / 魔塔)
人工智能·算法·机器学习·数据集·向量·嵌入模型
知乎的哥廷根数学学派9 小时前
基于数据驱动的自适应正交小波基优化算法(Python)
开发语言·网络·人工智能·pytorch·python·深度学习·算法
DisonTangor9 小时前
GLM-Image:面向密集知识与高保真图像生成的自回归模型
人工智能·ai作画·数据挖掘·回归·aigc