Skip to main content

Documentation Index

Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

A Checkpoint is a directory of files that Ray Train manages: copying it to durable storage, attaching it to a trial, and exposing it through the result object.

Save a checkpoint

import tempfile
import torch
import ray.train
from ray.train import Checkpoint

def train_loop_per_worker(config):
    ...
    for epoch in range(config["epochs"]):
        ...
        with tempfile.TemporaryDirectory() as tmpdir:
            torch.save(model.state_dict(), f"{tmpdir}/model.pt")
            ray.train.report(
                metrics={"loss": loss.item(), "epoch": epoch},
                checkpoint=Checkpoint.from_directory(tmpdir),
            )
Ray Train copies the directory to the configured storage_path and increments the checkpoint index.

Storage path

Checkpoints land under <storage_path>/<run_name>/checkpoint_<index>/.
from ray.train import RunConfig

run_config = RunConfig(storage_path="s3://bucket/runs/", name="resnet-finetune")

Configure retention

from ray.train import CheckpointConfig

CheckpointConfig(
    num_to_keep=3,
    checkpoint_score_attribute="val_loss",
    checkpoint_score_order="min",
)
num_to_keep controls how many checkpoints survive on disk; the lowest-scoring checkpoints are evicted first.

Resume from a checkpoint

trainer = TorchTrainer.restore(
    "s3://bucket/runs/resnet-finetune",
    train_loop_per_worker=train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
result = trainer.fit()
Inside train_loop_per_worker, recover state from the latest checkpoint:
checkpoint = ray.train.get_checkpoint()
if checkpoint:
    with checkpoint.as_directory() as ckpt_dir:
        state = torch.load(f"{ckpt_dir}/model.pt")
        model.load_state_dict(state)

Best checkpoint

result = trainer.fit()
print(result.best_checkpoints[0])         # (Checkpoint, metrics)
print(result.checkpoint)                  # latest

Best practices

Save model and optimizer state plus epoch/step counters in a single dict. On restore, you’ll need all three to resume mid-epoch.
Avoid saving raw Python objects in checkpoints. Use framework-native serializers (torch.save, model.save_pretrained, keras.save_model) so checkpoints are portable across Ray versions.

Next steps

Fault tolerance

Resume after a crash.

Run config

All run-level options.