A Trainer is the entry point. Choose the trainer that matches your framework: TorchTrainer, LightningTrainer, TransformersTrainer, XGBoostTrainer, LightGBMTrainer, TensorflowTrainer, JaxTrainer. Each accepts a train_loop_per_worker callable plus configuration.
A function (or class) that runs on each worker. Inside, you set up your model, optimizer, dataloader, and training loop. Ray Train injects the framework’s distributed primitives — process groups, communication backends, sharding — so the loop runs identically across all workers.
def train_loop_per_worker(config): model = build_model() model = ray.train.torch.prepare_model(model) ...
def train_loop_per_worker(config): train_shard = ray.train.get_dataset_shard("train") for epoch in range(config["epochs"]): for batch in train_shard.iter_torch_batches(batch_size=64): ...