nemo_rl.models.megatron.setup#

Module Contents#

Classes#

MoEFloat16Module

Float 16 Module with the ability to keep the expert bias in float32.

Functions#

destroy_parallel_state

Safely destroy parallel state and reset async call tracking.

setup_distributed

Handle NCCL settings, dtype mapping, and basic config setup.

validate_and_set_config

_canonicalize_hf_config_overrides

Return a stable JSON string for hf_config_overrides.

_get_hf_config_overrides_hash

Return a short stable hash for hf_config_overrides.

_resolve_iter_dir_from_root

Resolve the latest iteration directory under path.

validate_model_paths

Validate and setup model paths.

setup_model_config

Handle all the model configuration logic.

_apply_parallelism_config

Apply tensor/pipeline/context parallelism configuration.

_apply_moe_config

Apply Mixture of Experts configuration.

_apply_mtp_config

_apply_precision_config

Apply precision and dtype configuration.

_apply_performance_config

Apply performance optimization configuration.

_validate_optimizer_config

Validate optimizer configuration.

_validate_chunking_config

Validate chunking configuration.

_create_checkpoint_config

Create checkpoint configurations.

_validate_training_config

Validate training configuration.

_validate_dtype_config

_create_megatron_config

Create the final Megatron configuration container.

_create_draft_pre_wrap_hook

Create the hook that attaches draft weights before mixed-precision/DDP wrapping.

setup_model_and_optimizer

handle_model_import

Convert and cache the initial model checkpoint if it does not yet exist.

setup_reference_model_state

Setup the reference model for inference and return its state dict.

finalize_megatron_setup

Finalize the setup with remaining configurations.

Data#

API#

nemo_rl.models.megatron.setup.TokenizerType#

‘TypeVar(…)’

nemo_rl.models.megatron.setup.destroy_parallel_state()#

Safely destroy parallel state and reset async call tracking.

This function is called during initialization to clean up temporary distributed state from model import operations. Resetting async call tracking ensures that when the main Megatron distributed context is created, all ranks start with consistent call_idx values for async checkpointing.

nemo_rl.models.megatron.setup.setup_distributed() None#

Handle NCCL settings, dtype mapping, and basic config setup.

nemo_rl.models.megatron.setup.validate_and_set_config(
config,
rank,
hf_model_name,
pretrained_path,
weights_path,
optimizer_path,
)#
nemo_rl.models.megatron.setup._canonicalize_hf_config_overrides(
overrides: dict[str, Any],
) str#

Return a stable JSON string for hf_config_overrides.

nemo_rl.models.megatron.setup._get_hf_config_overrides_hash(overrides: dict[str, Any]) str#

Return a short stable hash for hf_config_overrides.

nemo_rl.models.megatron.setup._resolve_iter_dir_from_root(path: str, not_found_msg: str) str#

Resolve the latest iteration directory under path.

Checks latest_checkpointed_iteration.txt first; falls back to scanning for iter_* subdirectories and taking the last one (lexicographic order).

nemo_rl.models.megatron.setup.validate_model_paths(
config: nemo_rl.models.policy.PolicyConfig,
) tuple[str, str, bool]#

Validate and setup model paths.

Returns:

  • hf_model_name is the HuggingFace model name / path used for architecture config resolution and tokenizer setup.

  • pretrained_path is the path of the checkpoint that will be used as the pretrained starting point. For megatron_bridge format this is resolved to the specific iteration directory containing run_config.yaml. For megatron_lm format this is resolved to the specific iteration directory (via latest_checkpointed_iteration.txt or by scanning iter_* subdirs if a root dir is provided, since the bridge does not resolve iterations itself). For the default HF path this is the Megatron-Bridge cache directory.

  • pt_checkpoint_exists is True when the checkpoint at pretrained_path is already present and does not need to be created.

Return type:

A (hf_model_name, pretrained_path, pt_checkpoint_exists) tuple where

nemo_rl.models.megatron.setup.setup_model_config(
config: nemo_rl.models.policy.PolicyConfig,
rank,
dtype,
hf_model_name: str,
pretrained_path: str,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
) tuple[megatron.bridge.training.config.ConfigContainer, Any]#

Handle all the model configuration logic.

nemo_rl.models.megatron.setup._apply_parallelism_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
) None#

Apply tensor/pipeline/context parallelism configuration.

nemo_rl.models.megatron.setup._apply_moe_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
) None#

Apply Mixture of Experts configuration.

nemo_rl.models.megatron.setup._apply_mtp_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
) None#
nemo_rl.models.megatron.setup._apply_precision_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
dtype: torch.dtype,
) None#

Apply precision and dtype configuration.

nemo_rl.models.megatron.setup._apply_performance_config(
model_cfg: Any,
config: nemo_rl.models.policy.PolicyConfig,
) None#

Apply performance optimization configuration.

nemo_rl.models.megatron.setup._validate_optimizer_config(
config: nemo_rl.models.policy.PolicyConfig,
) None#

Validate optimizer configuration.

nemo_rl.models.megatron.setup._validate_chunking_config(
config: nemo_rl.models.policy.PolicyConfig,
) None#

