[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
Where-8 分钟前
深度学习中的过拟合问题及解决方式
人工智能·深度学习
wen__xvn14 分钟前
目标检测的局限
人工智能·目标检测·计算机视觉
力学与人工智能24 分钟前
博士答辩PPT分享 | 高雷诺数湍流场数据同化与湍流模型机器学习研究
人工智能·机器学习·ppt分享·高雷诺数·流场数据同化·湍流模型
调参札记36 分钟前
医学研究中的因果推断:重视态度与实践流程的结构性落差
人工智能
木卫四科技39 分钟前
Chonkie 技术深度学习
人工智能·python·rag
努力毕业的小土博^_^1 小时前
【地学应用】溜砂坡scree slope / talus slope的定义、机制、分布、危害、与滑坡区别、研究方向与代表论文
人工智能·深度学习·遥感·地质灾害·地学应用
JeffDingAI1 小时前
【Datawhale学习笔记】基于Gensim的词向量实战
人工智能·笔记·学习
Ryan老房1 小时前
自动驾驶数据标注-L4-L5级别的数据挑战
人工智能·目标检测·目标跟踪·自动驾驶
weixin_398187751 小时前
YOLOv8结合SCI低光照图像增强算法实现夜晚目标检测
人工智能·yolo
万行1 小时前
机器人系统ROS2
人工智能·python·机器学习·机器人·计算机组成原理