[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
aneasystone本尊17 小时前
OpenClaw 快速入门:从安装到第一次对话
人工智能
aneasystone本尊17 小时前
OpenClaw 接入第一个通道:Telegram
人工智能
IT_陈寒17 小时前
Redis这个内存杀手,差点让我们运维半夜追杀我
前端·人工智能·后端
私人珍藏库17 小时前
【Android】聆听岛[特殊字符]聚合全网音乐[特殊字符]免费听歌下载神器[特殊字符] 聚合音乐平台|无损母带下载|歌词封面同步|免费无广告听歌工具
android·人工智能·工具·软件·多功能
aneasystone本尊17 小时前
OpenClaw 介绍:一款运行在自己设备上的开源 AI 助手
人工智能
OneBlock Community17 小时前
穿越熊市与 AI 浪潮,Polkadot 仍以“自由”为锚!
人工智能
纤纡.18 小时前
本地部署 AI 大模型保姆级教程:Ollama 安装、模型下载与终端实战全流程
人工智能·深度学习·语言模型·llama
沸点小助手18 小时前
「新晋AI顶流PK:GPT-5.5 vs DeepSeek V4&掘友吐槽小会」沸点获奖名单公示|本周互动话题上新🎊
前端·人工智能
nikolay18 小时前
AI重塑企业信息安全:攻防升级与信任重构
网络·人工智能·网络安全
天辛大师18 小时前
天辛大师谈人工智能时代,如何用AI研究历代放生劝善忏悔文
大数据·人工智能·随机森林·启发式算法