Skip to main content

Documentation Index

Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

Ray Train integrates with PyTorch Lightning’s trainer through ray.train.lightning. You keep your existing LightningModule and LightningDataModule; Ray handles distribution.

Install

pip install -U "ray[train]" "pytorch-lightning>=2.1"

Minimal example

import lightning.pytorch as pl
import torch.nn as nn
import torch.nn.functional as F

import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return self.layer(x.flatten(1))

    def training_step(self, batch, _):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.log("loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


def train_loop_per_worker(config):
    dm = MyLightningDataModule(batch_size=config["batch_size"])
    trainer = pl.Trainer(
        max_epochs=config["epochs"],
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[RayTrainReportCallback()],
        enable_progress_bar=False,
    )
    trainer = prepare_trainer(trainer)
    trainer.fit(LitModel(), datamodule=dm)


trainer = TorchTrainer(
    train_loop_per_worker,
    train_loop_config={"epochs": 5, "batch_size": 64},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
result = trainer.fit()

Key integration points

HelperPurpose
RayDDPStrategyTells Lightning to use DDP with Ray’s process group setup. Variants: RayFSDPStrategy, RayDeepSpeedStrategy.
RayLightningEnvironmentPlugs Lightning’s distributed environment into Ray’s worker topology.
RayTrainReportCallbackForwards Lightning’s logged metrics to ray.train.report.
prepare_trainer(trainer)Final wiring step that returns a fully configured trainer.

FSDP / DeepSpeed

from ray.train.lightning import RayFSDPStrategy, RayDeepSpeedStrategy

trainer = pl.Trainer(strategy=RayFSDPStrategy(...), ...)
# or
trainer = pl.Trainer(strategy=RayDeepSpeedStrategy(...), ...)

Use a Ray Dataset

def train_loop_per_worker(config):
    train_shard = ray.train.get_dataset_shard("train")
    train_loader = train_shard.iter_torch_batches(batch_size=64)
    trainer = pl.Trainer(...)
    trainer = prepare_trainer(trainer)
    trainer.fit(LitModel(), train_dataloaders=train_loader)

Next steps

Distributed PyTorch

DDP and FSDP details.

Checkpointing

Save Lightning checkpoints with Ray Train.