core.transformer.moe.router_replay#

Module Contents#

Classes#

RouterReplayAction

A Enum to define the actions for router replay.

RouterReplay

A class to manage the recording and replaying of MoE routing decisions. It holds all router instances and provides static methods to globally control recording and replaying.

API#

class core.transformer.moe.router_replay.RouterReplayAction(*args, **kwds)#

Bases: enum.Enum

A Enum to define the actions for router replay.

Initialization

RECORD#

‘record’

REPLAY_FORWARD#

‘replay_forward’

REPLAY_BACKWARD#

‘replay_backward’

class core.transformer.moe.router_replay.RouterReplay#

A class to manage the recording and replaying of MoE routing decisions. It holds all router instances and provides static methods to globally control recording and replaying.

Initialization

Initializes a RouterReplay instance for a specific layer.

global_router_replay_instances: List[core.transformer.moe.router_replay.RouterReplay]#

[]

static set_replay_data(all_layers_topk_indices: List[torch.Tensor])#

Distributes the topk indices for all layers to their respective RouterReplay instances.

Parameters:

all_layers_topk_indices – A list of tensors, where each tensor contains the topk indices for a specific layer. The order must match the instantiation order of the routers.

static get_recorded_data() List[torch.Tensor]#

Collects the recorded topk indices from all RouterReplay instances.

Returns:

A list of tensors, each containing the recorded topk indices for a layer.

static clear_global_indices()#

Clears the recorded and target topk indices in all instances.

static set_global_router_replay_action(
router_replay_action: core.transformer.moe.router_replay.RouterReplayAction,
)#

Sets the router replay action for all router instances.

static clear_global_router_replay_action()#

Clears the router replay action for all router instances.

static clear_global_router_replay_instances()#

Clear the global list of router replay instances to prevent memory leaks.

set_target_indices(topk_indices: torch.Tensor)#

Sets the target topk indices for replay.

get_recorded_indices() Optional[torch.Tensor]#

Returns the recorded topk indices.

record_indices(topk_indices: torch.Tensor)#

Records the topk indices.

clear_indices()#

Clears the recorded and target topk indices.

set_router_replay_action(
router_replay_action: core.transformer.moe.router_replay.RouterReplayAction,
)#

Sets the router replay action for this layer.

clear_router_replay_action()#

Clears the router replay action for this layer.

get_replay_topk(
scores: torch.Tensor,
topk: int,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
default_compute_topk: Callable[[torch.Tensor, int, Optional[int], Optional[int]], Tuple[torch.Tensor, torch.Tensor]] = None,
) Tuple[torch.Tensor, torch.Tensor]#

A wrapper for top-k computation that handles different replay actions.

Parameters:
  • scores (torch.Tensor) – The scores to compute top-k on.

  • topk (int) – The number of top elements to select.

  • num_groups (Optional[int]) – Number of expert groups for group-limited routing.

  • group_topk (Optional[int]) – Number of groups to select for each token.

  • default_compute_topk (Callable) – The default top-k computation function, which should return a tuple of (values, indices).

Returns:

A tuple containing the top-k values and indices.

Return type:

Tuple[torch.Tensor, torch.Tensor]