[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
TsingtaoAI2 分钟前
技术博客外,Gen1比Gen0进化了什么
人工智能·具身智能
敢敢のwings2 分钟前
NVIDIA Thor学习之 |部署NVIDIA Cosmos Reason 2B视觉语言模型完整指南(一)
人工智能
荪荪4 分钟前
开发板断电启动相机报错
人工智能·机器人
财经资讯数据_灵砚智能5 分钟前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年4月15日
人工智能·python·信息可视化·自然语言处理·ai编程
青瓷程序设计6 分钟前
基于深度学习的【犬类识别】系统~Python+人工智能+卷积算法+图像识别+计算机毕设项目
人工智能·python·深度学习
Zzj_tju6 分钟前
大语言模型技术指南:长上下文是怎么做出来的?RoPE、位置插值、滑窗注意力与 KV Cache 详解
人工智能·语言模型·自然语言处理
茫忙然11 分钟前
[特殊字符]️ CTF AI大模型提示词注入 (Prompt Injection) 核心攻防方法总结大全
人工智能·prompt
OAK中国_官方13 分钟前
在OAK 4 系列上以480帧运行神经网络
人工智能·深度学习·神经网络
ASKED_201914 分钟前
Harness Enginner记录-驾驭AI Agent之术
人工智能
薛定猫AI17 分钟前
【技术干货】Claude Code 桌面版重大更新:AI 辅助编程进入 IDE 原生时代
ide·人工智能