core.transformer.fsdp_dtensor_checkpoint#

Module Contents#

Functions#

get_ep_layer_offset

Get the expert layer offset for the current model.

get_expert_index_from_key

Extract expert index from various expert key formats.

handle_experts_in_state_dict

Rewrite expert keys in state dict.

expert_param_local_key

Get the module parameter corresponding to the key.

handle_swiglu_in_state_dict

Handle SWiGLU in model and optimizer state dicts.

handle_fp8_extra_state_case

Handle the case where FP8 extra state is present in the model state dict.

flatten_state_dict

Recursively flattens a nested state dict into a single-level dict with keys

print_diff_in_state_dicts

Print the differences between two state dicts: metadata state dict and load state dict. This function compares the keys and shapes of the tensors in both dicts.

validate_loaded_state_dict

Validate the loaded state dict against the expected structure and types.

get_global_unique_param_name

Get the global unique parameter name for a given model and parameter.

Data#

API#

core.transformer.fsdp_dtensor_checkpoint.logger#

‘getLogger(…)’

core.transformer.fsdp_dtensor_checkpoint.get_ep_layer_offset(num_experts: int | None = None) int#

Get the expert layer offset for the current model.

Parameters:

num_experts – Total number of experts in the model. If None, returns 0.

Returns:

The expert layer offset for the current EP rank.

core.transformer.fsdp_dtensor_checkpoint.get_expert_index_from_key(key)#

Extract expert index from various expert key formats.

Supported formats:

  • GroupedMLP: ‘mlp.experts.linear_fc1.weight0’, ‘mlp.experts.linear_fc2.weight0’

  • SequentialMLP: ‘mlp.experts.local_experts.0.linear_fc1.weight’, ‘mlp.experts.local_experts.0.linear_fc2.weight’

Returns:

Expert index if found, None otherwise.

Return type:

int

core.transformer.fsdp_dtensor_checkpoint.handle_experts_in_state_dict(
state_dict,
num_experts: int | None = None,
)#

Rewrite expert keys in state dict.

Parameters:
  • state_dict – The state dictionary to process.

  • num_experts – Total number of experts in the model. If None, no expert processing occurs.

Returns:

The processed state dictionary with rewritten expert keys.

core.transformer.fsdp_dtensor_checkpoint.expert_param_local_key(
key: str,
num_experts: int | None = None,
) str#

Get the module parameter corresponding to the key.

Parameters:
  • key – The parameter key to process.

  • num_experts – Total number of experts in the model. If None, no expert processing occurs.

Returns:

The local parameter key with adjusted expert indices.

core.transformer.fsdp_dtensor_checkpoint.handle_swiglu_in_state_dict(
model,
model_state_dict,
optimizer_state_dict,
)#

Handle SWiGLU in model and optimizer state dicts.

core.transformer.fsdp_dtensor_checkpoint.handle_fp8_extra_state_case(model_state_dict)#

Handle the case where FP8 extra state is present in the model state dict.

core.transformer.fsdp_dtensor_checkpoint.flatten_state_dict(obj, parent_key='', sep='.')#

Recursively flattens a nested state dict into a single-level dict with keys

core.transformer.fsdp_dtensor_checkpoint.print_diff_in_state_dicts(
state_dict_metadata,
load_state_dict,
limit=100,
)#

Print the differences between two state dicts: metadata state dict and load state dict. This function compares the keys and shapes of the tensors in both dicts.

core.transformer.fsdp_dtensor_checkpoint.validate_loaded_state_dict(state_dict, checkpoint_path)#

Validate the loaded state dict against the expected structure and types.

core.transformer.fsdp_dtensor_checkpoint.get_global_unique_param_name(model_chunks, param)#

Get the global unique parameter name for a given model and parameter.

Parameters:
  • model_chunks – List of model chunks to search for the parameter.

  • param – The parameter to find the name for.

Returns:

The global unique parameter name.