[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
人工智能训练几秒前
windows系统中的docker,xinference直接运行在容器目录和持载在宿主机目录中的区别
linux·服务器·人工智能·windows·ubuntu·docker·容器
南蓝12 分钟前
【AI 日记】调用大模型的时候如何按照 sse 格式输出
前端·人工智能
robot_learner15 分钟前
11 月 AI 动态:多模态突破・智能体模型・开源浪潮・机器人仿真・AI 安全与主权 AI
人工智能·机器人·开源
Mintopia43 分钟前
🌐 动态网络环境中 WebAIGC 的断点续传与容错技术
人工智能·aigc·trae
后端小张1 小时前
【AI 学习】从0到1深入理解Agent AI智能体:理论与实践融合指南
人工智能·学习·搜索引擎·ai·agent·agi·ai agent
Mintopia1 小时前
🧩 Claude Code Hooks 最佳实践指南
人工智能·claude·全栈
【建模先锋】1 小时前
精品数据分享 | 锂电池数据集(四)PINN+锂离子电池退化稳定性建模和预测
深度学习·预测模型·pinn·锂电池剩余寿命预测·锂电池数据集·剩余寿命
星空的资源小屋1 小时前
极速精准!XSearch本地文件搜索神器
javascript·人工智能·django·电脑
九年义务漏网鲨鱼1 小时前
【大模型学习】现代大模型架构(二):旋转位置编码和SwiGLU
深度学习·学习·大模型·智能体
CoovallyAIHub1 小时前
破局红外小目标检测:异常感知Anomaly-Aware YOLO以“俭”驭“繁”
深度学习·算法·计算机视觉