Documentation Index
Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt
Use this file to discover all available pages before exploring further.
Ray Train integrates with TensorFlow via TensorflowTrainer. Each worker runs a TensorFlow process configured with MultiWorkerMirroredStrategy.
Install
pip install -U "ray[train]" tensorflow
Minimal example
import tensorflow as tf
import ray
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
def train_loop_per_worker(config):
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10),
])
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
train_ds = ray.train.get_dataset_shard("train")
tf_dataset = train_ds.to_tf(feature_columns="x", label_column="y", batch_size=64)
model.fit(tf_dataset, epochs=config["epochs"])
trainer = TensorflowTrainer(
train_loop_per_worker,
train_loop_config={"epochs": 3},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
datasets={"train": ray.data.from_pandas(df)},
)
trainer.fit()
Notes
- Ray Train sets
TF_CONFIG automatically for each worker, so MultiWorkerMirroredStrategy discovers its peers.
- Use
model.save inside the strategy scope on rank 0 to write a final checkpoint.
- For metric reporting, use
ray.train.report from a Keras callback.
Next steps
Checkpointing
Save and resume TensorFlow training.
Data loading
Pipe Ray Data into Keras models.