[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
AndrewHZ几秒前
【AI黑话日日新】什么是访存bound?
人工智能·语言模型·大模型·cpu·访存·计算逻辑
天一生水water1 分钟前
地质工程一体化从入门到精通:油气勘探开发核心技术教程
人工智能·智慧油田
努力也学不会java2 分钟前
【Spring Cloud】环境和工程基本搭建
java·人工智能·后端·spring·spring cloud·容器
狮子座明仔2 分钟前
PRL:让大模型推理不再“开盲盒“——过程奖励学习的理论与实践
人工智能·深度学习·学习·机器学习·语言模型
发哥来了3 分钟前
主流AI视频生成模型商用化能力评测:五大核心维度深度对比
人工智能·音视频
博思云为4 分钟前
企业级智能PPT生成:Amazon云+AI驱动,全流程自动化提效
人工智能·语言模型·云原生·数据挖掘·云计算·语音识别·aws
龙山云仓4 分钟前
No126:AI中国故事-仓颉:智能的符号编码、知识压缩与文明记忆
大数据·人工智能·深度学习·机器学习·计算机视觉·重构
柠檬丶抒情7 分钟前
Rust深度学习框架Burn 0.20是否能超过python?
python·深度学习·rust·vllm
乾元7 分钟前
范式转移:从基于规则的“特征码”到基于统计的“特征向量”
运维·网络·人工智能·网络协议·安全