nemo_rl.models.megatron.converters.common
#
Module Contents#
Classes#
Functions#
Assumes layer number is preceeded by ‘layers.’. |
|
Assumes experts have ‘experts.’ in their name. Expert num succeeds ‘.weight’. |
|
Assumes layer number is preceeded by ‘layers.’. |
|
Assumes experts have ‘experts.’ in their name. Expert num succeeds ‘.weight’. |
|
Split interleave-concatenated qkv to q, k, v. |
|
Split interleave-concatenated qkv bias to separate q, k, v bias. |
|
Data#
API#
- nemo_rl.models.megatron.converters.common._GROUP_TO_RANKS_CACHE#
None
- nemo_rl.models.megatron.converters.common.get_local_layer_num(s)#
Assumes layer number is preceeded by ‘layers.’.
- nemo_rl.models.megatron.converters.common.get_local_expert_num(s)#
Assumes experts have ‘experts.’ in their name. Expert num succeeds ‘.weight’.
- nemo_rl.models.megatron.converters.common.get_global_layer_num(s, cfg) int #
Assumes layer number is preceeded by ‘layers.’.
Assumes pipeline model parallel size is set. In the state dict, the layer number is the local layer number (PP local). This function converts the local layer number to the global layer number.
- nemo_rl.models.megatron.converters.common.get_global_expert_num(s, cfg)#
Assumes experts have ‘experts.’ in their name. Expert num succeeds ‘.weight’.
Assumes expert model parallel size is set. In the state dict, the expert number is the local expert number (expert local). This function converts the local expert number to the global expert number.
- nemo_rl.models.megatron.converters.common.get_global_key_from_local_key(local_key, model_cfg)#
- nemo_rl.models.megatron.converters.common.split_fc1_tp(
- ctx: nemo.lightning.io.state.TransformCTX,
- linear_fc1: torch.Tensor,
- nemo_rl.models.megatron.converters.common.split_fc1_etp(
- ctx: nemo.lightning.io.state.TransformCTX,
- linear_fc1: torch.Tensor,
- nemo_rl.models.megatron.converters.common.split_qkv_gpu(
- ctx: nemo.lightning.io.state.TransformCTX,
- linear_qkv: torch.Tensor,
Split interleave-concatenated qkv to q, k, v.
Example: export layer linear_qkv to HF {q|k|v}_proj
- nemo_rl.models.megatron.converters.common.split_qkv_bias_gpu(
- ctx: nemo.lightning.io.state.TransformCTX,
- qkv_bias: torch.Tensor,
Split interleave-concatenated qkv bias to separate q, k, v bias.
Example: export layer linear_qkv bias to HF {q|k|v}_proj bias
- nemo_rl.models.megatron.converters.common.update_transforms_for_nemorl(export_transforms)#
- class nemo_rl.models.megatron.converters.common.MegatronToHFConverter(hf_model_name, megatron_model)#
Initialization
- _get_empty_state_dict(source_keys=None)#
- _group(
- state_dict,
- key,
- item,
- main_state_dict_keys,
- main_items,
- exception_state_dict_keys_list,
- exception_items,
- _get_groups(state_dict)#
This function is used to group mappings and transforms together.
Goes through the mappings and transforms once to collect mapping and transform groups [(mapping, state_dict_keys)], [(transforms, state_dict_keys)] that can be converted together.
This is necessary because:
If the mapping or transform expression has 2 wildcard expressions, _match_keys assumes the matches for each wildcard are the same size. For example, if the mapping is “layers..mlp.experts..linear_fc1.weight”, where the first wildcard matches the layer number and the second wildcard matches the expert number, it assumes the number of experts is the same for each layer. This will fail in the case we’re doing batched streaming refit and the current state dict is missing experts from some layers. To handle this, we separate out the partial keys (e.g. the ones corresponding to less experts) in a separate group and run them through the mapping and transforms separately.
NOTE: this function currently only handles expressions with up to 2 wildcard expressions and will fail if the mapping or transform expression has more than 2 wildcard expressions.
An expression matches 0 keys in the current state dict. This can happen during batched streaming refit if the current state dict doesn’t have any keys that match the expression. To handle this, we skip these mapping/transforms.
- convert(state_dict, megatron_config)#