#执行测试
def test():
loader_test = torch.utils.data.DataLoader(
dataset['test'],
batch_size=8,
shuffle=True,
drop_last=True,
collate_fn=DataCollatorWithPadding(tokenizer))
correct = 0
total = 0
for i, data in enumerate(loader_test):
with torch.no_grad():
out = model(**data).logits
out = out.argmax(dim=1)
correct += (out == data.labels).sum().item()
total += len(data.labels)
print(i, len(loader_test), correct / total)
if i == 5:
break
return correct / total
test()