[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
女神下凡18 分钟前
office系列软件 激活破解(office 2019, 2021, 2024)
人工智能·microsoft
2503_9317124823 分钟前
京东裸眼3D展示——30分钟建模绒感褶皱光泽都能还原
人工智能
星马梦缘24 分钟前
机器学习与模式识别 第八章 MAP与偏方差 考点压缩
人工智能·机器学习·map·岭回归·mle·双重下降
一楼的猫29 分钟前
AI写作合规技术方案:平台检测机制分析与规避策略
人工智能·学习·机器学习·ai写作
阿拉斯攀登1 小时前
Agent 核心架构:思考-行动-观察循环(ReAct)
人工智能·ai·agent·react
HyperAI超神经1 小时前
活动预告|智源/TileRT/腾讯/华为/智元创新同台,共探 AI 编译的多层级协同优化
人工智能·ai 编译器·腾讯·具身智能·矩阵乘法·算子优化·华为昇腾
在水一缸1 小时前
GLM 5.2 发布:当长上下文与智能体走向深度融合
人工智能·大模型·智能体·智谱ai·长上下文·glm-5.2
小妖同学学AI1 小时前
AI编程 AI Ping+Cline搭建自己的编程助手!
人工智能·ai编程
星马梦缘1 小时前
机器学习与模式识别 第十四章 神经网络中的反向传播 考点压缩
人工智能·机器学习·微分·反向传播