[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
paopao_wu6 分钟前
解析 skill-creator:如何编写高质量的 AI Skill
人工智能·ai编程
IvanCodes6 分钟前
人工智能、机器学习和深度学习,其实不是一回事
人工智能·机器学习
前端Fusion6 分钟前
一文讲透 MCP 和 Skills 的分工与协作
人工智能·vibecoding
逐梦苍穹13 分钟前
谷歌新研究:训练大模型时“偷懒跳过“50%更新,性能反而提升20%?
人工智能·google·论文·梯度更新
向哆哆14 分钟前
单车/共享单车目标检测数据集(适用YOLO系列)(已标注+划分/可直接训练)
人工智能·yolo·目标检测
新缸中之脑16 分钟前
轻量AI助手的兴起
人工智能
陈天伟教授37 分钟前
人工智能应用- 预测化学反应:02. 化学反应简介
人工智能·神经网络·算法·机器学习·推荐算法
光的方向_1 小时前
04-Tokenization实战-从BPE到Hugging-Face应用
人工智能·深度学习·chatgpt·transformer
后端小肥肠1 小时前
喂饭级教程!免费部署云端 OpenClaw + 打通飞书,自动抓取 ClawHub 技能并写入飞书表格
人工智能·agent
AI_56781 小时前
Nmap端口扫描:SYN扫描+脚本绕过提升成功率
人工智能·nmap