【Hugging Face】解决BART模型调用时KeyError: ‘new_zeros‘的问题

错误代码:

python 复制代码
tokenizer = AutoTokenizer.from_pretrained("philschmid/bart-large-cnn-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained("philschmid/bart-large-cnn-samsum")

model.eval()
model.to("cuda")
loss = 0
for i in range(len(self.dataset)):
    batch = tokenizer([self.dataset[i]["source"]], return_tensors="pt", padding=True).to("cuda")
    labels = tokenizer([self.dataset[i]["target"]], return_tensors="pt", padding=True).to("cuda")
    print(batch)
    outputs = model(**batch, labels=labels)
    print(outputs.loss.item())

报错内容:

python 复制代码
Traceback (most recent call last):
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\tokenization_utils_base.py", line 266, in __getattr__
    return self.data[item]
KeyError: 'new_zeros'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "E:\supTextDebug\supTextDebugCode\textDebugger.py", line 360, in <module>
    debugger.run_baselines()
  File "E:\supTextDebug\supTextDebugCode\textDebugger.py", line 299, in run_baselines
    loss.get_loss()
  File "E:\supTextDebug\supTextDebugCode\lossbased.py", line 26, in get_loss
    outputs = model(**batch, labels=labels)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\models\bart\modeling_bart.py", line 1724, in forward
    decoder_input_ids = shift_tokens_right(
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\models\bart\modeling_bart.py", line 104, in shift_tokens_right
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\tokenization_utils_base.py", line 268, in __getattr__
    raise AttributeError
AttributeError

解决方案:

错误行:outputs = model(**batch, labels=labels)

直接使用模型的forward方法,而不是将所有参数传递给 model:

python 复制代码
tokenizer = AutoTokenizer.from_pretrained("philschmid/bart-large-cnn-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained("philschmid/bart-large-cnn-samsum")

model.eval()
model.to("cuda")
loss = 0
for i in range(len(self.dataset)):
    batch = tokenizer([self.dataset[i]["source"]], return_tensors="pt", padding=True).to("cuda")
    labels = tokenizer([self.dataset[i]["target"]], return_tensors="pt", padding=True).to("cuda")
    print(batch)
    outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=labels["input_ids"])
    print(outputs.loss.item())
相关推荐
好奇龙猫4 分钟前
【人工智能学习-AI入试相关题目练习-第九次】
人工智能·学习
抠头专注python环境配置12 分钟前
基于Python与深度学习的智能垃圾分类系统设计与实现
pytorch·python·深度学习·分类·垃圾分类·vgg·densenet
aspxiy40 分钟前
知识求解器:教会大型语言模型从知识图谱中搜索领域知识
人工智能·语言模型·自然语言处理·知识图谱
梦想是成为算法高手42 分钟前
带你从入门到精通——知识图谱(一. 知识图谱入门)
人工智能·pytorch·python·深度学习·神经网络·知识图谱
沛沛老爹1 小时前
从Web到AI:行业专属Agent Skills生态系统技术演进实战
java·开发语言·前端·vue.js·人工智能·rag·企业转型
B站计算机毕业设计超人1 小时前
计算机毕业设计Python+大模型音乐推荐系统 音乐数据分析 音乐可视化 音乐爬虫 知识图谱 大数据毕业设计
人工智能·hadoop·爬虫·python·数据分析·知识图谱·课程设计
陈天伟教授1 小时前
人工智能应用-机器视觉:AI 鉴伪 02.虚假人脸生成
人工智能·神经网络·数码相机·生成对抗网络·dnn
可能是阿伦1 小时前
探索 cccc:一个面向工程协作的多代理协作内核
人工智能·低代码·ai·web3