nemo_rl.models.megatron.refit_utils#

Module Contents#

Functions#

get_tp_dim

gather_params

get_param_info

get_local_key_to_global_keys

Get the local key to global keys mapping.

API#

nemo_rl.models.megatron.refit_utils.get_tp_dim(model, param_name, named_modules_dict)#
nemo_rl.models.megatron.refit_utils.gather_params(
model,
keys,
key_to_global_keys: Dict[str, List[str]],
)#
nemo_rl.models.megatron.refit_utils.get_param_info(model, dtype)#
nemo_rl.models.megatron.refit_utils.get_local_key_to_global_keys(
model,
state_dict_info: List[Tuple[Any, int]],
)#

Get the local key to global keys mapping.