[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
甲维斯2 分钟前
Claude Code中文界面版更一波!又改了5000+行!
人工智能·ai编程
腾讯云开发者6 分钟前
从前沿洞见到落地实践:腾讯云TVP布道澳门,燃动AI Agent新思潮
人工智能
雪隐17 分钟前
个人电脑玩AI-02让5060 Ti给你打工——Whisper语音识别篇(下)
人工智能·后端
HIT_Weston18 分钟前
110、【Agent】【OpenCode】todowrite 工具提示词(示例)(四)
人工智能·agent·opencode
ECT-OS-JiuHuaShan24 分钟前
什么是对和错?——“有针对性定义域的逻辑值的真伪”:认识论终极追问的公理化裁决
数据库·人工智能·算法·机器学习·数学建模
澹锦汐32 分钟前
从 0 到 1 构建 AI 创意工具:独立开发者的 LLM 应用实战
人工智能
道友可好33 分钟前
Superpowers vs OpenSpec vs Spec Kit:该选哪个?
前端·人工智能·后端
xixingzhe233 分钟前
AI运维注意点
运维·人工智能
morning_judger33 分钟前
Agent开发系列(十一)-知识库建设(知识地图)
人工智能