Gemma 4 + LoRA 电商评论好差评分类

Gemma 4 + LoRA 电商评论好差评分类

  • Gemma 4 + LoRA 微调:电商评论好差评分类(二分类)
  • 从原 emotion 分类(6分类)改为电商评论情感分类(正面/负面)。
  • 数据集:DAMO_NLP/jd(京东商品评论,来自 ModelScope)
  • 模型:google/gemma-4-E4B-it(从 ModelScope 下载)

Cell 1 --- 安装依赖(在 Radeon Cloud / Colab 中运行)

bash 复制代码
!uv pip install -U vllm modelscope transformers accelerate datasets trl peft scikit-learn pandas tqdm torchvision --no-cache -i https://mirrors.cloud.tencent.com/pypi/simple/ --extra-index-url https://wheels.vllm.ai/rocm/

Cell 2 --- 导入依赖与全局配置

bash 复制代码
import os
import glob
import re
import json
import random
import warnings

import numpy as np
import pandas as pd
import torch

from tqdm.auto import tqdm
from datasets import Dataset, DatasetDict, ClassLabel, load_dataset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score

from modelscope import snapshot_download
from modelscope.hub.snapshot_download import dataset_snapshot_download

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from peft import LoraConfig, PeftModel
from trl import SFTConfig, SFTTrainer

warnings.filterwarnings("ignore")

# ---- 配置参数 ----
MODELSCOPE_MODEL_ID = "google/gemma-4-E4B-it"
MODELSCOPE_DATASET_ID = "DAMO_NLP/jd"          # <-- 京东商品评论数据集 (ModelScope)
OUTPUT_DIR = "./gemma4-it-jd-sentiment-lora"

TRAIN_LIMIT = 4000
VALIDATION_LIMIT = 400
TEST_LIMIT = 400
EVAL_LIMIT = 400

SEED = 42
MODEL_DTYPE = torch.bfloat16
BF16 = True
FP16 = False

# ---- 任务定义:电商好差评二分类 ----
LABEL_NAMES = ["negative", "positive"]
LABEL_NAMES_ZH = ["负面", "正面"]

SYSTEM_PROMPT = """You are a Chinese e-commerce review sentiment classifier.
Read the user's product review and answer with exactly one label.
Only choose from: negative, positive.
Return only the label and nothing else."""

LABEL_PATTERN = re.compile(r"(negative|positive)", re.IGNORECASE)

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs("./models", exist_ok=True)
os.makedirs("./datasets", exist_ok=True)

print("torch version:", torch.__version__)
print("torch.cuda.is_available():", torch.cuda.is_available())
print("torch.cuda.device_count():", torch.cuda.device_count())
if torch.cuda.is_available():
    print("current device:", torch.cuda.current_device())
    print("device name:", torch.cuda.get_device_name(0))

Cell 3 --- 固定随机种子

bash 复制代码
def setup_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    set_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

setup_seed(SEED)

Cell 4 --- 从 ModelScope 下载模型

bash 复制代码
print("Downloading model from ModelScope...")
print("ModelScope model id:", MODELSCOPE_MODEL_ID)

model_dir = snapshot_download(
    MODELSCOPE_MODEL_ID,
    cache_dir="./models",
)
print("Downloaded model dir:", model_dir)
LOCAL_MODEL_DIR = model_dir

Cell 5 --- 从 ModelScope 下载并加载京东评论数据集

bash 复制代码
print("Downloading dataset from ModelScope...")
print("ModelScope dataset id:", MODELSCOPE_DATASET_ID)

dataset_dir = dataset_snapshot_download(
    MODELSCOPE_DATASET_ID,
    cache_dir="./datasets",
)
print("Downloaded dataset dir:", dataset_dir)

# 数据集文件: train.csv / dev.csv (无 test 集,从 train 中分出)
# 列: sentence (评论文本), label (0=负面, 1=正面)
raw_dataset = load_dataset(
    "csv",
    data_files={
        "train": os.path.join(dataset_dir, "train.csv"),
        "dev": os.path.join(dataset_dir, "dev.csv"),
    },
)

# 移除多余的 "dataset" 列,重命名 sentence → text
raw_dataset = raw_dataset.rename_column("sentence", "text")
if "dataset" in raw_dataset["train"].column_names:
    raw_dataset = raw_dataset.remove_columns(["dataset"])

# 调试:看一眼 label 的原始值和类型
print("=== 原始数据检查 ===")
for sn in raw_dataset.keys():
    ex0 = raw_dataset[sn][0]
    print(f"  {sn}[0]: label={ex0['label']!r} (type={type(ex0['label']).__name__}), "
          f"text={ex0['text'][:40]!r}")
print()

