Documentation Index Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt
Use this file to discover all available pages before exploring further.
Ray Train integrates with PyTorch Lightning’s trainer through ray.train.lightning. You keep your existing LightningModule and LightningDataModule; Ray handles distribution.
Install
pip install -U "ray[train]" "pytorch-lightning>=2.1"
Minimal example
import lightning.pytorch as pl
import torch.nn as nn
import torch.nn.functional as F
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
RayDDPStrategy,
RayLightningEnvironment,
RayTrainReportCallback,
prepare_trainer,
)
class LitModel ( pl . LightningModule ):
def __init__ ( self ):
super (). __init__ ()
self .layer = nn.Linear( 28 * 28 , 10 )
def forward ( self , x ):
return self .layer(x.flatten( 1 ))
def training_step ( self , batch , _ ):
x, y = batch
loss = F.cross_entropy( self (x), y)
self .log( "loss" , loss, prog_bar = True )
return loss
def configure_optimizers ( self ):
return torch.optim.Adam( self .parameters(), lr = 1e-3 )
def train_loop_per_worker ( config ):
dm = MyLightningDataModule( batch_size = config[ "batch_size" ])
trainer = pl.Trainer(
max_epochs = config[ "epochs" ],
devices = "auto" ,
accelerator = "auto" ,
strategy = RayDDPStrategy(),
plugins = [RayLightningEnvironment()],
callbacks = [RayTrainReportCallback()],
enable_progress_bar = False ,
)
trainer = prepare_trainer(trainer)
trainer.fit(LitModel(), datamodule = dm)
trainer = TorchTrainer(
train_loop_per_worker,
train_loop_config = { "epochs" : 5 , "batch_size" : 64 },
scaling_config = ScalingConfig( num_workers = 4 , use_gpu = True ),
)
result = trainer.fit()
Key integration points
Helper Purpose RayDDPStrategyTells Lightning to use DDP with Ray’s process group setup. Variants: RayFSDPStrategy, RayDeepSpeedStrategy. RayLightningEnvironmentPlugs Lightning’s distributed environment into Ray’s worker topology. RayTrainReportCallbackForwards Lightning’s logged metrics to ray.train.report. prepare_trainer(trainer)Final wiring step that returns a fully configured trainer.
FSDP / DeepSpeed
from ray.train.lightning import RayFSDPStrategy, RayDeepSpeedStrategy
trainer = pl.Trainer( strategy = RayFSDPStrategy( ... ), ... )
# or
trainer = pl.Trainer( strategy = RayDeepSpeedStrategy( ... ), ... )
Use a Ray Dataset
def train_loop_per_worker ( config ):
train_shard = ray.train.get_dataset_shard( "train" )
train_loader = train_shard.iter_torch_batches( batch_size = 64 )
trainer = pl.Trainer( ... )
trainer = prepare_trainer(trainer)
trainer.fit(LitModel(), train_dataloaders = train_loader)
Next steps
Distributed PyTorch DDP and FSDP details.
Checkpointing Save Lightning checkpoints with Ray Train.