基于Optuna的transformers模型自动调参

文章目录

六、Trainer和文本分类


一、导入相关包

python 复制代码
!pip install transformers datasets evaluate accelerate
python 复制代码
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

二、加载数据集

python 复制代码
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset = dataset.filter(lambda x: x["review"] is not None)
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 7765
})
'''

三、划分数据集

python 复制代码
datasets = dataset.train_test_split(test_size=0.1)
datasets
'''
DatasetDict({
    train: Dataset({
        features: ['label', 'review'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['label', 'review'],
        num_rows: 777
    })
})
'''

四、数据集预处理

python 复制代码
import torch

tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")

def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples

tokenized_datasets = datasets.map(process_function, batched=True, 
                                  remove_columns=datasets["train"].column_names)
tokenized_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 777
    })
})
'''

五、创建模型(区别一)

python 复制代码
def model_init():
    model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")
    return model

六、创建评估函数

python 复制代码
import evaluate

acc_metric = evaluate.load("accuracy")
f1_metirc = evaluate.load("f1")
python 复制代码
def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metirc.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

七、创建 TrainingArguments(区别二)

  • logging_steps=500为了防止多次训练 log 太多可以增大 logging_steps
python 复制代码
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=64,  # 训练时的batch_size
                               per_device_eval_batch_size=128,  # 验证时的batch_size
                               logging_steps=500,               # log 打印的频率
                               evaluation_strategy="epoch",     # 评估策略
                               save_strategy="epoch",           # 保存策略
                               save_total_limit=3,              # 最大保存数
                               learning_rate=2e-5,              # 学习率
                               weight_decay=0.01,               # weight_decay
                               metric_for_best_model="f1",      # 设定评估指标
                               load_best_model_at_end=True)     # 训练完成后加载最优模型

八、创建 Trainer(区别三)

  • 没有指定 model而是指定 model_init
python 复制代码
from transformers import DataCollatorWithPadding
trainer = Trainer(model_init=model_init, 
                  args=train_args, 
                  train_dataset=tokenized_datasets["train"], 
                  eval_dataset=tokenized_datasets["test"], 
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=eval_metric)


# 之前
from transformers import DataCollatorWithPadding
trainer = Trainer(model=model,
                  args=train_args,
                  train_dataset=tokenized_datasets["train"],
                  eval_dataset=tokenized_datasets["test"],
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=eval_metric)

九、模型训练

python 复制代码
trainer.train()

十、模型训练(自动搜索)(区别四)

python 复制代码
!pip install optuna
  • 使用默认的超参数空间
  • compute_objective=lambda x: x["eval_f1"]中的 x是指的评价函数的返回值,在这里因为没有显示的指定评价函数返回值的 key,所以 f1key采用默认值 eval_f1
python 复制代码
trainer.hyperparameter_search(compute_objective=lambda x: x["eval_f1"], direction="maximize", n_trials=10)
  • 自定义超参数空间
    • 可以在default_hp_space_optuna 函数中增加 trainer 的选项
python 复制代码
def default_hp_space_optuna(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
        "seed": trial.suggest_int("seed", 1, 40),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
        "optim": trial.suggest_categorical("optim", ["sgd", "adamw_hf"]),
    }

trainer.hyperparameter_search(hp_space=default_hp_space_optuna, compute_objective=lambda x: x["eval_f1"], direction="maximize", n_trials=10)

启动 tensorboard

  • 进入运行日志文件夹
    • 终端启动
python 复制代码
!tensorboard --logdir runs
  • jupyter 启动

    运行这行代码将加载 TensorBoard并允许我们将其用于可视化

    %reload_ext tensorboard
    %tensorboard --logdir=./runs/


相关推荐
weixin_4374977742 分钟前
读书笔记:Context Engineering 2.0 (上)
人工智能·nlp
喝拿铁写前端1 小时前
前端开发者使用 AI 的能力层级——从表面使用到工程化能力的真正分水岭
前端·人工智能·程序员
goodfat1 小时前
Win11如何关闭自动更新 Win11暂停系统更新的设置方法【教程】
人工智能·禁止windows更新·win11优化工具
北京领雁科技1 小时前
领雁科技反洗钱案例白皮书暨人工智能在反洗钱系统中的深度应用
人工智能·科技·安全
落叶,听雪1 小时前
河南建站系统哪个好
大数据·人工智能·python
清月电子1 小时前
杰理AC109N系列AC1082 AC1074 AC1090 芯片停产替代及资料说明
人工智能·单片机·嵌入式硬件·物联网
Dev7z1 小时前
非线性MPC在自动驾驶路径跟踪与避障控制中的应用及Matlab实现
人工智能·matlab·自动驾驶
七月shi人2 小时前
AI浪潮下,前端路在何方
前端·人工智能·ai编程
橙汁味的风2 小时前
1隐马尔科夫模型HMM与条件随机场CRF
人工智能·深度学习·机器学习
极客小云2 小时前
【生物医学NLP信息抽取:药物识别、基因识别与化学物质实体识别教程与应用】
python·机器学习·nlp