Documentation Index
Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt
Use this file to discover all available pages before exploring further.
This guide turns a single-GPU PyTorch training script into a distributed Ray Train job.
Install
pip install -U "ray[train]" torch torchvision
A minimal training loop
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import ray
import ray.train
import ray.train.torch as raytorch
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker(config):
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.FashionMNIST("/tmp/data", download=True, transform=transform)
loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
loader = raytorch.prepare_data_loader(loader)
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
model = raytorch.prepare_model(model)
optim = torch.optim.SGD(model.parameters(), lr=config["lr"])
loss_fn = nn.CrossEntropyLoss()
for epoch in range(config["epochs"]):
for x, y in loader:
optim.zero_grad()
logits = model(x)
loss = loss_fn(logits, y)
loss.backward()
optim.step()
ray.train.report({"loss": loss.item(), "epoch": epoch})
trainer = TorchTrainer(
train_loop_per_worker,
train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 5},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
result = trainer.fit()
print(result.metrics)
What changed vs. a single-GPU script
Wrap the model with prepare_model
prepare_model moves the model to the worker’s GPU and wraps it in DistributedDataParallel.
Wrap the loader with prepare_data_loader
prepare_data_loader adds a DistributedSampler so each worker sees a unique shard of the data.
Report metrics with ray.train.report
Reported metrics show up in the dashboard and the result object.
Pick a ScalingConfig
num_workers=4, use_gpu=True runs four worker processes, each on its own GPU.
Save checkpoints
import tempfile
from ray.train import Checkpoint
def train_loop_per_worker(config):
...
for epoch in range(config["epochs"]):
...
with tempfile.TemporaryDirectory() as tmpdir:
torch.save(model.state_dict(), f"{tmpdir}/model.pt")
ray.train.report(
{"loss": loss.item(), "epoch": epoch},
checkpoint=Checkpoint.from_directory(tmpdir),
)
Use a Ray Dataset
For datasets that don’t fit on a single node, use Ray Data:
ds = ray.data.read_parquet("s3://bucket/train/")
def train_loop_per_worker(config):
train_shard = ray.train.get_dataset_shard("train")
for batch in train_shard.iter_torch_batches(batch_size=64):
x, y = batch["features"], batch["label"]
...
trainer = TorchTrainer(
train_loop_per_worker,
datasets={"train": ds},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
Next steps
Distributed PyTorch
DDP, FSDP, and tensor parallelism.
Lightning
Same scaling story with the Lightning trainer.