[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
小烤箱9 分钟前
Autoware Universe 感知模块详解 | 第十一节:检测管线的通用工程模板与拆解思路导引
人工智能·机器人·自动驾驶·autoware·感知算法
jiayong2312 分钟前
model.onnx 深度分析报告(第2篇)
人工智能·机器学习·向量数据库·向量模型
川西胖墩墩19 分钟前
团队协作泳道图制作工具 PC中文免费
大数据·论文阅读·人工智能·架构·流程图
Codebee19 分钟前
ooder SkillFlow:破解 AI 编程冲击,重构企业级开发全流程
人工智能
TOPGUS22 分钟前
黑帽GEO手法揭秘:AI搜索阴影下的新型搜索劫持与风险
人工智能·搜索引擎·chatgpt·aigc·谷歌·数字营销
Sammyyyyy24 分钟前
Symfony AI 正式发布,PHP 原生 AI 时代开启
开发语言·人工智能·后端·php·symfony·servbay
汽车仪器仪表相关领域28 分钟前
光轴精准测量,安全照明保障——NHD-8101/8000型远近光检测仪项目实战分享
数据库·人工智能·安全·压力测试·可用性测试
WJSKad123530 分钟前
基于yolov5-RepNCSPELAN的商品价格标签识别系统实现
人工智能·yolo·目标跟踪
资深web全栈开发31 分钟前
深度对比 LangChain 8 种文档分割方式:从逻辑底层到选型实战
深度学习·自然语言处理·langchain
早日退休!!!33 分钟前
现代公司开发AI编译器的多元技术路线(非LLVM方向全解析)
人工智能