nemo_rl.algorithms.logits_sampling_utils#
Module Contents#
Classes#
Training-specific sampling parameters to match generation parameters. |
|
Autograd function for top-k and top-p filtering with proper gradient handling. |
Functions#
Check if top-k filtering is needed. |
|
Check if top-p filtering is needed. |
|
Check if top-k or top-p filtering is needed. |
|
Apply top-k mask to the logits. |
|
Apply top-k and top-p masks to the logits with chunking for memory efficiency. |
|
Apply top-k and top-p masks to the logits with proper gradient handling. |
Data#
API#
- nemo_rl.algorithms.logits_sampling_utils.TOP_K_TOP_P_CHUNK_SIZE: int#
256
- class nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams#
Training-specific sampling parameters to match generation parameters.
Used to ensure consistency between training and inference by applying the same sampling strategy during logprob computation. Not directly using vLLM’s SamplingParams class to avoid dependency on vLLM in this env.
.. attribute:: top_k
Top-k filtering parameter (None or -1 to disable)
.. attribute:: top_p
Top-p filtering parameter (1.0 to disable)
.. attribute:: temperature
Temperature for scaling logits (default: 1.0)
- top_k: int | None#
None
- top_p: float#
1.0
- temperature: float#
1.0
- nemo_rl.algorithms.logits_sampling_utils._need_top_k_filtering(top_k: int | None) bool#
Check if top-k filtering is needed.
- nemo_rl.algorithms.logits_sampling_utils._need_top_p_filtering(top_p: float | None) bool#
Check if top-p filtering is needed.
- nemo_rl.algorithms.logits_sampling_utils.need_top_k_or_top_p_filtering(
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams],
Check if top-k or top-p filtering is needed.
- nemo_rl.algorithms.logits_sampling_utils._apply_top_k_only_fn(
- logits: torch.Tensor,
- top_k: int | None,
Apply top-k mask to the logits.
Simplified version of VLLM’s implementation for scalar parameters. This implementation doesn’t involve sorting the entire vocab.
Based on VLLM’s implementation: https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py SPDX-License-Identifier: Apache-2.0 Copyright contributors to the vLLM project
- Parameters:
logits – Input logits tensor of shape [*, vocab_size].
top_k – Top-k sampling parameter.
- Returns:
Filtered logits tensor with the same shape as input logits. keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked. None if top-k filtering is not needed.
- Return type:
filtered_logits
- nemo_rl.algorithms.logits_sampling_utils._apply_top_k_top_p_fn(
- logits: torch.Tensor,
- top_k: int | None,
- top_p: float,
- chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE,
Apply top-k and top-p masks to the logits with chunking for memory efficiency.
The sort operation in top-p filtering is memory intensive because it creates intermediate tensors of shape [num_tokens, vocab_size] for both sorted values and indices. For large vocab sizes (e.g., 152K) and many tokens, this can cause OOM. This function flattens the input to 2D and processes in chunks along the token dimension (controlled by chunk_size) to reduce peak memory.
Based on VLLM’s implementation: https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py SPDX-License-Identifier: Apache-2.0 Copyright contributors to the vLLM project
- Parameters:
logits – Input logits tensor of shape [*, vocab_size] (e.g., [batch_size, seq_len, vocab_size] or [batch_size, vocab_size]). Internally flattened to [num_tokens, vocab_size] for processing.
top_k – Top-k sampling parameter. Set to -1 or None to consider all tokens.
top_p – Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens
chunk_size – Number of tokens to process per chunk for memory efficiency. Defaults to TOP_K_TOP_P_CHUNK_SIZE.
- Returns:
Filtered logits tensor with the same shape as input logits. keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked.
- Return type:
filtered_logits
- class nemo_rl.algorithms.logits_sampling_utils._ApplyTopKTopP#
Bases:
torch.autograd.FunctionAutograd function for top-k and top-p filtering with proper gradient handling.
- static forward(
- ctx,
- logits: torch.Tensor,
- top_k: Optional[int],
- top_p: float,
- chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE,
Apply top-k/top-p filtering and save masks for backward.
- Parameters:
logits – Input logits tensor of shape [*, vocab_size].
top_k – Top-k sampling parameter. Set to -1 or None to consider all tokens.
top_p – Top-p sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens.
chunk_size – Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE.
- static backward(ctx, *grad_outputs: torch.Tensor)#
Backward pass: mask out gradients for filtered tokens.
- nemo_rl.algorithms.logits_sampling_utils.apply_top_k_top_p(
- logits: torch.Tensor,
- top_k: int | None,
- top_p: float,
- chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE,
Apply top-k and top-p masks to the logits with proper gradient handling.
Simplified version of VLLM’s implementation for scalar parameters.
When top_p < 1.0, sorting is required which is memory intensive for large vocab sizes. Processing is done in chunks (controlled by chunk_size) to reduce peak memory.
Based on VLLM’s implementation: https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py SPDX-License-Identifier: Apache-2.0 Copyright contributors to the vLLM project
- Parameters:
logits – Input logits tensor of shape [*, vocab_size].
top_k – Top-k sampling parameter. Set to -1 to consider all tokens.
top_p – Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens.
chunk_size – Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE.
- Returns:
Filtered logits tensor with the same shape as input logits. keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked.
- Return type:
filtered_logits