[Bert] 提取特征之后训练模型报梯度图错误

报错:

RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

或者

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

原因:

训练模型的时候,输入数据x,y不应该requires_grad,而bert模型输出的embeddings默认是requires_grad的,所以会报错。

解决方法:

提取完embeddings之后,使用 embeddings.detach() 解除绑定就行了。

最后的代码:

复制代码
from transformers import BertTokenizer, BertModel

class BertFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        self.model = BertModel.from_pretrained('bert-base-chinese')

    def extract_features(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt")
        if len(inputs["input_ids"]) > 512:
            inputs["input_ids"] = inputs["input_ids"][:512]
            inputs["attention_mask"] = inputs["attention_mask"][:512]
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:,0,:]
    
feat = feat.detach()
相关推荐
ZFSS7 分钟前
VS Code + Luma MCP 使用教程
人工智能·ai·ai作画·copilot·ai编程·ai写作
某林2127 分钟前
ROS2 语音机器人实战:从 KCF 跟随失效到 RTAB-Map 建图闭环的完整排障
人工智能·机器人·语音识别·ros2·架构重构·技术复盘·c++底层排错
Tongpao_SSDHDD9 分钟前
希捷酷鹰ST6000VX008实测解析:中小安防监控高性价比存储方案
大数据·数据库·人工智能
Ricky055311 分钟前
基于作物特性的语义分割技术用于高效农业病害评估(西班牙德国2025年联合研究)
人工智能·目标检测·图像分割
jkyy201413 分钟前
车载健康座舱成新赛道?汽车健康数字化重塑出行新价值
大数据·人工智能·汽车·健康医疗
jllllyuz13 分钟前
MATLAB实现滚动轴承故障诊断(外圈故障)
开发语言·人工智能·matlab
xianghongtao011615 分钟前
把 Prompt 当成“可训练参数“:SkillOpt 如何用深度学习的纪律去优化 Agent 技能
人工智能·深度学习·性能优化·prompt
2601_9594801518 分钟前
Moneta Markets亿汇:“应用软件股遭遇AI再定价”
人工智能
IT古董18 分钟前
AI行业最新动态 | 2026年6月10日
人工智能
zhuhai_xigedian20 分钟前
源网荷储一体化 vs 传统供用电模式:差异、优势与转型路径
大数据·人工智能·分布式·系统架构·能源