[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
weixin_397574093 分钟前
向量空间携手山东信研院共建实验室,工业AI按下加速键
人工智能
DisonTangor3 分钟前
跃阶星辰开源Step 3.7 Flash:原生多模态,最高生成速度400 Tokens/s
人工智能·语言模型·数据挖掘·开源·aigc
lili00124 分钟前
Claude自动修Bug配置优化与避坑指南
java·人工智能·python·bug·ai编程
Szime6 分钟前
靠谱的终端工厂采购电子元器件供应链哪家更适合研发型企业?
人工智能·python
圣殿骑士-Khtangc9 分钟前
SuperSplat 架构深度解析:8.2K Star 的浏览器端 3D 高斯泼溅编辑器,PlayCanvas 如何用纯 WebGL 重新定义三维内容工作流
人工智能
Mem0rin10 分钟前
[Agent基础]Agent、消息和聊天模板
人工智能·transformer
智信中科张炜11 分钟前
全球及中国二板注塑机市场前景形势分析报告
人工智能
升鲜宝供应链及收银系统源代码服务12 分钟前
升鲜宝 AI 供应链分析方案业务分析、智能预警与实施落地方案(一)---升鲜宝生鲜配送供应链管理系统源代码服务
人工智能·生鲜供应链源代码·供应链源代码出售·生鲜配送源代码服务·猪肉生产加工系统源代码·生鲜供应链系统·生鲜配送系统ai应用
编程牛马姐14 分钟前
爬虫开发工具测评:Playwright vs Puppeteer
人工智能
andafaAPS17 分钟前
安达发|aps高级排产:电动工具行业智能制造的核心引擎
大数据·人工智能·制造·安达发aps·aps高级排产·aps自动排产