Skip to main content

Documentation Index

Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

This guide turns a single-GPU PyTorch training script into a distributed Ray Train job.

Install

pip install -U "ray[train]" torch torchvision

A minimal training loop

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import ray
import ray.train
import ray.train.torch as raytorch
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer


def train_loop_per_worker(config):
    transform = transforms.Compose([transforms.ToTensor()])
    train_ds = datasets.FashionMNIST("/tmp/data", download=True, transform=transform)
    loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
    loader = raytorch.prepare_data_loader(loader)

    model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
    model = raytorch.prepare_model(model)

    optim = torch.optim.SGD(model.parameters(), lr=config["lr"])
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(config["epochs"]):
        for x, y in loader:
            optim.zero_grad()
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            optim.step()
        ray.train.report({"loss": loss.item(), "epoch": epoch})


trainer = TorchTrainer(
    train_loop_per_worker,
    train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 5},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
result = trainer.fit()
print(result.metrics)

What changed vs. a single-GPU script

1

Wrap the model with prepare_model

prepare_model moves the model to the worker’s GPU and wraps it in DistributedDataParallel.
2

Wrap the loader with prepare_data_loader

prepare_data_loader adds a DistributedSampler so each worker sees a unique shard of the data.
3

Report metrics with ray.train.report

Reported metrics show up in the dashboard and the result object.
4

Pick a ScalingConfig

num_workers=4, use_gpu=True runs four worker processes, each on its own GPU.

Save checkpoints

import tempfile
from ray.train import Checkpoint

def train_loop_per_worker(config):
    ...
    for epoch in range(config["epochs"]):
        ...
        with tempfile.TemporaryDirectory() as tmpdir:
            torch.save(model.state_dict(), f"{tmpdir}/model.pt")
            ray.train.report(
                {"loss": loss.item(), "epoch": epoch},
                checkpoint=Checkpoint.from_directory(tmpdir),
            )

Use a Ray Dataset

For datasets that don’t fit on a single node, use Ray Data:
ds = ray.data.read_parquet("s3://bucket/train/")

def train_loop_per_worker(config):
    train_shard = ray.train.get_dataset_shard("train")
    for batch in train_shard.iter_torch_batches(batch_size=64):
        x, y = batch["features"], batch["label"]
        ...

trainer = TorchTrainer(
    train_loop_per_worker,
    datasets={"train": ds},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)

Next steps

Distributed PyTorch

DDP, FSDP, and tensor parallelism.

Lightning

Same scaling story with the Lightning trainer.