bridge.models.megatron_mimo.build_model#

Shared model-construction entry point for MegatronMIMO.

Used by both setup_megatron_mimo (training) and the conversion CLI. Composes provider.finalize + build_infra + per-module RNG init + distributed-model build + parallel_state global bridge.

Module Contents#

Functions#

build_megatron_mimo_model

Build a distributed MegatronMIMO model and return (model, infra).

_set_per_module_random_seeds

Initialise per-module RNG using TP/PP ranks from the rank’s grid.

_bridge_parallel_state_globals

Set parallel_state globals from the rank-local ProcessGroupCollection.

Data#

API#

bridge.models.megatron_mimo.build_model.logger#

‘getLogger(…)’

bridge.models.megatron_mimo.build_model.build_megatron_mimo_model(
provider: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOProvider,
*,
ddp_config: Optional[megatron.core.distributed.DistributedDataParallelConfig] = None,
fp16: bool = False,
bf16: bool = True,
seed: int = 0,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = True,
) tuple[megatron.core.models.mimo.MimoModel, megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra]#

Build a distributed MegatronMIMO model and return (model, infra).

Side effects: initialises Python/NumPy/torch/Megatron-Core RNG and sets parallel_state._*_GROUP globals from the rank-local pg_collection.

bridge.models.megatron_mimo.build_model._set_per_module_random_seeds(
infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
*,
seed: int,
) None#

Initialise per-module RNG using TP/PP ranks from the rank’s grid.

Mirrors _set_megatron_mimo_random_seeds in setup_megatron_mimo but takes a raw seed int instead of reading cfg.rng.seed, so callers outside the training loop (conversion CLI) can use it.

bridge.models.megatron_mimo.build_model._bridge_parallel_state_globals(local_pg_collection) None#

Set parallel_state globals from the rank-local ProcessGroupCollection.