nemo_automodel.components.models.common.inbatch_neg_utils
nemo_automodel.components.models.common.inbatch_neg_utils
Distributed in-batch negative utilities for bi-encoder contrastive training.
Architecture-agnostic helpers used by the bi-encoder trainer to expand the
negative pool with passages gathered across DP ranks. Backbones (Llama,
Ministral3, Qwen3, …) do not import these directly; the trainer wires them
in around BiEncoderModel.encode.
Module Contents
Functions
API
All-gather t along dim 0, preserving autograd only when needed.
All-gather t along dim 0 across the default process group.
When preserve_grad is true, tensors that require gradients use an
autograd-aware gather so distributed in-batch-negative losses can send
passage gradients back to the owning rank. Otherwise, remote slices are
detached and only the local slice keeps gradient flow. Non-gradient tensors,
such as masks or IDs, always use a regular detached gather.
Returns t unchanged when distributed is not available, not initialized,
or world size is 1.
All-gather t after padding dim 1 to the maximum length across ranks.
In-place mask passages sharing a doc id with this row’s positive.
After all-gather, each query’s positive sits at column i * train_n_passages
of the gathered passage tensor. For each local query row, set scores to
finfo(dtype).min on any other column whose passage_doc_ids matches
the positive’s id, so duplicates of the positive elsewhere in the global
batch are not treated as negatives. The true positive column is left intact.
Parameters:
[local_batch_size, B_global * train_n_passages] (already
sliced to the local rank’s query rows).
[B_global * train_n_passages] int64 doc ids for
every gathered passage.
Number of passages per query (1 positive + negatives).
Caller’s DP rank.
Number of queries per rank.