【大模型】Gemma4基于LORA的垃圾邮件识别单卡微调

  本文接续上一篇文章Gemma4基于LORA的情绪分类单卡微调,继续使用垃圾邮件分类数据集进行LoRA微调测试。

一、垃圾邮件数据集介绍

  SMS Spam Collection 是一个公开的带标签短信数据集,专门用于基于内容的垃圾短信检测研究。数据集包含 5,574 条 真实英文短信,每条都被标记为 ham(正常短信)或 spam(垃圾短信),是自然语言处理与文本分类领域最经典的基准数据集之一。

  • 数据来源:UCI Machine Learning Repository

  • 官方链接:https://archive.ics.uci.edu/dataset/228/sms+spam+collection

  • 格式:制表符分隔的文本文件(TSV),无表头,每行一条消息,两列分别为标签和文本内容。

  • 文件名称:SMSSpamCollection(可重命名为 .tsv)

二、微调

1、修改PROMPT

  将SYSTEM_PROMPT和LABEL_PATTERN修改为如下形式:

python 复制代码
SYSTEM_PROMPT = """You are an spam classification assistant.
Read the user's text and answer with exactly one label.
Only choose from: spam, ham.
Return only the label and nothing else."""    # change

LABEL_PATTERN = re.compile(r"(spam|ham)", re.IGNORECASE)

2、修改数据集配置

python 复制代码
# load dataset
import urllib.request
import zipfile
import os
from pathlib import Path
os.makedirs("./datasets/classify_finetune", exist_ok=True)

url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "./datasets/classify_finetune/sms_spam_collection.zip"
extracted_path = "./datasets/classify_finetune/sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"

save_root_path = "./datasets/classify_finetune"
save_train_path = os.path.join(save_root_path, "train.csv")
save_val_path = os.path.join(save_root_path, ".validation.csv")
save_test_path = os.path.join(save_root_path, "test.csv")

def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
    if data_file_path.exists():
        print(f"{data_file_path} already exists. Skipping download and extraction.")
        return

    # 下载文件
    with urllib.request.urlopen(url) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())

    # 解压文件
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    # 添加 .tsv 文件扩展
    original_file_path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)
    print(f"File downloaded and saved as {data_file_path}")

download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)

import pandas as pd

df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
print(df['Label'].value_counts())

def create_balanced_dataset(df):
    
    # 计算"spam"实例的数量
    num_spam = df[df["Label"] == "spam"].shape[0]
    
    # 随机采样"ham"实例以匹配"spam"实例的数量
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
    
    # 将"ham"子集与"spam"结合起来z
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])

    return balanced_df

balanced_df = create_balanced_dataset(df)
print(balanced_df["Label"].value_counts())

balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
# import ipdb; ipdb.set_trace()

# 数据集划分
def random_split(df, train_frac, validation_frac):
    # 打乱整个 DataFrame
    df = df.sample(frac=1, random_state=123).reset_index(drop=True)

    # 计算切分索引
    train_end = int(len(df) * train_frac)
    validation_end = train_end + int(len(df) * validation_frac)

    # 切分 DataFrame
    train_df = df[:train_end]
    val_df = df[train_end:validation_end]
    test_df = df[validation_end:]

    return train_df, val_df, test_df

train_df, val_df, test_df = random_split(balanced_df, 0.7, 0.1)
# 测试大小默认为 0.2

train_df.to_csv(save_train_path, index=None)
val_df.to_csv(save_val_path, index=None)
test_df.to_csv(save_test_path, index=None)

# import json
# 【可选】加载已划分好的 CSV
# train_df = pd.read_csv("./data/classify_finetune/train.csv")
# val_df = pd.read_csv("./data/classify_finetune/validation.csv")
# test_df = pd.read_csv("./data/classify_finetune/test.csv")

# 将 Label 映射为文本标签
label_map = {0: "ham", 1: "spam"}   # 定义标签类别

def csv_to_sharegpt(df, system_prompt="You are a spam classifier. Reply with only 'spam' or 'ham'."):
    """将 DataFrame 转换为 ShareGPT 格式的列表"""
    records = []
    for _, row in df.iterrows():
        text = row["Text"]
        label = label_map[row["Label"]]
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"SMS: {text}"},
            {"role": "assistant", "content": label}
        ]
        records.append({"messages": messages})   # 这里需要区别与情感分类数据集中的text
    return records

# 转换并保存为 JSONL
for name, df in [("train", train_df), ("validation", val_df), ("test", test_df)]:
    sharegpt_data = csv_to_sharegpt(df)
    out_path = os.path.join(save_root_path, f"{name}.jsonl")
    with open(out_path, "w", encoding="utf-8") as f:
        for item in sharegpt_data:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")
    print(f"Saved {len(sharegpt_data)} samples to {out_path}")

  数据集设置也同样需要变更,变更内容如下:

python 复制代码
import glob

EMOTION_LABEL_NAMES = ["spam", "ham"]


# 直接把魔搭上的数据集仓库(parquet 文件)整体下载到本地,然后用 datasets 库从本地 parquet 加载。
# 不走 MsDataset.load -> datasets.load_dataset 的桥接路径,可以规避 modelscope 与 datasets 之间
# `as_dataset() got an unexpected keyword argument 'verification_mode'` 这类版本错配错误。
print("Downloading dataset from ModelScope...")
print("ModelScope dataset id:", MODELSCOPE_DATASET_ID)

