[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
humors2219 分钟前
AI案例:头脑风暴创作-正反论证-报告撰写-摘要总结
人工智能·ai·写作·总结·案例·论证
HIT_Weston12 分钟前
115、【Agent】【OpenCode】项目配置(SemVer)
人工智能·agent·opencode
Sam092713 分钟前
OpenClaw 和 Hermes 怎么结合:从聊天入口到隔离执行器的 Agent 工程实践
人工智能·ai
沪漂阿龙19 分钟前
LangChain 系列之 Messages:为什么大模型对话不是简单字符串?
人工智能·深度学习·langchain
jiuLives19 分钟前
从 Prompt Engineering 到 Loop Engineering:AI 工程范式的演进
人工智能·prompt
ACP广源盛1392462567320 分钟前
IX7008 PCIe 交换芯片@ACP#RTX Spark 经济型 8 口扩展芯片(对比 ASM1806)
大数据·人工智能·分布式·嵌入式硬件·gpt·spark·电脑
SEOETC29 分钟前
GEO:杭州AI优化企业实战指南
人工智能·搜索引擎
大模型任我行30 分钟前
腾讯:原生多模态建模路线图
人工智能·语言模型·自然语言处理·论文笔记
搜移IT科技31 分钟前
工业设备更新行动全面推进,通用设备板块增量空间与受益环节解析
人工智能·科技·生活
lauo36 分钟前
碳基心脏最后的堡垒——ibbot青春版:你的随身Token生产厂
大数据·人工智能·chatgpt·智能手机·ai-native