nemo_rl.models.megatron.draft.eagle#
Module Contents#
Classes#
API#
- class nemo_rl.models.megatron.draft.eagle.EagleModel(config: megatron.core.transformer.TransformerConfig)#
Bases:
megatron.core.transformer.MegatronModuleInitialization
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: Tuple[Tuple[int, int, int], ...] = (),
- metadata: Optional[dict] = None,
Override to fix a bug in modelopt < 0.42.0.
In modelopt < 0.42.0, EagleTransformerBlock.sharded_state_dict omits tp_group when calling sharded_state_dict_default for non-layer children (e.g. final_layernorm). This causes make_sharded_tensors_for_checkpoint to receive tp_group=None while dp_cp_group is set, so the
tp_group is None and dp_cp_group is Noneguard never fires, and get_pg_rank(None)=0 is used for all TP ranks. With TP>1 and DP>1, two ranks end up with replica_id=(0,0,0), triggering CheckpointingException.
- forward(
- hidden_states: torch.Tensor,
- input_embeds: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- bootstrap_hidden_states: bool = True,