nemo_automodel.components.models.ministral_bidirectional.model

View as Markdown

Bidirectional Ministral3 model for embedding tasks.

This module provides a modified Ministral3Model that uses bidirectional (non-causal) attention, suitable for generating embeddings where each token should attend to all other tokens in the sequence.

Module Contents

Classes

NameDescription
Ministral3BidirectionalConfigConfiguration for Ministral3BidirectionalModel with pooling and temperature settings.
Ministral3BidirectionalModelMinistral3Model modified to use bidirectional (non-causal) attention.

Functions

NameDescription
_register_with_hf_auto_classesRegister bidirectional Ministral3 with HuggingFace Auto classes.

Data

ModelClass

__all__

logger

API

class nemo_automodel.components.models.ministral_bidirectional.model.Ministral3BidirectionalConfig(
pooling: str = 'avg',
temperature: float = 1.0,
kwargs = {}
)

Bases: Ministral3Config

Configuration for Ministral3BidirectionalModel with pooling and temperature settings.

model_type
= 'ministral3_bidirec'
class nemo_automodel.components.models.ministral_bidirectional.model.Ministral3BidirectionalModel(
config
)

Bases: Ministral3Model

Ministral3Model modified to use bidirectional (non-causal) attention.

In standard Ministral3, each token can only attend to previous tokens (causal attention). This model removes that restriction, allowing each token to attend to all tokens in the sequence, which is useful for embedding tasks.

Loading a Mistral3 VLM checkpoint (e.g. mistralai/Ministral-3-3B-Base-2512 or mistralai/Ministral-3-3B-Instruct-2512) requires extracting the language tower; this is driven by the recipe YAML via extract_submodel: language_model and handled by :func:nemo_automodel._transformers.retrieval.build_encoder_backbone.

Text-only checkpoints (e.g. mistralai/Ministral-3B-Instruct) load directly via the standard from_pretrained path with no extraction needed.

nemo_automodel.components.models.ministral_bidirectional.model.Ministral3BidirectionalModel.forward(
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: transformers.cache_utils.Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
kwargs = {}
) -> transformers.modeling_outputs.BaseModelOutputWithPast

Forward pass with bidirectional attention.

Identical to Ministral3Model.forward() except the causal mask is replaced with a bidirectional mask, allowing all tokens to attend to each other.

nemo_automodel.components.models.ministral_bidirectional.model._register_with_hf_auto_classes() -> None

Register bidirectional Ministral3 with HuggingFace Auto classes.

Needed so AutoModel.from_config(Ministral3BidirectionalConfig) and checkpoint reload paths that use Auto resolution work consistently.

nemo_automodel.components.models.ministral_bidirectional.model.ModelClass = [Ministral3BidirectionalModel]
nemo_automodel.components.models.ministral_bidirectional.model.__all__ = ['Ministral3BidirectionalModel', 'Ministral3BidirectionalConfig', 'ModelClass']
nemo_automodel.components.models.ministral_bidirectional.model.logger = logging.get_logger(__name__)