[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
kisshuan123961 分钟前
基于Mask-RCNN与Res2Net的排水系统缺陷检测与分类
人工智能·数据挖掘
P.H. Infinity6 分钟前
【QLIB】一、系统架构
人工智能·金融
搬砖的kk10 分钟前
openJiuwen 快速入门:使用华为云大模型搭建 AI 智能体
数据库·人工智能·华为云
Gavin在路上17 分钟前
SpringAIAlibaba之从执行生命周期到实战落地(7)
人工智能
万俟淋曦25 分钟前
【论文速递】2025年第50周(Dec-07-13)(Robotics/Embodied AI/LLM)
人工智能·深度学习·机器人·大模型·论文·robotics·具身智能
没有不重的名么34 分钟前
When Hypergraph Meets Heterophily: New Benchmark Datasets and Baseline
人工智能·深度学习·opencv·计算机视觉·超图
zxsz_com_cn1 小时前
设备预测性维护优势全景解读:安全、降本、增效与可量化ROI
人工智能
爬点儿啥1 小时前
[Ai Agent] 13 用 Streamlit 为 Agents SDK 打造可视化“驾驶舱”
人工智能·ai·状态模式·agent·streamlit·智能体
机器学习算法与Python实战1 小时前
腾讯翻译大模型,手机可运行
人工智能
百***58841 小时前
MATLAB高效算法实战技术文章大纲1
人工智能·算法·matlab