SFT(监督式微调)

一、SFT(监督式微调)

SFT(Supervised Fine-Tuning,监督式微调)是大语言模型(LLM)对齐(Alignment)流程中最基础也最核心的第一步,是在海量无标注数据预训练后的关键优化环节。

  • 本质:用人工标注的"指令-回答"成对数据,在预训练好的大模型基础上继续训练,让模型学习"理解人类指令→生成符合预期回答"的映射关系。
  • 对比预训练:预训练是让模型"认识语言"(学习语法、语义、知识),但只会无目的续写文本;SFT是让模型"听懂指令"(理解人类意图),并针对性生成有用的回答。

明白SFT的作用:一般大模型先开展预训练,核心采用从左到右的自回归生成训练模式:通过学习海量无标注文本中 "前文推导后文" 的语言规律,让模型先掌握基础的文字理解能力与通顺表达能力,既能精准捕捉语法逻辑、语义关联,也能积累广泛的通用知识,为后续理解人类指令、生成针对性内容筑牢基础。而SFT(监督式微调) 正是在这一通用能力之上的关键 "对齐" 步骤,它借助人工标注的 "指令 - 回答" 成对数据,让模型学会将预训练阶段积累的语言知识和通用信息,转化为对人类指令的精准响应,通过仅计算回答部分的损失优化,实现从 "无目的文本续写" 到 "有意图指令遵循" 的能力跨越。

二、SFT训练逻辑

代码是基于Qwen2-0.5B-Instruct模型的中文问答SFT微调实现,完美贴合SFT的核心逻辑:

1. 输入数据处理(指令+回答的格式化)
  • 原始数据 :代码中定义了10条"instruction(指令)+ output(回答)"数据,比如:

    python 复制代码
    {"instruction": "请介绍一下人工智能。", 
    "output": "人工智能(AI)是计算机科学的一个分支..."}
  • 格式拼接(替代分隔符) :没有用通用的,而是采用Qwen模型专属的对话分隔符拼接,本质是统一"指令-回答"的输入格式:

    python 复制代码
    text = f"<|im_start|>user\n{item['instruction']}
    <|im_end|>\n<|im_start|>assistant\n{item['output']}<|im_end|>"

    最终拼接后的输入文本示例:

    复制代码
    <|im_start|>user
    请介绍一下人工智能。<|im_end|>
    <|im_start|>assistant
    人工智能(AI)是计算机科学的一个分支,
    致力于创建能够执行通常需要人类智能的任务的系统。<|im_end|>
  • 词嵌入转换 :代码中通过AutoTokenizer(分词器)将拼接后的文本转换成模型可识别的token(词嵌入向量),对应提到的"转化为词嵌入向量(length * embedding_dim)"。

2. 模型复用(基于预训练模型)

代码中直接加载预训练好的Qwen2-0.5B-Instruct模型,没有重新训练底层架构,完全符合SFT"复用预训练模型"的核心逻辑:

python 复制代码
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,  # 本地预训练模型路径
    trust_remote_code=True,
    torch_dtype=torch.float32,
    device_map="auto"
)
3. 损失计算(仅优化回答部分)

虽然代码中没有显式写损失计算逻辑,但SFTTrainer( trl 框架的核心类)已经封装了这一关键逻辑:

  • 模型接收拼接后的"用户指令+助手回答"token序列作为输入,预测每个位置的下一个token;
  • 损失计算时,仅对"assistant(回答)"部分的token计算交叉熵损失(指令部分的token不计算损失),让模型只学习"根据指令生成回答"的映射;
  • 代码中<|im_end|>作为回答的终止标识,对应提到的<eos>(End of Sentence)作用。
4. 训练参数适配

为了让SFT训练能在普通电脑上执行,做了针对性优化:

  • per_device_train_batch_size=2:小批次大小,降低显存占用;
  • gradient_accumulation_steps=4:梯度累积,等效增大批次大小,保证训练效果;
  • torch.float32/fp16=False:禁用高精度训练,避免普通显卡的梯度缩放问题。
5. 训练与保存(完成微调)
  • 代码中trainer.train()执行SFT训练,模型在10条标注数据上迭代3轮(num_train_epochs=3),逐步优化参数;
  • 训练完成后,trainer.save_model()将微调后的模型保存到./output,此时模型已学会针对10条指令生成对应回答。

代码总结

代码中加载的Qwen2-0.5B-Instruct本身是 "预训练 + 基础 SFT" 的模型,在此基础上做二次 SFT,核心逻辑不变:

  1. 预训练打底:该模型已通过万亿级文本学会 "语言通顺" 和 "AI / 编程相关知识";
  2. 基础 SFT:模型已初步学会 "响应指令"(知道<|im_start|>user是用户提问,<|im_start|>assistant是需要回答);
  3. 二次 SFT:用 10 条专属问答数据,让模型更精准地回答指定的 AI / 编程问题(比如 "解释梯度下降算法")。

三、SFT流程图

以下是代码逻辑的SFT完整训练流程图,用文字步骤清晰呈现:

