【Hugging Face】解决BART模型调用时KeyError: ‘new_zeros‘的问题

错误代码:

python 复制代码
tokenizer = AutoTokenizer.from_pretrained("philschmid/bart-large-cnn-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained("philschmid/bart-large-cnn-samsum")

model.eval()
model.to("cuda")
loss = 0
for i in range(len(self.dataset)):
    batch = tokenizer([self.dataset[i]["source"]], return_tensors="pt", padding=True).to("cuda")
    labels = tokenizer([self.dataset[i]["target"]], return_tensors="pt", padding=True).to("cuda")
    print(batch)
    outputs = model(**batch, labels=labels)
    print(outputs.loss.item())

报错内容:

python 复制代码
Traceback (most recent call last):
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\tokenization_utils_base.py", line 266, in __getattr__
    return self.data[item]
KeyError: 'new_zeros'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "E:\supTextDebug\supTextDebugCode\textDebugger.py", line 360, in <module>
    debugger.run_baselines()
  File "E:\supTextDebug\supTextDebugCode\textDebugger.py", line 299, in run_baselines
    loss.get_loss()
  File "E:\supTextDebug\supTextDebugCode\lossbased.py", line 26, in get_loss
    outputs = model(**batch, labels=labels)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\models\bart\modeling_bart.py", line 1724, in forward
    decoder_input_ids = shift_tokens_right(
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\models\bart\modeling_bart.py", line 104, in shift_tokens_right
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  File "D:\anaconda\envs\supTextDebug\lib\site-packages\transformers\tokenization_utils_base.py", line 268, in __getattr__
    raise AttributeError
AttributeError

解决方案:

错误行:outputs = model(**batch, labels=labels)

直接使用模型的forward方法,而不是将所有参数传递给 model:

python 复制代码
tokenizer = AutoTokenizer.from_pretrained("philschmid/bart-large-cnn-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained("philschmid/bart-large-cnn-samsum")

model.eval()
model.to("cuda")
loss = 0
for i in range(len(self.dataset)):
    batch = tokenizer([self.dataset[i]["source"]], return_tensors="pt", padding=True).to("cuda")
    labels = tokenizer([self.dataset[i]["target"]], return_tensors="pt", padding=True).to("cuda")
    print(batch)
    outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=labels["input_ids"])
    print(outputs.loss.item())
相关推荐
Chat_zhanggong34534 分钟前
K4A8G165WC-BITD产品推荐
人工智能·嵌入式硬件·算法
霍格沃兹软件测试开发37 分钟前
Playwright MCP浏览器自动化指南:让AI精准理解你的命令
运维·人工智能·自动化
强化学习与机器人控制仿真42 分钟前
RSL-RL:开源人形机器人强化学习控制研究库
开发语言·人工智能·stm32·神经网络·机器人·强化学习·模仿学习
网易智企1 小时前
智能玩具新纪元:一个AI能力底座开启创新“加速度”
人工智能·microsoft
咚咚王者1 小时前
人工智能之数据分析 numpy:第十二章 数据持久化
人工智能·数据分析·numpy
沛沛老爹1 小时前
AI应用入门之LangChain中SerpAPI、LLM-Math等Tools的集成方法实践
人工智能·langchain·llm·ai入门·serpapi
roman_日积跬步-终至千里2 小时前
【强化学习基础(5)】策略搜索与学徒学习:从专家行为中学习加速学习过程
人工智能
杭州泽沃电子科技有限公司4 小时前
在线监测:为医药精细化工奠定安全、合规与质量基石
运维·人工智能·物联网·安全·智能监测
GIS数据转换器4 小时前
GIS+大模型助力安全风险精细化管理
大数据·网络·人工智能·安全·无人机
OJAC1114 小时前
AI跨界潮:金融精英与应届生正涌入人工智能领域
人工智能·金融