core.transformer.fsdp_dtensor_checkpoint#
Module Contents#
Functions#
Get the expert layer offset for the current model. |
|
Extract expert index from various expert key formats. |
|
Rewrite expert keys in state dict. |
|
Get the module parameter corresponding to the key. |
|
Handle SWiGLU in model and optimizer state dicts. |
|
Handle the case where FP8 extra state is present in the model state dict. |
|
Recursively flattens a nested state dict into a single-level dict with keys |
|
Print the differences between two state dicts: metadata state dict and load state dict. This function compares the keys and shapes of the tensors in both dicts. |
|
Validate the loaded state dict against the expected structure and types. |
|
Get the global unique parameter name for a given model and parameter. |
Data#
API#
- core.transformer.fsdp_dtensor_checkpoint.logger#
‘getLogger(…)’
- core.transformer.fsdp_dtensor_checkpoint.get_ep_layer_offset(num_experts: int | None = None) int#
Get the expert layer offset for the current model.
- Parameters:
num_experts – Total number of experts in the model. If None, returns 0.
- Returns:
The expert layer offset for the current EP rank.
- core.transformer.fsdp_dtensor_checkpoint.get_expert_index_from_key(key)#
Extract expert index from various expert key formats.
Supported formats:
GroupedMLP: ‘mlp.experts.linear_fc1.weight0’, ‘mlp.experts.linear_fc2.weight0’
SequentialMLP: ‘mlp.experts.local_experts.0.linear_fc1.weight’, ‘mlp.experts.local_experts.0.linear_fc2.weight’
- Returns:
Expert index if found, None otherwise.
- Return type:
int
- core.transformer.fsdp_dtensor_checkpoint.handle_experts_in_state_dict(
- state_dict,
- num_experts: int | None = None,
Rewrite expert keys in state dict.
- Parameters:
state_dict – The state dictionary to process.
num_experts – Total number of experts in the model. If None, no expert processing occurs.
- Returns:
The processed state dictionary with rewritten expert keys.
- core.transformer.fsdp_dtensor_checkpoint.expert_param_local_key(
- key: str,
- num_experts: int | None = None,
Get the module parameter corresponding to the key.
- Parameters:
key – The parameter key to process.
num_experts – Total number of experts in the model. If None, no expert processing occurs.
- Returns:
The local parameter key with adjusted expert indices.
- core.transformer.fsdp_dtensor_checkpoint.handle_swiglu_in_state_dict(
- model,
- model_state_dict,
- optimizer_state_dict,
Handle SWiGLU in model and optimizer state dicts.
- core.transformer.fsdp_dtensor_checkpoint.handle_fp8_extra_state_case(model_state_dict)#
Handle the case where FP8 extra state is present in the model state dict.
- core.transformer.fsdp_dtensor_checkpoint.flatten_state_dict(obj, parent_key='', sep='.')#
Recursively flattens a nested state dict into a single-level dict with keys
- core.transformer.fsdp_dtensor_checkpoint.print_diff_in_state_dicts(
- state_dict_metadata,
- load_state_dict,
- limit=100,
Print the differences between two state dicts: metadata state dict and load state dict. This function compares the keys and shapes of the tensors in both dicts.
- core.transformer.fsdp_dtensor_checkpoint.validate_loaded_state_dict(state_dict, checkpoint_path)#
Validate the loaded state dict against the expected structure and types.
- core.transformer.fsdp_dtensor_checkpoint.get_global_unique_param_name(model_chunks, param)#
Get the global unique parameter name for a given model and parameter.
- Parameters:
model_chunks – List of model chunks to search for the parameter.
param – The parameter to find the name for.
- Returns:
The global unique parameter name.