[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
~央千澈~几秒前
《2026鸿蒙NEXT纯血开发与AI辅助》第五章:选择成熟方案,创建第一个鸿蒙应用并成功运行-卓伊凡
人工智能·华为·harmonyos·harmony·harmony os
ting94520001 分钟前
Kimi-VL-A3B-Thinking 技术全解
人工智能·架构
monkeyhlj4 分钟前
AI Agent开发课程笔记记录 - 提升篇 About RAG
人工智能·笔记
qq_411262425 分钟前
四博 AI 智能音箱 4G S3架构方案
人工智能·架构·智能音箱
skywalk81638 分钟前
基于 Kotti CMS 的 AI 共创社区插件 —— 实现 AI 资源共享、协作交流和项目孵化(先放弃)
人工智能
qq_411262429 分钟前
四博AI智能拍学机方案设计
人工智能·智能音箱
格林威10 分钟前
面阵相机 vs 线阵相机:堡盟与Basler选型差异全解析 +C# 实战演示
开发语言·人工智能·数码相机·计算机视觉·c#·视觉检测·工业相机
爱上好庆祝16 分钟前
学习js的第三天
前端·css·人工智能·学习·计算机外设·js
隔壁大炮18 分钟前
10.PyTorch_元素类型转换
人工智能·pytorch·深度学习·算法