Custom Reward Functions#

In this lesson we’ll add the underlying math functions that determine rewards for our policy. Our first approach will combine two forms of rewards for positional error (how far is the robot from a goal position).

Defining custom reward functions#

In the rewards manager config file from earlier, we referenced a few MDP functions that don’t exist yet. As the last step to setting up our task, let’s define those functions now. As mentioned before, we’ll put these in an mdp folder for organization.

  1. Open the following Python file: source/Reach/Reach/tasks/manager_based/reach/mdp/rewards.py

  2. Delete the contents of this file, and replace with the following imports necessary for our next lesson.

 1# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
 2# All rights reserved.
 3#
 4# SPDX-License-Identifier: BSD-3-Clause
 5
 6from __future__ import annotations
 7
 8import torch
 9from typing import TYPE_CHECKING
10
11from isaaclab.managers import SceneEntityCfg
12from isaaclab.utils.math import combine_frame_transforms
13
14from isaaclab.assets import RigidObject
15from isaaclab.managers import SceneEntityCfg
16from isaaclab.utils.math import combine_frame_transforms, quat_error_magnitude, quat_mul
17
18
19if TYPE_CHECKING:
20    from isaaclab.envs import ManagerBasedRLEnv

In the next module, we’ll define new reward functions in this file.

Note

While knowing PyTorch, Linear Algebra, and Trigonometry will be key for your RL learning journey, don’t worry if this is your first experience with these functions. We’ll walk through the “why” behind the math.

Position Command Error#

This function computes the position error between the desired position (from the command) and the current position of the asset’s body (in world frame). The position error is computed as the L2-norm of the difference between the desired and current positions.

What’s an L2-norm?
The L2 norm, also known as the Euclidean norm, is a way to measure the length or magnitude of a vector in space.

How does this function interact with the managers, and our scene?
The environment is passed in - this includes all the robots on the stage, hence the pytorch operations. This is where some of the massively parallel nature of training is realized!

We also referenced this function with one of our Reward Terms in the Reward Manager. Basically here we are saying, here’s how to calculate that term, which gets summed into a final reward.

Add this code to your Reach/tasks/manager_based/reach/mdp/rewards.py file.

 1def position_command_error(env: ManagerBasedRLEnv, command_name: str, asset_cfg: SceneEntityCfg) -> torch.Tensor:
 2    """Penalize tracking of the position error using L2-norm.
 3
 4    The function computes the position error between the desired position (from the command) and the
 5    current position of the asset's body (in world frame). The position error is computed as the L2-norm
 6    of the difference between the desired and current positions.
 7    """
 8    # extract the asset (to enable type hinting)
 9    asset: RigidObject = env.scene[asset_cfg.name]
10    command = env.command_manager.get_command(command_name)
11    # obtain the desired and current positions
12    des_pos_b = command[:, :3]
13    des_pos_w, _ = combine_frame_transforms(asset.data.root_state_w[:, :3], asset.data.root_state_w[:, 3:7], des_pos_b)
14    curr_pos_w = asset.data.body_state_w[:, asset_cfg.body_ids[0], :3]  # type: ignore
15    return torch.norm(curr_pos_w - des_pos_w, dim=1)

Position Command Error Tanh#

This function computes the position error between the desired position from the command, and the current position of the asset’s body in world frame, and maps it with a tanh kernel.

Why do we have a second Position Command reward term?
As this positional error gets closer to zero, the tanh function produces larger gradients compared to a linear error term. This amplifies weight updates for small mistakes, accelerating convergence towards the goal we want. It also limits the outputs between -1 and 1.

Add this code to your Reach/tasks/manager_based/reach/mdp/rewards.py file.

 1def position_command_error_tanh(
 2    env: ManagerBasedRLEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg
 3) -> torch.Tensor:
 4    """Reward tracking of the position using the tanh kernel.
 5
 6    The function computes the position error between the desired position (from the command) and the
 7    current position of the asset's body (in world frame) and maps it with a tanh kernel.
 8    """
 9    # extract the asset (to enable type hinting)
10    asset: RigidObject = env.scene[asset_cfg.name]
11    command = env.command_manager.get_command(command_name)
12    # obtain the desired and current positions
13    des_pos_b = command[:, :3]
14    des_pos_w, _ = combine_frame_transforms(asset.data.root_state_w[:, :3], asset.data.root_state_w[:, 3:7], des_pos_b)
15    curr_pos_w = asset.data.body_state_w[:, asset_cfg.body_ids[0], :3]  # type: ignore
16    distance = torch.norm(curr_pos_w - des_pos_w, dim=1)
17    return 1 - torch.tanh(distance / std)

Completed rewards file#

Click to reveal the completed Python file rewards.py
 1# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
 2# All rights reserved.
 3#
 4# SPDX-License-Identifier: BSD-3-Clause
 5
 6from __future__ import annotations
 7
 8import torch
 9from typing import TYPE_CHECKING
