nemo_rl.models.generation.interfaces#

Module Contents#

Classes#

GenerationConfig

Configuration for generation.

GenerationDatumSpec

Specification for input data required by generation models.

GenerationOutputSpec

Specification for output data returned by generation models.

GenerationInterface

Abstract base class defining the interface for RL policies.

Functions#

verify_right_padding

Verify that a tensor is right-padded according to the provided lengths.

configure_generation_config

Apply specific configurations to generation config.

API#

nemo_rl.models.generation.interfaces.verify_right_padding(
data: Union[nemo_rl.distributed.batched_data_dict.BatchedDataDict[GenerationDatumSpec], nemo_rl.distributed.batched_data_dict.BatchedDataDict[GenerationOutputSpec]],
pad_value: int = 0,
raise_error: bool = True,
) Tuple[bool, Union[str, None]][source]#

Verify that a tensor is right-padded according to the provided lengths.

Parameters:
  • data

    The BatchedDataDict to check, containing either:

    • For GenerationDatumSpec: input_ids and input_lengths

    • For GenerationOutputSpec: output_ids and unpadded_sequence_lengths

  • pad_value – The expected padding value (default: 0)

  • raise_error – Whether to raise an error if wrong padding is detected

Returns:

Tuple of (is_right_padded, error_message)

  • is_right_padded: True if right padding confirmed, False otherwise

  • error_message: None if properly padded, otherwise a description of the issue

class nemo_rl.models.generation.interfaces.GenerationConfig[source]#

Bases: typing.TypedDict

Configuration for generation.

Initialization

Initialize self. See help(type(self)) for accurate signature.

backend: str#

None

max_new_tokens: int#

None

temperature: float#

None

top_p: float#

None

top_k: int#

None

model_name: str#

None

stop_token_ids: List[int]#

None

pad_token_id: int#

None

nemo_rl.models.generation.interfaces.configure_generation_config(
config: nemo_rl.models.generation.interfaces.GenerationConfig,
tokenizer: transformers.AutoTokenizer,
is_eval=False,
)[source]#

Apply specific configurations to generation config.

class nemo_rl.models.generation.interfaces.GenerationDatumSpec[source]#

Bases: typing.TypedDict

Specification for input data required by generation models.

  • input_ids: Tensor of token IDs representing the input sequences (right padded)

  • input_lengths: Tensor containing the actual length of each sequence (without padding)

  • stop_strings: Optional list of strings to stop generation (per sample)

  • extra: Additional model-specific data fields

Example of a batch with 4 entries with different sequence lengths:

# Batch of 4 sequences with lengths [3, 5, 2, 4]

input_ids (padded):
[
  [101, 2054, 2003,    0,    0],  # Length 3
  [101, 2054, 2003, 2001, 1996],  # Length 5
  [101, 2054,    0,    0,    0],  # Length 2
  [101, 2054, 2003, 2001,    0],  # Length 4
]

input_lengths:
[3, 5, 2, 4]

All functions receiving or returning GenerationDatumSpec should ensure right padding is maintained. Use verify_right_padding() to check.

Initialization

Initialize self. See help(type(self)) for accurate signature.

input_ids: torch.Tensor#

None

input_lengths: torch.Tensor#

None

stop_strings: Optional[List[str]]#

None

__extra__: Any#

None

class nemo_rl.models.generation.interfaces.GenerationOutputSpec[source]#

Bases: typing.TypedDict

Specification for output data returned by generation models.

  • output_ids: Tensor of token IDs representing the generated sequences (right padded)

  • generation_lengths: Tensor containing the actual length of each generated sequence

  • unpadded_sequence_lengths: Tensor containing the actual length of each input + generated sequence (without padding)

  • logprobs: Tensor of log probabilities for each generated token (right padded with zeros)

  • extra: Additional model-specific data fields

Example of a batch with 2 sequences:

# Sample batch with 2 examples
# - Example 1: Input length 3, generated response length 4
# - Example 2: Input length 5, generated response length 2

output_ids (right-padded):
[
  [101, 2054, 2003, 2023, 2003, 1037, 2200,    0],  # 7 valid tokens (3 input + 4 output)
  [101, 2054, 2003, 2001, 1996, 3014, 2005,    0],  # 7 valid tokens (5 input + 2 output)
]

generation_lengths:
[4, 2]  # Length of just the generated response part

unpadded_sequence_lengths:
[7, 7]  # Length of full valid sequence (input + generated response)

logprobs (right-padded with zeros):
[
  [0.0, 0.0, 0.0, -1.2, -0.8, -2.1, -1.5, 0.0],  # First 3 are 0 (input tokens), next 4 are actual logprobs
  [0.0, 0.0, 0.0, 0.0, 0.0, -0.9, -1.7, 0.0],     # First 5 are 0 (input tokens), next 2 are actual logprobs
]

All functions receiving or returning GenerationOutputSpec should ensure right padding is maintained. Use verify_right_padding() to check.

Initialization

Initialize self. See help(type(self)) for accurate signature.

output_ids: torch.Tensor#

None

generation_lengths: torch.Tensor#

None

unpadded_sequence_lengths: torch.Tensor#

None

logprobs: torch.Tensor#

None

__extra__: Any#

None

class nemo_rl.models.generation.interfaces.GenerationInterface[source]#

Bases: abc.ABC

Abstract base class defining the interface for RL policies.

abstractmethod generate(
data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
greedy: bool,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationOutputSpec][source]#
abstractmethod prepare_for_generation(*args, **kwargs)[source]#
abstractmethod finish_generation(*args, **kwargs)[source]#