微调生成特定写作风格助手

参考

如何把你的 DeePseek-R1 微调为某个领域的专家?今天我们一起来聊聊大模型的进阶使用:"模型微调" ,也就是较大 - 掘金

看完就想试!Unsloth打造个性化AI助手案例展示-CSDN博客

unsloth 布署见

unsloth 部署(简单易上手版本)-CSDN博客

把基础模型下载到本地,确定数据集的格式

训练过程

python 复制代码
import os
from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset
import subprocess
from trl import SFTTrainer,SFTConfig

# 设置代理(可选),详见autodl参考文档

# 1. 加载模型和 tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/root/Qwen/Qwen2-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# 2. 添加 LoRA(Unsloth 自动优化)
model = FastLanguageModel.get_peft_model(
    model,
    r=32,  # LoRA rank
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj", "lm_head"],
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing="unsloth",  # 更高效
    random_state=3407,
)

# 3. 加载数据集(假设是 Alpaca 格式)
dataset = load_dataset("json", data_files="/root/ft_data.json", split="train")
# load_dataset批量处理数据集

# 4. 定义 formatting_func(关键!)把结构化的 instruction/input/output 字段,批量转换成模型能识别的对话格式文本
#更标准的写法,和test的写法一致
def formatting_func(examples):
    outputs = []
    for i in range(len(examples["instruction"])):
        instruction = examples["instruction"][i].strip()
        input_text = examples["input"][i].strip() if "input" in examples and examples["input"][i] else ""
        output = examples["output"][i].strip()
        
        # 构造标准化 messages(和推理时一致)
        messages = [
            {"role": "user", "content": f"{instruction}\n{input_text}" if input_text else instruction},
            {"role": "assistant", "content": output}
        ]
        # 用 apply_chat_template 生成格式(和推理时完全相同)
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,  # 只生成文本,不编码
            add_generation_prompt=False  # 训练时不需要加 assistant 生成提示符
        )
        outputs.append(text)
    return outputs 
python 复制代码
# 5. 使用 Unsloth 的 train 方法(自动处理 packing、formatting 等)
training_config = SFTConfig(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    warmup_steps=10,
    max_steps=100,
    learning_rate=1e-4,
    logging_steps=2,
    save_steps=50,
    output_dir="/root/my_qwen_assistant",
    optim="adamw_8bit",
    seed=3407,
    fp16=False,
    bf16=True,
    packing=False,  # 关闭打包,与 Unsloth 无填充批处理兼容
    remove_unused_columns=False,  # 保留数据集字段,避免格式化函数报错
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=formatting_func,  # ←← Unsloth 的 SFTTrainer 要求这个
    max_seq_length=2048,
    dataset_num_proc=2,
    packing=False,
    args=training_config,
)

# 6. 训练并保存
trainer_stats = trainer.train()
model.save_pretrained("my_qwen_assistant")
tokenizer.save_pretrained("my_qwen_assistant")

推理过程

python 复制代码
from peft import PeftModel
from unsloth import FastLanguageModel, is_bfloat16_supported

# 1. 加载基础模型(保持和训练一致的dtype/device_map)
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/root/Qwen/Qwen2-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# 2. 加载LoRA适配器(校验路径)
lora_model_path = "/root/my_qwen_assistant"
# 校验LoRA权重是否存在
import os
if not os.path.exists(lora_model_path):
    raise ValueError(f"LoRA权重路径不存在:{lora_model_path}")
print(f"成功加载LoRA权重:{lora_model_path}")

model = PeftModel.from_pretrained(base_model, lora_model_path)
model.eval()  # 评估模式

# 3. 推理(仅用LoRA模型,删除基础模型推理代码)
messages = [
    {"role": "user", "content": "项目名称:xxxx;一,xxx,二,xxx。三,xxx。。。"}
]

# 用和训练时一致的chat template生成输入
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

# 生成参数(简化,聚焦核心)
outputs = model.generate(
    input_ids=input_ids["input_ids"],
    attention_mask = input_ids["attention_mask"],
    max_new_tokens=800,
    do_sample=True,  # 开启采样,提升回复多样性
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.1,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

# 解码输出
response = tokenizer.decode(outputs[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print("LoRA模型推理结果:")
print(response.strip())

#(可选)验证基础模型结果(对比用)
base_model.eval()
outputs_base = base_model.generate(
    input_ids=input_ids["input_ids"],
    attention_mask = input_ids["attention_mask"],
    max_new_tokens=800,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.1,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
response_base = tokenizer.decode(outputs_base[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print("\n基础模型推理结果:")
print(response_base.strip())
相关推荐
企业架构师老王几秒前
药品生产环节:用实在Agent自动生成批记录与打印领料单的合规设计与架构落地
大数据·人工智能·ai·架构
m0_588758484 分钟前
高效实现分组内跨行时间戳匹配:为每组生成布尔标记列 user_rejects
jvm·数据库·python
黎阳之光5 分钟前
视频孪生重构轨交数字孪生新范式|黎阳之光以自主核心技术破解落地难题
大数据·人工智能·算法·安全·数字孪生
ai产品老杨5 分钟前
告别重复造轮子:深度解析支持源码交付的 AI 视频平台架构,实现 X86/ARM 与 GPU/NPU 异构算力融合
人工智能·架构·音视频
好运的阿财6 分钟前
OpenClaw工具拆解之 web_fetch+image_generate
前端·python·机器学习·ai·ai编程·openclaw·openclaw工具
写代码的小阿帆8 分钟前
AI工具使用——外挂AI插件、AI原生IDE与AI终端
ide·人工智能·ai-native
谢谢 啊sir9 分钟前
L2-060 大语言模型的推理 - java
java·人工智能·语言模型
阿杰学AI9 分钟前
AI核心知识140—大语言模型之 推理期算力(简洁且通俗易懂版)
人工智能·语言模型·自然语言处理·思维链·思维树·慢思考·推理期算力
云淡风轻~窗明几净9 分钟前
关于TSP的sealine算法与角谷猜想(2026-04-25)
数据结构·人工智能·算法·动态规划·模拟退火算法
wayz1110 分钟前
Day 13:朴素贝叶斯分类器
人工智能·算法·机器学习·朴素贝叶斯