目录
- [3 SFT源码分析](#3 SFT源码分析)
-
- [3.1 accelerate](#3.1 accelerate)
-
- [3.1.1 关键特性](#3.1.1 关键特性)
- [3.1.2 使用场景](#3.1.2 使用场景)
- [3.1.3 简单示例](#3.1.3 简单示例)
- [3.2 代码主入口](#3.2 代码主入口)
- [3.3 设置随机种子](#3.3 设置随机种子)
- [3.4 设置Log](#3.4 设置Log)
- [3.5 加载数据集](#3.5 加载数据集)
- [3.6 加载Tokenizer](#3.6 加载Tokenizer)
- [3.7 模型参数配置初始化](#3.7 模型参数配置初始化)
- [3.8 初始化SFT Trainer](#3.8 初始化SFT Trainer)
- [3.9 开始训练](#3.9 开始训练)
-
- [3.9.1 主函数](#3.9.1 主函数)
- [3.9.2 核心循环](#3.9.2 核心循环)
- [3.9.3 单步训练](#3.9.3 单步训练)
- [3.9.4 原始Loss计算方法](#3.9.4 原始Loss计算方法)
- [3.9.5 标签平滑](#3.9.5 标签平滑)
- [3.9.6 SFT的Loss计算方法](#3.9.6 SFT的Loss计算方法)
- [3.9.7 计算令牌准确性](#3.9.7 计算令牌准确性)
- [3.10 保存模型](#3.10 保存模型)
- [3.11 评估](#3.11 评估)
- [3.12 推送到Hub](#3.12 推送到Hub)
【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上)
【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)
省流:本文重点是【3.9 开始训练】小节。
3 SFT源码分析
HuggingFace已经将很多重要的函数都封装好了,我们只需要掉包就能简单实现SFT了。
前面几篇博文我们详细介绍了如何一步步搭建环境了,感兴趣的话可以翻阅一下,此处不展开细说了:
- 【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)
- 【复现DeepSeek-R1之Open R1实战】系列2:没有卡也能训模型!Colab跑OpenR1(附源码)
3.1 accelerate
我们使用了accelerate库来训练模型:
bash
# Train via command line
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--max_seq_length 4096 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--bf16 \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
# Train via YAML config
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
Accelerate
是 Hugging Face 开发的一个库,旨在简化深度学习模型的训练过程,特别是在分布式环境或使用不同硬件(如多个GPU、TPU等)时。它提供了一个统一且灵活的接口,使得用户能够轻松地配置和运行训练脚本,而无需深入理解复杂的分布式计算概念。以下是 Accelerate
的一些关键特性和优势:
3.1.1 关键特性
-
简化分布式训练 :无论是单机多卡、多机多卡还是TPU训练,
Accelerate
都能通过简单的配置文件或者命令行参数进行设置,大大降低了分布式训练的复杂性。 -
灵活性与可扩展性:支持多种深度学习框架,但主要与PyTorch集成得最为紧密。它允许用户在不修改核心训练代码的情况下调整训练策略,包括混合精度训练、梯度累积、梯度检查点等高级功能。
-
易于使用的API :
Accelerate
提供了一个高层次的API,使得启动训练任务变得非常简单。例如,你可以使用Accelerator()
对象来包裹你的训练循环,它会自动处理设备分配、数据加载器的优化等细节。 -
配置管理:通过一个简单的YAML格式配置文件,用户可以指定训练所需的各种参数,比如使用的设备类型(CPU/GPU/TPU)、是否启用混合精度训练等,这极大地提高了实验的可重复性。
-
兼容性:与Hugging Face Transformers库高度集成,可以直接用于Transformer模型的训练。当然,它也适用于其他类型的神经网络模型。
3.1.2 使用场景
- 当你需要在不同的硬件环境中快速部署训练任务时。
- 在探索不同的训练策略(如改变批大小、学习率等)时,
Accelerate
能让你以最小的代码改动实现这些变化。 - 如果你正在寻找一种方法来简化分布式训练的配置和执行流程,
Accelerate
是一个很好的选择。
3.1.3 简单示例
下面是一个如何使用Accelerate
进行简单训练任务的例子:
python
from accelerate import Accelerator
accelerator = Accelerator()
model, optimizer, train_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, scheduler
)
for epoch in range(num_epochs):
for batch in train_dataloader:
outputs = model(batch)
loss = loss_function(outputs, labels)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
在这个例子中,Accelerator
对象帮助我们自动化了许多底层细节,如将模型和数据迁移到正确的设备上,以及处理分布式训练中的通信问题。这样,开发者就可以专注于模型设计和训练策略本身。
3.2 代码主入口
python
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
首先调用了TrlPaser库,将输入的参数归类分成script_args, training_args, model_args这三类,每一类都是封装好的函数,这样便于拓展和迁移使用。
- script_args 主要是一些关于数据集的参数,例如 dataset_name(数据名称/路径)、dataset_config(数据集的配置)、dataset_train_split(训练集)、dataset_test_split(测试集)等等。
- training_args 继承自SFTConfig类,主要是一些关于训练的参数,例如 max_seq_length(tokenized序列的最大长度)、learning_rate等等。
- model_args 主要是一些关于模型的参数,例如 model_name_or_path(模型名称/路径)、torch_dtype(数据类型:bfloat16、float16、float32和auto)。
3.3 设置随机种子
设置随机种子,默认是42。主要是为了确保实验的可重复性,在训练模型时,涉及许多随机过程,例如初始化权重、数据集的shuffle等。通过固定随机种子,可以使得这些随机过程在每次运行时都产生相同的结果,从而保证实验结果的一致性和可重复性。
另外,当模型出现问题或需要调整参数时,固定的随机种子可以帮助开发者更容易地进行调试。因为相同的输入会得到相同的输出,这有助于定位问题。
在进行模型选择或调参时,使用相同的随机种子可以让不同的实验之间只存在因模型架构或参数设置不同而产生的差异,而非由于随机因素导致的变化,这样可以更准确地评估模型性能。
python
# Set seed for reproducibility
set_seed(training_args.seed)
3.4 设置Log
主要是打印一些关键信息,例如 系统时间、训练和模型参数配置等等。
python
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
if "wandb" in training_args.report_to:
init_wandb_training(training_args)
此外,还会从output的文件夹中获取最新的checkpoint,打印checkpoint信息。
python
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
3.5 加载数据集
通过load_dataset加载来自Hugging Face Hub的数据集或本地数据集,我们可以在Hugging Face Hub上找到数据集列表,或者使用[huggingface_hub.list_datasets]进行查找。
这个函数在后台执行以下操作:
- 加载数据集构建器:
- 确定数据集中最常见的数据格式并选择其关联的构建器(例如JSON、CSV、Parquet、Webdataset、ImageFolder等)。
- 根据文件名和目录名或YAML配置确定哪些文件属于哪个分割(例如训练/测试)。
- 也可以手动指定data_files以及要使用的数据集构建器(例如"parquet")。
- 运行数据集构建器:
- 在一般情况下:
- 如果数据文件尚未在本地可用或缓存,则从数据集中下载这些文件。
- 将数据集处理并缓存为类型化的Arrow表以用于缓存。Arrow表是任意长度的、类型化的表格,可以存储嵌套对象,并映射到numpy/pandas/python的通用类型。它们可以直接从磁盘访问、加载到RAM中甚至通过网络流式传输。
- 在流式处理的情况下:
- 不下载或缓存任何内容。相反,数据集将被惰性加载并在迭代时动态流式传输。
- 在一般情况下:
- 返回由split参数(默认:所有)请求的分割构建的数据集。
python
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
3.6 加载Tokenizer
关于Tokenizer的详细介绍可以看上一篇博文。
执行完这段,就会从预训练的大模型文件夹中自动加载Tokenizer。
python
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token
3.7 模型参数配置初始化
主要是完成模型加载时的一些参数配置,例如数据类型、量化配置等等。
python
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
3.8 初始化SFT Trainer
SFT Trainer继承自transformers库的Trainer类,
python
trainer = SFTTrainer(
model=model_args.model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
3.9 开始训练
python
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
3.9.1 主函数
train()函数会先加载模型以及完成一些初始化工作,然后通过 find_executable_batch_size 装饰器函数以某种方式调用目标函数 _inner_training_loop:要么是直接使用给定的批处理大小,要么是经过调整找到的最佳批处理大小。find_executable_batch_size 函数的目的是帮助自动找到适合执行的batch size,特别是对于那些可能因为内存不足(out-of-memory)或CUDNN相关异常而失败的操作。
python
inner_training_loop = find_executable_batch_size(
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
)
3.9.2 核心循环
最关键的是成员函数_inner_training_loop,该方法涵盖了从初始化到训练结束的整个过程。
- 初始化与状态设置
- 记录训练参数如批处理大小、总训练批处理大小、梯度累积步数、优化步骤总数及可训练参数数量。
- 初始化训练状态变量。
python
if self.args.per_device_train_batch_size != self._train_batch_size:
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
- 检查点恢复
- 如果提供了检查点路径并且存在相应的状态文件,则从检查点恢复训练状态。
python
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
if not args.ignore_data_skip:
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
- 更新引用
- 更新回调处理器中的模型、优化器、学习率调度器和数据加载器的引用。
python
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
- 状态更新
- 设置
self.state.max_steps
和self.state.num_train_epochs
,并确保进程零的状态正确性。
python
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
- 初始化损失变量
- 初始化
tr_loss
和_total_loss_scalar
,并将模型梯度置零。
python
tr_loss = torch.tensor(0.0).to(args.device)
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
- 回调处理
- 调用
on_train_begin
回调,并在训练开始时进行一次评估(如果配置了)。
python
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
- 主训练循环
- 遍历每个epoch,并在每个epoch开始时调用
on_epoch_begin
回调。 - 根据是否需要同步梯度设置加速器的状态,并执行单步训练 (
training_step
)。 - 对于同步梯度步骤:
- 进行梯度裁剪。
- 执行优化器步骤,并根据情况更新学习率调度器。
- 将模型梯度置零,并更新全局步数和当前epoch。
- 调用
on_step_end
回调并可能进行日志记录、保存和评估。
- 对于非同步梯度步骤,调用
on_substep_end
回调。
python
for epoch in range(epochs_trained, num_train_epochs):
epoch_dataloader = train_dataloader
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)
steps_in_epoch = len(epoch_dataloader) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
# 处理从检查点恢复的情况
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
epoch_iterator = iter(epoch_dataloader)
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for i, inputs in enumerate(batch_samples):
step += 1
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
if not do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(False)
else:
self.accelerator.gradient_state._set_sync_gradients(True)
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):
tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
tr_loss = tr_loss + tr_loss_step
if do_sync_step:
if args.max_grad_norm is not None and args.max_grad_norm > 0:
if is_sagemaker_mp_enabled() and args.fp16:
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
_grad_norm = nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), args.max_grad_norm)
else:
_grad_norm = self.accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
grad_norm = _grad_norm.item() if hasattr(_grad_norm, "item") else _grad_norm
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
self.optimizer.step()
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run and not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available():
xm.mark_step()
break
if self.control.should_epoch_stop or self.control.should_training_stop:
if is_torch_xla_available():
xm.mark_step()
break
- Epoch结束处理
- 调用
on_epoch_end
回调并可能进行日志记录、保存和评估。 - 如果启用了TPU调试选项,则打印调试指标报告。
python
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_xla_available():
xm.master_print(met.metrics_report())
else:
logger.warning("You enabled PyTorch/XLA debug metrics but you don't have a TPU configured.")
if self.control.should_training_stop:
break
- 训练结束处理
- 输出一条信息提示训练完成。
- 如果配置了在训练结束时加载最佳模型,则加载最佳模型检查点。
- 计算总损失并将结果添加到
self._total_loss_scalar
中。 - 计算训练速度指标 (
speed_metrics
) 并记录它们。 - 停止内存跟踪器并更新指标。
- 记录最终的训练指标。
- 根据保存限制删除旧的检查点。
- 调用
on_train_end
回调并完成当前推送操作。 - 清理嵌入层的前向后钩子(如果使用了NEFTune噪声)。
- 返回包含全局步数、训练损失和指标的
TrainOutput
对象。
python
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
if is_torch_xla_available():
xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()
self._load_best_model()
self._total_loss_scalar += tr_loss.item()
effective_global_step = max(self.state.global_step, 0.001)
train_loss = self._total_loss_scalar / effective_global_step
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps, num_tokens=num_train_tokens)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
self._finish_current_push()
if self.neftune_noise_alpha is not None:
self._deactivate_neftune(self.model)
return TrainOutput(self.state.global_step, train_loss, metrics)
3.9.3 单步训练
在 _inner_training_loop 方法中,单步训练是在 training_step 方法内完成的。
python
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
training_step 方法是训练过程中对每个批次数据执行单步训练的核心函数。它负责前向传播、计算损失、后向传播等操作,并返回当前批次的训练损失,主要包括以下几个步骤:
- 前向传播:通过 self.compute_loss 方法计算损失,该方法通常包含模型的前向传播和损失函数的计算。
- 后向传播:根据是否使用Apex混合精度训练,选择不同的方式进行后向传播。
- 多GPU处理:如果使用多个GPU进行分布式训练,需要对损失值进行平均。
- 梯度累积:根据配置的梯度累积步数,对损失值进行缩放。
- 内存管理:根据配置,定期清空不同硬件类型的缓存以释放内存。
3.9.4 原始Loss计算方法
损失的计算主要是在单步训练中的compute_loss函数中完成,它处理了标签平滑、自定义损失函数以及多设备(如多GPU)的损失平均等问题。
python
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
- model (
nn.Module
): 要训练的模型。 - inputs (
Dict[str, Union[torch.Tensor, Any]]
): 包含输入和目标的字典,通常包括输入ID、注意力掩码、标签等。 - return_outputs (
bool
) : 是否返回模型输出,默认为False
。 - num_items_in_batch (
int
, optional): 批次中的样本数量(可选参数)。
-
处理标签
pythonif (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: labels = inputs.pop("labels") else: labels = None
- 如果启用了标签平滑器 (
label_smoother
) 或者存在自定义损失函数 (compute_loss_func
) 并且输入中包含labels
,则从中提取标签并从inputs
字典中移除。
- 如果启用了标签平滑器 (
-
准备损失计算的关键字参数
pythonif self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs}
- 如果模型接受额外的损失关键字参数,则将
num_items_in_batch
添加到inputs
中。
- 如果模型接受额外的损失关键字参数,则将
-
前向传播
pythonoutputs = model(**inputs)
- 将输入数据传递给模型进行前向传播,并获取模型输出。
-
保存过去的状态(如果适用)
pythonif self.args.past_index >= 0: self._past = outputs[self.args.past_index]
- 如果配置了过去索引 (
past_index
),则保存模型输出中的相应部分(例如,对于某些生成任务)。
- 如果配置了过去索引 (
-
计算损失
- 根据是否有标签、是否使用自定义损失函数或标签平滑器,选择不同的方式计算损失。
情况一:有标签且使用自定义损失函数或标签平滑器
python
if labels is not None:
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
if self.compute_loss_func is not None:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
- 如果存在标签:
- 解包加速器中的模型。
- 判断模型是否为PEFT模型(Parameter-Efficient Fine-Tuning),并获取模型名称。
- 如果存在自定义损失函数 (
compute_loss_func
),则调用该函数计算损失。 - 如果模型属于因果语言模型(Causal Language Model),则使用标签平滑器 (
label_smoother
) 并设置shift_labels=True
。 - 否则,直接使用标签平滑器计算损失。
情况二:无标签或模型未返回损失
python
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
- 如果没有标签或者模型输出中不包含
loss
键,则抛出异常提示用户模型未返回损失。 - 否则,从模型输出中提取损失值。
-
多设备损失平均
pythonif self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: loss *= self.accelerator.num_processes
- 如果配置了跨设备平均令牌数 (
average_tokens_across_devices
),则根据设备数量调整损失值。
- 如果配置了跨设备平均令牌数 (
-
返回结果
pythonreturn (loss, outputs) if return_outputs else loss
- 如果
return_outputs
参数为True
,则返回一个元组(loss, outputs)
;否则仅返回损失值。
- 如果
3.9.5 标签平滑
LabelSmoother
是一个用于在预计算的模型输出上添加标签平滑(label smoothing)的类,标签平滑是一种正则化技术,旨在防止模型对训练数据中的特定标签过度自信,从而提高泛化能力。
python
@dataclass
class LabelSmoother:
"""
Adds label-smoothing on a pre-computed output from a Transformers model.
Args:
epsilon (`float`, *optional*, defaults to 0.1):
The label smoothing factor.
ignore_index (`int`, *optional*, defaults to -100):
The index in the labels to ignore when computing the loss.
"""
epsilon: float = 0.1
ignore_index: int = -100
- epsilon (
float
): 标签平滑因子,默认值为 0.1。 - ignore_index (
int
): 在计算损失时忽略的标签索引,默认值为 -100。
-
提取
logits
pythonlogits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
- 从模型输出中提取
logits
,如果输出是字典形式,则通过键"logits"
获取;否则直接取第一个元素。
- 从模型输出中提取
-
偏移处理(如果需要)
pythonif shift_labels: logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous()
- 如果需要偏移标签(如因果语言模型),则对
logits
和labels
进行偏移处理,使它们对齐。
- 如果需要偏移标签(如因果语言模型),则对
-
计算负对数概率
pythonlog_probs = -nn.functional.log_softmax(logits, dim=-1)
- 使用
log_softmax
函数计算负对数概率(即负对数似然)。
- 使用
-
调整标签维度
pythonif labels.dim() == log_probs.dim() - 1: labels = labels.unsqueeze(-1)
- 如果
labels
的维度比log_probs
少一维,则增加一个维度以匹配log_probs
的形状。
- 如果
-
创建填充掩码
pythonpadding_mask = labels.eq(self.ignore_index) labels = torch.clamp(labels, min=0)
- 创建一个填充掩码
padding_mask
,标记哪些位置是填充(使用ignore_index
)。 - 使用
clamp
函数将标签限制在非负值范围内,避免在后续操作中出现错误。
- 创建一个填充掩码
-
计算负对数似然损失(NLL Loss)
pythonnll_loss = log_probs.gather(dim=-1, index=labels)
- 使用
gather
函数从log_probs
中提取对应于真实标签的负对数概率。
- 使用
-
计算平滑损失
pythonsmoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
- 计算所有类别的负对数概率之和,并保持维度不变。
-
应用填充掩码
pythonnll_loss.masked_fill_(padding_mask, 0.0) smoothed_loss.masked_fill_(padding_mask, 0.0)
- 使用填充掩码将填充位置的损失置为零。
-
计算有效元素数量
pythonnum_active_elements = padding_mask.numel() - padding_mask.long().sum()
- 计算非填充位置的有效元素数量。
-
归一化损失
pythonnll_loss = nll_loss.sum() / num_active_elements smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
- 对负对数似然损失和平滑损失进行归一化处理。
-
组合最终损失
pythonreturn (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
- 组合负对数似然损失和平滑损失,得到最终的标签平滑损失。
3.9.6 SFT的Loss计算方法
SFT Trainer重写了compute_loss方法,不仅计算训练损失,还额外计算了令牌(token)准确性,这对于评估模型在生成任务中的表现特别有用。
-
调用父类的
compute_loss
方法python(loss, outputs) = super().compute_loss( model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch )
- 调用父类的
compute_loss
方法来计算损失值和模型输出。这里使用return_outputs=True
确保返回模型输出以便后续计算令牌准确性。
- 调用父类的
-
计算令牌准确性(如果适用)
pythonif "labels" in inputs and not self.args.use_liger: shift_logits = outputs.logits[..., :-1, :].contiguous() shift_labels = inputs["labels"][..., 1:].contiguous()
- 如果输入中包含标签且未使用 Liger(一种特定的优化器或模型配置),则从模型输出中提取
logits
和labels
。 - 对于因果语言模型(Causal Language Model),通常需要对
logits
和labels
进行偏移处理:shift_logits
: 将logits
的最后一个维度去掉一个位置,使其与labels
对齐。shift_labels
: 将labels
的第一个位置去掉,使其与logits
对齐。
- 如果输入中包含标签且未使用 Liger(一种特定的优化器或模型配置),则从模型输出中提取
-
多GPU环境下收集 logits 和 labels
pythonshift_logits = self.accelerator.gather_for_metrics(shift_logits) shift_labels = self.accelerator.gather_for_metrics(shift_labels)
- 使用加速器的
gather_for_metrics
方法将所有GPU上的logits
和labels
收集到主进程中。这一步确保了在分布式训练环境中能够正确地计算全局指标。
- 使用加速器的
-
计算令牌准确性
pythonif self.accelerator.is_main_process: accuracy = compute_token_accuracy(shift_logits, shift_labels) self._metrics["mean_token_accuracy"].append(accuracy)
- 在主进程中(即
is_main_process
为True
),调用compute_token_accuracy
函数计算令牌准确性,并将其添加到_metrics
字典中。
- 在主进程中(即
-
返回结果
pythonreturn (loss, outputs) if return_outputs else loss
- 如果
return_outputs
参数为True
,则返回一个元组(loss, outputs)
;否则仅返回损失值。
- 如果
3.9.7 计算令牌准确性
该函数用于计算令牌(token)的准确性,即模型预测的正确率。它通过比较模型输出的预测值和真实标签来计算准确率,并忽略填充(padding)部分的令牌。
python
def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
- logits (
torch.Tensor
) : 模型输出的对数概率张量,形状通常为(batch_size, sequence_length, vocab_size)
。 - labels (
torch.Tensor
) : 真实标签张量,形状通常为(batch_size, sequence_length)
。 - ignore_index (
int
) : 忽略的索引,默认值为-100
,表示填充部分的标签。
-
获取预测值
pythonpredictions = logits.argmax(dim=-1)
- 使用
argmax
函数从logits
中获取每个位置的最大概率对应的索引,作为模型的预测值。dim=-1
表示在最后一个维度(词汇表维度)上进行操作,结果是一个形状为(batch_size, sequence_length)
的张量。
- 使用
-
创建非填充掩码
pythonmask = labels != ignore_index
- 创建一个布尔掩码
mask
,标记哪些位置不是填充部分(即不等于ignore_index
)。这个掩码用于后续计算时忽略填充部分的令牌。
- 创建一个布尔掩码
-
计算正确的预测
pythoncorrect_predictions = (predictions == labels) & mask
- 计算预测值与真实标签相等的位置,并结合掩码
mask
过滤掉填充部分的令牌。结果是一个布尔张量,其中True
表示正确预测且非填充位置。
- 计算预测值与真实标签相等的位置,并结合掩码
-
统计有效令牌数量
pythontotal_tokens = mask.sum() correct_tokens = correct_predictions.sum()
- 使用
sum
函数统计掩码中True
的数量,得到有效令牌的总数total_tokens
。 - 同样地,使用
sum
函数统计correct_predictions
中True
的数量,得到正确预测的令牌数correct_tokens
。
- 使用
-
计算准确性
pythonaccuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0
- 计算准确性:将正确预测的令牌数除以总的有效令牌数。如果有效令牌数为零,则返回
0.0
以避免除零错误。
- 计算准确性:将正确预测的令牌数除以总的有效令牌数。如果有效令牌数为零,则返回
-
返回准确性
pythonreturn accuracy
- 返回计算出的令牌准确性。
3.10 保存模型
python
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
3.11 评估
直接调用trainer的evaluate()函数完成评测。
python
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
3.12 推送到Hub
将训练结果推送到HuggingFace Hub上。
python
if training_args.push_to_hub:
logger.info("Pushing to hub...")
trainer.push_to_hub(**kwargs)