Skip to main content

Documentation Index

Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

Ray Train sets up the PyTorch distributed environment so you don’t have to. The same train_loop_per_worker runs across N workers; you just pick the parallelism strategy.

DistributedDataParallel (DDP)

The default. Each worker holds a full model replica; gradients are all-reduced at every backward pass.
from ray.train.torch import prepare_model

model = MyModel()
model = prepare_model(model)  # wraps in DDP under the hood
prepare_model moves the model to the worker’s GPU and wraps it in torch.nn.parallel.DistributedDataParallel.

Fully Sharded Data Parallel (FSDP)

For models too large to fit on one GPU. FSDP shards model parameters, gradients, and optimizer state across workers.
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

def train_loop_per_worker(config):
    model = MyLargeModel().cuda()
    model = FSDP(model)
    ...
Use Lightning’s RayFSDPStrategy for a higher-level integration.

Tensor parallelism

For very large models, combine FSDP with tensor parallelism. Use libraries like Megatron-LM, DeepSpeed, or PyTorch’s DeviceMesh. Ray Train provides the process group; the partitioning strategy is up to your library of choice.

Process group initialization

Ray Train initializes the default process group before train_loop_per_worker runs. You can use any PyTorch distributed primitive directly:
import torch.distributed as dist

if dist.get_rank() == 0:
    print("I am the rank-0 worker")

Communication backend

By default, Ray uses nccl for GPU workers and gloo for CPU workers. Override with TorchConfig:
from ray.train.torch import TorchConfig

trainer = TorchTrainer(
    ...,
    torch_config=TorchConfig(backend="nccl", timeout_s=1800),
)

Multi-node training

Just bump num_workers past one node’s GPU count. Ray Train places workers across nodes and configures the process group with the correct master address.
ScalingConfig(num_workers=16, use_gpu=True)  # 16 GPUs, possibly across many nodes

Next steps

Data loading

Feed Ray Datasets to PyTorch workers.

Fault tolerance

Recover from worker failures.