bridge.training.eval_context_parallel_rebinding#
Runtime CP-group rebinding for eval-time context parallelism.
This module provides helpers to switch a modelβs cached process-group references from the training CP layout to a different eval CP layout, run evaluation, then restore the training layout. It works only with the decentralized PG path (use_decentralized_pg=True) and requires no changes to Megatron-Core.
Typical usage::
with eval_cp_context(model, eval_pgs, train_pgs):
evaluate_and_print_results(..., pg_collection=eval_pgs)
Module Contents#
Functions#
Rebind all CP-affected process groups on every module of model. |
|
Context manager: install eval_pgs for the duration of the block. |
|
Yield every nn.Module across all virtual-PP chunks. |
Data#
API#
- bridge.training.eval_context_parallel_rebinding._GROUP_ATTRS: dict[str, str]#
None
- bridge.training.eval_context_parallel_rebinding.install_pg_collection(
- model: Union[list, torch.nn.Module],
- target: megatron.core.process_groups_config.ProcessGroupCollection,
Rebind all CP-affected process groups on every module of model.
Walks every sub-module across all virtual-PP chunks and rebinds:
pg_collection(used by TransformerLayer, DotProductAttention, β¦)Named CP-bearing group attributes (cp_group, tp_cp_group, β¦)
TEDotProductAttention internal CP comm state via
set_context_parallel_group
TP/PP/EP groups are never touched because those do not change between train and eval.
- Parameters:
model β Single model chunk or a list of virtual-PP chunks.
target β The ProcessGroupCollection to install.
- bridge.training.eval_context_parallel_rebinding.eval_cp_context(
- model: Union[list, torch.nn.Module],
- eval_pgs: megatron.core.process_groups_config.ProcessGroupCollection,
- train_pgs: megatron.core.process_groups_config.ProcessGroupCollection,
Context manager: install eval_pgs for the duration of the block.
On entry, rebinds all CP-affected module attributes to eval_pgs. On exit (including exceptions), restores train_pgs.
- Parameters:
model β Single model chunk or list of virtual-PP chunks.
eval_pgs β ProcessGroupCollection for eval (different CP degree).
train_pgs β ProcessGroupCollection for training (restored on exit).
Example::
with eval_cp_context(model, eval_pgs, train_pgs): evaluate_and_print_results(..., pg_collection=eval_pgs)
- bridge.training.eval_context_parallel_rebinding._iter_all_modules(
- model: Union[list, torch.nn.Module],
Yield every nn.Module across all virtual-PP chunks.