core.transformer.moe.router_replay#
Module Contents#
Classes#
A Enum to define the actions for router replay. |
|
A class to manage the recording and replaying of MoE routing decisions. It holds all router instances and provides static methods to globally control recording and replaying. |
API#
- class core.transformer.moe.router_replay.RouterReplayAction(*args, **kwds)#
Bases:
enum.EnumA Enum to define the actions for router replay.
Initialization
- RECORD#
‘record’
- REPLAY_FORWARD#
‘replay_forward’
- REPLAY_BACKWARD#
‘replay_backward’
- class core.transformer.moe.router_replay.RouterReplay#
A class to manage the recording and replaying of MoE routing decisions. It holds all router instances and provides static methods to globally control recording and replaying.
Initialization
Initializes a RouterReplay instance for a specific layer.
- global_router_replay_instances: List[core.transformer.moe.router_replay.RouterReplay]#
[]
- static set_replay_data(all_layers_topk_indices: List[torch.Tensor])#
Distributes the topk indices for all layers to their respective RouterReplay instances.
- Parameters:
all_layers_topk_indices – A list of tensors, where each tensor contains the topk indices for a specific layer. The order must match the instantiation order of the routers.
- static get_recorded_data() List[torch.Tensor]#
Collects the recorded topk indices from all RouterReplay instances.
- Returns:
A list of tensors, each containing the recorded topk indices for a layer.
- static clear_global_indices()#
Clears the recorded and target topk indices in all instances.
- static set_global_router_replay_action(
- router_replay_action: core.transformer.moe.router_replay.RouterReplayAction,
Sets the router replay action for all router instances.
- static clear_global_router_replay_action()#
Clears the router replay action for all router instances.
- static clear_global_router_replay_instances()#
Clear the global list of router replay instances to prevent memory leaks.
- set_target_indices(topk_indices: torch.Tensor)#
Sets the target topk indices for replay.
- get_recorded_indices() Optional[torch.Tensor]#
Returns the recorded topk indices.
- record_indices(topk_indices: torch.Tensor)#
Records the topk indices.
- clear_indices()#
Clears the recorded and target topk indices.
- set_router_replay_action(
- router_replay_action: core.transformer.moe.router_replay.RouterReplayAction,
Sets the router replay action for this layer.
- clear_router_replay_action()#
Clears the router replay action for this layer.
- get_replay_topk(
- scores: torch.Tensor,
- topk: int,
- num_groups: Optional[int] = None,
- group_topk: Optional[int] = None,
- default_compute_topk: Callable[[torch.Tensor, int, Optional[int], Optional[int]], Tuple[torch.Tensor, torch.Tensor]] = None,
A wrapper for top-k computation that handles different replay actions.
- Parameters:
scores (torch.Tensor) – The scores to compute top-k on.
topk (int) – The number of top elements to select.
num_groups (Optional[int]) – Number of expert groups for group-limited routing.
group_topk (Optional[int]) – Number of groups to select for each token.
default_compute_topk (Callable) – The default top-k computation function, which should return a tuple of (values, indices).
- Returns:
A tuple containing the top-k values and indices.
- Return type:
Tuple[torch.Tensor, torch.Tensor]