[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
wuxinyan1239 小时前
工业级大模型学习之路020:LangChain零基础入门教程(第三篇):提示词工程与提示模板系统
人工智能·python·学习·langchain
海盗12349 小时前
科技与科学领域每日新闻摘要-2026年5月20日
人工智能·科技
threelab9 小时前
Three.js 3D 热力图效果 | 三维可视化 / AI 提示词
开发语言·前端·javascript·人工智能·3d·着色器
听风吹等浪起10 小时前
基于改进ResUNet的植物叶片语义分割系统设计与实现
人工智能·深度学习·分类
得物技术10 小时前
Claude Code Harness 工程:数仓侧落地方案|得物技术
数据库·人工智能·ai编程
隔窗听雨眠10 小时前
AI开发者的网络卡点:Anthropic连接超时实战避坑
网络·人工智能
8K超高清10 小时前
CCBN展会多图回顾
人工智能·算法·fpga开发·接口隔离原则·智能硬件
AI大法师10 小时前
从 Adobe 焕新看品牌系统升级:Logo、主色、字体与产品体验如何重新对齐
大数据·人工智能·adobe·设计模式
解局易否结局10 小时前
从零搭建 ops-transformer 开发环境:在昇腾NPU上跑通第一个算子
人工智能·深度学习·transformer
xiaoxiaoxiaolll10 小时前
Light: Sci. Appl. 封面级研究:光谱奇点拓扑环绕 + BIC共振 = 新一代多功能平面器件
人工智能·机器学习