分布式执行引擎ray入门--(3)Ray Train

Ray Train中包含4个部分

  1. Training function: 包含训练模型逻辑的函数

  2. Worker: 用来跑训练的

  3. Scaling configuration: 配置

  4. Trainer: 协调以上三个部分

Ray Train+PyTorch

这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。

复制代码
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray.train.torch

def train_func(config):
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # model.to("cuda")  # This is done by `prepare_model`
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(10):
        for images, labels in train_loader:
            # This is done by `prepare_data_loader`!
            # images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()

# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
    model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.load_state_dict(model_state_dict)

模型

ray.train.torch.prepare_model()

复制代码
model = ray.train.torch.prepare_model(model)
复制代码
相当于model.to(device_id or "cpu") +  DistributedDataParallel(model, device_ids=[device_id])

将model移动到合适的device上,同时实现分布式

数据

ray.train.torch.prepare_data_loader()

报告 checkpoints 和 metrics

复制代码
+import ray.train
+from ray.train import Checkpoint

 def train_func(config):

     ...
     torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
+    metrics = {"loss": loss.item()} # Training/validation metrics.
+    checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
+    ray.train.report(metrics=metrics, checkpoint=checkpoint)

     ...
复制代码
data_loader = ray.train.torch.prepare_data_loader(data_loader)

将batches移动到合适的device上,同时实现分布式sampler

配置 scale 和 GPUs

复制代码
from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储

多节点分布式训练时必须指定,本地路径会有问题。

复制代码
from ray.train import RunConfig

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")

# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")

# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

启动训练任务

复制代码
from ray.train.torch import TorchTrainer

trainer = TorchTrainer(
    train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()
相关推荐
Include everything2 小时前
Rust学习笔记(三)|所有权机制 Ownership
笔记·学习·rust
杜子不疼.2 小时前
《Python学习之文件操作:从入门到精通》
数据库·python·学习
★YUI★2 小时前
学习游戏制作记录(玩家掉落系统,删除物品功能和独特物品)8.17
java·学习·游戏·unity·c#
livemetee3 小时前
Flink2.0学习笔记:Flink服务器搭建与flink作业提交
大数据·笔记·学习·flink
INS_KF4 小时前
【C++知识杂记2】free和delete区别
c++·笔记·学习
Easocen4 小时前
Mybatis学习笔记(五)
笔记·学习·mybatis
lisuwen1165 小时前
AI三国杀:马斯克炮轰苹果“偏袒”OpenAI,Grok与ChatGPT的应用商店战争揭秘
人工智能·chatgpt
暮小暮5 小时前
从ChatGPT到智能助手:Agent智能体如何颠覆AI应用
人工智能·深度学习·神经网络·ai·语言模型·chatgpt
John_ToDebug5 小时前
大模型提示词(Prompt)终极指南:从原理到实战,让AI输出质量提升300%
人工智能·chatgpt·prompt