【Hugging Face】解决BART模型调用时KeyError: ‘new_zeros‘的问题

错误代码:

python 复制代码
tokenizer = AutoTokenizer.from_pretrained("philschmid/bart-large-cnn-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained("philschmid/bart-large-cnn-samsum")

model.eval()
model.to("cuda")
loss = 0
for i in range(len(self.dataset)):
    batch = tokenizer([self.dataset[i]["source"]], return_tensors="pt", padding=True).to("cuda")
    labels = tokenizer([self.dataset[i]["target"]], return_tensors="pt", padding=True).to("cuda")
    print(batch)
    outputs = model(**batch, labels=labels)
    print(outputs.loss.item())

报错内容:

python 复制代码
Traceback (most recent call last):
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\tokenization_utils_base.py", line 266, in __getattr__
    return self.data[item]
KeyError: 'new_zeros'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "E:\supTextDebug\supTextDebugCode\textDebugger.py", line 360, in <module>
    debugger.run_baselines()
  File "E:\supTextDebug\supTextDebugCode\textDebugger.py", line 299, in run_baselines
    loss.get_loss()
  File "E:\supTextDebug\supTextDebugCode\lossbased.py", line 26, in get_loss
    outputs = model(**batch, labels=labels)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\models\bart\modeling_bart.py", line 1724, in forward
    decoder_input_ids = shift_tokens_right(
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\models\bart\modeling_bart.py", line 104, in shift_tokens_right
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\tokenization_utils_base.py", line 268, in __getattr__
    raise AttributeError
AttributeError

解决方案:

错误行:outputs = model(**batch, labels=labels)

直接使用模型的forward方法,而不是将所有参数传递给 model:

python 复制代码
tokenizer = AutoTokenizer.from_pretrained("philschmid/bart-large-cnn-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained("philschmid/bart-large-cnn-samsum")

model.eval()
model.to("cuda")
loss = 0
for i in range(len(self.dataset)):
    batch = tokenizer([self.dataset[i]["source"]], return_tensors="pt", padding=True).to("cuda")
    labels = tokenizer([self.dataset[i]["target"]], return_tensors="pt", padding=True).to("cuda")
    print(batch)
    outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=labels["input_ids"])
    print(outputs.loss.item())
相关推荐
后端小肥肠2 分钟前
放弃漫画内卷!育儿赛道才是黑马,用 Coze 智能体做10w+育儿漫画,成品直接发
人工智能·agent·coze
whaosoft-1435 分钟前
51c~Pytorch~合集6
人工智能
后端小张7 分钟前
[AI 学习日记] 深入解析MCP —— 从基础配置到高级应用指南
人工智能·python·ai·开源协议·mcp·智能化转型·通用协议
天青色等烟雨..10 分钟前
AI+Python驱动的无人机生态三维建模与碳储/生物量/LULC估算全流程实战技术
人工智能·python·无人机
渡我白衣14 分钟前
深度学习进阶(七)——智能体的进化:从 LLM 到 AutoGPT 与 OpenDevin
人工智能·深度学习
乌恩大侠31 分钟前
【USRP】AI-RAN Sionna 5G NR 开发者套件
人工智能·5g
孤狼灬笑32 分钟前
机器学习十大经典算法解析与对比
人工智能·算法·机器学习
聚梦小课堂34 分钟前
ComfyUI Blog: ImagenWorld 发布:面向图像生成与编辑的真实世界基准测试数据集
人工智能·深度学习·图像生成·benchmark·imagenworld
星际棋手39 分钟前
【AI】一文说清楚神经网络、机器学习、专家系统
人工智能·神经网络·机器学习
测试开发技术44 分钟前
什么样的 prompt 是好的 prompt?
人工智能·ai·大模型·prompt