nemo_rl.models.megatron.draft.eagle#

Module Contents#

Classes#

API#

class nemo_rl.models.megatron.draft.eagle.EagleModel(config: megatron.core.transformer.TransformerConfig)#

Bases: megatron.core.transformer.MegatronModule

Initialization

sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int], ...] = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Override to fix a bug in modelopt < 0.42.0.

In modelopt < 0.42.0, EagleTransformerBlock.sharded_state_dict omits tp_group when calling sharded_state_dict_default for non-layer children (e.g. final_layernorm). This causes make_sharded_tensors_for_checkpoint to receive tp_group=None while dp_cp_group is set, so the tp_group is None and dp_cp_group is None guard never fires, and get_pg_rank(None)=0 is used for all TP ranks. With TP>1 and DP>1, two ranks end up with replica_id=(0,0,0), triggering CheckpointingException.

forward(
hidden_states: torch.Tensor,
input_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
bootstrap_hidden_states: bool = True,
) torch.Tensor#