nemo_automodel._transformers.infrastructure#

Infrastructure instantiation and application.

Distributed manager instantiation, sharding, PEFT/quantization application, and checkpoint loading utilities. These free functions operate on an already-instantiated nn.Module and have no coupling to the _BaseNeMoAutoModelClass hierarchy.

MeshContext (from mesh) is the single source of truth for device meshes, parallelism sizes, and axis names.

Module Contents#

Functions#

_apply_peft_and_lower_precision

_shard_pp

_shard_ep_fsdp

Apply EP + FSDP sharding (non-PP path).

_instantiate_distributed

Instantiate the appropriate distributed manager from config.

_instantiate_pipeline

Instantiate AutoPipeline from config.

_instantiate_qat

parallelize_for_pp

Parallelize model for pipeline parallelism (non-MoE case).

instantiate_infrastructure

Instantiate infrastructure objects from config classes.

apply_model_infrastructure

Apply sharding, PEFT, quantization, and checkpoint loading to a model.

Data#

API#

nemo_automodel._transformers.infrastructure.logger#

β€˜getLogger(…)’

nemo_automodel._transformers.infrastructure._apply_peft_and_lower_precision(
model,
tp_size,
autopipeline,
peft_config,
quantization_config,
fp8_config,
qat_quantizer,
)#
nemo_automodel._transformers.infrastructure._shard_pp(autopipeline, model, loss_fn, parallelize_fn)#
nemo_automodel._transformers.infrastructure._shard_ep_fsdp(
model,
model_wrapper,
parallelize_fn,
mesh: nemo_automodel.components.distributed.mesh.MeshContext,
)#

Apply EP + FSDP sharding (non-PP path).

nemo_automodel._transformers.infrastructure._instantiate_distributed(
config: nemo_automodel.components.distributed.config.DistributedConfig,
mesh: nemo_automodel.components.distributed.mesh.MeshContext,
) Union[nemo_automodel.components.distributed.fsdp2.FSDP2Manager, nemo_automodel.components.distributed.megatron_fsdp.MegatronFSDPManager, nemo_automodel.components.distributed.ddp.DDPManager, None]#

Instantiate the appropriate distributed manager from config.

Parameters:
  • config – Distributed config (FSDP2Config, MegatronFSDPConfig, or DDPConfig).

  • mesh – MeshContext holding device_mesh and moe_mesh references.

Returns:

The instantiated manager, or None if config is None.

Raises:

ValueError – If device_mesh is required but not provided.

nemo_automodel._transformers.infrastructure._instantiate_pipeline(
config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig],
mesh: nemo_automodel.components.distributed.mesh.MeshContext,
device: Optional[torch.device] = None,
) Optional[nemo_automodel.components.distributed.pipelining.autopipeline.AutoPipeline]#

Instantiate AutoPipeline from config.

Parameters:
  • config – Pipeline config. If None or pp_size <= 1, returns None.

  • mesh – MeshContext holding device_mesh, moe_mesh, and axis names.

  • device – Target device for pipeline computation.

Returns:

AutoPipeline instance, or None if pipeline parallelism is not enabled.

nemo_automodel._transformers.infrastructure._instantiate_qat(
config: Optional[nemo_automodel.components.quantization.qat.QATConfig],
) Optional[Union[torchao.quantization.qat.linear.Int4WeightOnlyQATQuantizer, torchao.quantization.qat.linear.Int8DynActInt4WeightQATQuantizer]]#
nemo_automodel._transformers.infrastructure.parallelize_for_pp(
model: torch.nn.Module,
*,
model_wrapper: Optional[Union[nemo_automodel.components.distributed.fsdp2.FSDP2Manager, nemo_automodel.components.distributed.megatron_fsdp.MegatronFSDPManager, nemo_automodel.components.distributed.ddp.DDPManager]] = None,
**kwargs,
) torch.nn.Module#

Parallelize model for pipeline parallelism (non-MoE case).

This function adapts the pipeline parallelism interface to use model_wrapper.parallelize(). For MoE models, use parallelize_model from nemo_automodel.components.moe.parallelizer directly.

