nemo_rl.models.automodel.config#
Configuration classes for automodel-based training in NeMo RL.
Module Contents#
Classes#
Distributed context returned by setup_distributed(). |
|
Runtime configuration for model training and inference. |
|
Container for model and optimizer state. |
API#
- class nemo_rl.models.automodel.config.DistributedContext#
Bases:
typing.NamedTupleDistributed context returned by setup_distributed().
Contains the device meshes and distributed configuration needed for model parallelization and training.
- device_mesh: Any#
None
- moe_mesh: Any#
None
- fsdp2_config: Any#
None
- moe_config: Any#
None
- dp_size: int#
None
- tp_size: int#
None
- cp_size: int#
None
- class nemo_rl.models.automodel.config.RuntimeConfig#
Bases:
typing.NamedTupleRuntime configuration for model training and inference.
This contains all validated runtime settings needed for model initialization, parallelization, and training.
- model_class: type#
None
- model_config: Any#
None
- hf_config_overrides: dict[str, Any]#
None
- allow_flash_attn_args: bool#
None
- attn_impl: Optional[str]#
None
- dtype: torch.dtype#
None
- enable_seq_packing: bool#
None
- max_grad_norm: float#
None
- cpu_offload: bool#
None
- offload_optimizer_for_logprob: bool#
None
- is_generation_colocated: Optional[bool]#
None
- sampling_params: Optional[nemo_rl.algorithms.logits_sampling_utils.TrainingSamplingParams]#
None
- is_reward_model: bool#
None
- class nemo_rl.models.automodel.config.ModelAndOptimizerState#
Bases:
typing.NamedTupleContainer for model and optimizer state.
This named tuple holds all model-related state including the model itself, optimizer, scheduler, and metadata about the model type and configuration.
- model: torch.nn.Module#
None
- optimizer: Optional[torch.optim.Optimizer]#
None
- scheduler: Optional[Any]#
None
- is_hf_model: bool#
None
- is_moe_model: bool#
None
- is_reward_model: bool#
None
- model_class: type#
None
- model_config: Any#
None
- peft_config: Optional[nemo_automodel.components._peft.lora.PeftConfig]#
None
- autocast_enabled: bool#
None