[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
leo在掘金1 小时前
从DeepSeek 510亿融资到GitHub 33K Star开源项目:这周的技术生态发生了什么?
人工智能
小姜前线技术2 小时前
AI流式渲染打字机效果抖动?节流方案踩坑实录
人工智能
用户018349301692 小时前
AI对话状态管理:useReducer还是XState
人工智能
先锋部队2 小时前
给AI对话加「停止生成」按钮:abort SSE实战
人工智能
新新技术迷2 小时前
移动端H5接AI对话的坑:键盘顶起与滚动到底
人工智能
aqi005 小时前
15天学会AI应用开发(七)有了大模型为什么还要引入RAG
人工智能·python·大模型·ai编程·ai应用
用户5191495848456 小时前
libcurl Headers API 释放后重利用漏洞:跨请求复用头句柄导致堆内存安全风险
人工智能·aigc
踩蚂蚁6 小时前
自定义语音唤醒词:从训练到部署的完整链路实践
人工智能
用户5191495848456 小时前
CVE-2025-1094 PostgreSQL SQL注入与WebSocket劫持远程代码执行利用工具
人工智能·aigc