[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
就这个丶调调19 分钟前
VLLM部署全部参数详解及其作用说明
深度学习·模型部署·vllm·参数配置
余俊晖22 分钟前
3秒实现语音克隆的Qwen3-TTS的Qwen-TTS-Tokenizer和方法架构概览
人工智能·语音识别
森屿~~23 分钟前
AI 手势识别系统:踩坑与实现全记录 (PyTorch + MediaPipe)
人工智能·pytorch·python
运维行者_1 小时前
2026 技术升级,OpManager 新增 AI 网络拓扑与带宽预测功能
运维·网络·数据库·人工智能·安全·web安全·自动化
淬炼之火1 小时前
图文跨模态融合基础:大语言模型(LLM)
人工智能·语言模型·自然语言处理
Elastic 中国社区官方博客1 小时前
Elasticsearch:上下文工程 vs. 提示词工程
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
轴测君1 小时前
SE Block(Squeeze and Excitation Block)
深度学习·机器学习·计算机视觉
正宗咸豆花1 小时前
LangGraph实战:构建可自愈的多智能体客服系统架构
人工智能·系统架构·claude
檐下翻书1731 小时前
文本创作进化:从辅助写作到内容策划的全面赋能
人工智能
仙人掌_lz2 小时前
AI代理记忆设计指南:从单一特征到完整系统,打造可靠智能体
人工智能