Skip to main content

Documentation Index

Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

Trainer

A Trainer is the entry point. Choose the trainer that matches your framework: TorchTrainer, LightningTrainer, TransformersTrainer, XGBoostTrainer, LightGBMTrainer, TensorflowTrainer, JaxTrainer. Each accepts a train_loop_per_worker callable plus configuration.
trainer = TorchTrainer(
    train_loop_per_worker,
    train_loop_config={...},
    scaling_config=ScalingConfig(...),
    run_config=RunConfig(...),
    datasets={"train": train_ds, "valid": valid_ds},
)
result = trainer.fit()

train_loop_per_worker

A function (or class) that runs on each worker. Inside, you set up your model, optimizer, dataloader, and training loop. Ray Train injects the framework’s distributed primitives — process groups, communication backends, sharding — so the loop runs identically across all workers.
def train_loop_per_worker(config):
    model = build_model()
    model = ray.train.torch.prepare_model(model)
    ...

ScalingConfig

Specifies how many workers to run, what resources each worker needs, and whether to use GPUs.
from ray.train import ScalingConfig

ScalingConfig(
    num_workers=8,
    use_gpu=True,
    resources_per_worker={"CPU": 4, "GPU": 1, "memory": 16 * 1024**3},
    placement_strategy="PACK",
)

RunConfig

Captures storage, naming, callbacks, and stop conditions.
from ray.train import RunConfig, CheckpointConfig

RunConfig(
    storage_path="s3://bucket/runs/",
    name="my-experiment",
    checkpoint_config=CheckpointConfig(num_to_keep=3),
    failure_config=FailureConfig(max_failures=2),
)

Datasets

Pass Ray Datasets via the datasets argument. Each worker receives a sharded iterator.
trainer = TorchTrainer(
    train_loop_per_worker,
    datasets={"train": train_ds, "valid": valid_ds},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
Inside the loop:
def train_loop_per_worker(config):
    train_shard = ray.train.get_dataset_shard("train")
    for epoch in range(config["epochs"]):
        for batch in train_shard.iter_torch_batches(batch_size=64):
            ...

Checkpoints

Workers report checkpoints via ray.train.report:
import tempfile
import ray.train
from ray.train import Checkpoint

def train_loop_per_worker(config):
    for epoch in range(config["epochs"]):
        ...
        with tempfile.TemporaryDirectory() as tmpdir:
            torch.save(model.state_dict(), f"{tmpdir}/model.pt")
            ray.train.report(
                metrics={"loss": loss},
                checkpoint=Checkpoint.from_directory(tmpdir),
            )

Result

trainer.fit() returns a Result object with the final metrics, the best checkpoint, and access to all reported checkpoints.
result = trainer.fit()
print(result.metrics)
print(result.checkpoint)

Next steps

PyTorch quickstart

A full PyTorch training example.

Distributed PyTorch

DDP, FSDP, and tensor parallelism.