nemo_automodel.components.models.nemotron_v3.model#
Module Contents#
Classes#
NemotronV3 base model (without LM head). |
|
NemotronV3 model with language modeling head. |
Data#
API#
- class nemo_automodel.components.models.nemotron_v3.model.NemotronV3Model(
- config,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- *,
- moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
Bases:
torch.nn.ModuleNemotronV3 base model (without LM head).
This is a hybrid architecture with Mamba2, Attention, MLP, and MoE layers.
Initialization
Initialize NemotronV3Model.
- Parameters:
config β NemotronH config with model parameters
backend β Backend configuration for MoE and other components
moe_config β MoE configuration (optional, will create default if None)
- forward(
- input_ids: torch.LongTensor | None = None,
- *,
- attention_mask: torch.Tensor | None = None,
- causal_mask_mapping: dict[str, torch.Tensor] | None = None,
- inputs_embeds: torch.Tensor | None = None,
- past_key_values=None,
- cache_position: torch.LongTensor | None = None,
- **kwargs: Any,
Forward pass through the model.
- Parameters:
input_ids β Input token IDs [batch_size, seq_len] (optional)
attention_mask β 2D padding mask [batch_size, seq_len] (1=real, 0=padding)
causal_mask_mapping β Dict with precomputed 4D causal masks for attention layers
inputs_embeds β Input embeddings [batch_size, seq_len, hidden_size] (optional)
past_key_values β Optional NemotronHybridCache for incremental decoding.
cache_position β Token position indices for cache updates.
**kwargs β Additional arguments (ignored)
- Returns:
Hidden states tensor [batch_size, seq_len, hidden_size]
- initialize_weights(buffer_device: torch.device | None = None) None#
Initialize model weights according to NemotronV3 spec.
- Parameters:
buffer_device β Device to use for buffer initialization
- class nemo_automodel.components.models.nemotron_v3.model.NemotronHForCausalLM(
- config,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- **kwargs,
Bases:
nemo_automodel.components.models.common.HFCheckpointingMixin,transformers.generation.GenerationMixin,torch.nn.Module,nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixinNemotronV3 model with language modeling head.
Supports
.generate()fromtransformers.generation.GenerationMixinwith O(1) per-step KV caching for attention layers and recurrent state caching for Mamba2 layers.Initialization
Initialize NemotronV3ForCausalLM.
- Parameters:
config β NemotronH config
backend β Backend configuration
**kwargs β Additional arguments
- _is_stateful: bool#
True
- main_input_name: str#
βinput_idsβ
- classmethod from_config(
- config,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- **kwargs,
Create model from config.
- Parameters:
config β NemotronH config
backend β Backend configuration
**kwargs β Additional arguments
- Returns:
NemotronHForCausalLM instance
- classmethod from_pretrained(
- pretrained_model_name_or_path: str,
- *model_args,
- **kwargs,
Load pretrained model.
- Parameters:
pretrained_model_name_or_path β Path or name of pretrained model
*model_args β Additional positional arguments
**kwargs β Additional keyword arguments
- Returns:
NemotronHForCausalLM instance
- property device: torch.device#
Return the device of the first model parameter (required by GenerationMixin).
- property dtype: torch.dtype#
Return the dtype of the first model parameter (used by cache construction).
- get_input_embeddings()#
- set_input_embeddings(value)#
- get_output_embeddings()#
- set_output_embeddings(new_embeddings)#
- forward(
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- causal_mask_mapping: Optional[dict[str, torch.Tensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Any] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs: Any,
Forward pass with optional loss computation.
- Parameters:
input_ids β Input token IDs [batch_size, seq_len] (optional)
attention_mask β 2D padding mask [batch_size, seq_len]
causal_mask_mapping β Dict with precomputed 4D causal masks
inputs_embeds β Pre-computed input embeddings (optional)
labels β Token IDs for loss computation [batch_size, seq_len] (optional)
past_key_values β Optional NemotronHybridCache for incremental decoding.
use_cache β Whether to return past_key_values for subsequent steps.
cache_position β Token position indices for cache updates.
position_ids β Unused β accepted for API compatibility with GenerationMixin.
logits_to_keep β If > 0, only compute logits for the last
logits_to_keeptoken positions (avoids materialising the full logit matrix during generation).output_hidden_states β Whether to return hidden states
return_dict β Accepted for API compatibility (always returns CausalLMOutputWithPast)
**kwargs β Additional arguments forwarded to the base model.
- Returns:
- class:
~transformers.modeling_outputs.CausalLMOutputWithPastwithlogits(float32,[batch_size, seq_len, vocab_size]), optionalloss,past_key_values, andhidden_states.
- static _make_causal_mask(
- query_len: int,
- kv_len: int,
- batch_size: int,
- dtype: torch.dtype,
- device: torch.device,
Build a 4D SDPA-compatible causal mask.
Prefill (query_len == kv_len): standard lower-triangular causal mask. Decode (query_len == 1): all-zeros row allowing attention to all cached positions.
- prepare_inputs_for_generation(
- input_ids: torch.LongTensor,
- attention_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Any] = None,
- cache_position: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = True,
- **kwargs,
Prepare model inputs for each generation step.
On the first call (prefill), creates a :class:
NemotronHybridCacheand forwards the full prompt. On subsequent calls (decode), only the newly generated token is forwarded.- Parameters:
input_ids β Accumulated token ids [batch_size, current_seq_len].
attention_mask β Padding mask [batch_size, current_seq_len].
inputs_embeds β Pre-computed embeddings for the first step (optional).
past_key_values β NemotronHybridCache from the previous step (None on first call).
cache_position β Token position indices.
use_cache β Whether to use caching (default True).
**kwargs β Remaining model kwargs.
- Returns:
Dict of keyword arguments to pass to :meth:
forward.
- initialize_weights(
- buffer_device: torch.device | None = None,
- dtype: torch.dtype = torch.bfloat16,
Initialize model weights.
- Parameters:
buffer_device β Device to use for buffer initialization
dtype β Target dtype for model weights
- nemo_automodel.components.models.nemotron_v3.model.ModelClass#
None