[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
Macbethad1 分钟前
智能硬件产品系统技术报告
大数据·人工智能
J_Xiong01172 分钟前
【VLMs篇】10:使用Transformer的端到端目标检测(DETR)
深度学习·目标检测·transformer
有Li2 分钟前
泛用型nnUNet脑血管周围间隙识别系统(PINGU)|文献速递-医疗影像分割与目标检测最新技术
论文阅读·深度学习·文献·医学生
这张生成的图像能检测吗4 分钟前
(论文速读)基于M-LLM的高效视频理解视频帧选择
人工智能·贪心算法·视频生成·多模态大语言模型
Shiyuan77 分钟前
【IEEE冠名EI会议】2026年IEEE第三届深度学习与计算机视觉国际会议
人工智能·深度学习·计算机视觉
q_30238195569 分钟前
YOLOv11训练NEU-DET钢材缺陷数据集并部署香橙派推理全流程
人工智能·python·深度学习·课程设计
编码小哥12 分钟前
OpenCV图像金字塔与图像拼接技术
人工智能·opencv·计算机视觉
LeeZhao@13 分钟前
【狂飙全模态】灵曦星灿视频助手-影视级音画同步视频生成
人工智能·语言模型·音视频·agi
java1234_小锋15 分钟前
Transformer 大语言模型(LLM)基石 - Transformer PyTorch2内置实现
深度学习·语言模型·transformer
丝斯201115 分钟前
AI学习笔记整理(35)——生成模型与视觉大模型
人工智能·笔记·学习