在Google Cloud上使用PyTorch:如何在Vertex AI上训练和调优PyTorch模型

本文将介绍如何在Google Cloud的Vertex AI平台上使用PyTorch进行模型训练和调优,重点是情感分析任务。我们将通过示例代码和详细步骤,帮助大家更好地理解这一过程。

什么是Vertex AI?

Vertex AI是Google Cloud提供的一个端到端机器学习平台,旨在简化机器学习项目的构建和管理。它整合了Google Cloud的现有机器学习工具,支持从数据准备到模型部署的整个工作流。

使用案例和数据集

我们将使用Hugging Face的预训练BERT模型来进行IMDB情感分类。BERT(双向编码器表示模型)是一种在大量无标签文本上自我监督训练的Transformer模型,非常适合处理自然语言处理(NLP)任务。

数据集准备

我们将使用IMDB情感分类数据集进行实验。首先,我们需要在Jupyter Notebook中设置开发环境。

创建开发环境

  1. 在Google Cloud控制台中创建一个Notebook实例。
  2. 打开Notebook页面,点击"OPEN JUPYTERLAB"链接以启动JupyterLab。

在Vertex Training上训练PyTorch模型

模型细节

我们将微调Hugging Face的BERT模型,以分析IMDB电影评论的情感。以下是训练过程中的主要步骤:

  1. 数据预处理:对评论数据进行分词处理。
  2. 加载预训练模型:加载BERT模型并添加用于情感分析的序列分类头。
  3. 微调模型:对BERT模型进行微调。

示例代码

以下是用于数据预处理和微调BERT模型的代码片段:

python 复制代码
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

# 加载数据集
datasets = load_dataset('imdb')

# 实例化分词器
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased", use_fast=True)

# 数据预处理函数
def preprocess_function(examples):
    return tokenizer(examples['text'], truncation=True)

# 预处理数据集
datasets = datasets.map(preprocess_function, batched=True)

# 加载预训练BERT模型
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

# 设置训练参数
training_args = TrainingArguments(
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    output_dir='./results'
)

# 创建Trainer对象
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=datasets['train'],
    eval_dataset=datasets['test']
)

# 开始训练
trainer.train()

在Vertex AI上运行训练作业

为了处理更大的数据集或更复杂的模型,我们可以使用Vertex Training服务。其优势包括:

  • 自动配置和释放计算资源。
  • 便于重用和移植性。
  • 支持大规模分布式训练。
  • 提供日志记录和监控功能。

提交自定义作业

  1. 将训练代码打包并上传到Cloud Storage。
  2. 使用Vertex SDK提交自定义作业。

以下是提交自定义作业的示例代码:

python 复制代码
from google.cloud import aiplatform

# 初始化Vertex SDK
aiplatform.init(project='YOUR_PROJECT_ID', staging_bucket='YOUR_BUCKET_NAME')

# 定义作业参数
job = aiplatform.CustomPythonPackageTrainingJob(
    display_name='finetuned-bert-classifier',
    python_package_gcs_uri='gs://YOUR_BUCKET_NAME/YOUR_PACKAGE_PATH',
    python_module_name='trainer.task',
    container_uri='us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-7:latest',
)

# 提交作业
model = job.run(
    replica_count=1,
    machine_type='n1-standard-8',
    accelerator_type='NVIDIA_TESLA_T4',
    accelerator_count=1,
)

超参数调优

超参数如学习率和权重衰减对模型性能有重大影响。我们可以通过Vertex AI自动化调优这些超参数。

示例代码

以下是如何设置超参数调优的示例代码:

python 复制代码
from google.cloud import aiplatform

# 定义超参数调优作业
hp_job = aiplatform.HyperparameterTuningJob(
    display_name='finetuned-bert-hptune',
    custom_job=custom_job,
    metric_spec={'accuracy': 'maximize'},
    parameter_spec={
        'learning-rate': hpt.DoubleParameterSpec(min=1e-6, max=0.001, scale='log'),
        'weight-decay': hpt.DiscreteParameterSpec(values=[0.0001, 0.001, 0.01], scale=None),
    },
    max_trial_count=5,
    parallel_trial_count=2,
)

model = hp_job.run(sync=False)

本地运行预测

完成训练后,可以在本地运行预测。以下是预测调用的示例代码:

python 复制代码
def predict(input_text, saved_model_path):
   tokenizer = AutoTokenizer.from_pretrained(saved_model_path)
   predict_input = tokenizer.encode(input_text, truncation=True, max_length=128, return_tensors='pt')
   loaded_model = AutoModelForSequenceClassification.from_pretrained(saved_model_path)
   output = loaded_model(predict_input)
   label_id = torch.argmax(output.logits, dim=1)
   return label_id.item()

清理环境

实验结束后,可以选择停止或删除Notebook实例,以避免产生额外费用。

总结

本文介绍了如何在Google Cloud的Vertex AI上使用PyTorch进行情感分析任务,包括创建开发环境、训练和调优模型、提交作业等步骤。接下来的文章将展示如何在Vertex Prediction服务上部署PyTorch模型,并利用Vertex Pipelines自动化机器学习工作流。

相关推荐
东方翱翔6 分钟前
第十六届蓝桥杯大赛软件赛省赛第二场 C/C++ 大学 A 组
算法·职场和发展·蓝桥杯
JiangJiang14 分钟前
🧠 面试官:受控组件都分不清?还敢说自己写过 React?
前端·react.js·面试
Jenlybein14 分钟前
[ Javascript 面试题 ]:提取对应的信息,并给其赋予一个颜色,保持幂等性
前端·javascript·面试
夜熵15 分钟前
JavaScript 中的 this
前端·面试
Synmbrf19 分钟前
说说平时开发注意事项
javascript·面试·代码规范
小智疯狂敲代码28 分钟前
Spring MVC-DispatcherServlet 的源码解析
java·面试
Blossom.11834 分钟前
量子计算在密码学中的应用与挑战:重塑信息安全的未来
人工智能·深度学习·物联网·算法·密码学·量子计算·量子安全
1白天的黑夜139 分钟前
贪心算法-860.柠檬水找零-力扣(LeetCode)
c++·算法·leetcode·贪心算法
搏博1 小时前
专家系统的基本概念解析——基于《人工智能原理与方法》的深度拓展
人工智能·python·深度学习·算法·机器学习·概率论
yzx9910131 小时前
决策树随机深林
人工智能·python·算法·决策树·机器学习