Gemma 4 + LoRA 电商评论好差评分类
- Gemma 4 + LoRA 微调:电商评论好差评分类(二分类)
- 从原 emotion 分类(6分类)改为电商评论情感分类(正面/负面)。
- 数据集:DAMO_NLP/jd(京东商品评论,来自 ModelScope)
- 模型:google/gemma-4-E4B-it(从 ModelScope 下载)
Cell 1 --- 安装依赖(在 Radeon Cloud / Colab 中运行)
!uv pip install -U vllm modelscope transformers accelerate datasets trl peft scikit-learn pandas tqdm torchvision --no-cache -i https://mirrors.cloud.tencent.com/pypi/simple/ --extra-index-url https://wheels.vllm.ai/rocm/
Cell 2 --- 导入依赖与全局配置
import os
import glob
import re
import json
import random
import warnings
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from datasets import Dataset, DatasetDict, ClassLabel, load_dataset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from modelscope import snapshot_download
from modelscope.hub.snapshot_download import dataset_snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from peft import LoraConfig, PeftModel
from trl import SFTConfig, SFTTrainer
warnings.filterwarnings("ignore")
# ---- 配置参数 ----
MODELSCOPE_MODEL_ID = "google/gemma-4-E4B-it"
MODELSCOPE_DATASET_ID = "DAMO_NLP/jd" # <-- 京东商品评论数据集 (ModelScope)
OUTPUT_DIR = "./gemma4-it-jd-sentiment-lora"
TRAIN_LIMIT = 4000
VALIDATION_LIMIT = 400
TEST_LIMIT = 400
EVAL_LIMIT = 400
SEED = 42
MODEL_DTYPE = torch.bfloat16
BF16 = True
FP16 = False
# ---- 任务定义:电商好差评二分类 ----
LABEL_NAMES = ["negative", "positive"]
LABEL_NAMES_ZH = ["负面", "正面"]
SYSTEM_PROMPT = """You are a Chinese e-commerce review sentiment classifier.
Read the user's product review and answer with exactly one label.
Only choose from: negative, positive.
Return only the label and nothing else."""
LABEL_PATTERN = re.compile(r"(negative|positive)", re.IGNORECASE)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs("./models", exist_ok=True)
os.makedirs("./datasets", exist_ok=True)
print("torch version:", torch.__version__)
print("torch.cuda.is_available():", torch.cuda.is_available())
print("torch.cuda.device_count():", torch.cuda.device_count())
if torch.cuda.is_available():
print("current device:", torch.cuda.current_device())
print("device name:", torch.cuda.get_device_name(0))
Cell 3 --- 固定随机种子
def setup_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
set_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
setup_seed(SEED)
Cell 4 --- 从 ModelScope 下载模型
print("Downloading model from ModelScope...")
print("ModelScope model id:", MODELSCOPE_MODEL_ID)
model_dir = snapshot_download(
MODELSCOPE_MODEL_ID,
cache_dir="./models",
)
print("Downloaded model dir:", model_dir)
LOCAL_MODEL_DIR = model_dir
Cell 5 --- 从 ModelScope 下载并加载京东评论数据集
print("Downloading dataset from ModelScope...")
print("ModelScope dataset id:", MODELSCOPE_DATASET_ID)
dataset_dir = dataset_snapshot_download(
MODELSCOPE_DATASET_ID,
cache_dir="./datasets",
)
print("Downloaded dataset dir:", dataset_dir)
# 数据集文件: train.csv / dev.csv (无 test 集,从 train 中分出)
# 列: sentence (评论文本), label (0=负面, 1=正面)
raw_dataset = load_dataset(
"csv",
data_files={
"train": os.path.join(dataset_dir, "train.csv"),
"dev": os.path.join(dataset_dir, "dev.csv"),
},
)
# 移除多余的 "dataset" 列,重命名 sentence → text
raw_dataset = raw_dataset.rename_column("sentence", "text")
if "dataset" in raw_dataset["train"].column_names:
raw_dataset = raw_dataset.remove_columns(["dataset"])
# 调试:看一眼 label 的原始值和类型
print("=== 原始数据检查 ===")
for sn in raw_dataset.keys():
ex0 = raw_dataset[sn][0]
print(f" {sn}[0]: label={ex0['label']!r} (type={type(ex0['label']).__name__}), "
f"text={ex0['text'][:40]!r}")
print()
# 将 label 转为 int:用 filter + map,防御 None 值漏进来
def _safe_int_label(example):
val = example["label"]
# None / NaN / 空字符串 都视为无效,标为 -1 后续过滤
if val is None or (isinstance(val, float) and (val != val)):
return {"label": -1}
return {"label": int(val)}
for split_name in list(raw_dataset.keys()):
raw_dataset[split_name] = raw_dataset[split_name].map(_safe_int_label)
before = len(raw_dataset[split_name])
raw_dataset[split_name] = raw_dataset[split_name].filter(lambda x: x["label"] != -1)
after = len(raw_dataset[split_name])
if before != after:
print(f" ⚠ {split_name}: 过滤掉 {before - after} 条无效 label")
print("Raw dataset:", raw_dataset)
print("Train size:", len(raw_dataset["train"]))
print("Dev size:", len(raw_dataset["dev"]))
def maybe_limit(split, limit):
split = split.shuffle(seed=SEED)
if limit is None:
return split
return split.select(range(min(limit, len(split))))
# 从 train 中分出 test(train/validation/test 三分)
train_full = raw_dataset["train"]
# split: 80% train, 10% validation, 10% test
train_test = train_full.train_test_split(test_size=0.2, seed=SEED)
train_val = train_test["train"].train_test_split(test_size=0.125, seed=SEED) # 0.8*0.125=0.1
dataset = DatasetDict({
"train": maybe_limit(train_val["train"], TRAIN_LIMIT),
"validation": maybe_limit(train_val["test"], VALIDATION_LIMIT),
"test": maybe_limit(train_test["test"], TEST_LIMIT),
})
# label 现在就是普通 int(0/1),映射逻辑全用 LABEL_NAMES
label_names = LABEL_NAMES
VALID_LABELS = set(label_names)
ALL_EVAL_LABELS = label_names + ["INVALID"]
print(dataset)
print("label_names:", label_names)
print("Label distribution (train):", pd.Series(dataset["train"]["label"]).value_counts().to_dict())
print("example:", dataset["train"][0])
Cell 6 --- 构造 prompt-completion 格式
def to_prompt_completion(example):
text = example["text"]
label_idx = int(example["label"]) # 确保是 int,避免 NoneType
label = LABEL_NAMES[label_idx]
user_content = f"Classify the sentiment of this product review:\n\n{text}"
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
"completion": [
{"role": "assistant", "content": label},
],
}
sft_dataset = dataset.map(
to_prompt_completion,
remove_columns=dataset["train"].column_names,
)
print(sft_dataset)
print(sft_dataset["train"][0])
Cell 7 --- 加载 tokenizer 和基础模型
print("Loading tokenizer from:", LOCAL_MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(
LOCAL_MODEL_DIR,
use_fast=True,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
TEMPLATE_SOURCE_MODEL_ID = "google/gemma-4-E4B-it"
def _load_official_gemma_chat_template() -> str:
try:
template_dir = snapshot_download(
TEMPLATE_SOURCE_MODEL_ID,
cache_dir="./models",
allow_file_pattern=["chat_template.jinja"],
)
path = os.path.join(template_dir, "chat_template.jinja")
if os.path.exists(path):
with open(path, "r", encoding="utf-8") as f:
return f.read()
except Exception as e:
print("snapshot_download(allow_file_pattern) failed, fallback to HTTP. err =", e)
import urllib.request
url = (
"https://www.modelscope.cn/api/v1/models/"
f"{TEMPLATE_SOURCE_MODEL_ID}/repo?Revision=master&FilePath=chat_template.jinja"
)
with urllib.request.urlopen(url, timeout=60) as resp:
return resp.read().decode("utf-8")
if not getattr(tokenizer, "chat_template", None):
print(f"Loading official chat_template.jinja from {TEMPLATE_SOURCE_MODEL_ID} ...")
tokenizer.chat_template = _load_official_gemma_chat_template()
else:
print("tokenizer.chat_template already set, leaving as-is.")
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = AutoModelForCausalLM.from_pretrained(
LOCAL_MODEL_DIR,
torch_dtype=MODEL_DTYPE,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
base_model.to(device)
base_model.config.use_cache = False
base_model.config.pad_token_id = tokenizer.pad_token_id
base_model.config.bos_token_id = tokenizer.bos_token_id
base_model.config.eos_token_id = tokenizer.eos_token_id
base_model.generation_config.pad_token_id = tokenizer.pad_token_id
base_model.generation_config.bos_token_id = tokenizer.bos_token_id
base_model.generation_config.eos_token_id = tokenizer.eos_token_id
print("Base model loaded.")
print("Base model device:", next(base_model.parameters()).device)
Cell 8 --- 推理辅助函数
def extract_label(raw_text: str) -> str:
"""从模型原始输出中提取标签"""
raw_text = raw_text.strip().lower()
match = LABEL_PATTERN.search(raw_text)
if match:
return match.group(1)
tokens = raw_text.split()
if not tokens:
return "INVALID"
return tokens[0].strip(".,!?:;\"'()[]{}")
def generate_label(model, tokenizer, user_text: str,
system_prompt: str = SYSTEM_PROMPT,
max_new_tokens: int = 4) -> str:
"""用模型对一段文本做情感分类推理"""
user_content = f"Classify the sentiment of this product review:\n\n{user_text}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
]
device = next(model.parameters()).device
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = {k: v.to(device) for k, v in inputs.items()}
input_len = inputs["input_ids"].shape[-1]
model.eval()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
raw_pred = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
return extract_label(raw_pred)
def predict_sentiment(text: str, model=None):
"""对一条评论做情感预测"""
model = model or base_model
return generate_label(model, tokenizer, text)
# 快速测试
print("\n--- 微调前快速测试 ---")
for t in ["这个商品质量很好,物流也很快", "太差了,用了两天就坏了"]:
pred = predict_sentiment(t)
print(f" 「{t}」→ {pred}")
print()
Cell 9 --- 评估函数
def evaluate_model(model, tokenizer, split="test", limit=EVAL_LIMIT):
y_true, y_pred, rows = [], [], []
raw_source = dataset[split]
if limit is not None:
raw_source = raw_source.select(range(min(limit, len(raw_source))))
model.eval()
for ex in tqdm(raw_source, desc=f"Evaluating {split}", leave=False):
true_label = label_names[int(ex["label"])]
raw_pred_label = generate_label(model, tokenizer, ex["text"], SYSTEM_PROMPT)
pred_label = raw_pred_label if raw_pred_label in VALID_LABELS else "INVALID"
y_true.append(true_label)
y_pred.append(pred_label)
rows.append({
"text": ex["text"],
"true_label": true_label,
"pred_label": pred_label,
"raw_pred_label": raw_pred_label,
"correct": true_label == pred_label,
})
metrics = {
"accuracy": accuracy_score(y_true, y_pred),
"macro_f1": f1_score(y_true, y_pred, labels=label_names, average="macro", zero_division=0),
"invalid_predictions": sum(1 for p in y_pred if p == "INVALID"),
"evaluated_examples": len(y_true),
}
report = classification_report(
y_true, y_pred, labels=label_names, output_dict=True, zero_division=0,
)
return metrics, report, pd.DataFrame(rows)
def confusion_matrix_df(pred_df):
return pd.DataFrame(
confusion_matrix(pred_df["true_label"], pred_df["pred_label"], labels=ALL_EVAL_LABELS),
index=ALL_EVAL_LABELS,
columns=ALL_EVAL_LABELS,
)
Cell 10 --- LoRA 配置
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
Cell 11 --- 训练参数
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=4,
per_device_eval_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=1e-4,
weight_decay=0.01,
lr_scheduler_type="linear",
warmup_steps=50,
num_train_epochs=1,
logging_steps=5,
eval_strategy="steps",
eval_steps=25,
save_strategy="steps",
save_steps=25,
save_total_limit=2,
metric_for_best_model="eval_loss",
greater_is_better=False,
gradient_checkpointing=True,
bf16=BF16,
fp16=FP16,
tf32=False,
max_length=256,
packing=False,
completion_only_loss=True,
remove_unused_columns=False,
dataloader_num_workers=2,
optim="adamw_torch",
report_to="none",
seed=SEED,
data_seed=SEED,
)
Cell 12 --- 训练
print("\n" + "="*60)
print("开始微调训练 (电商评论好差评分类)")
print("="*60 + "\n")
if isinstance(base_model, PeftModel):
base_model = base_model.unload()
base_model.config.use_cache = False
trainer = SFTTrainer(
model=base_model,
train_dataset=sft_dataset["train"],
eval_dataset=sft_dataset["validation"],
peft_config=lora_config,
args=training_args,
processing_class=tokenizer,
)
# 打印可训练参数
trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in base_model.parameters())
print(f"Trainable params: {trainable_params:,} / {total_params:,} "
f"({100 * trainable_params / total_params:.2f}%)")
train_result = trainer.train()
trainer.model.eval()
trainer.model.config.use_cache = True
print("\n训练完成!")
print("Training metrics:", train_result.metrics)
Cell 13 --- 保存模型
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
with open(os.path.join(OUTPUT_DIR, "train_metrics.json"), "w", encoding="utf-8") as f:
json.dump(train_result.metrics, f, ensure_ascii=False, indent=2)
print("Saved adapter and tokenizer to:", OUTPUT_DIR)
Cell 14 --- 微调前后对比评估
print("\n" + "="*60)
print("微调前 → 微调后 效果对比")
print("="*60 + "\n")
# ---- 微调前评估 ----
print("评估微调前 (base model) ...")
pre_metrics, pre_report, pre_preds = evaluate_model(base_model, tokenizer, split="test", limit=EVAL_LIMIT)
print(f" Accuracy: {pre_metrics['accuracy']:.4f}")
print(f" Macro F1: {pre_metrics['macro_f1']:.4f}")
print(f" Invalid: {pre_metrics['invalid_predictions']}")
print()
# ---- 微调后评估 ----
print("评估微调后 (fine-tuned model) ...")
lora_model = trainer.model
post_metrics, post_report, post_preds = evaluate_model(lora_model, tokenizer, split="test", limit=EVAL_LIMIT)
print(f" Accuracy: {post_metrics['accuracy']:.4f}")
print(f" Macro F1: {post_metrics['macro_f1']:.4f}")
print(f" Invalid: {post_metrics['invalid_predictions']}")
print()
# ---- 对比汇总 ----
comparison_df = pd.DataFrame([
{"metric": "accuracy", "before": pre_metrics["accuracy"], "after": post_metrics["accuracy"]},
{"metric": "macro_f1", "before": pre_metrics["macro_f1"], "after": post_metrics["macro_f1"]},
{"metric": "invalid_preds", "before": pre_metrics["invalid_predictions"], "after": post_metrics["invalid_predictions"]},
])
print("效果对比:")
print(comparison_df.to_string(index=False))
print()
# ---- 预测示例 ----
merged_examples = pd.merge(
pre_preds[["text", "true_label", "pred_label"]].rename(columns={"pred_label": "pred_before"}),
post_preds[["text", "pred_label"]].rename(columns={"pred_label": "pred_after"}),
on="text",
)
changed_predictions = merged_examples[merged_examples["pred_before"] != merged_examples["pred_after"]]
print(f"预测发生变化的样本数: {len(changed_predictions)} / {len(merged_examples)}")
if len(changed_predictions) > 0:
print("示例(微调前 → 微调后):")
for _, row in changed_predictions.head(5).iterrows():
before_icon = "✓" if row["pred_before"] == row["true_label"] else "✗"
after_icon = "✓" if row["pred_after"] == row["true_label"] else "✗"
print(f" [{before_icon}→{after_icon}] {row['text'][:50]}...")
print(f" 真实: {row['true_label']} | 微调前: {row['pred_before']} | 微调后: {row['pred_after']}")
Cell 15 --- 保存评估结果
comparison_df.to_csv(os.path.join(OUTPUT_DIR, "sentiment_before_after_metrics.csv"), index=False)
merged_examples.to_csv(os.path.join(OUTPUT_DIR, "sentiment_prediction_examples.csv"), index=False)
changed_predictions.to_csv(os.path.join(OUTPUT_DIR, "sentiment_changed_predictions.csv"), index=False)
pre_preds.to_csv(os.path.join(OUTPUT_DIR, "pre_finetuning_predictions.csv"), index=False)
post_preds.to_csv(os.path.join(OUTPUT_DIR, "post_finetuning_predictions.csv"), index=False)
pd.DataFrame(pre_report).transpose().to_csv(
os.path.join(OUTPUT_DIR, "pre_finetuning_classification_report.csv"))
pd.DataFrame(post_report).transpose().to_csv(
os.path.join(OUTPUT_DIR, "post_finetuning_classification_report.csv"))
confusion_matrix_df(pre_preds).to_csv(
os.path.join(OUTPUT_DIR, "pre_finetuning_confusion_matrix.csv"))
confusion_matrix_df(post_preds).to_csv(
os.path.join(OUTPUT_DIR, "post_finetuning_confusion_matrix.csv"))
print("\nSaved all outputs to:", OUTPUT_DIR)
# ---- 最终总结 ----
print("\n" + "="*60)
print("电商好差评分类微调完成!")
print(f" 数据集: {MODELSCOPE_DATASET_ID} (京东商品评论)")
print(f" 模型: {MODELSCOPE_MODEL_ID} + LoRA")
print(f" 类别: {LABEL_NAMES_ZH}")
print(f" 输出: {OUTPUT_DIR}/")
print("="*60)