复制代码
1. 准备阶段
   ├── 定义配置参数:指定本地预训练模型路径(Qwen2-0.5B-Instruct)、输出路径、最大序列长度
   └── 构造标注数据:准备"指令-回答"成对的中文问答数据(10条AI/编程相关问题)

2. 数据预处理
   ├── 格式标准化:将"指令+回答"按Qwen模型要求拼接(<|im_start|>user/assistant分隔符)
   ├── 转换为Dataset格式:方便Hugging Face框架加载
   └── 分词编码:用AutoTokenizer将文本转换为token向量(词嵌入)

3. 模型加载
   ├── 加载预训练模型:复用Qwen2-0.5B-Instruct的权重和架构
   ├── 加载分词器:匹配预训练模型的token规则
   └── 配置模型参数:设置float32精度、自动分配设备(CPU/GPU)

4. 训练配置
   ├── 设置SFT训练参数:批次大小、训练轮数、学习率、梯度累积等
   └── 初始化SFTTrainer:绑定模型、数据、分词器、训练参数

5. 核心训练流程
   ├── 输入token序列到模型:"指令部分+回答部分"的完整向量
   ├── 模型预测每个token的下一个值
   ├── 计算损失:仅对"回答部分"的token计算交叉熵损失
   ├── 反向传播:更新模型参数(仅微调顶层/部分参数,保留预训练知识)
   └── 迭代训练:重复上述步骤,直到完成指定训练轮数

6. 结果保存
   ├── 保存微调后的模型权重到输出目录
   └── 保存分词器:保证后续推理时的文本编码一致性

7. 训练结束:得到适配中文问答场景的微调后模型

总结

  1. SFT核心:用"指令-回答"标注数据微调预训练模型,让模型从"文本续写"升级为"指令理解+回答生成"。
  2. 代码落地关键:数据格式需匹配目标模型(Qwen的分隔符规则)、训练参数适配硬件(小批次+梯度累积)、损失仅计算回答部分。
  3. 流程逻辑:数据预处理→模型加载→训练配置→损失优化→模型保存,是SFT落地的通用步骤,的代码完整覆盖了这一流程。

实验设置

未经SFT微调的预训练模型虽具备直接生成文本的能力,但核心仅为从左到右的续写能力------它能基于前文规律补全内容,却完全不具备理解人类指令的对话交互能力。而通过SFT(监督式微调)这一关键步骤,用"指令-回答"的标注数据对预训练模型进行针对性训练,就能让模型学会映射人类指令与对应回答的关系,实现从"被动续写文本"到"主动响应问答"的能力跃迁,这正是初代大语言模型获得实用对话问答能力的核心流程。

完整代码

训练代码

python 复制代码
"""
大语言模型SFT(Supervised Fine-Tuning)微调示例
使用trl框架进行监督式微调
"""

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments
)
# trl	大模型对齐微调专用框架,封装 SFT 训练逻辑
from trl import SFTTrainer, SFTConfig

# 配置
# 使用本地模型路径,确保可以离线运行
# 首次运行前,下载模型
MODEL_NAME = "./models/Qwen2-0.5B-Instruct"  # 本地模型路径
OUTPUT_DIR = "./output"
MAX_SEQ_LENGTH = 512

# 构造中文问答数据
def create_chinese_dataset():
    """构造简单的中文问答数据集"""
    data = [
        {
            "instruction": "请介绍一下人工智能。",
            "input": "",
            "output": "人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。这包括学习、推理、感知和自然语言处理等能力。"
        },
        {
            "instruction": "什么是深度学习?",
            "input": "",
            "output": "深度学习是机器学习的一个子领域,使用具有多个层(深度)的神经网络来学习数据的复杂模式和表示。它模仿人脑的神经网络结构。"
        },
        {
            "instruction": "Python中的列表和元组有什么区别?",
            "input": "",
            "output": "列表(list)是可变的,可以修改、添加或删除元素,使用方括号[]。元组(tuple)是不可变的,一旦创建就不能修改,使用圆括号()。"
        },
        {
            "instruction": "解释一下什么是监督学习。",
            "input": "",
            "output": "监督学习是机器学习的一种方法,使用标记的训练数据来训练模型。模型学习输入和输出之间的映射关系,然后可以对新的未标记数据进行预测。"
        },
        {
            "instruction": "如何提高模型的泛化能力?",
            "input": "",
            "output": "提高模型泛化能力的方法包括:1) 增加训练数据量和多样性 2) 使用正则化技术(如Dropout、L2正则化)3) 数据增强 4) 交叉验证 5) 防止过拟合。"
        },
        {
            "instruction": "什么是Transformer模型?",
            "input": "",
            "output": "Transformer是一种基于注意力机制的神经网络架构,由编码器和解码器组成。它摒弃了RNN和CNN,完全依赖注意力机制来处理序列数据,成为现代NLP的基础架构。"
        },
        {
            "instruction": "请解释一下梯度下降算法。",
            "input": "",
            "output": "梯度下降是一种优化算法,用于最小化损失函数。它通过计算损失函数对参数的梯度,然后沿着梯度反方向更新参数,逐步接近最优解。学习率控制每次更新的步长。"
        },
        {
            "instruction": "什么是迁移学习?",
            "input": "",
            "output": "迁移学习是将在一个任务或领域上学到的知识应用到另一个相关任务上的技术。它允许模型利用预训练的知识,从而在目标任务上更快地学习和获得更好的性能。"
        },
        {
            "instruction": "如何处理自然语言处理中的文本分类问题?",
            "input": "",
            "output": "文本分类的常见步骤包括:1) 文本预处理(分词、去停用词)2) 特征提取(词袋、TF-IDF、词向量)3) 选择分类算法(朴素贝叶斯、SVM、神经网络)4) 训练和评估模型。"
        },
        {
            "instruction": "请介绍一下大语言模型。",
            "input": "",
            "output": "大语言模型(LLM)是拥有数十亿甚至千亿参数的深度学习模型,通过在海量文本数据上预训练获得语言理解能力。它们可以执行各种NLP任务,如文本生成、问答、翻译等。"
        }
    ]
    
    # 将数据转换为对话格式(Qwen格式)
    formatted_data = []
    for item in data:
        # 使用Qwen的对话格式
        text = f"<|im_start|>user\n{item['instruction']}<|im_end|>\n<|im_start|>assistant\n{item['output']}<|im_end|>"
        formatted_data.append({"text": text})
    
    return formatted_data

