[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
lifallen9 小时前
淘宝RecGPT:通过LLM增强推荐
人工智能·深度学习·ai·推荐算法
IT学长编程9 小时前
计算机毕业设计 基于深度学习的酒店评论文本情感分析研究 Python毕业设计项目 Hadoop毕业设计选题 机器学习选题【附源码+文档报告+安装调试】
hadoop·python·深度学习·机器学习·数据分析·毕业设计·酒店评论文本情感分析
金井PRATHAMA10 小时前
认知语义学对人工智能自然语言处理的深层语义分析:理论启示与实践路径
人工智能·自然语言处理·知识图谱
小王爱学人工智能10 小时前
OpenCV的特征检测
人工智能·opencv·计算机视觉
羊羊小栈10 小时前
基于「YOLO目标检测 + 多模态AI分析」的铁路轨道缺陷检测安全系统(vue+flask+数据集+模型训练)
人工智能·yolo·目标检测·语言模型·毕业设计·创业创新·大作业
钝挫力PROGRAMER10 小时前
GPT与BERT BGE
人工智能·gpt·bert
Baihai IDP10 小时前
2025 年大语言模型架构演进:DeepSeek V3、OLMo 2、Gemma 3 与 Mistral 3.1 核心技术剖析
人工智能·ai·语言模型·llm·transformer
☼←安于亥时→❦10 小时前
PyTorch之张量创建与运算
人工智能·算法·机器学习
nuczzz10 小时前
pytorch非线性回归
人工智能·pytorch·机器学习·ai
~-~%%10 小时前
Moe机制与pytorch实现
人工智能·pytorch·python