HuggingFace项目实战之使用Trainer执行训练

目录:

一、加载tokenizer

python 复制代码
import torch

from transformers import AutoTokenizer

#加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')

tokenizer

二、加载数据集和编码

python 复制代码
from datasets import load_dataset

#加载数据集
dataset = load_dataset(path='lansinuote/ChnSentiCorp')

#编码
f = lambda x: tokenizer(x['text'], truncation=True, max_length=500)
dataset = dataset.map(f, remove_columns=['text'])

#设置数据类型
dataset.set_format('pt')

dataset, dataset['train'][0]

三、加载模型

python 复制代码
#定义模型
from transformers import BertConfig, BertForSequenceClassification

#在线加载一个语句分类模型
model = BertForSequenceClassification.from_pretrained(
    'google-bert/bert-base-chinese', num_labels=2)

model.config

四、执行训练

python 复制代码
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding

#配置训练参数
args = TrainingArguments(output_dir='output_dir',
                         use_cpu=True,
                         num_train_epochs=1,
                         max_steps=300,
                         eval_strategy='no',
                         per_device_train_batch_size=8)

#创建trainer
trainer = Trainer(model=model,
                  args=args,
                  train_dataset=dataset['train'],
                  data_collator=DataCollatorWithPadding(tokenizer))

#执行训练
trainer.train()

五、执行测试

python 复制代码
#执行测试
def test():
    loader_test = torch.utils.data.DataLoader(
        dataset['test'],
        batch_size=8,
        shuffle=True,
        drop_last=True,
        collate_fn=DataCollatorWithPadding(tokenizer))

    correct = 0
    total = 0
    for i, data in enumerate(loader_test):
        with torch.no_grad():
            out = model(**data).logits

        out = out.argmax(dim=1)
        correct += (out == data.labels).sum().item()
        total += len(data.labels)

        print(i, len(loader_test), correct / total)

        if i == 5:
            break

    return correct / total


test()
相关推荐
星期五不见面1 小时前
机器人学习!(二)ROS-基于Gazebo项目-YOLO(3)2026/01/13
人工智能·学习·机器人
d0ublεU0x002 小时前
注意力机制与transformer
人工智能·深度学习·transformer
凤希AI伴侣2 小时前
凤希AI提出:FXPA2P - 当P2P技术遇上AI,重新定义数据与服务的边界
人工智能·凤希ai伴侣
腾迹2 小时前
2026年企业微信SCRM系统服务推荐:微盛·企微管家的AI私域增长方案
大数据·人工智能
寰宇视讯2 小时前
脑科技走进日常 消费级应用开启新蓝海,安全与普惠成关键
人工智能·科技·安全
云卓SKYDROID2 小时前
无人机电机模块选型与技术要点
人工智能·无人机·遥控器·高科技·云卓科技
小酒星小杜2 小时前
在AI时代,技术人应该每天都要花两小时来构建一个自身的构建系统 - 总结篇
前端·vue.js·人工智能
云卓SKYDROID2 小时前
无人机螺旋桨材料与技术解析
人工智能·无人机·高科技·云卓科技·技术解析、
智驱力人工智能2 小时前
矿山皮带锚杆等异物识别 从事故预防到智慧矿山的工程实践 锚杆检测 矿山皮带铁丝异物AI预警系统 工厂皮带木桩异物实时预警技术
人工智能·算法·安全·yolo·目标检测·计算机视觉·边缘计算