[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
得赢科技1 小时前
2025年GEO营销应用白皮书 - 服务业区域推广深度剖析
大数据·人工智能
Deepoch1 小时前
Deepoc具身智能家庭系统:重塑居家生活新体验
人工智能·科技·机器人·生活·具身模型·deepoc·deepoc具身模型开发板
GIS数据转换器1 小时前
基于GIS的宠物救助服务平台
大数据·人工智能·科技·机器学习·无人机·智慧城市·宠物
qwy7152292581631 小时前
3-用摄像头拍摄图像及视频
人工智能·opencv·音视频
AI街潜水的八角1 小时前
基于YOLO26苹果水果缺陷检测系统1:苹果水果缺陷检测数据集说明(含下载链接)
人工智能·深度学习·神经网络
Solar20251 小时前
工程材料企业如何借助数字化工具突破获客瓶颈:方法论与实践路径
大数据·人工智能·物联网
audyxiao0011 小时前
会议热点扫描|通过智能交通顶级会议IEEE IV 2025看自动驾驶领域研究热点
人工智能·机器学习·自动驾驶·热点分析·ieee iv
茶栀(*´I`*)2 小时前
【视觉探索】OpenCV 全景导论:从数字图像基石到核心模块体系
人工智能·opencv·计算机视觉
喝可乐的希饭a2 小时前
AI Agent 的九种设计模式
人工智能·设计模式
春日见2 小时前
Docker中如何删除镜像
运维·前端·人工智能·驱动开发·算法·docker·容器