Frequently Asked Questions (FAQ)
FP8 checkpoint compatibility
Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a ._extra_state key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the ._extra_state key has also shifted.
Here, we take the MultiheadAttention module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as core_attention._extra_state as shown below.
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
... mha = MultiheadAttention(
... hidden_size=1024,
... num_attention_heads=16,
... bias=True,
... params_dtype=torch.bfloat16,
... input_layernorm=False,
... fuse_qkv_params=True,
... attention_type="self",
... qkv_weight_interleaved=True,
... ).to(dtype=torch.bfloat16, device="cuda")
...
>>> state_dict = mha.state_dict()
>>> print(state_dict.keys())
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])
Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.
Version: <= 1.5
|
Version: 1.6, 1.7
|
Version: >=1.8, <= 1.11
|
Version: >=1.12
|