Validate chunking configuration.

nemo_rl.models.megatron.setup._create_checkpoint_config(
pretrained_path: str,
weights_path: Optional[str],
optimizer_path: Optional[str],
) megatron.bridge.training.config.CheckpointConfig#

Create checkpoint configurations.

nemo_rl.models.megatron.setup._validate_training_config(
config: nemo_rl.models.policy.PolicyConfig,
model_cfg: Any,
) None#

Validate training configuration.

nemo_rl.models.megatron.setup._validate_dtype_config(
dtype: torch.dtype,
model_cfg: Any,
optimizer_cfg: Any,
) None#
nemo_rl.models.megatron.setup._create_megatron_config(
model_cfg: Any,
checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
config: nemo_rl.models.policy.PolicyConfig,
hf_model_name: str,
dtype: torch.dtype,
) megatron.bridge.training.config.ConfigContainer#

Create the final Megatron configuration container.

nemo_rl.models.megatron.setup._create_draft_pre_wrap_hook(
policy_cfg: nemo_rl.models.policy.PolicyConfig,
megatron_cfg: megatron.bridge.training.config.ConfigContainer,
state: megatron.bridge.training.state.GlobalState,
*,
preload_policy_from_pretrained: bool,
) Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]]#

Create the hook that attaches draft weights before mixed-precision/DDP wrapping.

nemo_rl.models.megatron.setup.setup_model_and_optimizer(
policy_cfg: nemo_rl.models.policy.PolicyConfig,
megatron_cfg: megatron.bridge.training.config.ConfigContainer,
load_optimizer: bool = True,
get_embedding_ranks=None,
get_position_embedding_ranks=None,
pre_load_checkpoint_hook: Optional[Callable] = None,
)#
nemo_rl.models.megatron.setup.handle_model_import(
config: nemo_rl.models.policy.PolicyConfig,
hf_model_name: str,
pretrained_path: str,
pt_checkpoint_exists: bool,
model_post_wrap_hook: Optional[Callable] = None,
transformer_layer_spec: Optional[Any] = None,
) None#

Convert and cache the initial model checkpoint if it does not yet exist.

Behaviour depends on policy.pretrained_checkpoint.format:

  • "megatron_bridge": The checkpoint is already in the correct format; no conversion is performed.

  • "megatron_lm": Megatron-Bridge can load torch_dist MLM checkpoints directly (the bridge falls back to extracting config from the state dict when run_config.yaml is absent), so no conversion is performed.

  • No pretrained_checkpoint (default): The HuggingFace model identified by hf_model_name is converted to Megatron-Bridge format (existing behaviour).

The force_reconvert_from_hf flag forces the HF conversion to run again even if the output already exists. It has no effect for megatron_bridge or megatron_lm formats.

Parameters:
  • config – Policy config used for pretrained_checkpoint, hf_config_overrides, and megatron_cfg.

  • hf_model_name – HF model id (or local path) to import.

  • pretrained_path – Output directory for the Megatron checkpoint.

  • pt_checkpoint_exists – Whether a Megatron checkpoint already exists at pretrained_path. If True and force_reconvert_from_hf is False, the import is skipped.

  • model_post_wrap_hook

    Optional callable forwarded to

    func:

    import_model_from_hf_name. Invoked on each Megatron model chunk after it is built (and before DDP wrapping).

  • transformer_layer_spec – Optional Megatron ModuleSpec (or callable returning one) overriding the default layer spec from the model provider.

nemo_rl.models.megatron.setup.setup_reference_model_state(
config: nemo_rl.models.policy.PolicyConfig,
megatron_cfg: megatron.bridge.training.config.ConfigContainer,
pretrained_path: str,
pre_load_checkpoint_hook: Optional[Callable] = None,
) dict#

Setup the reference model for inference and return its state dict.

nemo_rl.models.megatron.setup.finalize_megatron_setup(
config: nemo_rl.models.policy.PolicyConfig,
megatron_cfg: megatron.bridge.training.config.ConfigContainer,
hf_model_name: str,
worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
model,
optimizer,
) tuple#

Finalize the setup with remaining configurations.

Returns:

Tuple of (megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size)

class nemo_rl.models.megatron.setup.MoEFloat16Module(
config: megatron.core.transformer.transformer_config.TransformerConfig,
module: torch.nn.Module,
)#

Bases: megatron.core.transformer.module.Float16Module

Float 16 Module with the ability to keep the expert bias in float32.

.. attribute:: config

Transformer config

Type:

TransformerConfig

.. attribute:: fp16

Specifies if the model runs in fp16 mode

Type:

bool

.. attribute:: bf16

Specifies if the model runs in bf16 mode

Type:

bool

Parameters:

config (TransformerConfig) – The transformer config used to initalize the model

Initialization

re_enable_float32_expert_bias() None#

Ensure MoE router expert bias stays in float32 for numerical stability.

Walks the wrapped module to find MoE routers and invokes the _maintain_float32_expert_bias() helper which recreates or casts the expert bias tensors to float32 as required by Megatron-LM.