nemo_automodel.components.models.common.inbatch_neg_utils

View as Markdown

Distributed in-batch negative utilities for bi-encoder contrastive training.

Architecture-agnostic helpers used by the bi-encoder trainer to expand the negative pool with passages gathered across DP ranks. Backbones (Llama, Ministral3, Qwen3, …) do not import these directly; the trainer wires them in around BiEncoderModel.encode.

Module Contents

Functions

NameDescription
_all_gather_tensorAll-gather t along dim 0, preserving autograd only when needed.
dist_gather_tensorAll-gather t along dim 0 across the default process group.
dist_gather_tensor_with_dim1_paddingAll-gather t after padding dim 1 to the maximum length across ranks.
mask_gathered_passages_same_doc_as_positiveIn-place mask passages sharing a doc id with this row’s positive.

API

nemo_automodel.components.models.common.inbatch_neg_utils._all_gather_tensor(
t: torch.Tensor,
preserve_grad: bool = False
) -> torch.Tensor

All-gather t along dim 0, preserving autograd only when needed.

nemo_automodel.components.models.common.inbatch_neg_utils.dist_gather_tensor(
t: typing.Optional[torch.Tensor],
preserve_grad: bool = False
) -> typing.Optional[torch.Tensor]

All-gather t along dim 0 across the default process group.

When preserve_grad is true, tensors that require gradients use an autograd-aware gather so distributed in-batch-negative losses can send passage gradients back to the owning rank. Otherwise, remote slices are detached and only the local slice keeps gradient flow. Non-gradient tensors, such as masks or IDs, always use a regular detached gather. Returns t unchanged when distributed is not available, not initialized, or world size is 1.

nemo_automodel.components.models.common.inbatch_neg_utils.dist_gather_tensor_with_dim1_padding(
t: typing.Optional[torch.Tensor],
padding_value: int | float | bool = 0,
preserve_grad: bool = False
) -> typing.Optional[torch.Tensor]

All-gather t after padding dim 1 to the maximum length across ranks.

nemo_automodel.components.models.common.inbatch_neg_utils.mask_gathered_passages_same_doc_as_positive(
scores: torch.Tensor,
passage_doc_ids: torch.Tensor,
train_n_passages: int,
rank: int,
local_batch_size: int
) -> None

In-place mask passages sharing a doc id with this row’s positive.

After all-gather, each query’s positive sits at column i * train_n_passages of the gathered passage tensor. For each local query row, set scores to finfo(dtype).min on any other column whose passage_doc_ids matches the positive’s id, so duplicates of the positive elsewhere in the global batch are not treated as negatives. The true positive column is left intact.

Parameters:

scores
torch.Tensor

[local_batch_size, B_global * train_n_passages] (already sliced to the local rank’s query rows).

passage_doc_ids
torch.Tensor

[B_global * train_n_passages] int64 doc ids for every gathered passage.

train_n_passages
int

Number of passages per query (1 positive + negatives).

rank
int

Caller’s DP rank.

local_batch_size
int

Number of queries per rank.