Transformer Engine
0.7.0 -770e968
Version select:
Current release
Older releases
Home
Getting Started
Installation
Prerequisites
Transformer Engine in NGC Containers
pip - from GitHub
Additional Prerequisites
Installation (stable release)
Installation (development build)
Getting Started
Overview
Let’s build a Transformer layer!
Meet Transformer Engine
Fused TE Modules
Enabling FP8
Python API documentation
Common API
Format
DelayedScaling
Framework-specific API
pyTorch
Linear
forward
LayerNorm
LayerNormLinear
forward
LayerNormMLP
forward
DotProductAttention
forward
TransformerLayer
forward
fp8_autocast
checkpoint
Jax
MajorShardingType
ShardingType
TransformerLayerType
ShardingResource
LayerNorm
__call__
DenseGeneral
__call__
LayerNormDenseGeneral
__call__
LayerNormMLP
__call__
RelativePositionBiases
__call__
MultiHeadAttention
__call__
TransformerLayer
__call__
extend_logical_axis_rules
fp8_autocast
update_collections
update_fp8_metas
Examples and Tutorials
Using FP8 with Transformer Engine
Introduction to FP8
Structure
Mixed precision training - a quick introduction
Mixed precision training with FP8
Using FP8 with Transformer Engine
FP8 recipe
FP8 autocasting
Handling backward pass
Precision
Performance Optimizations
Multi-GPU training
Gradient accumulation fusion
FP8 weight caching
Advanced
C/C++ API
activation.h
void nvte_gelu
void nvte_geglu
void nvte_dgeglu
cast.h
void nvte_fp8_quantize
void nvte_fp8_dequantize
gemm.h
void nvte_cublas_gemm
layer_norm.h
void nvte_layernorm_fwd
void nvte_layernorm1p_fwd
void nvte_layernorm_bwd
void nvte_layernorm1p_bwd
softmax.h
void nvte_scaled_softmax_forward
void nvte_scaled_softmax_backward
void nvte_scaled_masked_softmax_forward
void nvte_scaled_masked_softmax_backward
void nvte_scaled_upper_triang_masked_softmax_forward
void nvte_scaled_upper_triang_masked_softmax_backward
transformer_engine.h
typedef void *NVTETensor
enum NVTEDType
enumerator kNVTEByte
enumerator kNVTEInt32
enumerator kNVTEFloat32
enumerator kNVTEFloat16
enumerator kNVTEBFloat16
enumerator kNVTEFloat8E4M3
enumerator kNVTEFloat8E5M2
enumerator kNVTENumTypes
NVTETensor nvte_create_tensor
void nvte_destroy_tensor
NVTEDType nvte_tensor_type
NVTEShape nvte_tensor_shape
void *nvte_tensor_data
float *nvte_tensor_amax
float *nvte_tensor_scale
float *nvte_tensor_scale_inv
struct NVTEShape
const size_t *data
size_t ndim
namespace transformer_engine
enum class DType
struct TensorWrapper
transpose.h
void nvte_cast_transpose
void nvte_transpose
void nvte_cast_transpose_dbias
void nvte_cast_transpose_dbias_dgelu
void nvte_multi_cast_transpose
void nvte_dgeglu_cast_transpose
Transformer Engine
»
Framework-specific API
View page source
Framework-specific API
¶
pyTorch
Linear
forward
LayerNorm
LayerNormLinear
forward
LayerNormMLP
forward
DotProductAttention
forward
TransformerLayer
forward
fp8_autocast
checkpoint
Jax
MajorShardingType
ShardingType
TransformerLayerType
ShardingResource
LayerNorm
__call__
DenseGeneral
__call__
LayerNormDenseGeneral
__call__
LayerNormMLP
__call__
RelativePositionBiases
__call__
MultiHeadAttention
__call__
TransformerLayer
__call__
extend_logical_axis_rules
fp8_autocast
update_collections
update_fp8_metas