利用 PyTorch Lightning 搭建一个文本分类模型

引言

在这篇博文^[1]^中,将逐步介绍如何使用 PyTorch Lightning 来构建和部署一个基础的文本分类模型。该项目借助了 PyTorch 生态中的多个强大工具,例如 torch、pytorch_lightning 以及 Hugging Face 提供的 transformers,从而构建了一个强大且可扩展的机器学习流程。

代码库包含四个核心的 Python 脚本:

  • data.py:负责数据的加载和预处理工作。
  • model.py:构建模型的结构。
  • train.py:包含了训练循环和训练的配置。
  • inference.py:支持使用训练好的模型进行推断。

下面详细解析每个部分,以便理解它们是如何协同作用,以实现文本分类的高效工作流程。

1. 数据加载与预处理

data.py 文件中,DataModule 类被设计用来处理数据加载和预处理的所有环节。它利用了 PyTorch Lightning 的 LightningDataModule,这有助于保持数据处理任务的模块化和可复用性。

class DataModule(pl.LightningDataModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

这个类在初始化时需要指定模型名称和批量大小,并从 Hugging Face 的 Transformers 库加载一个分词器。prepare_data() 函数会从 GLUE 基准测试套件中下载 CoLA 数据集,这个数据集经常用来评估自然语言理解(NLU)模型的性能。

setup() 函数负责对文本数据进行分词处理,并创建用于训练和验证的 PyTorch DataLoader 对象:

def setup(self, stage=None):
    if stage == "fit" or stage is None:
        self.train_data = self.train_data.map(self.tokenize_data, batched=True)
        self.train_data.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
        self.val_data = self.val_data.map(self.tokenize_data, batched=True)
        self.val_data.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

2. 模型架构

model.py 文件中定义的 ColaModel 类继承自 PyTorch Lightning 的 LightningModule。该模型采用 BERT(一种双向编码器表示,源自 Transformers)的简化版本作为文本表示的核心模型。

class ColaModel(pl.LightningModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=1e-2):
        super(ColaModel, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.W = nn.Linear(self.bert.config.hidden_size, 2)

模型在前向传播过程中提取 BERT 的最终隐藏状态,并通过一个线性层来生成用于二分类的对数几率(logits):

def forward(self, input_ids, attention_mask):
    outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    h_cls = outputs.last_hidden_state[:, 0]
    logits = self.W(h_cls)
    return logits

另外,training_step()validation_step() 函数分别负责处理训练和验证的逻辑,并记录诸如损失和准确率等关键指标。

3. Training Loop

train.py 脚本利用 PyTorch Lightning 的 Trainer 类来控制训练过程。它还包含了模型检查点和提前停止的回调机制,以防止模型过拟合。

checkpoint_callback = ModelCheckpoint(dirpath="./models", monitor="val_loss", mode="min")
early_stopping_callback = EarlyStopping(monitor="val_loss", patience=3, verbose=True, mode="min")

训练过程设定了最大周期数,并在可能的情况下利用 GPU 进行加速:

trainer = pl.Trainer(
    default_root_dir="logs",
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=5,
    fast_dev_run=False,
    logger=pl.loggers.TensorBoardLogger("logs/", name="cola", version=1),
    callbacks=[checkpoint_callback, early_stopping_callback],
)
trainer.fit(cola_model, cola_data)

这样的配置不仅让训练变得更加简便,还保证了模型能够定期保存并对其性能进行监控。

4. 推理

训练结束后,将利用模型来进行预测。inference.py 脚本中定义了一个名为 ColaPredictor 的类,该类负责加载经过训练的模型检查点,并提供了一个用于生成预测的方法:

class ColaPredictor:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = ColaModel.load_from_checkpoint(model_path)
        self.model.eval()
        self.model.freeze()

Predict() 方法接受文本输入,使用分词器对其进行处理,并返回模型的预测:

def predict(self, text):
    inference_sample = {"sentence": text}
    processed = self.processor.tokenize_data(inference_sample)
    logits = self.model(
        torch.tensor([processed["input_ids"]]),
        torch.tensor([processed["attention_mask"]]),
    )
    scores = self.softmax(logits[0]).tolist()
    predictions = [{"label": label, "score": score} for score, label in zip(scores, self.labels)]
    return predictions

总结

本项目展示了如何采用 PyTorch Lightning 进行构建、训练和部署文本分类模型的系统化方法。尽情地尝试代码,调整参数,并试用不同的数据集或模型吧。编程愉快!
Reference [1]

Source: https://medium.com/@mzeynali01/building-a-text-classification-model-with-pytorch-lightning-a-deep-dive-7a262cb5784b

本文由mdnice多平台发布

相关推荐
Guofu_Liao9 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
sp_fyf_202416 小时前
【大语言模型】ACL2024论文-19 SportsMetrics: 融合文本和数值数据以理解大型语言模型中的信息融合
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理
思通数科多模态大模型18 小时前
10大核心应用场景,解锁AI检测系统的智能安全之道
人工智能·深度学习·安全·目标检测·计算机视觉·自然语言处理·数据挖掘
龙的爹233318 小时前
论文翻译 | RECITATION-AUGMENTED LANGUAGE MODELS
人工智能·语言模型·自然语言处理·prompt·gpu算力
sp_fyf_202418 小时前
【大语言模型】ACL2024论文-18 MINPROMPT:基于图的最小提示数据增强用于少样本问答
人工智能·深度学习·神经网络·目标检测·机器学习·语言模型·自然语言处理
爱喝白开水a19 小时前
Sentence-BERT实现文本匹配【分类目标函数】
人工智能·深度学习·机器学习·自然语言处理·分类·bert·大模型微调
Guofu_Liao21 小时前
大语言模型中Softmax函数的计算过程及其参数描述
人工智能·语言模型·自然语言处理
曼城周杰伦1 天前
自然语言处理:第六十二章 KAG 超越GraphRAG的图谱框架
人工智能·pytorch·神经网络·自然语言处理·chatgpt·nlp·gpt-3
Donvink1 天前
多模态大语言模型——《动手学大模型》实践教程第六章
人工智能·深度学习·语言模型·自然语言处理·llama
我爱学Python!1 天前
解决复杂查询难题:如何通过 Self-querying Prompting 提高 RAG 系统效率?
人工智能·程序人生·自然语言处理·大模型·llm·大语言模型·rag