nemo_automodel.components.models.ernie4_5.model

View as Markdown

Module Contents

Classes

NameDescription
Ernie4_5AttentionERNIE 4.5 GQA attention with interleaved RoPE.
Ernie4_5BlockDense ERNIE 4.5 decoder block.
Ernie4_5ForCausalLMDense ERNIE 4.5 causal language model.
Ernie4_5ModelDense ERNIE 4.5 transformer body.
Ernie4_5MoeBlockERNIE 4.5 MoE decoder block.
Ernie4_5_MoeForCausalLMERNIE 4.5 MoE causal language model with AutoModel EP support.
Ernie4_5_MoeModelERNIE 4.5 MoE transformer body.

Functions

NameDescription
_config_dtype-

Data

ModelClass

API

class nemo_automodel.components.models.ernie4_5.model.Ernie4_5Attention(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config | transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

ERNIE 4.5 GQA attention with interleaved RoPE.

head_dim
k_proj
num_heads
= config.num_attention_heads
num_kv_heads
= config.num_key_value_heads
o_proj
q_proj
v_proj
nemo_automodel.components.models.ernie4_5.model.Ernie4_5Attention.forward(
x: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5Block(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config | transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Dense ERNIE 4.5 decoder block.

input_layernorm
mlp
post_attention_layernorm
self_attn
= Ernie4_5Attention(config, backend)
nemo_automodel.components.models.ernie4_5.model.Ernie4_5Block.forward(
x: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, Module

Dense ERNIE 4.5 causal language model.

_nemo_tied_weights_keys
= {'lm_head.weight': 'model.embed_tokens.weight'}
_pp_plan
= {'lm_head': (['hidden_states'], ['logits'])}
_tp_plan
= {'lm_head': 'colwise_rep'}
backend
= backend or BackendConfig()
lm_head
model
= Ernie4_5Model(config, self.backend)
state_dict_adapter
= Ernie4_5StateDictAdapter(config)
vocab_size
= config.vocab_size
nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM.forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
output_hidden_states: typing.Optional[bool] = None,
attn_kwargs: typing.Any = {}
) -> transformers.modeling_outputs.CausalLMOutputWithPast
nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM.from_config(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM.get_input_embeddings()
nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM.get_output_embeddings()
nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM.set_input_embeddings(
value
)
nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM.set_output_embeddings(
new_embeddings
)
nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM.tie_weights()
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5Model(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Dense ERNIE 4.5 transformer body.

embed_tokens
layers
norm
rotary_emb
= Ernie4_5RotaryEmbedding(config)
nemo_automodel.components.models.ernie4_5.model.Ernie4_5Model.forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5MoeBlock(
layer_idx: int,
config: transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

ERNIE 4.5 MoE decoder block.

input_layernorm
mlp
= MoE(moe_config, backend)
post_attention_layernorm
self_attn
= Ernie4_5Attention(config, backend)
nemo_automodel.components.models.ernie4_5.model.Ernie4_5MoeBlock.forward(
x: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM(
config: transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, Module, MoEFSDPSyncMixin

ERNIE 4.5 MoE causal language model with AutoModel EP support.

_nemo_tied_weights_keys
= {'lm_head.weight': 'model.embed_tokens.weight'}
_pp_plan
= {'lm_head': (['hidden_states'], ['logits'])}
_tp_plan
= {'lm_head': 'colwise_rep'}
backend
= backend or BackendConfig()
lm_head
model
state_dict_adapter
vocab_size
= config.vocab_size
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
output_hidden_states: typing.Optional[bool] = None,
attn_kwargs: typing.Any = {}
) -> transformers.modeling_outputs.CausalLMOutputWithPast
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.from_config(
config: transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.get_capabilities(
config
) -> nemo_automodel._transformers.model_capabilities.ModelCapabilities
classmethod

Return parallelism capabilities for a specific ERNIE-4.5 config.

ERNIE-4.5 ships in two flavors that share this class file but exercise different code paths:

  1. baidu/ERNIE-4.5-21B-A3B-PT — MoE variant (this NeMo custom class). moe_num_experts > 0 in the HF config. Demonstrated by examples/llm_finetune/ernie4_5/ernie4_5_21b_a3b_hellaswag.yaml (ep_size=8).
  2. baidu/ERNIE-4.5-0.3B-PT — dense variant. No expert config. Demonstrated by examples/llm_finetune/ernie4_5/ernie4_5_0p3b_hellaswag.yaml (tp/cp/pp/ep all 1).
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.get_input_embeddings()
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.get_output_embeddings()
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.set_input_embeddings(
value
)
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.set_output_embeddings(
new_embeddings
)
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM.tie_weights()
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeModel(
config: transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
moe_overrides: dict | None = None
)

Bases: Module

ERNIE 4.5 MoE transformer body.

embed_tokens
layers
moe_config
= moe_config or MoEConfig(**moe_defaults)
norm
rotary_emb
= Ernie4_5RotaryEmbedding(config)
nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeModel.forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.ernie4_5.model._config_dtype(
config: typing.Any
) -> torch.dtype
nemo_automodel.components.models.ernie4_5.model.ModelClass = Ernie4_5_MoeForCausalLM