[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
Winner1300几秒前
查看rk3566摄像头设备、能力、支持格式
linux·网络·人工智能
shizhenshide13 分钟前
“绕过”与“破解”的成本账:自行研发、购买API与外包打码的性价比全分析
人工智能·验证码·recaptcha·ezcaptcha·recaptcha v2
龙腾亚太25 分钟前
大模型在工业物流领域有哪些应用
人工智能·具身智能·智能体·世界模型·智能体培训·具身智能培训
Deepoch35 分钟前
智能清洁新纪元:Deepoc开发板如何重塑扫地机器人的“大脑“
人工智能·机器人·清洁机器人·具身模型·deepoc
装不满的克莱因瓶37 分钟前
【Coze智能体实战二】一键生成儿歌背单词视频
人工智能·ai·实战·agent·工作流·智能体·coze
杰米不放弃40 分钟前
AI大模型应用开发学习-26【20251227】
人工智能·学习
一个会的不多的人1 小时前
人工智能基础篇:概念性名词浅谈(第八讲)
人工智能·制造·数字化转型
weixin_446260851 小时前
Robin: AI驱动的暗网OSINT工具
人工智能
Coder_Boy_1 小时前
基于SpringAI的智能运维平台(AI驱动)
大数据·运维·人工智能
圆号本昊1 小时前
RimWorld AI记忆系统深度技术分析
人工智能