HuggingFace项目实战之使用Trainer执行训练

目录:

一、加载tokenizer

python 复制代码
import torch

from transformers import AutoTokenizer

#加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')

tokenizer

二、加载数据集和编码

python 复制代码
from datasets import load_dataset

#加载数据集
dataset = load_dataset(path='lansinuote/ChnSentiCorp')

#编码
f = lambda x: tokenizer(x['text'], truncation=True, max_length=500)
dataset = dataset.map(f, remove_columns=['text'])

#设置数据类型
dataset.set_format('pt')

dataset, dataset['train'][0]

三、加载模型

python 复制代码
#定义模型
from transformers import BertConfig, BertForSequenceClassification

#在线加载一个语句分类模型
model = BertForSequenceClassification.from_pretrained(
    'google-bert/bert-base-chinese', num_labels=2)

model.config

四、执行训练

python 复制代码
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding

#配置训练参数
args = TrainingArguments(output_dir='output_dir',
                         use_cpu=True,
                         num_train_epochs=1,
                         max_steps=300,
                         eval_strategy='no',
                         per_device_train_batch_size=8)

#创建trainer
trainer = Trainer(model=model,
                  args=args,
                  train_dataset=dataset['train'],
                  data_collator=DataCollatorWithPadding(tokenizer))

#执行训练
trainer.train()

五、执行测试

python 复制代码
#执行测试
def test():
    loader_test = torch.utils.data.DataLoader(
        dataset['test'],
        batch_size=8,
        shuffle=True,
        drop_last=True,
        collate_fn=DataCollatorWithPadding(tokenizer))

    correct = 0
    total = 0
    for i, data in enumerate(loader_test):
        with torch.no_grad():
            out = model(**data).logits

        out = out.argmax(dim=1)
        correct += (out == data.labels).sum().item()
        total += len(data.labels)

        print(i, len(loader_test), correct / total)

        if i == 5:
            break

    return correct / total


test()
相关推荐
杭州泽沃电子科技有限公司1 天前
为电气风险定价:如何利用监测数据评估工厂的“电气安全风险指数”?
人工智能·安全
Godspeed Zhao1 天前
自动驾驶中的传感器技术24.3——Camera(18)
人工智能·机器学习·自动驾驶
顾北121 天前
MCP协议实战|Spring AI + 高德地图工具集成教程
人工智能
wfeqhfxz25887821 天前
毒蝇伞品种识别与分类_Centernet模型优化实战
人工智能·分类·数据挖掘
中杯可乐多加冰1 天前
RAG 深度实践系列(七):从“能用”到“好用”——RAG 系统优化与效果评估
人工智能·大模型·llm·大语言模型·rag·检索增强生成
珠海西格电力科技1 天前
微电网系统架构设计:并网/孤岛双模式运行与控制策略
网络·人工智能·物联网·系统架构·云计算·智慧城市
FreeBuf_1 天前
AI扩大攻击面,大国博弈引发安全新挑战
人工智能·安全·chatgpt
weisian1511 天前
进阶篇-8-数学篇-7--特征值与特征向量:AI特征提取的核心逻辑
人工智能·pca·特征值·特征向量·降维
Java程序员 拥抱ai1 天前
撰写「从0到1构建下一代游戏AI客服」系列技术博客的初衷
人工智能
186******205311 天前
AI重构项目开发全流程:效率革命与实践指南
人工智能·重构