10
11from isaaclab.managers import SceneEntityCfg
12from isaaclab.utils.math import combine_frame_transforms
13
14from isaaclab.assets import RigidObject
15from isaaclab.managers import SceneEntityCfg
16from isaaclab.utils.math import combine_frame_transforms, quat_error_magnitude, quat_mul
17
18if TYPE_CHECKING:
19    from isaaclab.envs import ManagerBasedRLEnv
20
21def position_command_error(env: ManagerBasedRLEnv, command_name: str, asset_cfg: SceneEntityCfg) -> torch.Tensor:
22    """Penalize tracking of the position error using L2-norm.
23
24    The function computes the position error between the desired position (from the command) and the
25    current position of the asset's body (in world frame). The position error is computed as the L2-norm
26    of the difference between the desired and current positions.
27    """
28    # extract the asset (to enable type hinting)
29    asset: RigidObject = env.scene[asset_cfg.name]
30    command = env.command_manager.get_command(command_name)
31    # obtain the desired and current positions
32    des_pos_b = command[:, :3]
33    des_pos_w, _ = combine_frame_transforms(asset.data.root_state_w[:, :3], asset.data.root_state_w[:, 3:7], des_pos_b)
34    curr_pos_w = asset.data.body_state_w[:, asset_cfg.body_ids[0], :3]  # type: ignore
35    return torch.norm(curr_pos_w - des_pos_w, dim=1)
36
37def position_command_error_tanh(
38    env: ManagerBasedRLEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg
39) -> torch.Tensor:
40    """Reward tracking of the position using the tanh kernel.
41
42    The function computes the position error between the desired position (from the command) and the
43    current position of the asset's body (in world frame) and maps it with a tanh kernel.
44    """
45    # extract the asset (to enable type hinting)
46    asset: RigidObject = env.scene[asset_cfg.name]
47    command = env.command_manager.get_command(command_name)
48    # obtain the desired and current positions
49    des_pos_b = command[:, :3]
50    des_pos_w, _ = combine_frame_transforms(asset.data.root_state_w[:, :3], asset.data.root_state_w[:, 3:7], des_pos_b)
51    curr_pos_w = asset.data.body_state_w[:, asset_cfg.body_ids[0], :3]  # type: ignore
52    distance = torch.norm(curr_pos_w - des_pos_w, dim=1)
53    return 1 - torch.tanh(distance / std)
54

Configuring Hyperparameters#

Settings and hyperparameters for PPO#

Lastly, let’s replace the code in source/Reach/Reach/tasks/manager_based/reach/agents/skrl_ppo_cfg.yaml.

The is a configuration file that defines the settings and hyperparameters for training our RL agents using the Proximal Policy Optimization (PPO) algorithm with the SKRL library.

This file is essential for customizing how PPO operates within Isaac Lab. While we won’t go into detail on these configurations, links are provided in the file below for further learning.

Read about the intuition behind PPO.

seed: 42


# Models are instantiated using skrl's model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
models:
  separate: False
  policy:  # see gaussian_model parameters
    class: GaussianMixin
    clip_actions: False
    clip_log_std: True
    min_log_std: -20.0
    max_log_std: 2.0
    initial_log_std: 0.0
    network:
      - name: net
        input: STATES
        layers: [64, 64]
        activations: elu
    output: ACTIONS
  value:  # see deterministic_model parameters
    class: DeterministicMixin
    clip_actions: False
    network:
      - name: net
        input: STATES
        layers: [64, 64]
        activations: elu
    output: ONE


# Rollout memory
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
memory:
  class: RandomMemory
  memory_size: -1  # automatically determined (same as agent:rollouts)


# PPO agent configuration (field names are from PPO_DEFAULT_CONFIG)
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
agent:
  class: PPO
  rollouts: 24
  learning_epochs: 5
  mini_batches: 4
  discount_factor: 0.99
  lambda: 0.95
  learning_rate: 1.0e-03
  learning_rate_scheduler: KLAdaptiveLR
  learning_rate_scheduler_kwargs:
    kl_threshold: 0.01
  state_preprocessor: RunningStandardScaler
  state_preprocessor_kwargs: null
  value_preprocessor: RunningStandardScaler
  value_preprocessor_kwargs: null
  random_timesteps: 0
  learning_starts: 0
  grad_norm_clip: 1.0
  ratio_clip: 0.2
  value_clip: 0.2
  clip_predicted_values: True
  entropy_loss_scale: 0.01
  value_loss_scale: 1.0
  kl_threshold: 0.0
  rewards_shaper_scale: 1.0
  time_limit_bootstrap: False
  # logging and checkpoint
  experiment:
    directory: "reach_ur10"
    experiment_name: ""
    write_interval: auto
    checkpoint_interval: auto


# Sequential trainer
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer:
  class: SequentialTrainer
  timesteps: 24000
  environment_info: log