# 将 label 转为 int:用 filter + map,防御 None 值漏进来
def _safe_int_label(example):
    val = example["label"]
    # None / NaN / 空字符串 都视为无效,标为 -1 后续过滤
    if val is None or (isinstance(val, float) and (val != val)):
        return {"label": -1}
    return {"label": int(val)}

for split_name in list(raw_dataset.keys()):
    raw_dataset[split_name] = raw_dataset[split_name].map(_safe_int_label)
    before = len(raw_dataset[split_name])
    raw_dataset[split_name] = raw_dataset[split_name].filter(lambda x: x["label"] != -1)
    after = len(raw_dataset[split_name])
    if before != after:
        print(f"  ⚠ {split_name}: 过滤掉 {before - after} 条无效 label")

print("Raw dataset:", raw_dataset)
print("Train size:", len(raw_dataset["train"]))
print("Dev size:", len(raw_dataset["dev"]))

def maybe_limit(split, limit):
    split = split.shuffle(seed=SEED)
    if limit is None:
        return split
    return split.select(range(min(limit, len(split))))

# 从 train 中分出 test(train/validation/test 三分)
train_full = raw_dataset["train"]

# split: 80% train, 10% validation, 10% test
train_test = train_full.train_test_split(test_size=0.2, seed=SEED)
train_val = train_test["train"].train_test_split(test_size=0.125, seed=SEED)  # 0.8*0.125=0.1
dataset = DatasetDict({
    "train": maybe_limit(train_val["train"], TRAIN_LIMIT),
    "validation": maybe_limit(train_val["test"], VALIDATION_LIMIT),
    "test": maybe_limit(train_test["test"], TEST_LIMIT),
})

# label 现在就是普通 int(0/1),映射逻辑全用 LABEL_NAMES
label_names = LABEL_NAMES
VALID_LABELS = set(label_names)
ALL_EVAL_LABELS = label_names + ["INVALID"]

print(dataset)
print("label_names:", label_names)
print("Label distribution (train):", pd.Series(dataset["train"]["label"]).value_counts().to_dict())
print("example:", dataset["train"][0])

Cell 6 --- 构造 prompt-completion 格式

bash 复制代码
def to_prompt_completion(example):
    text = example["text"]
    label_idx = int(example["label"])          # 确保是 int,避免 NoneType
    label = LABEL_NAMES[label_idx]
    user_content = f"Classify the sentiment of this product review:\n\n{text}"
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_content},
        ],
        "completion": [
            {"role": "assistant", "content": label},
        ],
    }

sft_dataset = dataset.map(
    to_prompt_completion,
    remove_columns=dataset["train"].column_names,
)

print(sft_dataset)
print(sft_dataset["train"][0])

Cell 7 --- 加载 tokenizer 和基础模型

bash 复制代码
print("Loading tokenizer from:", LOCAL_MODEL_DIR)

tokenizer = AutoTokenizer.from_pretrained(
    LOCAL_MODEL_DIR,
    use_fast=True,
    trust_remote_code=True,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

TEMPLATE_SOURCE_MODEL_ID = "google/gemma-4-E4B-it"

def _load_official_gemma_chat_template() -> str:
    try:
        template_dir = snapshot_download(
            TEMPLATE_SOURCE_MODEL_ID,
            cache_dir="./models",
            allow_file_pattern=["chat_template.jinja"],
        )
        path = os.path.join(template_dir, "chat_template.jinja")
        if os.path.exists(path):
            with open(path, "r", encoding="utf-8") as f:
                return f.read()
    except Exception as e:
        print("snapshot_download(allow_file_pattern) failed, fallback to HTTP. err =", e)

    import urllib.request
    url = (
        "https://www.modelscope.cn/api/v1/models/"
        f"{TEMPLATE_SOURCE_MODEL_ID}/repo?Revision=master&FilePath=chat_template.jinja"
    )
    with urllib.request.urlopen(url, timeout=60) as resp:
        return resp.read().decode("utf-8")

if not getattr(tokenizer, "chat_template", None):
    print(f"Loading official chat_template.jinja from {TEMPLATE_SOURCE_MODEL_ID} ...")
    tokenizer.chat_template = _load_official_gemma_chat_template()
else:
    print("tokenizer.chat_template already set, leaving as-is.")

device = "cuda" if torch.cuda.is_available() else "cpu"

base_model = AutoModelForCausalLM.from_pretrained(
    LOCAL_MODEL_DIR,
    torch_dtype=MODEL_DTYPE,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
)

base_model.to(device)
base_model.config.use_cache = False
base_model.config.pad_token_id = tokenizer.pad_token_id
base_model.config.bos_token_id = tokenizer.bos_token_id
base_model.config.eos_token_id = tokenizer.eos_token_id
base_model.generation_config.pad_token_id = tokenizer.pad_token_id
base_model.generation_config.bos_token_id = tokenizer.bos_token_id
base_model.generation_config.eos_token_id = tokenizer.eos_token_id

print("Base model loaded.")
print("Base model device:", next(base_model.parameters()).device)

Cell 8 --- 推理辅助函数

bash 复制代码
def extract_label(raw_text: str) -> str:
    """从模型原始输出中提取标签"""
    raw_text = raw_text.strip().lower()
    match = LABEL_PATTERN.search(raw_text)
    if match:
        return match.group(1)
    tokens = raw_text.split()
    if not tokens:
        return "INVALID"
    return tokens[0].strip(".,!?:;\"'()[]{}")

def generate_label(model, tokenizer, user_text: str,
                   system_prompt: str = SYSTEM_PROMPT,
                   max_new_tokens: int = 4) -> str:
    """用模型对一段文本做情感分类推理"""
    user_content = f"Classify the sentiment of this product review:\n\n{user_text}"
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_content},
    ]
    device = next(model.parameters()).device
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt",
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    input_len = inputs["input_ids"].shape[-1]
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    raw_pred = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
    return extract_label(raw_pred)

