nemo_rl.algorithms.logits_sampling_utils#

Module Contents#

Classes#

TrainingSamplingParams

Training-specific sampling parameters to match generation parameters.

_ApplyTopKTopP

Autograd function for top-k and top-p filtering with proper gradient handling.

Functions#

_need_top_k_filtering

Check if top-k filtering is needed.

_need_top_p_filtering

Check if top-p filtering is needed.

need_top_k_or_top_p_filtering

Check if top-k or top-p filtering is needed.

_apply_top_k_only_fn

Apply top-k mask to the logits.

_apply_top_k_top_p_fn

Apply top-k and top-p masks to the logits with chunking for memory efficiency.

apply_top_k_top_p

Apply top-k and top-p masks to the logits with proper gradient handling.

Data#

API#

nemo_rl.algorithms.logits_sampling_utils.TOP_K_TOP_P_CHUNK_SIZE: int#

256

class nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams#

Training-specific sampling parameters to match generation parameters.

Used to ensure consistency between training and inference by applying the same sampling strategy during logprob computation. Not directly using vLLM’s SamplingParams class to avoid dependency on vLLM in this env.

.. attribute:: top_k

Top-k filtering parameter (None or -1 to disable)

.. attribute:: top_p

Top-p filtering parameter (1.0 to disable)

.. attribute:: temperature

Temperature for scaling logits (default: 1.0)

top_k: int | None#

None

top_p: float#

1.0

temperature: float#

1.0

nemo_rl.algorithms.logits_sampling_utils._need_top_k_filtering(top_k: int | None) bool#

Check if top-k filtering is needed.

nemo_rl.algorithms.logits_sampling_utils._need_top_p_filtering(top_p: float | None) bool#

Check if top-p filtering is needed.

nemo_rl.algorithms.logits_sampling_utils.need_top_k_or_top_p_filtering(
sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams],
) bool#

Check if top-k or top-p filtering is needed.

nemo_rl.algorithms.logits_sampling_utils._apply_top_k_only_fn(
logits: torch.Tensor,
top_k: int | None,
) tuple[torch.Tensor, torch.Tensor | None]#

Apply top-k mask to the logits.

Simplified version of VLLM’s implementation for scalar parameters. This implementation doesn’t involve sorting the entire vocab.

Based on VLLM’s implementation: https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py SPDX-License-Identifier: Apache-2.0 Copyright contributors to the vLLM project

Parameters:
  • logits – Input logits tensor of shape [*, vocab_size].

  • top_k – Top-k sampling parameter.

Returns:

Filtered logits tensor with the same shape as input logits. keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked. None if top-k filtering is not needed.

Return type:

filtered_logits

nemo_rl.algorithms.logits_sampling_utils._apply_top_k_top_p_fn(
logits: torch.Tensor,
top_k: int | None,
top_p: float,
chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE,
) tuple[torch.Tensor, torch.Tensor | None]#

Apply top-k and top-p masks to the logits with chunking for memory efficiency.

The sort operation in top-p filtering is memory intensive because it creates intermediate tensors of shape [num_tokens, vocab_size] for both sorted values and indices. For large vocab sizes (e.g., 152K) and many tokens, this can cause OOM. This function flattens the input to 2D and processes in chunks along the token dimension (controlled by chunk_size) to reduce peak memory.

Based on VLLM’s implementation: https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py SPDX-License-Identifier: Apache-2.0 Copyright contributors to the vLLM project

Parameters:
  • logits – Input logits tensor of shape [*, vocab_size] (e.g., [batch_size, seq_len, vocab_size] or [batch_size, vocab_size]). Internally flattened to [num_tokens, vocab_size] for processing.

  • top_k – Top-k sampling parameter. Set to -1 or None to consider all tokens.

  • top_p – Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens

  • chunk_size – Number of tokens to process per chunk for memory efficiency. Defaults to TOP_K_TOP_P_CHUNK_SIZE.

Returns:

Filtered logits tensor with the same shape as input logits. keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked.

Return type:

filtered_logits

class nemo_rl.algorithms.logits_sampling_utils._ApplyTopKTopP#

Bases: torch.autograd.Function

Autograd function for top-k and top-p filtering with proper gradient handling.

static forward(
ctx,
logits: torch.Tensor,
top_k: Optional[int],
top_p: float,
chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE,
) tuple[torch.Tensor, torch.Tensor | None]#

Apply top-k/top-p filtering and save masks for backward.

Parameters:
  • logits – Input logits tensor of shape [*, vocab_size].

  • top_k – Top-k sampling parameter. Set to -1 or None to consider all tokens.

  • top_p – Top-p sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens.

  • chunk_size – Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE.

static backward(ctx, *grad_outputs: torch.Tensor)#

Backward pass: mask out gradients for filtered tokens.

nemo_rl.algorithms.logits_sampling_utils.apply_top_k_top_p(
logits: torch.Tensor,
top_k: int | None,
top_p: float,
chunk_size: int | None = TOP_K_TOP_P_CHUNK_SIZE,
) tuple[torch.Tensor, torch.Tensor | None]#

Apply top-k and top-p masks to the logits with proper gradient handling.

Simplified version of VLLM’s implementation for scalar parameters.

When top_p < 1.0, sorting is required which is memory intensive for large vocab sizes. Processing is done in chunks (controlled by chunk_size) to reduce peak memory.

Based on VLLM’s implementation: https://github.com/vllm-project/vllm/blob/34a20c49b3f81f64133428b3a0d62309db1256f9/vllm/v1/sample/ops/topk_topp_sampler.py SPDX-License-Identifier: Apache-2.0 Copyright contributors to the vLLM project

Parameters:
  • logits – Input logits tensor of shape [*, vocab_size].

  • top_k – Top-k sampling parameter. Set to -1 to consider all tokens.

  • top_p – Top-p (nucleus) sampling parameter. Must be in (0, 1]. Set to 1 to consider all tokens.

  • chunk_size – Number of tokens to process per chunk. Defaults to TOP_K_TOP_P_CHUNK_SIZE.

Returns:

Filtered logits tensor with the same shape as input logits. keep_mask: Mask tensor with the same shape as input logits, where 1 (True) indicates tokens to be kept, 0 (False) indicates tokens to be masked.

Return type:

filtered_logits