nemo_automodel.components.models.gpt_oss.model#

Module Contents#

Classes#

API#

class nemo_automodel.components.models.gpt_oss.model.Block(
layer_idx: int,
config: transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig,
moe_config: nemo_automodel.components.moe.layers.MoEConfig,
backend: nemo_automodel.components.moe.utils.BackendConfig,
)#

Bases: torch.nn.Module

Initialization

forward(
x: torch.Tensor,
*,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
) tuple[torch.Tensor, torch.Tensor | None]#
_mlp(
x: torch.Tensor,
padding_mask: torch.Tensor | None,
) torch.Tensor#
init_weights(buffer_device: torch.device)#
class nemo_automodel.components.models.gpt_oss.model.GptOssModel(
config: transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig,
backend: nemo_automodel.components.moe.utils.BackendConfig,
*,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
)#

Bases: torch.nn.Module

Initialization

forward(
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
init_weights(buffer_device: torch.device | None = None) None#
class nemo_automodel.components.models.gpt_oss.model.GptOssForCausalLM(
config: transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.moe.utils.BackendConfig | None = None,
)#

Bases: torch.nn.Module

Initialization

classmethod from_config(
pretrained_model_name_or_path: str | transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig,
moe_config: nemo_automodel.components.moe.layers.MoEConfig | None = None,
backend: nemo_automodel.components.moe.utils.BackendConfig | None = None,
trust_remote_code: bool = False,
)#
forward(
input_ids: torch.Tensor,
*,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
) None#