nemo_rl.models.generation.interfaces
#
Module Contents#
Classes#
Configuration for generation. |
|
Specification for input data required by generation models. |
|
Specification for output data returned by generation models. |
|
Abstract base class defining the interface for RL policies. |
Functions#
Verify that a tensor is right-padded according to the provided lengths. |
|
Apply specific configurations to generation config. |
API#
- nemo_rl.models.generation.interfaces.verify_right_padding(
- data: Union[nemo_rl.distributed.batched_data_dict.BatchedDataDict[GenerationDatumSpec], nemo_rl.distributed.batched_data_dict.BatchedDataDict[GenerationOutputSpec]],
- pad_value: int = 0,
- raise_error: bool = True,
Verify that a tensor is right-padded according to the provided lengths.
- Parameters:
data –
The BatchedDataDict to check, containing either:
For GenerationDatumSpec: input_ids and input_lengths
For GenerationOutputSpec: output_ids and unpadded_sequence_lengths
pad_value – The expected padding value (default: 0)
raise_error – Whether to raise an error if wrong padding is detected
- Returns:
Tuple of (is_right_padded, error_message)
is_right_padded: True if right padding confirmed, False otherwise
error_message: None if properly padded, otherwise a description of the issue
- class nemo_rl.models.generation.interfaces.GenerationConfig[source]#
Bases:
typing.TypedDict
Configuration for generation.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- backend: str#
None
- max_new_tokens: int#
None
- temperature: float#
None
- top_p: float#
None
- top_k: int#
None
- model_name: str#
None
- stop_token_ids: List[int]#
None
- pad_token_id: int#
None
- nemo_rl.models.generation.interfaces.configure_generation_config(
- config: nemo_rl.models.generation.interfaces.GenerationConfig,
- tokenizer: transformers.AutoTokenizer,
- is_eval=False,
Apply specific configurations to generation config.
- class nemo_rl.models.generation.interfaces.GenerationDatumSpec[source]#
Bases:
typing.TypedDict
Specification for input data required by generation models.
input_ids: Tensor of token IDs representing the input sequences (right padded)
input_lengths: Tensor containing the actual length of each sequence (without padding)
stop_strings: Optional list of strings to stop generation (per sample)
extra: Additional model-specific data fields
Example of a batch with 4 entries with different sequence lengths:
# Batch of 4 sequences with lengths [3, 5, 2, 4] input_ids (padded): [ [101, 2054, 2003, 0, 0], # Length 3 [101, 2054, 2003, 2001, 1996], # Length 5 [101, 2054, 0, 0, 0], # Length 2 [101, 2054, 2003, 2001, 0], # Length 4 ] input_lengths: [3, 5, 2, 4]
All functions receiving or returning GenerationDatumSpec should ensure right padding is maintained. Use verify_right_padding() to check.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- input_ids: torch.Tensor#
None
- input_lengths: torch.Tensor#
None
- stop_strings: Optional[List[str]]#
None
- __extra__: Any#
None
- class nemo_rl.models.generation.interfaces.GenerationOutputSpec[source]#
Bases:
typing.TypedDict
Specification for output data returned by generation models.
output_ids: Tensor of token IDs representing the generated sequences (right padded)
generation_lengths: Tensor containing the actual length of each generated sequence
unpadded_sequence_lengths: Tensor containing the actual length of each input + generated sequence (without padding)
logprobs: Tensor of log probabilities for each generated token (right padded with zeros)
extra: Additional model-specific data fields
Example of a batch with 2 sequences:
# Sample batch with 2 examples # - Example 1: Input length 3, generated response length 4 # - Example 2: Input length 5, generated response length 2 output_ids (right-padded): [ [101, 2054, 2003, 2023, 2003, 1037, 2200, 0], # 7 valid tokens (3 input + 4 output) [101, 2054, 2003, 2001, 1996, 3014, 2005, 0], # 7 valid tokens (5 input + 2 output) ] generation_lengths: [4, 2] # Length of just the generated response part unpadded_sequence_lengths: [7, 7] # Length of full valid sequence (input + generated response) logprobs (right-padded with zeros): [ [0.0, 0.0, 0.0, -1.2, -0.8, -2.1, -1.5, 0.0], # First 3 are 0 (input tokens), next 4 are actual logprobs [0.0, 0.0, 0.0, 0.0, 0.0, -0.9, -1.7, 0.0], # First 5 are 0 (input tokens), next 2 are actual logprobs ]
All functions receiving or returning GenerationOutputSpec should ensure right padding is maintained. Use verify_right_padding() to check.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- output_ids: torch.Tensor#
None
- generation_lengths: torch.Tensor#
None
- unpadded_sequence_lengths: torch.Tensor#
None
- logprobs: torch.Tensor#
None
- __extra__: Any#
None
- class nemo_rl.models.generation.interfaces.GenerationInterface[source]#
Bases:
abc.ABC
Abstract base class defining the interface for RL policies.