Transformer Engine
1.11.0
Version select:
  • Home

Getting Started

  • Installation
    • Prerequisites
    • Transformer Engine in NGC Containers
    • pip - from PyPI
    • pip - from GitHub
      • Additional Prerequisites
      • Installation (stable release)
      • Installation (development build)
      • Installation (from source)
  • Getting Started
    • Overview
    • Let’s build a Transformer layer!
    • Meet Transformer Engine
    • Fused TE Modules
    • Enabling FP8

Python API documentation

  • Common API
    • Format
    • DelayedScaling
  • Framework-specific API
    • pyTorch
      • Linear
        • forward
        • set_tensor_parallel_group
      • GroupedLinear
        • forward
        • set_tensor_parallel_group
      • LayerNorm
      • RMSNorm
      • LayerNormLinear
        • forward
        • set_tensor_parallel_group
      • LayerNormMLP
        • forward
        • set_tensor_parallel_group
      • DotProductAttention
        • forward
        • set_context_parallel_group
      • MultiheadAttention
        • forward
        • set_context_parallel_group
        • set_tensor_parallel_group
      • TransformerLayer
        • forward
        • set_context_parallel_group
        • set_tensor_parallel_group
      • InferenceParams
      • CudaRNGStatesTracker
        • add
        • fork
        • get_states
        • reset
        • set_states
      • fp8_autocast
      • fp8_model_init
      • checkpoint
      • onnx_export
      • make_graphed_callables
      • get_cpu_offload_context
      • moe_permute
      • moe_unpermute
    • Jax
      • Pre-defined Variable of Logical Axes
      • Modules
        • TransformerLayerType
        • MeshResource
        • fp8_autocast
        • update_collections
        • LayerNorm
        • DenseGeneral
        • LayerNormDenseGeneral
        • LayerNormMLP
        • RelativePositionBiases
        • DotProductAttention
        • MultiHeadAttention
        • TransformerLayer
        • extend_logical_axis_rules
    • paddle
      • Linear
        • forward
      • LayerNorm
      • LayerNormLinear
        • forward
      • LayerNormMLP
        • forward
      • FusedScaleMaskSoftmax
        • forward
      • DotProductAttention
        • forward
      • MultiHeadAttention
        • forward
      • TransformerLayer
        • forward
      • fp8_autocast
      • recompute

Examples and Tutorials

  • Using FP8 with Transformer Engine
    • Introduction to FP8
      • Structure
      • Mixed precision training - a quick introduction
      • Mixed precision training with FP8
    • Using FP8 with Transformer Engine
      • FP8 recipe
      • FP8 autocasting
      • Handling backward pass
      • Precision
  • Performance Optimizations
    • Multi-GPU training
    • Gradient accumulation fusion
    • FP8 weight caching
  • Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine
    • Dependencies for this tutorial
    • Table of contents
    • From “Transformer” to “Llama”
    • Hugging Face’s LlamaModel
      • Hugging Face’s LlamaDecoderLayer
        • Self_Attn Layer
        • MLP Layer
    • [Baseline] Running HF LlamaModel (Precision: BF16)
    • [Improvement 1] Replace HF’s LlamaDecoderLayer with TE’s TransformerLayer (Precision: BF16)
      • Transformer Engine’s TransformerLayer
      • TransformerLayer options explained
      • Mapping weights from HF’s LlamaDecoderLayer to TE’s TransformerLayer
    • [Improvement 2] Replace HF’s LlamaDecoderLayer with TE’s TransformerLayer (Precision: FP8)
      • How to run the model in FP8 precision
      • Llama 3 performance results
    • Conclusion

