nemo_automodel.components.models.gpt_oss.model
#
Module Contents#
Classes#
API#
- class nemo_automodel.components.models.gpt_oss.model.Block(
- layer_idx: int,
- config: transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig,
- backend: nemo_automodel.components.moe.utils.BackendConfig,
Bases:
torch.nn.Module
Initialization
- forward(
- x: torch.Tensor,
- *,
- freqs_cis: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- _mlp(
- x: torch.Tensor,
- padding_mask: torch.Tensor | None,
- init_weights(buffer_device: torch.device)#
- class nemo_automodel.components.models.gpt_oss.model.GptOssModel(
- config: transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig,
- backend: nemo_automodel.components.moe.utils.BackendConfig,
- *,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
Bases:
torch.nn.Module
Initialization
- forward(
- input_ids: torch.Tensor,
- *,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- **attn_kwargs: Any,
- init_weights(buffer_device: torch.device | None = None) None #
- class nemo_automodel.components.models.gpt_oss.model.GptOssForCausalLM(
- config: transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
- backend: nemo_automodel.components.moe.utils.BackendConfig | None = None,
Bases:
torch.nn.Module
Initialization
- classmethod from_config(
- pretrained_model_name_or_path: str | transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig,
- moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
- backend: nemo_automodel.components.moe.utils.BackendConfig | None = None,
- trust_remote_code: bool = False,
- forward(
- input_ids: torch.Tensor,
- *,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- **attn_kwargs: Any,
- initialize_weights(
- buffer_device: torch.device | None = None,
- dtype: torch.dtype = torch.bfloat16,