[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
Joseph Cooper几秒前
AI Agent 落地入门:从模型、工具到 Skills 与 MCP 的分工
人工智能·ai·agent·claude·skill·mcp
爱学习的张大2 分钟前
具身智能论文精读(五):OpenVLA
人工智能·算法
AI创界者3 分钟前
OmniVoice 语音大模型一键部署:支持批量任务、智能 SRT 配音与多人对话全攻略》
人工智能
丷丩7 分钟前
为什么Geo-UP是一款可以直接用于交付的智能应用
人工智能·gis·空间分析·geoai
xiangzhihong811 分钟前
Claude Code系列教程之Claude Code钩子
人工智能
sheji10513 分钟前
泳池机器人行业市场分析报告
人工智能·机器人·智能硬件
虾壳云管家20 分钟前
【含四月底最新安装包】OpenClaw一键安装及使用教程
人工智能·openclaw·小龙虾·openclaw安装·openclaw一键部署
无心水23 分钟前
【Hermes:Skill系统深度】21、Skill 调试与冲突解决:为什么没触发?怎么修复? —— Honcho 智能体排障完全手册
人工智能·windows·openclaw·养龙虾·hermes·养马·honcho
袖手蹲28 分钟前
把 Claude 的愚人节彩蛋跑在 行空板K10上:BLE 应用与 ASCII 宠物动画实战
人工智能·自动化·宠物
春风有信29 分钟前
【DM】DDPM与DDIM的数学原理
人工智能·深度学习·机器学习