nemo_rl.models.megatron.setup#
Module Contents#
Classes#
Float 16 Module with the ability to keep the expert bias in float32. |
Functions#
Safely destroy parallel state and reset async call tracking. |
|
Handle NCCL settings, dtype mapping, and basic config setup. |
|
Validate and setup model paths. |
|
Handle all the model configuration logic. |
|
Apply tensor/pipeline/context parallelism configuration. |
|
Apply Mixture of Experts configuration. |
|
Apply precision and dtype configuration. |
|
Apply performance optimization configuration. |
|
Validate optimizer configuration. |
|
Validate chunking configuration. |
|
Create checkpoint configurations. |
|
Validate training configuration. |
|
Create the final Megatron configuration container. |
|
Handle HF model import if checkpoint doesn’t exist. |
|
Setup the reference model for inference and return its state dict. |
|
Finalize the setup with remaining configurations. |
Data#
API#
- nemo_rl.models.megatron.setup.TokenizerType#
‘TypeVar(…)’
- nemo_rl.models.megatron.setup.destroy_parallel_state()#
Safely destroy parallel state and reset async call tracking.
This function is called during initialization to clean up temporary distributed state from model import operations. Resetting async call tracking ensures that when the main Megatron distributed context is created, all ranks start with consistent call_idx values for async checkpointing.
- nemo_rl.models.megatron.setup.setup_distributed() None#
Handle NCCL settings, dtype mapping, and basic config setup.
- nemo_rl.models.megatron.setup.validate_and_set_config(
- config,
- rank,
- hf_model_name,
- pretrained_path,
- weights_path,
- tokenizer,
- nemo_rl.models.megatron.setup.validate_model_paths(
- config: nemo_rl.models.policy.PolicyConfig,
Validate and setup model paths.
- nemo_rl.models.megatron.setup.setup_model_config(
- config: nemo_rl.models.policy.PolicyConfig,
- rank,
- dtype,
- hf_model_name: str,
- pretrained_path: str,
- weights_path: Optional[str] = None,
Handle all the model configuration logic.
- nemo_rl.models.megatron.setup._apply_parallelism_config(
- model_cfg: Any,
- config: nemo_rl.models.policy.PolicyConfig,
Apply tensor/pipeline/context parallelism configuration.
- nemo_rl.models.megatron.setup._apply_moe_config(
- model_cfg: Any,
- config: nemo_rl.models.policy.PolicyConfig,
Apply Mixture of Experts configuration.
- nemo_rl.models.megatron.setup._apply_precision_config(
- model_cfg: Any,
- config: nemo_rl.models.policy.PolicyConfig,
- dtype: torch.dtype,
Apply precision and dtype configuration.
- nemo_rl.models.megatron.setup._apply_performance_config(
- model_cfg: Any,
- config: nemo_rl.models.policy.PolicyConfig,
Apply performance optimization configuration.
- nemo_rl.models.megatron.setup._validate_optimizer_config(
- config: nemo_rl.models.policy.PolicyConfig,
Validate optimizer configuration.
- nemo_rl.models.megatron.setup._validate_chunking_config(
- config: nemo_rl.models.policy.PolicyConfig,
Validate chunking configuration.
- nemo_rl.models.megatron.setup._create_checkpoint_config(
- pretrained_path: str,
- weights_path: Optional[str],
Create checkpoint configurations.
- nemo_rl.models.megatron.setup._validate_training_config(
- config: nemo_rl.models.policy.PolicyConfig,
- model_cfg: Any,
Validate training configuration.
- nemo_rl.models.megatron.setup._validate_dtype_config(
- dtype: torch.dtype,
- model_cfg: Any,
- optimizer_cfg: Any,
- nemo_rl.models.megatron.setup._create_megatron_config(
- model_cfg: Any,
- checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
- config: nemo_rl.models.policy.PolicyConfig,
- hf_model_name: str,
- dtype: torch.dtype,
Create the final Megatron configuration container.
- nemo_rl.models.megatron.setup.setup_model_and_optimizer(
- policy_cfg: nemo_rl.models.policy.PolicyConfig,
- megatron_cfg: megatron.bridge.training.config.ConfigContainer,
- load_optimizer: bool = True,
- get_embedding_ranks=None,
- get_position_embedding_ranks=None,
- nemo_rl.models.megatron.setup.handle_model_import(
- config: nemo_rl.models.policy.PolicyConfig,
- hf_model_name: str,
- pretrained_path: str,
- pt_checkpoint_exists: bool,
Handle HF model import if checkpoint doesn’t exist.
- nemo_rl.models.megatron.setup.setup_reference_model_state(
- config: nemo_rl.models.policy.PolicyConfig,
- megatron_cfg: megatron.bridge.training.config.ConfigContainer,
- pretrained_path: str,
Setup the reference model for inference and return its state dict.
- nemo_rl.models.megatron.setup.finalize_megatron_setup(
- config: nemo_rl.models.policy.PolicyConfig,
- megatron_cfg: megatron.bridge.training.config.ConfigContainer,
- hf_model_name: str,
- worker_sharding_annotations: nemo_rl.distributed.named_sharding.NamedSharding,
- model,
- optimizer,
Finalize the setup with remaining configurations.
- Returns:
Tuple of (megatron_tokenizer, megatron_bridge, should_disable_forward_pre_hook, dp_size)
- class nemo_rl.models.megatron.setup.MoEFloat16Module(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- module: torch.nn.Module,
Bases:
megatron.core.transformer.module.Float16ModuleFloat 16 Module with the ability to keep the expert bias in float32.
.. attribute:: config
Transformer config
- Type:
TransformerConfig
.. attribute:: fp16
Specifies if the model runs in fp16 mode
- Type:
bool
.. attribute:: bf16
Specifies if the model runs in bf16 mode
- Type:
bool
- Parameters:
config (TransformerConfig) – The transformer config used to initalize the model
Initialization
- re_enable_float32_expert_bias() None#
Ensure MoE router expert bias stays in float32 for numerical stability.
Walks the wrapped module to find MoE routers and invokes the
_maintain_float32_expert_bias()helper which recreates or casts the expert bias tensors to float32 as required by Megatron-LM.