[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
萤丰信息9 小时前
智慧园区:科技赋能的未来产业生态新载体
大数据·运维·人工智能·科技·智慧园区
ASD123asfadxv9 小时前
【医疗影像检测】VFNet模型在医疗器械目标检测中的应用与优化
人工智能·目标检测·计算机视觉
小真zzz9 小时前
2025-2026年AI PPT工具排行榜:ChatPPT的全面领先与竞品格局解析
人工智能·ai·powerpoint·ppt·aippt
翱翔的苍鹰9 小时前
CIFAR-10 是一个经典的小型彩色图像分类数据集,广泛用于深度学习入门、模型验证和算法研究
深度学习·算法·分类
智慧化智能化数字化方案9 小时前
详解人工智能安全治理框架(中文版)【附全文阅读】
大数据·人工智能·人工智能安全治理框架
人工智能培训9 小时前
开源与闭源大模型的竞争未来会如何?
人工智能·机器学习·语言模型·大模型·大模型幻觉·开源大模型·闭源大模型
啊阿狸不会拉杆9 小时前
《机器学习》第六章-强化学习
人工智能·算法·机器学习·ai·机器人·强化学习·ml
人工智能AI技术9 小时前
【Agent从入门到实践】21 Prompt工程基础:为Agent设计“思考指令”,简单有效即可
人工智能·python
式5169 小时前
大模型学习基础(九)LoRA微调原理
人工智能·深度学习·学习
CCPC不拿奖不改名9 小时前
python基础面试编程题汇总+个人练习(入门+结构+函数+面向对象编程)--需要自取
开发语言·人工智能·python·学习·自然语言处理·面试·职场和发展