nemo_automodel.components.models.common.inbatch_neg_utils#

Distributed in-batch negative utilities for bi-encoder contrastive training.

Architecture-agnostic helpers used by the bi-encoder trainer to expand the negative pool with passages gathered across DP ranks. Backbones (Llama, Ministral3, Qwen3, …) do not import these directly; the trainer wires them in around BiEncoderModel.encode.

Module Contents#

Functions#

dist_gather_tensor

All-gather t along dim 0 across the default process group.

mask_gathered_passages_same_doc_as_positive

In-place mask passages sharing a doc id with this row’s positive.

API#

nemo_automodel.components.models.common.inbatch_neg_utils.dist_gather_tensor(
t: Optional[torch.Tensor],
) Optional[torch.Tensor]#

All-gather t along dim 0 across the default process group.

The local-rank slice is replaced with the original t so that gradients flow back only to the local portion of the gathered tensor (other ranks’ slices are detached). Returns t unchanged when distributed is not available, not initialized, or world size is 1.

nemo_automodel.components.models.common.inbatch_neg_utils.mask_gathered_passages_same_doc_as_positive(
scores: torch.Tensor,
passage_doc_ids: torch.Tensor,
train_n_passages: int,
rank: int,
local_batch_size: int,
) None#

In-place mask passages sharing a doc id with this row’s positive.

After all-gather, each query’s positive sits at column i * train_n_passages of the gathered passage tensor. For each local query row, set scores to finfo(dtype).min on any other column whose passage_doc_ids matches the positive’s id, so duplicates of the positive elsewhere in the global batch are not treated as negatives. The true positive column is left intact.

Parameters:
  • scores – [local_batch_size, B_global * train_n_passages] (already sliced to the local rank’s query rows).

  • passage_doc_ids – [B_global * train_n_passages] int64 doc ids for every gathered passage.

  • train_n_passages – Number of passages per query (1 positive + negatives).

  • rank – Caller’s DP rank.

  • local_batch_size – Number of queries per rank.