Frequently Asked Questions (FAQ)

FP8 checkpoint compatibility

Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a ._extra_state key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the ._extra_state key has also shifted.

Here, we take the MultiheadAttention module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as core_attention._extra_state as shown below.

>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
...     mha = MultiheadAttention(
...         hidden_size=1024,
...         num_attention_heads=16,
...         bias=True,
...         params_dtype=torch.bfloat16,
...         input_layernorm=False,
...         fuse_qkv_params=True,
...         attention_type="self",
...         qkv_weight_interleaved=True,
...     ).to(dtype=torch.bfloat16, device="cuda")
>>> state_dict = mha.state_dict()
>>> print(state_dict.keys())
odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])

Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.

Version: <= 1.5

  • Saves no FP8 metadata since FP8 attention is not supported

  • Loading behavior for checkpoints created by the following versions:

    <= 1.5:

    Loads no FP8 metadata

    > 1.5:

    Error: unexpected key

Version: 1.6, 1.7

  • Saves FP8 metadata to core_attention.fused_attention._extra_state

  • Loading behavior for checkpoints created by the following versions:

    <= 1.5:

    Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes

    1.6, 1.7:

    Loads FP8 metadata from checkpoint

    >= 1.8:

    Error: unexpected key

Version: >=1.8, <= 1.11

  • Saves FP8 metadata to core_attention._extra_state

  • Loading behavior for checkpoints created by the following versions:

    <= 1.5:

    Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes

    1.6, 1.7:

    This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this MultiheadAttention example, by

    >>> state_dict["core_attention._extra_state"] = \
    >>> del state_dict["core_attention.fused_attention._extra_state"]
    >= 1.8:

    Loads FP8 metadata from checkpoint

Version: >=1.12

  • Saves FP8 metadata to core_attention._extra_state

  • Loading behavior for checkpoints created by the following versions:

    <= 1.5:

    Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes

    >= 1.6:

    Loads FP8 metadata from checkpoint