Skip to main content

Documentation Index

Fetch the complete documentation index at: https://ray-preview.mintlify.app/llms.txt

Use this file to discover all available pages before exploring further.

RLModule is RLlib’s neural-network abstraction. It defines the forward passes used during training, exploration, and inference, leaving the algorithm to handle losses and updates.

Default modules

For most environments, RLlib’s default modules pick reasonable architectures based on observation and action spaces. You don’t have to write any modules to get started.

Custom torch RLModule

import torch
import torch.nn as nn
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.core.columns import Columns

class MyModule(TorchRLModule):
    def setup(self):
        obs_dim = self.observation_space.shape[0]
        n_actions = self.action_space.n
        self.encoder = nn.Sequential(
            nn.Linear(obs_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
        )
        self.policy_head = nn.Linear(128, n_actions)
        self.value_head = nn.Linear(128, 1)

    def _forward(self, batch, **kwargs):
        h = self.encoder(batch[Columns.OBS])
        return {
            Columns.ACTION_DIST_INPUTS: self.policy_head(h),
            Columns.VF_PREDS: self.value_head(h).squeeze(-1),
        }
Wire it into the algorithm config:
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

config = config.rl_module(
    rl_module_spec=RLModuleSpec(module_class=MyModule),
)

Sharing layers

When the policy and value functions share a backbone, define one shared encoder and two heads. The default _forward computes both for each batch.

Multi-agent

Define one module per policy. RLlib runs them independently and aggregates losses.
config.rl_module(
    rl_module_spec=MultiRLModuleSpec(rl_module_specs={
        "policy_a": RLModuleSpec(module_class=ModuleA),
        "policy_b": RLModuleSpec(module_class=ModuleB),
    })
)

Inference-only modules

For deployment, instantiate the module without the value head and load weights from a checkpoint.
inference_only = MyModule(observation_space=..., action_space=..., inference_only=True)
inference_only.load_state_dict(checkpoint["module_state"])

Next steps

Learner

How learners consume modules.

Training

Inside the iteration loop.