nemo_rl.models.megatron.setup#

Module Contents#

Classes#

MoEFloat16Module

Float 16 Module with the ability to keep the expert bias in float32.

Functions#

destroy_parallel_state

Safely destroy parallel state and reset async call tracking.

setup_distributed

Handle NCCL settings, dtype mapping, and basic config setup.

validate_and_set_config

validate_model_paths

Validate and setup model paths.

setup_model_config

Handle all the model configuration logic.

_apply_parallelism_config

Apply tensor/pipeline/context parallelism configuration.

_apply_moe_config

Apply Mixture of Experts configuration.

_apply_precision_config

Apply precision and dtype configuration.

_apply_performance_config

Apply performance optimization configuration.

_validate_optimizer_config

Validate optimizer configuration.

_validate_chunking_config

Validate chunking configuration.

_create_checkpoint_config

Create checkpoint configurations.

_validate_training_config

Validate training configuration.

_validate_dtype_config

_create_megatron_config

Create the final Megatron configuration container.

setup_model_and_optimizer

handle_model_import

Handle HF model import if checkpoint doesn’t exist.

setup_reference_model_state

Setup the reference model for inference and return its state dict.

finalize_megatron_setup

Finalize the setup with remaining configurations.

Data#

API#

nemo_rl.models.megatron.setup.TokenizerType#

‘TypeVar(…)’

nemo_rl.models.megatron.setup.destroy_parallel_state()#

Safely destroy parallel state and reset async call tracking.

This function is called during initialization to clean up temporary distributed state from model import operations. Resetting async call tracking ensures that when the main Megatron distributed context is created, all ranks start with consistent call_idx values for async checkpointing.

nemo_rl.models.megatron.setup.setup_distributed() None#

Handle NCCL settings, dtype mapping, and basic config setup.

nemo_rl.models.megatron.setup.validate_and_set_config(
config,
rank,
hf_model_name,
pretrained_path,
weights_path,
tokenizer,
)#
nemo_rl.models.megatron.setup.validate_model_paths(
config: nemo_rl.models.policy.PolicyConfig,
) tuple[str, str, bool]#

Validate and setup model paths.

nemo_rl.models.megatron.setup.setup_model_config(
config: nemo_rl.models.policy.PolicyConfig,
rank,
dtype,
hf_model_name: str,
pretrained_path: str,
weights_path: Optional[str] = None,
) tuple[megatron.bridge.training.config.ConfigContainer, Any]#

Handle all the model configuration logic.

nemo_rl.models.megatron.setup._apply_parallelism_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
) None#

Apply tensor/pipeline/context parallelism configuration.

nemo_rl.models.megatron.setup._apply_moe_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
) None#

Apply Mixture of Experts configuration.

nemo_rl.models.megatron.setup._apply_precision_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
dtype: torch.dtype,
) None#

Apply precision and dtype configuration.

nemo_rl.models.megatron.setup._apply_performance_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
) None#

Apply performance optimization configuration.

nemo_rl.models.megatron.setup._validate_optimizer_config(
config: nemo_rl.models.policy.PolicyConfig,
) None#

Validate optimizer configuration.

nemo_rl.models.megatron.setup._validate_chunking_config(
config: nemo_rl.models.policy.PolicyConfig,
) None#

Validate chunking configuration.

nemo_rl.models.megatron.setup._create_checkpoint_config(
pretrained_path: str,
weights_path: Optional[str],
) megatron.bridge.training.config.CheckpointConfig#

Create checkpoint configurations.

nemo_rl.models.megatron.setup._validate_training_config(
config: nemo_rl.models.policy.PolicyConfig,
model_cfg: Any,
) None#

Validate training configuration.

nemo_rl.models.megatron.setup._validate_dtype_config(
dtype: torch.dtype,
model_cfg: Any,
optimizer_cfg: Any,
) None#
nemo_rl.models.megatron.setup._create_megatron_config(
model_cfg: Any,
checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
config: nemo_rl.models.policy.PolicyConfig,
hf_model_name: str,
dtype: torch.dtype,
) megatron.bridge.training.config.ConfigContainer#

Create the final Megatron configuration container.

nemo_rl.models.megatron.setup.setup_model_and_optimizer(
policy_cfg: nemo_rl.models.policy.PolicyConfig,
megatron_cfg: megatron.bridge.training.config.ConfigContainer,
load_optimizer: bool = True,
get_embedding_ranks=None,
get_position_embedding_ranks=None,
)#
nemo_rl.models.megatron.setup.handle_model_import(
config: nemo_rl.models.policy.PolicyConfig,
hf_model_name: str,
pretrained_path: str,
pt_checkpoint_exists: bool,
) None#

Handle HF model import if checkpoint doesn’t exist.

nemo_rl.models.megatron.setup.setup_reference_model_state(
config: nemo_rl.models.policy.PolicyConfig,
megatron_cfg: megatron.bridge.training.config.ConfigContainer,
pretrained_path: str,
) dict#

Setup the reference model for inference and return its state dict.

nemo_rl.models.megatron.setup.finalize_megatron_setup(
config: nemo_rl.models.policy.PolicyConfig,
megatron_cfg: megatron.bridge.training.config.ConfigContainer,
hf_model_name: str,
worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
model,
optimizer,
) tuple#

Finalize the setup with remaining configurations.

Returns:

Tuple of (megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size)

class nemo_rl.models.megatron.setup.MoEFloat16Module(
config: megatron.core.transformer.transformer_config.TransformerConfig,
module: torch.nn.Module,
)#

Bases: megatron.core.transformer.module.Float16Module

Float 16 Module with the ability to keep the expert bias in float32.

.. attribute:: config

Transformer config

Type:

TransformerConfig

.. attribute:: fp16

Specifies if the model runs in fp16 mode

Type:

bool

.. attribute:: bf16

Specifies if the model runs in bf16 mode

Type:

bool

Parameters:

config (TransformerConfig) – The transformer config used to initalize the model

Initialization

re_enable_float32_expert_bias() None#

Ensure MoE router expert bias stays in float32 for numerical stability.

Walks the wrapped module to find MoE routers and invokes the _maintain_float32_expert_bias() helper which recreates or casts the expert bias tensors to float32 as required by Megatron-LM.