nemo_rl.models.megatron.draft.utils#

Module Contents#

Classes#

Functions#

Data#

API#

nemo_rl.models.megatron.draft.utils.StateDict#

None

nemo_rl.models.megatron.draft.utils.CheckpointLoader#

None

nemo_rl.models.megatron.draft.utils._CHECKPOINT_CANDIDATE_NAMES#

(‘model.safetensors’, ‘model.safetensors.index.json’, ‘pytorch_model.bin’, ‘pytorch_model.bin.index….

nemo_rl.models.megatron.draft.utils._HF_SNAPSHOT_ALLOW_PATTERNS#

[‘model.safetensors’, ‘model-*.safetensors’, ‘model.safetensors.index.json’, ‘pytorch_model.bin’, ‘p…

nemo_rl.models.megatron.draft.utils._HF_SNAPSHOT_IGNORE_PATTERNS#

[’.pt’, ‘.pth’, ‘*.ckpt’]

nemo_rl.models.megatron.draft.utils._MODEL_LAYER_QKV_KEY_PATTERN#

‘compile(…)’

nemo_rl.models.megatron.draft.utils._CHECKPOINT_LAYER_KEY_PATTERN#

‘compile(…)’

class nemo_rl.models.megatron.draft.utils._EagleLayerLayout#
layer_index: int#

None

model_prefix: str#

None

checkpoint_prefix: str#

None

hidden_norm_key: str | None#

None

input_layernorm_key: str | None#

None

post_attention_layernorm_key: str | None#

None

property qkv_weight_key: str#
property proj_weight_key: str#
property fc1_weight_key: str#
property fc2_weight_key: str#
nemo_rl.models.megatron.draft.utils._resolve_optional_key(
model_keys: set[str],
*candidates: str | None,
) str | None#
class nemo_rl.models.megatron.draft.utils._EagleModelLayout#
layers: tuple[nemo_rl.models.megatron.draft.utils._EagleLayerLayout, ...]#

None

final_norm_key: str | None#

None

lm_head_key: str | None#

None

classmethod detect(
model_state: Mapping[str, torch.Tensor],
) nemo_rl.models.megatron.draft.utils._EagleModelLayout#
property layer_by_index: dict[int, nemo_rl.models.megatron.draft.utils._EagleLayerLayout]#
nemo_rl.models.megatron.draft.utils._combine_or_shard_weight_parts(
*,
parameter_name: str,
fused_weight: torch.Tensor | None,
component_weights: tuple[torch.Tensor | None, ...],
target: torch.Tensor | None,
tp_rank: int,
incomplete_error: str,
) torch.Tensor | None#
class nemo_rl.models.megatron.draft.utils._PendingLayerWeights#
qkv_weight: torch.Tensor | None#

None

q_weight: torch.Tensor | None#

None

k_weight: torch.Tensor | None#

None

v_weight: torch.Tensor | None#

None

fc1_weight: torch.Tensor | None#

None

gate_weight: torch.Tensor | None#

None

up_weight: torch.Tensor | None#

None

apply_to(
mapped_state: nemo_rl.models.megatron.draft.utils.StateDict,
layer: nemo_rl.models.megatron.draft.utils._EagleLayerLayout,
model_state: Mapping[str, torch.Tensor],
tp_rank: int,
) None#
nemo_rl.models.megatron.draft.utils._get_num_aux_hidden_states(
config: megatron.core.transformer.TransformerConfig,
) int#
nemo_rl.models.megatron.draft.utils._all_gather_tp_shards(
local_weight: torch.Tensor,
) list[torch.Tensor]#
nemo_rl.models.megatron.draft.utils._gather_tp_qkv_weight(
local_fused_weight: torch.Tensor,
q_dim: int,
kv_dim: int,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#
nemo_rl.models.megatron.draft.utils._gather_tp_gate_up_weight(
local_fused_weight: torch.Tensor,
ffn_hidden_size: int,
) tuple[torch.Tensor, torch.Tensor]#
nemo_rl.models.megatron.draft.utils._gather_tp_weight_if_needed(
local_weight: torch.Tensor,
expected_shape_or_tp_group: tuple[int, ...] | torch.distributed.ProcessGroup | None,
split_axis: int | None = None,
) torch.Tensor#
nemo_rl.models.megatron.draft.utils._extract_tensor_state_dict(
checkpoint_obj: object,
checkpoint_path: pathlib.Path,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils._load_safetensors_file(
checkpoint_path: pathlib.Path,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils._load_torch_file(
checkpoint_path: pathlib.Path,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils._merge_checkpoint_shards(
checkpoint_dir: pathlib.Path,
shard_names: list[str],
shard_loader: nemo_rl.models.megatron.draft.utils.CheckpointLoader,
source_name: str,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils._load_index_checkpoint(
index_path: pathlib.Path,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils._load_checkpoint_file(
checkpoint_path: pathlib.Path,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils._load_checkpoint_from_directory(
checkpoint_dir: pathlib.Path,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils._load_checkpoint_state(
checkpoint_source: str,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils._normalize_hf_key(raw_hf_key: str) str#
nemo_rl.models.megatron.draft.utils._parse_layer_checkpoint_key(hf_key: str) tuple[int, str] | None#
nemo_rl.models.megatron.draft.utils._get_tp_rank() int#
nemo_rl.models.megatron.draft.utils._build_split_axis_by_parameter(
layout: nemo_rl.models.megatron.draft.utils._EagleModelLayout,
) dict[str, int]#
nemo_rl.models.megatron.draft.utils._shard_to_local_tp(
parameter_name: str,
tensor: torch.Tensor,
model_state: Mapping[str, torch.Tensor],
split_axis_by_parameter: Mapping[str, int],
tp_rank: int,
) torch.Tensor#
nemo_rl.models.megatron.draft.utils._assign_optional_layer_weight(
*,
model_key: str | None,
hf_weight: torch.Tensor,
mapped_state: nemo_rl.models.megatron.draft.utils.StateDict,
) bool#
nemo_rl.models.megatron.draft.utils._map_layer_hf_weight(
layer_key: str,
hf_weight: torch.Tensor,
layer: nemo_rl.models.megatron.draft.utils._EagleLayerLayout,
mapped_state: nemo_rl.models.megatron.draft.utils.StateDict,
pending_weights: nemo_rl.models.megatron.draft.utils._PendingLayerWeights,
) None#
nemo_rl.models.megatron.draft.utils._map_hf_state_to_eagle_state(
hf_state_dict: Mapping[str, torch.Tensor],
model_state: Mapping[str, torch.Tensor],
layout: nemo_rl.models.megatron.draft.utils._EagleModelLayout,
checkpoint_source: str,
) nemo_rl.models.megatron.draft.utils.StateDict#
nemo_rl.models.megatron.draft.utils.load_hf_weights_to_eagle(
model: torch.nn.Module,
model_name: str,
) tuple[list[str], list[str]]#

Load HF Eagle weights from a local path or Hub repo into a draft model.

nemo_rl.models.megatron.draft.utils._require_state_tensor(
source_state: Mapping[str, torch.Tensor],
parameter_name: str,
) torch.Tensor#
nemo_rl.models.megatron.draft.utils.find_draft_owner_chunk(
model: list[megatron.core.transformer.MegatronModule],
) megatron.core.transformer.MegatronModule | None#

Return the post-process chunk that should own the nested draft model.

nemo_rl.models.megatron.draft.utils.get_attached_draft_model(
model: list[megatron.core.transformer.MegatronModule],
) megatron.core.transformer.MegatronModule | None#

Find an already attached draft model after Megatron wrapping has been applied.

nemo_rl.models.megatron.draft.utils._export_layer_weights_to_hf(
*,
source_state: Mapping[str, torch.Tensor],
layer: nemo_rl.models.megatron.draft.utils._EagleLayerLayout,
q_dim: int,
kv_dim: int,
hidden_size: int,
ffn_hidden_size: int,
) list[tuple[str, torch.Tensor]]#
nemo_rl.models.megatron.draft.utils.export_eagle_weights_to_hf(
model: torch.nn.Module,
) list[tuple[str, torch.Tensor]]#

Export the standalone Eagle draft model to HF naming.

nemo_rl.models.megatron.draft.utils.get_policy_lm_head_weight(
policy_model_chunk: megatron.core.transformer.MegatronModule,
) torch.Tensor#

Return the local policy LM-head shard for draft initialization.

nemo_rl.models.megatron.draft.utils._get_draft_output_layer(
draft_model: megatron.core.transformer.MegatronModule,
)#
nemo_rl.models.megatron.draft.utils._get_draft_to_target_token_mapping(
draft_model: megatron.core.transformer.MegatronModule,
device: torch.device,
) torch.Tensor#
nemo_rl.models.megatron.draft.utils.copy_policy_lm_head_to_draft(
*,
draft_model: megatron.core.transformer.MegatronModule,
policy_model_chunk: megatron.core.transformer.MegatronModule,
) None#

Initialize the draft LM head from the policy LM head shard.

nemo_rl.models.megatron.draft.utils.build_draft_model(
model_provider,
draft_config: dict[str, Any],
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
policy_model_chunk: megatron.core.transformer.MegatronModule,
) megatron.core.transformer.MegatronModule | None#

Build an Eagle draft model before parent mixed-precision/DDP wrapping.