[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
龙山云仓17 分钟前
小G&老D求解:第7日·立夏·蝼蝈鸣
人工智能·机器学习
LaughingZhu29 分钟前
Product Hunt 每日热榜 | 2026-04-30
人工智能·经验分享·深度学习·神经网络·产品运营
sunneo34 分钟前
专栏D-团队与组织-03-产品文化
人工智能·产品运营·aigc·产品经理·ai编程
Muyuan199834 分钟前
28.Paper RAG Agent 开发记录:修复 LLM Rerank 的解析、Fallback 与可验证性
linux·人工智能·windows·python·django·fastapi
小呆呆6661 小时前
Codex 穷鬼大救星
前端·人工智能·后端
薛定猫AI1 小时前
【深度解析】Kimi K2.6 的长上下文 Agentic Coding 能力与 OpenAI 兼容 API 接入实践
人工智能·自动化·知识图谱
星爷AG I1 小时前
20-6 记忆整合(AGI基础理论)
人工智能·agi
AI创界者1 小时前
人工智能 GPT-Image DMXAPI Python AI绘画
人工智能
播播资源1 小时前
GPT-5.5 模型功能深度解析:从模型介绍、核心特点到应用场景全景分析 如何快速接入使用
人工智能·gpt
谁似人间西林客1 小时前
工厂大脑是什么?从经验驱动到AI辅助的决策跃迁
人工智能