Advanced

  • C/C++ API
    • activation.h
      • enum class NVTE_Activation_Type
        • enumerator GELU
        • enumerator GEGLU
        • enumerator SILU
        • enumerator SWIGLU
        • enumerator RELU
        • enumerator REGLU
        • enumerator QGELU
        • enumerator QGEGLU
        • enumerator SRELU
        • enumerator SREGLU
      • void nvte_gelu
      • void nvte_silu
      • void nvte_relu
      • void nvte_qgelu
      • void nvte_srelu
      • void nvte_dgelu
      • void nvte_dsilu
      • void nvte_drelu
      • void nvte_dqgelu
      • void nvte_dsrelu
      • void nvte_geglu
      • void nvte_swiglu
      • void nvte_reglu
      • void nvte_qgeglu
      • void nvte_sreglu
      • void nvte_dgeglu
      • void nvte_dswiglu
      • void nvte_dreglu
      • void nvte_dqgeglu
      • void nvte_dsreglu
    • cast.h
      • void nvte_fp8_quantize
      • void nvte_fp8_dequantize
    • gemm.h
      • void nvte_cublas_gemm
      • void nvte_cublas_atomic_gemm
      • void nvte_multi_stream_cublas_gemm
      • namespace transformer_engine
        • constexpr int num_streams = 4
    • fused_attn.h
      • enum NVTE_QKV_Layout
        • enumerator NVTE_SB3HD
        • enumerator NVTE_SBH3D
        • enumerator NVTE_SBHD_SB2HD
        • enumerator NVTE_SBHD_SBH2D
        • enumerator NVTE_SBHD_SBHD_SBHD
        • enumerator NVTE_BS3HD
        • enumerator NVTE_BSH3D
        • enumerator NVTE_BSHD_BS2HD
        • enumerator NVTE_BSHD_BSH2D
        • enumerator NVTE_BSHD_BSHD_BSHD
        • enumerator NVTE_T3HD
        • enumerator NVTE_TH3D
        • enumerator NVTE_THD_T2HD
        • enumerator NVTE_THD_TH2D
        • enumerator NVTE_THD_THD_THD
      • enum NVTE_QKV_Layout_Group
        • enumerator NVTE_3HD
        • enumerator NVTE_H3D
        • enumerator NVTE_HD_2HD
        • enumerator NVTE_HD_H2D
        • enumerator NVTE_HD_HD_HD
      • enum NVTE_QKV_Format
        • enumerator NVTE_SBHD
        • enumerator NVTE_BSHD
        • enumerator NVTE_THD
      • enum NVTE_Bias_Type
        • enumerator NVTE_NO_BIAS
        • enumerator NVTE_PRE_SCALE_BIAS
        • enumerator NVTE_POST_SCALE_BIAS
        • enumerator NVTE_ALIBI
      • enum NVTE_Mask_Type
        • enumerator NVTE_NO_MASK
        • enumerator NVTE_PADDING_MASK
        • enumerator NVTE_CAUSAL_MASK
        • enumerator NVTE_PADDING_CAUSAL_MASK
        • enumerator NVTE_CAUSAL_BOTTOM_RIGHT_MASK
        • enumerator NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
      • enum NVTE_Fused_Attn_Backend
        • enumerator NVTE_No_Backend
        • enumerator NVTE_F16_max512_seqlen
        • enumerator NVTE_F16_arbitrary_seqlen
        • enumerator NVTE_FP8
      • NVTE_QKV_Layout_Group nvte_get_qkv_layout_group
      • NVTE_QKV_Format nvte_get_qkv_format
      • NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend
      • void nvte_fused_attn_fwd_qkvpacked
      • void nvte_fused_attn_bwd_qkvpacked
      • void nvte_fused_attn_fwd_kvpacked
      • void nvte_fused_attn_bwd_kvpacked
      • void nvte_fused_attn_fwd
      • void nvte_fused_attn_bwd
    • layer_norm.h
      • void nvte_layernorm_fwd
      • void nvte_layernorm1p_fwd
      • void nvte_layernorm_bwd
      • void nvte_layernorm1p_bwd
    • rmsnorm.h
      • void nvte_rmsnorm_fwd
      • void nvte_rmsnorm1p_fwd
      • void nvte_rmsnorm_bwd
      • void nvte_rmsnorm1p_bwd
    • softmax.h
      • void nvte_scaled_softmax_forward
      • void nvte_scaled_softmax_backward
      • void nvte_scaled_masked_softmax_forward
      • void nvte_scaled_masked_softmax_backward
      • void nvte_scaled_upper_triang_masked_softmax_forward
      • void nvte_scaled_upper_triang_masked_softmax_backward
      • void nvte_scaled_aligned_causal_masked_softmax_forward
      • void nvte_scaled_aligned_causal_masked_softmax_backward
    • transformer_engine.h
      • typedef void *NVTETensor
      • enum NVTEDType
        • enumerator kNVTEByte
        • enumerator kNVTEInt32
        • enumerator kNVTEInt64
        • enumerator kNVTEFloat32
        • enumerator kNVTEFloat16
        • enumerator kNVTEBFloat16
        • enumerator kNVTEFloat8E4M3
        • enumerator kNVTEFloat8E5M2
        • enumerator kNVTENumTypes
      • NVTETensor nvte_create_tensor
      • void nvte_destroy_tensor
      • NVTEDType nvte_tensor_type
      • NVTEShape nvte_tensor_shape
      • void *nvte_tensor_data
      • float *nvte_tensor_amax
      • float *nvte_tensor_scale
      • float *nvte_tensor_scale_inv
      • void nvte_tensor_pack_create
      • void nvte_tensor_pack_destroy
      • struct NVTEShape
        • const size_t *data
        • size_t ndim
      • struct NVTETensorPack
        • NVTETensor tensors[MAX_SIZE]
        • size_t size = 0
        • static const int MAX_SIZE = 10
      • namespace transformer_engine
        • enum class DType
        • struct TensorWrapper
    • transpose.h
      • void nvte_cast_transpose
      • void nvte_transpose
      • void nvte_cast_transpose_dbias
      • void nvte_fp8_transpose_dbias
      • void nvte_multi_cast_transpose
      • void nvte_cast_transpose_dbias_dgelu
      • void nvte_cast_transpose_dbias_dsilu
      • void nvte_cast_transpose_dbias_drelu
      • void nvte_cast_transpose_dbias_dqgelu
      • void nvte_cast_transpose_dbias_dsrelu
      • void nvte_dgeglu_cast_transpose
      • void nvte_dswiglu_cast_transpose
      • void nvte_dreglu_cast_transpose
      • void nvte_dqgeglu_cast_transpose
      • void nvte_dsreglu_cast_transpose
  • Attention Is All You Need!
    • 1. Attention Backends
      • 1.1 Flash vs. Non-Flash
      • 1.2 flash-attention
      • 1.3 cuDNN Attention
    • 2. Backend Selection
      • 2.1 Debug Information
      • 2.2 User Control
      • 2.3 Example Tests
    • 3. Backend Support
      • 3.1 QKV Layout
      • 3.2 Attention Mask
      • 3.3 Attention Bias
      • 3.4 FP8 Attention
Transformer Engine
  • »
  • Search


Privacy Policy | Manage My Privacy | Do Not Sell or Share My Data | Terms of Service | Accessibility | Corporate Policies | Product Security | Contact

© Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved..