大模型之Bloom&LLAMA----SFT(模型微调)

0. 简介

随着chatgpt的爆火,最近也有很多大模型在不断地出现,比如说Bloom系列以及以LLAMA为基础的ziya和baichuan。这些模型相较于chatglm来说,更加具有发展前景,因为其是完全可商用,并可以不断迭代更新的。最近作者在跟着hiyouga大佬的LLaMA-Efficient-Tuning进行学习,相较于其他的项目来说,该项目是非常适合跟着学习并入门的。

1. 什么是SFT

SFT(Scalable Fine-Tuning)是一种用于自然语言处理的技术,它通过对预训练的语言模型进行微调,使其适应特定任务。在大模型SFT中,使用的是大型的预训练语言模型,例如LLAMA、GPT等,这些模型具有数十亿甚至数百亿个参数,可以处理大量的文本数据。

SFT的主要思想是在一个大型的预训练模型的基础上,针对特定的任务对模型进行微调。在微调过程中,模型会根据任务的特点调整模型的参数和结构,以提高模型在该任务上的表现。在微调过程中,可以使用不同的技术,例如数据增强、正则化、优化算法等。

SFT的优点是可以快速地针对不同的任务进行微调,而无需重新训练整个模型。此外,由于使用的是大型的预训练模型,可以利用海量的文本数据进行训练,从而获得更好的性能。不过,SFT也有一些缺点,例如需要大量的计算资源和时间进行微调,以及可能会出现过拟合等问题。

目前常用的SFT方法有P-Tuning v2LORAQLoRA、冻结(Freeze)、全参数(full-parameter)等方法。我们先来看一看在LLaMA-Efficient-Tuning中是如何写SFT的

2. 代码阅读--train_sft.py

下面是sft对应大模型的脚本,主要包括模型和数据的准备,数据集的划分,训练和评估等步骤。

首先,代码导入了一些必要的模块和函数。这包括一些用于数据处理、训练、加载预训练模型和绘制损失图的工具函数。(这部分和pt中一样)

python 复制代码
    # Prepare pretrained model and dataset
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")# 用于准备各种参数,包括模型参数、数据参数、训练参数和微调参数。
    dataset = prepare_data(model_args, data_args)# 用于准备数据集
    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")# 用于加载sft微调的模型和分词器。
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")# 用于预处理数据,例如将文本转换为模型可以理解的格式。
    data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)# 动态地对数据进行填充,使得每个batch中的数据长度一致。

下面的代码是用于Seq2SeqTrainer的解码参数进行覆盖

python 复制代码
   # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = training_args.generation_max_length if \
                training_args.generation_max_length is not None else data_args.max_target_length# 设置训练参数(training_args)中的生成最大长度
    training_args.generation_num_beams = data_args.eval_num_beams if \
                data_args.eval_num_beams is not None else training_args.generation_num_beams # 设置训练参数中的生成束搜索数(generation_num_beams)

然后,根据是否进行训练,对数据集进行划分。如果进行训练,且开发集的比例大于0,那么数据集会被划分为训练集和开发集;否则,全部数据用于训练。如果不进行训练,那么全部数据用于评估或预测。

python 复制代码
    # Split the dataset
    if training_args.do_train:
        if data_args.dev_ratio > 1e-6:
            dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
            trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
        else:
            trainer_kwargs = {"train_dataset": dataset}
    else: # do_eval or do_predict
        trainer_kwargs = {"eval_dataset": dataset}

接着,初始化Seq2SeqPeftTrainer对象,传入微调参数、模型、训练参数、分词器、数据处理器、回调函数和计算度量等参数(都是继承自Seq2SeqTrainer),以及前面划分的数据集。这个我们下一节将会仔细阅读里面的操作

...详情请参照古月居

相关推荐
kuokay10 小时前
MLOps 与 AIOps 的核心概
人工智能·分布式·大模型·agent·llama
Trouville011 天前
windows系统使用llama.cpp进行本地大模型部署
llama
棒棒的唐1 天前
windows 直接安装llama.cpp的方法
llama
troubles maker1 天前
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
llm·nlp·llama·多模态
xyz_CDragon1 天前
把旧电脑变成AI算力:llama.cpp RPC 局域网分布式推理验证与实战
人工智能·分布式·python·rpc·llama
wengad2 天前
llama.cpp进行模型格式转换和量化
llama
小七-七牛开发者3 天前
本地模型为什么能跑起来?从 llama.cpp 量化说起
agent·llama·模型部署·ollama·本地模型
七牛云行业应用3 天前
Llama 4 实战指南:Scout/Maverick 本地部署 + API 调用完整流程【2026】
llama
Soari4 天前
llama.cpp更新(b9553):LLM inference in C/C++,本地和云端实现高性能大模型推理
c语言·c++·llama
一叶知秋dong4 天前
llama.cpp 启动脚本
linux·服务器·llama