参考
如何把你的 DeePseek-R1 微调为某个领域的专家?今天我们一起来聊聊大模型的进阶使用:"模型微调" ,也就是较大 - 掘金
看完就想试!Unsloth打造个性化AI助手案例展示-CSDN博客
unsloth 布署见
把基础模型下载到本地,确定数据集的格式
训练过程
python
import os
from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset
import subprocess
from trl import SFTTrainer,SFTConfig
# 设置代理(可选),详见autodl参考文档
# 1. 加载模型和 tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="/root/Qwen/Qwen2-1.5B-Instruct",
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
# 2. 添加 LoRA(Unsloth 自动优化)
model = FastLanguageModel.get_peft_model(
model,
r=32, # LoRA rank
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj", "lm_head"],
lora_alpha=64,
lora_dropout=0.05,
bias="none",
use_gradient_checkpointing="unsloth", # 更高效
random_state=3407,
)
# 3. 加载数据集(假设是 Alpaca 格式)
dataset = load_dataset("json", data_files="/root/ft_data.json", split="train")
# load_dataset批量处理数据集
# 4. 定义 formatting_func(关键!)把结构化的 instruction/input/output 字段,批量转换成模型能识别的对话格式文本
#更标准的写法,和test的写法一致
def formatting_func(examples):
outputs = []
for i in range(len(examples["instruction"])):
instruction = examples["instruction"][i].strip()
input_text = examples["input"][i].strip() if "input" in examples and examples["input"][i] else ""
output = examples["output"][i].strip()
# 构造标准化 messages(和推理时一致)
messages = [
{"role": "user", "content": f"{instruction}\n{input_text}" if input_text else instruction},
{"role": "assistant", "content": output}
]
# 用 apply_chat_template 生成格式(和推理时完全相同)
text = tokenizer.apply_chat_template(
messages,
tokenize=False, # 只生成文本,不编码
add_generation_prompt=False # 训练时不需要加 assistant 生成提示符
)
outputs.append(text)
return outputs
python
# 5. 使用 Unsloth 的 train 方法(自动处理 packing、formatting 等)
training_config = SFTConfig(
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
warmup_steps=10,
max_steps=100,
learning_rate=1e-4,
logging_steps=2,
save_steps=50,
output_dir="/root/my_qwen_assistant",
optim="adamw_8bit",
seed=3407,
fp16=False,
bf16=True,
packing=False, # 关闭打包,与 Unsloth 无填充批处理兼容
remove_unused_columns=False, # 保留数据集字段,避免格式化函数报错
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=formatting_func, # ←← Unsloth 的 SFTTrainer 要求这个
max_seq_length=2048,
dataset_num_proc=2,
packing=False,
args=training_config,
)
# 6. 训练并保存
trainer_stats = trainer.train()
model.save_pretrained("my_qwen_assistant")
tokenizer.save_pretrained("my_qwen_assistant")
推理过程
python
from peft import PeftModel
from unsloth import FastLanguageModel, is_bfloat16_supported
# 1. 加载基础模型(保持和训练一致的dtype/device_map)
base_model, tokenizer = FastLanguageModel.from_pretrained(
model_name="/root/Qwen/Qwen2-1.5B-Instruct",
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
# 2. 加载LoRA适配器(校验路径)
lora_model_path = "/root/my_qwen_assistant"
# 校验LoRA权重是否存在
import os
if not os.path.exists(lora_model_path):
raise ValueError(f"LoRA权重路径不存在:{lora_model_path}")
print(f"成功加载LoRA权重:{lora_model_path}")
model = PeftModel.from_pretrained(base_model, lora_model_path)
model.eval() # 评估模式
# 3. 推理(仅用LoRA模型,删除基础模型推理代码)
messages = [
{"role": "user", "content": "项目名称:xxxx;一,xxx,二,xxx。三,xxx。。。"}
]
# 用和训练时一致的chat template生成输入
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# 生成参数(简化,聚焦核心)
outputs = model.generate(
input_ids=input_ids["input_ids"],
attention_mask = input_ids["attention_mask"],
max_new_tokens=800,
do_sample=True, # 开启采样,提升回复多样性
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# 解码输出
response = tokenizer.decode(outputs[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print("LoRA模型推理结果:")
print(response.strip())
#(可选)验证基础模型结果(对比用)
base_model.eval()
outputs_base = base_model.generate(
input_ids=input_ids["input_ids"],
attention_mask = input_ids["attention_mask"],
max_new_tokens=800,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response_base = tokenizer.decode(outputs_base[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True)
print("\n基础模型推理结果:")
print(response_base.strip())