微调生成特定写作风格助手

参考

如何把你的 DeePseek-R1 微调为某个领域的专家?今天我们一起来聊聊大模型的进阶使用:"模型微调" ,也就是较大 - 掘金

看完就想试!Unsloth打造个性化AI助手案例展示-CSDN博客

unsloth 布署见

unsloth 部署(简单易上手版本)-CSDN博客

把基础模型下载到本地,确定数据集的格式

训练过程

python 复制代码
import os
from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset
import subprocess
from trl import SFTTrainer,SFTConfig

# 设置代理(可选),详见autodl参考文档

# 1. 加载模型和 tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/root/Qwen/Qwen2-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# 2. 添加 LoRA(Unsloth 自动优化)
model = FastLanguageModel.get_peft_model(
    model,
    r=32,  # LoRA rank
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj", "lm_head"],
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing="unsloth",  # 更高效
    random_state=3407,
)

# 3. 加载数据集(假设是 Alpaca 格式)
dataset = load_dataset("json", data_files="/root/ft_data.json", split="train")
# load_dataset批量处理数据集

# 4. 定义 formatting_func(关键!)把结构化的 instruction/input/output 字段,批量转换成模型能识别的对话格式文本
#更标准的写法,和test的写法一致
def formatting_func(examples):
    outputs = []
    for i in range(len(examples["instruction"])):
        instruction = examples["instruction"][i].strip()
        input_text = examples["input"][i].strip() if "input" in examples and examples["input"][i] else ""
        output = examples["output"][i].strip()
        
        # 构造标准化 messages(和推理时一致)
        messages = [
            {"role": "user", "content": f"{instruction}\n{input_text}" if input_text else instruction},
            {"role": "assistant", "content": output}
        ]
        # 用 apply_chat_template 生成格式(和推理时完全相同)
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,  # 只生成文本,不编码
            add_generation_prompt=False  # 训练时不需要加 assistant 生成提示符
        )
        outputs.append(text)
    return outputs 
python 复制代码
# 5. 使用 Unsloth 的 train 方法(自动处理 packing、formatting 等)
training_config = SFTConfig(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    warmup_steps=10,
    max_steps=100,
    learning_rate=1e-4,
    logging_steps=2,
    save_steps=50,
    output_dir="/root/my_qwen_assistant",
    optim="adamw_8bit",
    seed=3407,
    fp16=False,
    bf16=True,
    packing=False,  # 关闭打包,与 Unsloth 无填充批处理兼容
    remove_unused_columns=False,  # 保留数据集字段,避免格式化函数报错
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=formatting_func,  # ←← Unsloth 的 SFTTrainer 要求这个
    max_seq_length=2048,
    dataset_num_proc=2,
    packing=False,
    args=training_config,
)

# 6. 训练并保存
trainer_stats = trainer.train()
model.save_pretrained("my_qwen_assistant")
tokenizer.save_pretrained("my_qwen_assistant")

推理过程

python 复制代码
from peft import PeftModel
from unsloth import FastLanguageModel, is_bfloat16_supported

# 1. 加载基础模型(保持和训练一致的dtype/device_map)
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/root/Qwen/Qwen2-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# 2. 加载LoRA适配器(校验路径)
lora_model_path = "/root/my_qwen_assistant"
# 校验LoRA权重是否存在
import os
if not os.path.exists(lora_model_path):
    raise ValueError(f"LoRA权重路径不存在:{lora_model_path}")
print(f"成功加载LoRA权重:{lora_model_path}")

model = PeftModel.from_pretrained(base_model, lora_model_path)
model.eval()  # 评估模式

# 3. 推理(仅用LoRA模型,删除基础模型推理代码)
messages = [
    {"role": "user", "content": "项目名称:xxxx;一,xxx,二,xxx。三,xxx。。。"}
]

# 用和训练时一致的chat template生成输入
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

# 生成参数(简化,聚焦核心)
outputs = model.generate(
    input_ids=input_ids["input_ids"],
    attention_mask = input_ids["attention_mask"],
    max_new_tokens=800,
    do_sample=True,  # 开启采样,提升回复多样性
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.1,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

# 解码输出
response = tokenizer.decode(outputs[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print("LoRA模型推理结果:")
print(response.strip())

#(可选)验证基础模型结果(对比用)
base_model.eval()
outputs_base = base_model.generate(
    input_ids=input_ids["input_ids"],
    attention_mask = input_ids["attention_mask"],
    max_new_tokens=800,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.1,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
response_base = tokenizer.decode(outputs_base[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print("\n基础模型推理结果:")
print(response_base.strip())
相关推荐
Learn Beyond Limits5 小时前
文献阅读:A Probabilistic U-Net for Segmentation of Ambiguous Images
论文阅读·人工智能·深度学习·算法·机器学习·计算机视觉·ai
-To be number.wan5 小时前
Python数据分析:Matplotlib 绘图练习
python·数据分析·matplotlib
naruto_lnq5 小时前
Python生成器(Generator)与Yield关键字:惰性求值之美
jvm·数据库·python
Stream_Silver5 小时前
【Agent学习笔记1:Python调用Function Calling,阿里云API函数调用与DeepSeek API对比分析】
开发语言·python·阿里云
OpenMiniServer5 小时前
电气化能源革命下的社会
java·人工智能·能源
猿小羽6 小时前
探索 Codex:AI 编程助手的未来潜力
人工智能·openai·代码生成·codex·ai编程助手
没事儿写两篇6 小时前
Python 包管理工具-uv
python·uv·开源包管理工具
2501_941418556 小时前
基于YOLO11-C3k2-ESC的避雷器外部缺陷检测实现
python
流㶡6 小时前
Python爬虫:POST与Selenium
爬虫·python·selenium