[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
redreamSo4 分钟前
大模型是不是到顶了?瓶颈到底在哪
人工智能·openai
Oo9207 分钟前
Tool Use 背后的技术逻辑
人工智能
姗姗来迟了9 分钟前
Vue3封装AI流式对话组件踩坑实录
人工智能
码上天下1 小时前
用Pinia管理AI多会话状态
人工智能
用户054324329702 小时前
Next.js接大模型流式SSE实操踩坑
人工智能
Assby2 小时前
从 Function Calling 到 MCP:理解 Agent 工具调用的底层通信机制
人工智能·后端
小星AI2 小时前
Claude Code 从入门到精通,一步到位
人工智能
后端小肥肠2 小时前
Codex + Obsidian 做人生副本视频:输入主题文案,直通剪映草稿
人工智能·aigc·agent
百度Geek说3 小时前
全链路研发智能体 ——从"体感能用"到"实际可用"的工程实践
人工智能
甲维斯4 小时前
500块的豆包,能帮我搞定这个么?!
人工智能