[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
伊一大数据&人工智能学习日志5 分钟前
机器学习之模型评估——混淆矩阵,交叉验证与数据标准化
人工智能·深度学习·机器学习
Jackilina_Stone8 分钟前
【HUAWEI】HCIP-AI-MindSpore Developer V1.0 | 第一章 神经网络基础(4 生成对抗网络 ) | 学习笔记
人工智能·笔记·神经网络·生成对抗网络·华为·hcip
Jackilina_Stone11 分钟前
【HUAWEI】HCIP-AI-MindSpore Developer V1.0 | 第一章 神经网络基础( 3 循环神经网络 ) | 学习笔记
人工智能·笔记·rnn·神经网络·hcip·huawei
剑盾云安全专家13 分钟前
AI智能生成PPT,告别手工操作的新选择
人工智能·科技·aigc·powerpoint·软件
又南又难20 分钟前
deepFM模型pytorch实现
人工智能·pytorch·python
阿松のblog33 分钟前
深度学习之计算机视觉相关数据集
人工智能·深度学习·计算机视觉
远洋录43 分钟前
Tailwind CSS 实战:动画效果设计与实现
前端·人工智能·react
数据分析能量站1 小时前
RWKV 语言模型
人工智能·语言模型·自然语言处理
吃个糖糖1 小时前
38 Opencv HOG特征检测
人工智能·opencv·计算机视觉
deephub2 小时前
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
人工智能·pytorch·神经网络·强化学习