bridge.training.eval_context_parallel_rebinding#

Runtime CP-group rebinding for eval-time context parallelism.

This module provides helpers to switch a model’s cached process-group references from the training CP layout to a different eval CP layout, run evaluation, then restore the training layout. It works only with the decentralized PG path (use_decentralized_pg=True) and requires no changes to Megatron-Core.

Typical usage::

with eval_cp_context(model, eval_pgs, train_pgs):
    evaluate_and_print_results(..., pg_collection=eval_pgs)

Module Contents#

Functions#

install_pg_collection

Rebind all CP-affected process groups on every module of model.

eval_cp_context

Context manager: install eval_pgs for the duration of the block.

_iter_all_modules

Yield every nn.Module across all virtual-PP chunks.

Data#

API#

bridge.training.eval_context_parallel_rebinding._GROUP_ATTRS: dict[str, str]#

None

bridge.training.eval_context_parallel_rebinding.install_pg_collection(
model: Union[list, torch.nn.Module],
target: megatron.core.process_groups_config.ProcessGroupCollection,
) None#

Rebind all CP-affected process groups on every module of model.

Walks every sub-module across all virtual-PP chunks and rebinds:

  • pg_collection (used by TransformerLayer, DotProductAttention, …)

  • Named CP-bearing group attributes (cp_group, tp_cp_group, …)

  • TEDotProductAttention internal CP comm state via set_context_parallel_group

TP/PP/EP groups are never touched because those do not change between train and eval.

Parameters:
  • model – Single model chunk or a list of virtual-PP chunks.

  • target – The ProcessGroupCollection to install.

bridge.training.eval_context_parallel_rebinding.eval_cp_context(
model: Union[list, torch.nn.Module],
eval_pgs: megatron.core.process_groups_config.ProcessGroupCollection,
train_pgs: megatron.core.process_groups_config.ProcessGroupCollection,
) Iterator[None]#

Context manager: install eval_pgs for the duration of the block.

On entry, rebinds all CP-affected module attributes to eval_pgs. On exit (including exceptions), restores train_pgs.

Parameters:
  • model – Single model chunk or list of virtual-PP chunks.

  • eval_pgs – ProcessGroupCollection for eval (different CP degree).

  • train_pgs – ProcessGroupCollection for training (restored on exit).

Example::

with eval_cp_context(model, eval_pgs, train_pgs):
    evaluate_and_print_results(..., pg_collection=eval_pgs)
bridge.training.eval_context_parallel_rebinding._iter_all_modules(
model: Union[list, torch.nn.Module],
) Iterator[torch.nn.Module]#

Yield every nn.Module across all virtual-PP chunks.