Parameters:
  • model – The model to parallelize.

  • model_wrapper – Distributed manager instance.

  • **kwargs – Additional arguments (world_mesh, moe_mesh, axis names) passed by AutoPipeline but unused for non-MoE parallelization.

Returns:

The parallelized model.

nemo_automodel._transformers.infrastructure.instantiate_infrastructure(
*,
distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig] = None,
qat_config: Optional[nemo_automodel.components.quantization.qat.QATConfig] = None,
moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
activation_checkpointing: bool = False,
device: Optional[torch.device] = None,
mesh: Optional[nemo_automodel.components.distributed.mesh.MeshContext] = None,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
ep_size: int = 1,
) tuple#

Instantiate infrastructure objects from config classes.

This function converts config objects into the runtime objects needed by apply_model_infrastructure. It provides a cleaner, more HuggingFace-like API where users pass config objects instead of constructing runtime objects directly.

Parameters:
  • distributed_config – Distributed training config (FSDP2Config, MegatronFSDPConfig, or DDPConfig).

  • pipeline_config – Pipeline parallelism config.

  • qat_config – Quantization-aware training config.

  • moe_config – MoE parallelizer config (for expert parallel models).

  • activation_checkpointing – Enable activation checkpointing for transformer blocks. Defaults to False.

  • device – Target device for model.

  • mesh – MeshContext holding device meshes, sizes, and axis names. If None, built from the legacy device_mesh / moe_mesh params.

  • device_mesh – (deprecated) Device mesh for distributed operations.

  • moe_mesh – (deprecated) Optional MOE mesh for expert parallelism.

  • ep_size – (deprecated) Expert parallelism size. Ignored when mesh is provided.

Returns:

(model_wrapper, autopipeline, parallelize_fn, qat_quantizer) - model_wrapper: Distributed manager instance (or None) - autopipeline: AutoPipeline instance (or None) - parallelize_fn: Parallelization function (or None) - built for EP (MoE-specific parallelizer when ep_size > 1) or PP (via model_wrapper) - qat_quantizer: QAT quantizer instance (or None)

Return type:

tuple

nemo_automodel._transformers.infrastructure.apply_model_infrastructure(
model,
*,
is_meta_device,
device,
model_wrapper=None,
mesh=None,
peft_config=None,
quantization_config=None,
fp8_config=None,
qat_quantizer=None,
loss_fn=None,
autopipeline=None,
parallelize_fn=None,
compile_config=None,
load_base_model=False,
cache_dir=None,
pretrained_model_name_or_path='',
**_kwargs,
)#

Apply sharding, PEFT, quantization, and checkpoint loading to a model.

This function contains the common post-init logic shared between from_pretrained and from_config methods. It can also be called directly for models built via custom builder functions (e.g., build_gpt2_model). It handles:

  • PEFT and lower precision application (LoRA, FP8, QAT)

  • Loss function setup

  • Pipeline parallelism or EP/FSDP sharding

  • Device placement and compilation

  • Checkpoint loading for meta device models

Parameters:
  • model – The model to apply infrastructure to

  • is_meta_device – Whether model was initialized on meta device

  • device – Target device for model

  • model_wrapper – Model wrapper (FSDP2Manager, DDPManager, etc.). Default: None

  • mesh – MeshContext with parallelism sizes (tp_size, cp_size, etc.) and mesh references. Default: None (treated as single-GPU defaults).

  • peft_config – PEFT/LoRA configuration dict. Default: None

  • quantization_config – Quantization configuration. Default: None

  • fp8_config – FP8 configuration. Default: None

  • qat_quantizer – QAT quantizer instance. Default: None

  • loss_fn – Loss function (may be replaced with MaskedCrossEntropy). Default: None

  • autopipeline – AutoPipeline instance for pipeline parallelism. Default: None

  • parallelize_fn – Function to apply parallelization (EP + FSDP2). Default: None

  • compile_config – Compilation configuration. Default: None

  • pretrained_model_name_or_path – Model name or path for checkpoint loading. Default: β€œβ€

  • load_base_model – Whether to load base model weights (True for from_pretrained). Default: False

  • cache_dir – Cache directory for model weights. Default: None

  • **_kwargs – Additional keyword arguments (ignored, allows passing extra kwargs)

Returns:

The model with all infrastructure applied