def load_model_and_tokenizer():
    """加载模型和分词器"""
    print(f"正在加载模型: {MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    # 使用float32以避免fp16训练时的梯度缩放问题
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        torch_dtype=torch.float32,  # 使用float32以确保训练稳定性
        device_map="auto" if torch.cuda.is_available() else None
    )
    
    # 设置pad_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer

def main():
    """主训练函数"""
    print("=" * 50)
    print("开始SFT微调训练")
    print("=" * 50)
    
    # 加载模型和分词器
    model, tokenizer = load_model_and_tokenizer()
    
    # 创建数据集
    print("\n正在构造训练数据...")
    train_data = create_chinese_dataset()
    dataset = Dataset.from_list(train_data)
    
    print(f"训练样本数量: {len(dataset)}")
    
    # 训练参数 - 使用SFTConfig(新版本trl推荐)
    sft_config = SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=3,
        per_device_train_batch_size=2,  # batch size适合普通电脑
        gradient_accumulation_steps=4,  # 通过梯度累积增大有效batch size
        learning_rate=2e-5,
        warmup_steps=10,
        logging_steps=1,
        save_steps=50,
        save_total_limit=2,
        fp16=False,  # 禁用fp16以避免梯度缩放问题
        bf16=False,  # 禁用bf16
        remove_unused_columns=False,
        report_to=None,  # 不使用wandb等
        dataloader_pin_memory=False,  # Windows上可能需要设置为False
        max_length=MAX_SEQ_LENGTH,  # 最大序列长度(注意:新版本使用max_length而不是max_seq_length)
        dataset_text_field="text",  # 指定数据集中文本字段的名称
    )
    
    # 创建SFTTrainer
    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=dataset,
        processing_class=tokenizer,  # 新版本使用processing_class而不是tokenizer
    )
    
    # 开始训练
    print("\n开始训练...")
    trainer.train()
    
    # 保存模型
    print(f"\n训练完成,正在保存模型到 {OUTPUT_DIR}")
    trainer.save_model()
    tokenizer.save_pretrained(OUTPUT_DIR)
    
    print("=" * 50)
    print("训练完成!")
    print("=" * 50)

if __name__ == "__main__":
    main()

测试代码

python 复制代码
"""
测试微调后的模型
"""

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

MODEL_PATH = "./output"  # 微调后的模型路径
# 如果还没有微调,可以使用本地原始模型路径进行测试
# MODEL_PATH = "./models/Qwen2-0.5B-Instruct"  # 本地模型路径

def test_model():
    """测试模型生成能力"""
    print("正在加载模型...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        trust_remote_code=True,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None
    )
    
    # 测试问题
    test_questions = [
        "请介绍一下人工智能。",
        "什么是深度学习?",
        "请介绍一下大语言模型。"
    ]
    
    print("\n开始测试...")
    print("=" * 50)
    
    for question in test_questions:
        # 构建输入(使用Qwen的对话格式)
        prompt = f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
        
        # 编码
        inputs = tokenizer(prompt, return_tensors="pt")
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        # 生成
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=200,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else tokenizer.pad_token_id
            )
        
        # 解码
        response = tokenizer.decode(outputs[0], skip_special_tokens=False)
        
        # 提取回答部分
        if "<|im_start|>assistant\n" in response:
            answer = response.split("<|im_start|>assistant\n")[1].split("<|im_end|>")[0].strip()
        else:
            answer = response.replace(prompt, "").replace("<|im_end|>", "").strip()
        
        print(f"问题: {question}")
        print(f"回答: {answer}")
        print("-" * 50)

if __name__ == "__main__":
    test_model()
相关推荐
NAGNIP9 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab10 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab10 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP14 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年14 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼14 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS14 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区15 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈16 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang16 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx