[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
Neolnfra7 小时前
拒绝数据“裸奔”!把顶级AI装进自己的硬盘,这款神仙开源工具我粉了
人工智能·开源·蓝耘maas
code_li7 小时前
只花了几分钟,用AI开发了一个微信小程序!(附教程)
人工智能·微信小程序·小程序
飞Link7 小时前
瑞萨联姻 Irida Labs:嵌入式开发者如何玩转“端侧视觉 AI”新范式?
人工智能
RSTJ_16258 小时前
PYTHON+AI LLM DAY THREETY-SEVEN
开发语言·人工智能·python
郝学胜-神的一滴8 小时前
深度学习优化核心:梯度下降与网络训练全解析
数据结构·人工智能·python·深度学习·算法·机器学习
Aision_8 小时前
Agent 为什么需要 Checkpoint?
人工智能·python·gpt·langchain·prompt·aigc·agi
小贺儿开发8 小时前
《唐朝诡事录之长安》——盛世马球
人工智能·unity·ai·shader·绘画·影视·互动
秋98 小时前
ESP32 与 Air780E 4G 模块配合做 MQTT 数据传输
人工智能
DeepFlow 零侵扰全栈可观测8 小时前
运动战:AI 时代 IT 运维的决胜之道——DeepFlow 业务全链路可观测性的落地实践
运维·网络·人工智能·arcgis·云计算
链上日记8 小时前
AgentWin:AI Agent驱动的Web4智能金融新纪元
人工智能·金融