[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
小诸葛IT课堂12 分钟前
PyTorch 生态概览:为什么选择动态计算图框架?
人工智能·pytorch·python
雅菲奥朗24 分钟前
4大观点直面呈现|直播回顾-DeepSeek时代的AI算力管理
人工智能·ai算力·deepseek
程序员JerrySUN36 分钟前
深入解析 TensorFlow 兼容性问题及构建输出文件结构*
人工智能·tensorflow·neo4j
ruokkk44 分钟前
搭建一个声纹识别系统
人工智能
kula7751 小时前
Trae,国产首款AI编程IDE初体验
人工智能
moonless02221 小时前
【AI】MercuryCoder与LLaDA? 自回归模型与扩散模型的碰撞,谁才是未来的LLM答案?
人工智能·llm
吾名招财1 小时前
pytorch快速入门——手写数字分类GPU加速
人工智能·pytorch·分类
天天向上杰1 小时前
地基Prompt提示常用方式
人工智能·prompt·提示词
KARL1 小时前
最小闭环manus,langchainjs+mcp-client+mcp-server
前端·人工智能
zhongken2591 小时前
AI智能混剪工具:AnKo打造高效创作的利器!
人工智能·ai·ai编程·ai网站·ai工具·ai软件·ai平台