def predict_sentiment(text: str, model=None):
    """对一条评论做情感预测"""
    model = model or base_model
    return generate_label(model, tokenizer, text)

# 快速测试
print("\n--- 微调前快速测试 ---")
for t in ["这个商品质量很好,物流也很快", "太差了,用了两天就坏了"]:
    pred = predict_sentiment(t)
    print(f"  「{t}」→ {pred}")
print()

Cell 9 --- 评估函数

bash 复制代码
def evaluate_model(model, tokenizer, split="test", limit=EVAL_LIMIT):
    y_true, y_pred, rows = [], [], []
    raw_source = dataset[split]
    if limit is not None:
        raw_source = raw_source.select(range(min(limit, len(raw_source))))
    model.eval()
    for ex in tqdm(raw_source, desc=f"Evaluating {split}", leave=False):
        true_label = label_names[int(ex["label"])]
        raw_pred_label = generate_label(model, tokenizer, ex["text"], SYSTEM_PROMPT)
        pred_label = raw_pred_label if raw_pred_label in VALID_LABELS else "INVALID"
        y_true.append(true_label)
        y_pred.append(pred_label)
        rows.append({
            "text": ex["text"],
            "true_label": true_label,
            "pred_label": pred_label,
            "raw_pred_label": raw_pred_label,
            "correct": true_label == pred_label,
        })
    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "macro_f1": f1_score(y_true, y_pred, labels=label_names, average="macro", zero_division=0),
        "invalid_predictions": sum(1 for p in y_pred if p == "INVALID"),
        "evaluated_examples": len(y_true),
    }
    report = classification_report(
        y_true, y_pred, labels=label_names, output_dict=True, zero_division=0,
    )
    return metrics, report, pd.DataFrame(rows)

def confusion_matrix_df(pred_df):
    return pd.DataFrame(
        confusion_matrix(pred_df["true_label"], pred_df["pred_label"], labels=ALL_EVAL_LABELS),
        index=ALL_EVAL_LABELS,
        columns=ALL_EVAL_LABELS,
    )

Cell 10 --- LoRA 配置

bash 复制代码
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear",
)

Cell 11 --- 训练参数

bash 复制代码
training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    weight_decay=0.01,
    lr_scheduler_type="linear",
    warmup_steps=50,
    num_train_epochs=1,
    logging_steps=5,
    eval_strategy="steps",
    eval_steps=25,
    save_strategy="steps",
    save_steps=25,
    save_total_limit=2,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    gradient_checkpointing=True,
    bf16=BF16,
    fp16=FP16,
    tf32=False,
    max_length=256,
    packing=False,
    completion_only_loss=True,
    remove_unused_columns=False,
    dataloader_num_workers=2,
    optim="adamw_torch",
    report_to="none",
    seed=SEED,
    data_seed=SEED,
)

Cell 12 --- 训练

bash 复制代码
print("\n" + "="*60)
print("开始微调训练 (电商评论好差评分类)")
print("="*60 + "\n")

if isinstance(base_model, PeftModel):
    base_model = base_model.unload()
    base_model.config.use_cache = False

trainer = SFTTrainer(
    model=base_model,
    train_dataset=sft_dataset["train"],
    eval_dataset=sft_dataset["validation"],
    peft_config=lora_config,
    args=training_args,
    processing_class=tokenizer,
)

# 打印可训练参数
trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in base_model.parameters())
print(f"Trainable params: {trainable_params:,} / {total_params:,} "
      f"({100 * trainable_params / total_params:.2f}%)")

train_result = trainer.train()
trainer.model.eval()
trainer.model.config.use_cache = True

print("\n训练完成!")
print("Training metrics:", train_result.metrics)

