HuggingFace项目实战之使用Trainer执行训练

目录:

一、加载tokenizer

python 复制代码
import torch

from transformers import AutoTokenizer

#加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')

tokenizer

二、加载数据集和编码

python 复制代码
from datasets import load_dataset

#加载数据集
dataset = load_dataset(path='lansinuote/ChnSentiCorp')

#编码
f = lambda x: tokenizer(x['text'], truncation=True, max_length=500)
dataset = dataset.map(f, remove_columns=['text'])

#设置数据类型
dataset.set_format('pt')

dataset, dataset['train'][0]

三、加载模型

python 复制代码
#定义模型
from transformers import BertConfig, BertForSequenceClassification

#在线加载一个语句分类模型
model = BertForSequenceClassification.from_pretrained(
    'google-bert/bert-base-chinese', num_labels=2)

model.config

四、执行训练

python 复制代码
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding

#配置训练参数
args = TrainingArguments(output_dir='output_dir',
                         use_cpu=True,
                         num_train_epochs=1,
                         max_steps=300,
                         eval_strategy='no',
                         per_device_train_batch_size=8)

#创建trainer
trainer = Trainer(model=model,
                  args=args,
                  train_dataset=dataset['train'],
                  data_collator=DataCollatorWithPadding(tokenizer))

#执行训练
trainer.train()

五、执行测试

python 复制代码
#执行测试
def test():
    loader_test = torch.utils.data.DataLoader(
        dataset['test'],
        batch_size=8,
        shuffle=True,
        drop_last=True,
        collate_fn=DataCollatorWithPadding(tokenizer))

    correct = 0
    total = 0
    for i, data in enumerate(loader_test):
        with torch.no_grad():
            out = model(**data).logits

        out = out.argmax(dim=1)
        correct += (out == data.labels).sum().item()
        total += len(data.labels)

        print(i, len(loader_test), correct / total)

        if i == 5:
            break

    return correct / total


test()
相关推荐
CDYXY15 小时前
2026年4月成都卡布灯箱源头口碑深度调研与避坑指南
大数据·人工智能
小真zzz20 小时前
2026年GEO监测工具深度横评:谁在AI时代守护品牌心智?
人工智能·百度·重构
ZFSS20 小时前
Localization Translate API 集成与使用指南
java·服务器·数据库·人工智能·mysql·ai编程
天行健,君子而铎21 小时前
合规对标·低误报漏报·稳定运行——知源-AI数据分类分级系统金融行业解决方案
人工智能·金融·分类
视觉&物联智能21 小时前
【杂谈】-游戏生成数据:人工智能训练中极易被低估的核心资源
人工智能·游戏·ai·chatgpt·openai·agi·deepseek
扫地的小何尚21 小时前
NVIDIA Vera Rubin 平台如何解决 Agentic AI 的 Scale-up 难题
大数据·人工智能·机器学习
莞凰21 小时前
昇腾CANN的“灵脉根基“:Runtime仓库探秘
android·人工智能·transformer
5201-1 天前
ops-conv:卷积算子从 CPU 到昇腾 NPU 的优化之路
人工智能·深度学习
HIT_Weston1 天前
92、【Agent】【OpenCode】edit 工具提示词
人工智能·agent·opencode
Shan12051 天前
机器学习评价指标之基础指标与综合指标
人工智能·机器学习