在Mac M1/M2上使用Hugging Face Transformers进行中文文本分类(完整指南)

在Mac M1/M2上使用Hugging Face Transformers进行中文文本分类(完整指南)

前言

随着Apple Silicon芯片(M1/M2)的普及,越来越多的开发者希望在Mac上运行深度学习任务。本文将详细介绍如何在Mac M1/M2设备上使用Hugging Face Transformers库进行中文文本分类任务,包括环境配置、数据处理、模型训练和性能优化等完整流程。

环境准备

1. 硬件和系统要求

设备 :Apple M1/M2系列芯片的Mac

系统 :macOS 12.3 (Monterey)或更高版本

Python:3.8或更高版本

2. 安装必要的库

bash 复制代码
# 创建虚拟环境
python -m venv .venv
source .venv/bin/activate

# 安装支持MPS的PyTorch
pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu

# 安装其他依赖
pip install transformers datasets evaluate

数据处理

1. 加载和预处理数据集

我们使用中文情感分析数据集ChnSentiCorp:

python 复制代码
from datasets import load_from_disk
from transformers import AutoTokenizer

# 加载数据集
dataset = load_from_disk('./data/ChnSentiCorp')

# 缩小数据集规模
dataset['train'] = dataset['train'].shuffle().select(range(1500))
dataset['test'] = dataset['test'].shuffle().select(range(100))

# 初始化tokenizer
tokenizer = AutoTokenizer.from_pretrained('hfl/rbt3')

# 编码函数
def encode_data(data):
    return tokenizer.batch_encode_plus(data['text'], truncation=True)

# 应用编码
dataset = dataset.map(encode_data, batched=True, batch_size=1000, num_proc=4, remove_columns=['text'])

# 过滤过长的句子
dataset = dataset.filter(lambda x: len(x['input_ids']) <= 512, batched=True)

2. 数据格式转换

python 复制代码
from transformers import DataCollatorWithPadding

# 数据整理器
data_collator = DataCollatorWithPadding(tokenizer)

模型加载与配置

1. 加载预训练模型

python 复制代码
from transformers import AutoModelForSequenceClassification
import torch

# 检查MPS是否可用
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# 加载模型
model = AutoModelForSequenceClassification.from_pretrained('./model/rbt3', num_labels=2)
model = model.to(device)

2. 自定义训练器(适配MPS)

python 复制代码
from transformers import Trainer

class MPSReadyTrainer(Trainer):
    def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
        inputs = {k: v.to('mps') for k, v in inputs.items()}
        return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)

训练配置

1. 设置训练参数

python 复制代码
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./out/rbt3',
    evaluation_strategy="steps",
    eval_steps=30,
    save_strategy='steps',
    save_steps=30,
    learning_rate=5e-5,
    per_device_train_batch_size=16,  # M1/M2建议较小batch size
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    logging_dir='./log/rbt3',
    no_cuda=True,
    use_mps_device=True,
    fp16=False  # MPS暂不支持混合精度
)

2. 定义评估指标

python 复制代码
from evaluate import load

def compute_metrics(eval_pred):
    metric = load('accuracy')
    logits, labels = eval_pred
    
    if isinstance(logits, torch.Tensor):
        predictions = logits.argmax(dim=-1)
    else:
        predictions = torch.from_numpy(logits).argmax(dim=-1)
    
    return metric.compute(predictions=predictions, references=labels)

训练与评估

1. 初始化训练器

python 复制代码
trainer = MPSReadyTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    compute_metrics=compute_metrics,
    data_collator=data_collator
)

2. 开始训练

python 复制代码
print("========== 开始训练 ==========")
trainer.train()

print("========== 最终评估 ==========")
trainer.evaluate()

性能优化技巧

  1. 调整batch size:M1/M2芯片建议使用8-32的batch size
  2. 禁用pin_memory :在TrainingArguments中设置dataloader_pin_memory=False
  3. 减少数据加载线程 :设置num_proc=2或更低
  4. 简化模型:使用更小的预训练模型如'rbt3'而非'bert-base'

常见问题解决

  1. MPS不可用错误

    • 确保安装了正确版本的PyTorch

    • 检查macOS版本≥12.3

    • 运行python -c "import torch; print(torch.backends.mps.is_available())"确认

  2. 内存不足错误

    • 减小batch size

    • 缩短序列长度(max_length=256)

  3. 数据类型不匹配

    • 确保所有张量都通过.to(device)转移到MPS

结语

本文详细介绍了在Apple Silicon Mac上使用Hugging Face Transformers进行中文文本分类的完整流程。通过合理配置和优化,可以在Mac设备上高效地进行NLP模型训练。希望这篇指南能帮助开发者充分利用M1/M2芯片的性能优势。

完整代码 已上传GitHub:项目链接
问题讨论欢迎在评论区留言

相关推荐
FrankHuang8884 小时前
使用高斯朴素贝叶斯算法对鸢尾花数据集进行分类
算法·机器学习·ai·分类
自由鬼6 小时前
数据分析图表类型及其应用场景
信息可视化·数据挖掘·数据分析
Hello.Reader7 小时前
Git 安装全攻略Linux、macOS、Windows 与源码编译
linux·git·macos
Hope Fancy7 小时前
macOS 连接 Docker 运行 postgres,使用navicat添加并关联数据库
macos·docker·postgresql
John Song7 小时前
macOS 上使用 Homebrew 安装redis-cli
数据库·redis·macos
yanjiee7 小时前
编译一个Mac M系列可以用的yuview
macos
数据知道7 小时前
Mac电脑上本地安装 redis并配置开启自启完整流程
数据库·redis·macos
狂小虎11 小时前
01 Deep learning神经网络的编程基础 二分类--吴恩达
深度学习·神经网络·分类
deephub11 小时前
让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比
人工智能·机器学习·数据挖掘·回归·异常值
电手12 小时前
Win10停更,Win11不好用?现在Mac电脑比Win11电脑更便宜
windows·macos·电脑·mac