python transformers笔记(TrainingArguments类)

TrainingArguments类

TrainingArguments是Hugging Face Transformers库中用于集中管理超参数和配置的核心类。它定义了模型训练、评估、保存和日志记录的所有关键参数,并通过Trainer类实现自动化训练流程。

1、核心功能

(1)统一管理训练配置:替代手动定义学习率、批次大小等分散的超参数

(2)自动化流程配置:控制训练、评估、保存、日志记录等行为的触发策略

(3)支持混合精度训练(FP16/BP16)、多GPU/TPU分布式训练、梯度积累等

2、关键参数

(1)基础训练配置

参数名 作用 示例 默认值
output_dir 模型和日志的保存路径 "./results" "tmp_trainer"
num_train_epochs 训练总轮次 3 3.0
per_device_train_batch_size 每个GPU/CPU的批次大小 16 8
per_device_eval_batch_size 评估时的批次大小 32 8
learning_rate 初始学习率 2e-5 5e-5
weight_decay 权重衰减(L2正则化)系数 0.01 0.0

(2)训练策略控制

参数名 作用 示例 默认值
gradient_accumulation_steps 梯度累积步数(模拟更大批次) 4 1
evaluation_strategy 评估触发策略:"epoch"/"steps"/"no" "epoch" "no"
save_strategy 模型保存策略(需要与evaluation_strategy一致) "epoch" "steps"
logging_strategy 日志记录策略 "steps" "steps"
logging_steps 每多少步记录一次日志(当logging_strategy="steps"时生效) 100 500
save_steps 每多少步保存一次模型 100 500
eval_steps 每多少步评估一次模型 100 None

(3)硬件与性能优化

参数名 作用 示例 默认值
fp16 是否启用FP16混合精度训练(NVIDIA GPU) True False
bf16 是否启用BF16混合精度训练(AMD/Intel GPU/TPU) False False
optim 优化器类型(如"adamw_torch"、"adafactor") "adamw_torch" "adamw_hf"
dataloader_num_workers 数据加载的线程数 4 0
gradient_checkpointing 是否启用梯度检查点(节省显存,但降低速度) False False

(4)模型保存或加载

参数名 作用 示例 默认值
save_total_limit 最大保存的检查点数量(旧的会被删除) 3 None
load_best_model_at_end 训练结束后是否加载最佳模型(需启用评估) True False
metric_for_best_model 用于选择最佳模型的指标(如"eval_loss"或自定义指标) "eval_accuracy" None
great_is_better 指标是否越大越好(需与metric_for_best_model配合) None None

(5)日志与监控

参数名 作用 示例 默认值
report_to 日志上报工具(如"tensorboard"、"wandb") ["tensorboard"]
logging_dir TensorBoard日志记录 "./logs"
disable_tqdm 是否禁用进度条 False
log_level 日志级别("debug"、"warning"、"info"等) "debug" "passive"

3、使用示例

python 复制代码
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    logging_dir="./logs",
    report_to=["tensorboard"],
)

4、高级用法

(1)自定义优化器和调度器

(2)动态参数调整

(3)超参数搜索

5、配置文件支持

python 复制代码
# 保存配置
training_args.save_to_json("./training_args.json")

# 加载配置
new_args = TrainingArguments.from_json_file("./training_args.json")

6、注意事项

(1)策略一致性

save_strategy和evaluation_strategy必须相同(除非设置为"no")

(2)混合精度选择

NVIDIA GPU用fp16,AMD/Intel GPU或TPU用bf16

(3)内存优化

遇到OOM错误时,减小per_device_train_batch_size或增加gradient_accumulation_steps。

相关推荐
数据科学作家20 小时前
学数据分析必囤!数据分析必看!清华社9本书覆盖Stata/SPSS/Python全阶段学习路径
人工智能·python·机器学习·数据分析·统计·stata·spss
HXQ_晴天21 小时前
CASToR 生成的文件进行转换
python
java1234_小锋1 天前
Scikit-learn Python机器学习 - 特征预处理 - 标准化 (Standardization):StandardScaler
python·机器学习·scikit-learn
Python×CATIA工业智造1 天前
Python带状态生成器完全指南:从基础到高并发系统设计
python·pycharm
向qian看_-_1 天前
Linux 使用pip报错(error: externally-managed-environment )解决方案
linux·python·pip
Nicole-----1 天前
Python - Union联合类型注解
开发语言·python
Eric.5651 天前
python advance -----object-oriented
python
云天徽上1 天前
【数据可视化-107】2025年1-7月全国出口总额Top 10省市数据分析:用Python和Pyecharts打造炫酷可视化大屏
开发语言·python·信息可视化·数据挖掘·数据分析·pyecharts
THMAIL1 天前
机器学习从入门到精通 - 数据预处理实战秘籍:清洗、转换与特征工程入门
人工智能·python·算法·机器学习·数据挖掘·逻辑回归
@HNUSTer1 天前
Python数据可视化科技图表绘制系列教程(六)
python·数据可视化·科技论文·专业制图·科研图表