nemo_automodel.components.models.common.inbatch_neg_utils#
Distributed in-batch negative utilities for bi-encoder contrastive training.
Architecture-agnostic helpers used by the bi-encoder trainer to expand the
negative pool with passages gathered across DP ranks. Backbones (Llama,
Ministral3, Qwen3, …) do not import these directly; the trainer wires them
in around BiEncoderModel.encode.
Module Contents#
Functions#
All-gather |
|
In-place mask passages sharing a doc id with this row’s positive. |
API#
- nemo_automodel.components.models.common.inbatch_neg_utils.dist_gather_tensor(
- t: Optional[torch.Tensor],
All-gather
talong dim 0 across the default process group.The local-rank slice is replaced with the original
tso that gradients flow back only to the local portion of the gathered tensor (other ranks’ slices are detached). Returnstunchanged when distributed is not available, not initialized, or world size is 1.
- nemo_automodel.components.models.common.inbatch_neg_utils.mask_gathered_passages_same_doc_as_positive(
- scores: torch.Tensor,
- passage_doc_ids: torch.Tensor,
- train_n_passages: int,
- rank: int,
- local_batch_size: int,
In-place mask passages sharing a doc id with this row’s positive.
After all-gather, each query’s positive sits at column
i * train_n_passagesof the gathered passage tensor. For each local query row, set scores tofinfo(dtype).minon any other column whosepassage_doc_idsmatches the positive’s id, so duplicates of the positive elsewhere in the global batch are not treated as negatives. The true positive column is left intact.- Parameters:
scores –
[local_batch_size, B_global * train_n_passages](already sliced to the local rank’s query rows).passage_doc_ids –
[B_global * train_n_passages]int64 doc ids for every gathered passage.train_n_passages – Number of passages per query (1 positive + negatives).
rank – Caller’s DP rank.
local_batch_size – Number of queries per rank.