微调生成特定写作风格助手

参考

如何把你的 DeePseek-R1 微调为某个领域的专家?今天我们一起来聊聊大模型的进阶使用:"模型微调" ,也就是较大 - 掘金

看完就想试!Unsloth打造个性化AI助手案例展示-CSDN博客

unsloth 布署见

unsloth 部署(简单易上手版本)-CSDN博客

把基础模型下载到本地,确定数据集的格式

训练过程

python 复制代码
import os
from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset
import subprocess
from trl import SFTTrainer,SFTConfig

# 设置代理(可选),详见autodl参考文档

# 1. 加载模型和 tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/root/Qwen/Qwen2-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# 2. 添加 LoRA(Unsloth 自动优化)
model = FastLanguageModel.get_peft_model(
    model,
    r=32,  # LoRA rank
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj", "lm_head"],
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing="unsloth",  # 更高效
    random_state=3407,
)

# 3. 加载数据集(假设是 Alpaca 格式)
dataset = load_dataset("json", data_files="/root/ft_data.json", split="train")
# load_dataset批量处理数据集

# 4. 定义 formatting_func(关键!)把结构化的 instruction/input/output 字段,批量转换成模型能识别的对话格式文本
#更标准的写法,和test的写法一致
def formatting_func(examples):
    outputs = []
    for i in range(len(examples["instruction"])):
        instruction = examples["instruction"][i].strip()
        input_text = examples["input"][i].strip() if "input" in examples and examples["input"][i] else ""
        output = examples["output"][i].strip()
        
        # 构造标准化 messages(和推理时一致)
        messages = [
            {"role": "user", "content": f"{instruction}\n{input_text}" if input_text else instruction},
            {"role": "assistant", "content": output}
        ]
        # 用 apply_chat_template 生成格式(和推理时完全相同)
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,  # 只生成文本,不编码
            add_generation_prompt=False  # 训练时不需要加 assistant 生成提示符
        )
        outputs.append(text)
    return outputs 
python 复制代码
# 5. 使用 Unsloth 的 train 方法(自动处理 packing、formatting 等)
training_config = SFTConfig(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    warmup_steps=10,
    max_steps=100,
    learning_rate=1e-4,
    logging_steps=2,
    save_steps=50,
    output_dir="/root/my_qwen_assistant",
    optim="adamw_8bit",
    seed=3407,
    fp16=False,
    bf16=True,
    packing=False,  # 关闭打包,与 Unsloth 无填充批处理兼容
    remove_unused_columns=False,  # 保留数据集字段,避免格式化函数报错
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=formatting_func,  # ←← Unsloth 的 SFTTrainer 要求这个
    max_seq_length=2048,
    dataset_num_proc=2,
    packing=False,
    args=training_config,
)

# 6. 训练并保存
trainer_stats = trainer.train()
model.save_pretrained("my_qwen_assistant")
tokenizer.save_pretrained("my_qwen_assistant")

推理过程

python 复制代码
from peft import PeftModel
from unsloth import FastLanguageModel, is_bfloat16_supported

# 1. 加载基础模型(保持和训练一致的dtype/device_map)
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/root/Qwen/Qwen2-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# 2. 加载LoRA适配器(校验路径)
lora_model_path = "/root/my_qwen_assistant"
# 校验LoRA权重是否存在
import os
if not os.path.exists(lora_model_path):
    raise ValueError(f"LoRA权重路径不存在:{lora_model_path}")
print(f"成功加载LoRA权重:{lora_model_path}")

model = PeftModel.from_pretrained(base_model, lora_model_path)
model.eval()  # 评估模式

# 3. 推理(仅用LoRA模型,删除基础模型推理代码)
messages = [
    {"role": "user", "content": "项目名称:xxxx;一,xxx,二,xxx。三,xxx。。。"}
]

# 用和训练时一致的chat template生成输入
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

# 生成参数(简化,聚焦核心)
outputs = model.generate(
    input_ids=input_ids["input_ids"],
    attention_mask = input_ids["attention_mask"],
    max_new_tokens=800,
    do_sample=True,  # 开启采样,提升回复多样性
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.1,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

# 解码输出
response = tokenizer.decode(outputs[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print("LoRA模型推理结果:")
print(response.strip())

#(可选)验证基础模型结果(对比用)
base_model.eval()
outputs_base = base_model.generate(
    input_ids=input_ids["input_ids"],
    attention_mask = input_ids["attention_mask"],
    max_new_tokens=800,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.1,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
response_base = tokenizer.decode(outputs_base[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print("\n基础模型推理结果:")
print(response_base.strip())
相关推荐
小程故事多_807 小时前
从人工编写到自主迭代进化,SkillEvolver重构大模型智能体技能生成新范式
人工智能·重构
wengad7 小时前
机器学习实践理论基础|算法、模型和数据集
人工智能·算法·机器学习
kishu_iOS&AI7 小时前
LLM —— Prompt提示词工程
人工智能·prompt
li-xun7 小时前
2026年6月7日博客精选
人工智能·chatgpt·每日阅读
人工智能AI技术7 小时前
【VibeCoding系列教程12】 AI代码编辑器
人工智能
辣椒思密达8 小时前
Python公开数据采集实战:如何解决请求高频拦截与Session会话中断问题
开发语言·python
zhangfeng11338 小时前
ai训练 顿悟“总数据量是 m²,训练所需要的数据量是 log m
人工智能
半兽先生8 小时前
05阶段:NLP自然语言处理基础
人工智能·自然语言处理
盈飞无限8 小时前
SPC选型:智能VS传统,谁更懂中国制造?
人工智能·制造
li-xun8 小时前
LINUX DO 社区注册机制调整与公益 AI 服务动态
linux·运维·人工智能