Skip to main content

Documentation Index

Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

Ray Train integrates with TensorFlow via TensorflowTrainer. Each worker runs a TensorFlow process configured with MultiWorkerMirroredStrategy.

Install

pip install -U "ray[train]" tensorflow

Minimal example

import tensorflow as tf

import ray
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer


def train_loop_per_worker(config):
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = tf.keras.Sequential([
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10),
        ])
        model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

    train_ds = ray.train.get_dataset_shard("train")
    tf_dataset = train_ds.to_tf(feature_columns="x", label_column="y", batch_size=64)
    model.fit(tf_dataset, epochs=config["epochs"])


trainer = TensorflowTrainer(
    train_loop_per_worker,
    train_loop_config={"epochs": 3},
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
    datasets={"train": ray.data.from_pandas(df)},
)
trainer.fit()

Notes

  • Ray Train sets TF_CONFIG automatically for each worker, so MultiWorkerMirroredStrategy discovers its peers.
  • Use model.save inside the strategy scope on rank 0 to write a final checkpoint.
  • For metric reporting, use ray.train.report from a Keras callback.

Next steps

Checkpointing

Save and resume TensorFlow training.

Data loading

Pipe Ray Data into Keras models.