nemo_automodel.components.models.ernie4_5.model#

Module Contents#

Classes#

Ernie4_5Attention

ERNIE 4.5 GQA attention with interleaved RoPE.

Ernie4_5Block

Dense ERNIE 4.5 decoder block.

Ernie4_5MoeBlock

ERNIE 4.5 MoE decoder block.

Ernie4_5Model

Dense ERNIE 4.5 transformer body.

Ernie4_5_MoeModel

ERNIE 4.5 MoE transformer body.

Ernie4_5ForCausalLM

Dense ERNIE 4.5 causal language model.

Ernie4_5_MoeForCausalLM

ERNIE 4.5 MoE causal language model with AutoModel EP support.

Functions#

Data#

API#

nemo_automodel.components.models.ernie4_5.model._config_dtype(config: Any) torch.dtype#
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5Attention(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config | transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

ERNIE 4.5 GQA attention with interleaved RoPE.

Initialization

forward(
x: torch.Tensor,
*,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5Block(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config | transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Dense ERNIE 4.5 decoder block.

Initialization

forward(
x: torch.Tensor,
*,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5MoeBlock(
layer_idx: int,
config: transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

ERNIE 4.5 MoE decoder block.

Initialization

forward(
x: torch.Tensor,
*,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5Model(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Dense ERNIE 4.5 transformer body.

Initialization

forward(
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeModel(
config: transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
*,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
moe_overrides: dict | None = None,
)#

Bases: torch.nn.Module

ERNIE 4.5 MoE transformer body.

Initialization

forward(
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5ForCausalLM(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#

Bases: nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin, torch.nn.Module

Dense ERNIE 4.5 causal language model.

Initialization

supports_gradient_checkpointing#

True

_skip_init_weights_on_load#

True

_nemo_tied_weights_keys#

None

_tp_plan#

None

_pp_plan#

None

classmethod from_config(
config: transformers.models.ernie4_5.configuration_ernie4_5.Ernie4_5Config,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#
classmethod from_pretrained(
pretrained_model_name_or_path: str,
*model_args,
**kwargs,
)#
get_input_embeddings()#
set_input_embeddings(value)#
get_output_embeddings()#
set_output_embeddings(new_embeddings)#
tie_weights()#
forward(
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**attn_kwargs: Any,
) torch.Tensor#
class nemo_automodel.components.models.ernie4_5.model.Ernie4_5_MoeForCausalLM(
config: transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#

Bases: nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin, torch.nn.Module, nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin

ERNIE 4.5 MoE causal language model with AutoModel EP support.

Initialization

supports_gradient_checkpointing#

True

_skip_init_weights_on_load#

True

_nemo_tied_weights_keys#

None

_tp_plan#

None

_pp_plan#

None

classmethod from_config(
config: transformers.models.ernie4_5_moe.configuration_ernie4_5_moe.Ernie4_5_MoeConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#
classmethod from_pretrained(
pretrained_model_name_or_path: str,
*model_args,
**kwargs,
)#
get_input_embeddings()#
set_input_embeddings(value)#
get_output_embeddings()#
set_output_embeddings(new_embeddings)#
tie_weights()#
forward(
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**attn_kwargs: Any,
) torch.Tensor#
nemo_automodel.components.models.ernie4_5.model.ModelClass#

None