[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
富唯智能32 分钟前
移动+协作+视觉:开箱即用的下一代复合机器人如何重塑智能工厂
人工智能·工业机器人·复合机器人
Antonio9151 小时前
【图像处理】图像的基础几何变换
图像处理·人工智能·计算机视觉
新加坡内哥谈技术2 小时前
Perplexity AI 的 RAG 架构全解析:幕后技术详解
人工智能
武子康3 小时前
AI研究-119 DeepSeek-OCR PyTorch FlashAttn 2.7.3 推理与部署 模型规模与资源详细分析
人工智能·深度学习·机器学习·ai·ocr·deepseek·deepseek-ocr
Sirius Wu4 小时前
深入浅出:Tongyi DeepResearch技术解读
人工智能·语言模型·langchain·aigc
忙碌5444 小时前
AI大模型时代下的全栈技术架构:从深度学习到云原生部署实战
人工智能·深度学习·架构
LZ_Keep_Running4 小时前
智能变电巡检:AI检测新突破
人工智能
InfiSight智睿视界5 小时前
AI 技术助力汽车美容行业实现精细化运营管理
大数据·人工智能
没有钱的钱仔6 小时前
机器学习笔记
人工智能·笔记·机器学习
听风吹等浪起6 小时前
基于改进TransUNet的港口船只图像分割系统研究
人工智能·深度学习·cnn·transformer