raw_dataset = load_dataset(
    "json",
    data_files={
        "train": os.path.join(save_root_path, "train.jsonl"),
        "validation": os.path.join(save_root_path, "validation.jsonl"),
        "test": os.path.join(save_root_path, "test.jsonl"),
    },
)

def extract_label_from_messages(example):
    # 取 assistant 的最后一条消息内容作为标签
    last_assistant = [msg["content"] for msg in example["messages"] if msg["role"] == "assistant"]
    example["label"] = last_assistant[-1] if last_assistant else "INVALID"
    return example

raw_dataset = raw_dataset.map(extract_label_from_messages)

# 从 parquet 加载时,label 字段类型会退化成普通整数,这里显式 cast 成 ClassLabel,
# 这样后续 `dataset["train"].features["label"].names` 和原始 HF 版接口完全一致。
for split_name in list(raw_dataset.keys()):
    if not isinstance(raw_dataset[split_name].features.get("label"), ClassLabel):
        raw_dataset[split_name] = raw_dataset[split_name].cast_column(
            "label", ClassLabel(names=EMOTION_LABEL_NAMES)
        )

print("Raw dataset:", raw_dataset)


def maybe_limit(split, limit):
    split = split.shuffle(seed=SEED)
    if limit is None:
        return split
    return split.select(range(min(limit, len(split))))


dataset = DatasetDict({
    "train": maybe_limit(raw_dataset["train"], TRAIN_LIMIT),
    "validation": maybe_limit(raw_dataset["validation"], VALIDATION_LIMIT),
    "test": maybe_limit(raw_dataset["test"], TEST_LIMIT),
})

label_names = dataset["train"].features["label"].names
VALID_LABELS = set(label_names)
ALL_EVAL_LABELS = label_names + ["INVALID"]

print(dataset)
print("label_names:", label_names)
print("example:", dataset["train"][0])

3、构造prompt-completion 格式数据

python 复制代码
def to_prompt_completion(example):
    text = example["messages"]    # change
    label = label_names[example["label"]]
    user_content = f"Classify the emotion of this text:\n\n{text}"
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_content},
        ],
        "completion": [
            {"role": "assistant", "content": label},
        ],
    }

sft_dataset = dataset.map(
    to_prompt_completion,
    remove_columns=dataset["train"].column_names,
)

print(sft_dataset)
print(sft_dataset["train"][0])

4、其他部分

  其他部分与上一篇文章Gemma4基于LORA的情绪分类单卡微调中的内容一致,读者可自行拷贝,或者直接访问https://github.com/Reversev/attempts/tree/main/Gemma4_LoRA_ROCm获取项目文件。

笔者在用LoRA进行微调的时候由于数据集比较简单,为了避免过拟合的情况,使用Early Stopping,仅微调了不到一个epoch。

5、微调前后评估对比

  微调前的评估结果如下:

  微调后的评估结果如下:

  结论:微调前后Gemma4对垃圾邮件分类的准确率等都有所提升,甚至达到了100%,看来这个任务果然很简单。

三、部署测试

  本节部署测试主要说明手动测试微调模型和加载本地 LoRA adapter 推理的过程,验证微调后模型的能力。

1、手动测试微调模型

  运行代码如下所示:

python 复制代码
def predict_emotion_ft(user_text: str) -> str:
    return generate_label(ft_model, tokenizer, user_text, SYSTEM_PROMPT)
test_texts = [
    "Hey, just wanted to check if we're still on",
    
]
for text in test_texts:
    print(text, "=>", predict_emotion_ft(text))

  运行上述代码结果如下:

2、加载本地 LoRA adapter 推理

  运行代码如下所示:

python 复制代码
RUN_RELOAD_TEST = True

if RUN_RELOAD_TEST:
    reload_tokenizer = AutoTokenizer.from_pretrained(
        OUTPUT_DIR,
        use_fast=True,
        trust_remote_code=True,
    )
    if reload_tokenizer.pad_token is None:
        reload_tokenizer.pad_token = reload_tokenizer.eos_token

    reload_base_model = AutoModelForCausalLM.from_pretrained(
        LOCAL_MODEL_DIR,
        torch_dtype=MODEL_DTYPE,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )

    reload_model = PeftModel.from_pretrained(
        reload_base_model,
        OUTPUT_DIR,
    )
    reload_model.eval()

    print(generate_label(reload_model, reload_tokenizer, "You are a winner you have been specially selected to receive $1000 cash or a $2000 award.."))

  运行上述代码结果如下:

  从测试结果可以看出,模型对垃圾邮件的内容还是能够准确且有效区分。后续也可使用其他复杂数据集或者模型进行测试。

  如需要保存权重文件,也可以用以下命令将所需文件打包到一个tar文件中。

python 复制代码
tar -czvf archive.tar.gz 文件名或者*(当前目录下所有文件)

四、参考链接

1 https://archive.ics.uci.edu/dataset/228/sms+spam+collection (https://archive.ics.uci.edu/dataset/228/sms+spam+collection)

2 https://github.com/datawhalechina/hello-rocm