nemo_rl.models.megatron.converters.common#

Module Contents#

Classes#

Functions#

get_local_layer_num

Assumes layer number is preceeded by ‘layers.’.

get_local_expert_num

Assumes experts have ‘experts.’ in their name. Expert num succeeds ‘.weight’.

get_global_layer_num

Assumes layer number is preceeded by ‘layers.’.

get_global_expert_num

Assumes experts have ‘experts.’ in their name. Expert num succeeds ‘.weight’.

get_global_key_from_local_key

split_fc1_tp

split_fc1_etp

split_qkv_gpu

Split interleave-concatenated qkv to q, k, v.

split_qkv_bias_gpu

Split interleave-concatenated qkv bias to separate q, k, v bias.

update_transforms_for_nemorl

Data#

API#

nemo_rl.models.megatron.converters.common._GROUP_TO_RANKS_CACHE#

None

nemo_rl.models.megatron.converters.common.get_local_layer_num(s)#

Assumes layer number is preceeded by ‘layers.’.

nemo_rl.models.megatron.converters.common.get_local_expert_num(s)#

Assumes experts have ‘experts.’ in their name. Expert num succeeds ‘.weight’.

nemo_rl.models.megatron.converters.common.get_global_layer_num(s, cfg) int#

Assumes layer number is preceeded by ‘layers.’.

Assumes pipeline model parallel size is set. In the state dict, the layer number is the local layer number (PP local). This function converts the local layer number to the global layer number.

nemo_rl.models.megatron.converters.common.get_global_expert_num(s, cfg)#

Assumes experts have ‘experts.’ in their name. Expert num succeeds ‘.weight’.

Assumes expert model parallel size is set. In the state dict, the expert number is the local expert number (expert local). This function converts the local expert number to the global expert number.

nemo_rl.models.megatron.converters.common.get_global_key_from_local_key(local_key, model_cfg)#
nemo_rl.models.megatron.converters.common.split_fc1_tp(
ctx: nemo.lightning.io.state.TransformCTX,
linear_fc1: torch.Tensor,
)#
nemo_rl.models.megatron.converters.common.split_fc1_etp(
ctx: nemo.lightning.io.state.TransformCTX,
linear_fc1: torch.Tensor,
)#
nemo_rl.models.megatron.converters.common.split_qkv_gpu(
ctx: nemo.lightning.io.state.TransformCTX,
linear_qkv: torch.Tensor,
)#

Split interleave-concatenated qkv to q, k, v.

Example: export layer linear_qkv to HF {q|k|v}_proj

nemo_rl.models.megatron.converters.common.split_qkv_bias_gpu(
ctx: nemo.lightning.io.state.TransformCTX,
qkv_bias: torch.Tensor,
)#

Split interleave-concatenated qkv bias to separate q, k, v bias.

Example: export layer linear_qkv bias to HF {q|k|v}_proj bias

nemo_rl.models.megatron.converters.common.update_transforms_for_nemorl(export_transforms)#
class nemo_rl.models.megatron.converters.common.MegatronToHFConverter(hf_model_name, megatron_model)#

Initialization

_get_empty_state_dict(source_keys=None)#
_group(
state_dict,
key,
item,
main_state_dict_keys,
main_items,
exception_state_dict_keys_list,
exception_items,
)#
_get_groups(state_dict)#

This function is used to group mappings and transforms together.

Goes through the mappings and transforms once to collect mapping and transform groups [(mapping, state_dict_keys)], [(transforms, state_dict_keys)] that can be converted together.

This is necessary because:

  1. If the mapping or transform expression has 2 wildcard expressions, _match_keys assumes the matches for each wildcard are the same size. For example, if the mapping is “layers..mlp.experts..linear_fc1.weight”, where the first wildcard matches the layer number and the second wildcard matches the expert number, it assumes the number of experts is the same for each layer. This will fail in the case we’re doing batched streaming refit and the current state dict is missing experts from some layers. To handle this, we separate out the partial keys (e.g. the ones corresponding to less experts) in a separate group and run them through the mapping and transforms separately.

    NOTE: this function currently only handles expressions with up to 2 wildcard expressions and will fail if the mapping or transform expression has more than 2 wildcard expressions.

  2. An expression matches 0 keys in the current state dict. This can happen during batched streaming refit if the current state dict doesn’t have any keys that match the expression. To handle this, we skip these mapping/transforms.

convert(state_dict, megatron_config)#