Cell 13 --- 保存模型

bash 复制代码
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

with open(os.path.join(OUTPUT_DIR, "train_metrics.json"), "w", encoding="utf-8") as f:
    json.dump(train_result.metrics, f, ensure_ascii=False, indent=2)

print("Saved adapter and tokenizer to:", OUTPUT_DIR)

Cell 14 --- 微调前后对比评估

bash 复制代码
print("\n" + "="*60)
print("微调前 → 微调后 效果对比")
print("="*60 + "\n")

# ---- 微调前评估 ----
print("评估微调前 (base model) ...")
pre_metrics, pre_report, pre_preds = evaluate_model(base_model, tokenizer, split="test", limit=EVAL_LIMIT)
print(f"  Accuracy:  {pre_metrics['accuracy']:.4f}")
print(f"  Macro F1:  {pre_metrics['macro_f1']:.4f}")
print(f"  Invalid:   {pre_metrics['invalid_predictions']}")
print()

# ---- 微调后评估 ----
print("评估微调后 (fine-tuned model) ...")
lora_model = trainer.model
post_metrics, post_report, post_preds = evaluate_model(lora_model, tokenizer, split="test", limit=EVAL_LIMIT)
print(f"  Accuracy:  {post_metrics['accuracy']:.4f}")
print(f"  Macro F1:  {post_metrics['macro_f1']:.4f}")
print(f"  Invalid:   {post_metrics['invalid_predictions']}")
print()

# ---- 对比汇总 ----
comparison_df = pd.DataFrame([
    {"metric": "accuracy",  "before": pre_metrics["accuracy"],  "after": post_metrics["accuracy"]},
    {"metric": "macro_f1",  "before": pre_metrics["macro_f1"],  "after": post_metrics["macro_f1"]},
    {"metric": "invalid_preds", "before": pre_metrics["invalid_predictions"], "after": post_metrics["invalid_predictions"]},
])
print("效果对比:")
print(comparison_df.to_string(index=False))
print()

# ---- 预测示例 ----
merged_examples = pd.merge(
    pre_preds[["text", "true_label", "pred_label"]].rename(columns={"pred_label": "pred_before"}),
    post_preds[["text", "pred_label"]].rename(columns={"pred_label": "pred_after"}),
    on="text",
)
changed_predictions = merged_examples[merged_examples["pred_before"] != merged_examples["pred_after"]]
print(f"预测发生变化的样本数: {len(changed_predictions)} / {len(merged_examples)}")
if len(changed_predictions) > 0:
    print("示例(微调前 → 微调后):")
    for _, row in changed_predictions.head(5).iterrows():
        before_icon = "✓" if row["pred_before"] == row["true_label"] else "✗"
        after_icon = "✓" if row["pred_after"] == row["true_label"] else "✗"
        print(f"  [{before_icon}→{after_icon}] {row['text'][:50]}...")
        print(f"          真实: {row['true_label']} | 微调前: {row['pred_before']} | 微调后: {row['pred_after']}")

Cell 15 --- 保存评估结果

bash 复制代码
comparison_df.to_csv(os.path.join(OUTPUT_DIR, "sentiment_before_after_metrics.csv"), index=False)
merged_examples.to_csv(os.path.join(OUTPUT_DIR, "sentiment_prediction_examples.csv"), index=False)
changed_predictions.to_csv(os.path.join(OUTPUT_DIR, "sentiment_changed_predictions.csv"), index=False)
pre_preds.to_csv(os.path.join(OUTPUT_DIR, "pre_finetuning_predictions.csv"), index=False)
post_preds.to_csv(os.path.join(OUTPUT_DIR, "post_finetuning_predictions.csv"), index=False)

pd.DataFrame(pre_report).transpose().to_csv(
    os.path.join(OUTPUT_DIR, "pre_finetuning_classification_report.csv"))
pd.DataFrame(post_report).transpose().to_csv(
    os.path.join(OUTPUT_DIR, "post_finetuning_classification_report.csv"))
confusion_matrix_df(pre_preds).to_csv(
    os.path.join(OUTPUT_DIR, "pre_finetuning_confusion_matrix.csv"))
confusion_matrix_df(post_preds).to_csv(
    os.path.join(OUTPUT_DIR, "post_finetuning_confusion_matrix.csv"))

print("\nSaved all outputs to:", OUTPUT_DIR)

# ---- 最终总结 ----
print("\n" + "="*60)
print("电商好差评分类微调完成!")
print(f"  数据集: {MODELSCOPE_DATASET_ID} (京东商品评论)")
print(f"  模型:   {MODELSCOPE_MODEL_ID} + LoRA")
print(f"  类别:   {LABEL_NAMES_ZH}")
print(f"  输出:   {OUTPUT_DIR}/")
print("="*60)