Transformer Engine
0.7.0 -770e968
Version select:
Current release
Older releases
Home
Getting Started
Installation
Prerequisites
Transformer Engine in NGC Containers
pip - from GitHub
Additional Prerequisites
Installation (stable release)
Installation (development build)
Getting Started
Overview
Let’s build a Transformer layer!
Meet Transformer Engine
Fused TE Modules
Enabling FP8
Python API documentation
Common API
Format
DelayedScaling
Framework-specific API
pyTorch
Linear
forward
LayerNorm
LayerNormLinear
forward
LayerNormMLP
forward
DotProductAttention
forward
TransformerLayer
forward
fp8_autocast
checkpoint
Jax
MajorShardingType
ShardingType
TransformerLayerType
ShardingResource
LayerNorm
__call__
DenseGeneral
__call__
LayerNormDenseGeneral
__call__
LayerNormMLP
__call__
RelativePositionBiases
__call__
MultiHeadAttention
__call__
TransformerLayer
__call__
extend_logical_axis_rules
fp8_autocast
update_collections
update_fp8_metas
Examples and Tutorials
Using FP8 with Transformer Engine
Introduction to FP8
Structure
Mixed precision training - a quick introduction
Mixed precision training with FP8
Using FP8 with Transformer Engine
FP8 recipe
FP8 autocasting
Handling backward pass
Precision
Performance Optimizations
Multi-GPU training
Gradient accumulation fusion
FP8 weight caching
Advanced
C/C++ API
activation.h
void nvte_gelu
void nvte_geglu
void nvte_dgeglu
cast.h
void nvte_fp8_quantize
void nvte_fp8_dequantize
gemm.h
void nvte_cublas_gemm
layer_norm.h
void nvte_layernorm_fwd
void nvte_layernorm1p_fwd
void nvte_layernorm_bwd
void nvte_layernorm1p_bwd
softmax.h
void nvte_scaled_softmax_forward
void nvte_scaled_softmax_backward
void nvte_scaled_masked_softmax_forward
void nvte_scaled_masked_softmax_backward
void nvte_scaled_upper_triang_masked_softmax_forward
void nvte_scaled_upper_triang_masked_softmax_backward
transformer_engine.h
typedef void *NVTETensor
enum NVTEDType
enumerator kNVTEByte
enumerator kNVTEInt32
enumerator kNVTEFloat32
enumerator kNVTEFloat16
enumerator kNVTEBFloat16
enumerator kNVTEFloat8E4M3
enumerator kNVTEFloat8E5M2
enumerator kNVTENumTypes
NVTETensor nvte_create_tensor
void nvte_destroy_tensor
NVTEDType nvte_tensor_type
NVTEShape nvte_tensor_shape
void *nvte_tensor_data
float *nvte_tensor_amax
float *nvte_tensor_scale
float *nvte_tensor_scale_inv
struct NVTEShape
const size_t *data
size_t ndim
namespace transformer_engine
enum class DType
struct TensorWrapper
transpose.h
void nvte_cast_transpose
void nvte_transpose
void nvte_cast_transpose_dbias
void nvte_cast_transpose_dbias_dgelu
void nvte_multi_cast_transpose
void nvte_dgeglu_cast_transpose
Transformer Engine
»
Index
Index
_
|
C
|
D
|
E
|
F
|
L
|
M
|
N
|
R
|
S
|
T
|
U
_
__call__() (transformer_engine.jax.DenseGeneral method)
(transformer_engine.jax.LayerNorm method)
(transformer_engine.jax.LayerNormDenseGeneral method)
(transformer_engine.jax.LayerNormMLP method)
(transformer_engine.jax.MultiHeadAttention method)
(transformer_engine.jax.RelativePositionBiases method)
(transformer_engine.jax.TransformerLayer method)
C
checkpoint() (in module transformer_engine.pytorch)
D
DelayedScaling (class in transformer_engine.common.recipe)
DenseGeneral (class in transformer_engine.jax)
DotProductAttention (class in transformer_engine.pytorch)
E
extend_logical_axis_rules() (in module transformer_engine.jax)
F
Format (class in transformer_engine.common.recipe)
forward() (transformer_engine.pytorch.DotProductAttention method)
(transformer_engine.pytorch.LayerNormLinear method)
(transformer_engine.pytorch.LayerNormMLP method)
(transformer_engine.pytorch.Linear method)
(transformer_engine.pytorch.TransformerLayer method)
fp8_autocast() (in module transformer_engine.jax)
(in module transformer_engine.pytorch)
L
LayerNorm (class in transformer_engine.jax)
(class in transformer_engine.pytorch)
LayerNormDenseGeneral (class in transformer_engine.jax)
LayerNormLinear (class in transformer_engine.pytorch)
LayerNormMLP (class in transformer_engine.jax)
(class in transformer_engine.pytorch)
Linear (class in transformer_engine.pytorch)
M
MajorShardingType (class in transformer_engine.jax)
MultiHeadAttention (class in transformer_engine.jax)
N
nvte_cast_transpose (C++ function)
nvte_cast_transpose_dbias (C++ function)
nvte_cast_transpose_dbias_dgelu (C++ function)
nvte_create_tensor (C++ function)
nvte_cublas_gemm (C++ function)
nvte_destroy_tensor (C++ function)
nvte_dgeglu (C++ function)
nvte_dgeglu_cast_transpose (C++ function)
nvte_fp8_dequantize (C++ function)
nvte_fp8_quantize (C++ function)
nvte_geglu (C++ function)
nvte_gelu (C++ function)
nvte_layernorm1p_bwd (C++ function)
nvte_layernorm1p_fwd (C++ function)
nvte_layernorm_bwd (C++ function)
nvte_layernorm_fwd (C++ function)
nvte_multi_cast_transpose (C++ function)
nvte_scaled_masked_softmax_backward (C++ function)
nvte_scaled_masked_softmax_forward (C++ function)
nvte_scaled_softmax_backward (C++ function)
nvte_scaled_softmax_forward (C++ function)
nvte_scaled_upper_triang_masked_softmax_backward (C++ function)
nvte_scaled_upper_triang_masked_softmax_forward (C++ function)
nvte_tensor_amax (C++ function)
nvte_tensor_data (C++ function)
nvte_tensor_scale (C++ function)
nvte_tensor_scale_inv (C++ function)
nvte_tensor_shape (C++ function)
nvte_tensor_type (C++ function)
nvte_transpose (C++ function)
NVTEDType (C++ enum)
NVTEDType::kNVTEBFloat16 (C++ enumerator)
NVTEDType::kNVTEByte (C++ enumerator)
NVTEDType::kNVTEFloat16 (C++ enumerator)
NVTEDType::kNVTEFloat32 (C++ enumerator)
NVTEDType::kNVTEFloat8E4M3 (C++ enumerator)
NVTEDType::kNVTEFloat8E5M2 (C++ enumerator)
NVTEDType::kNVTEInt32 (C++ enumerator)
NVTEDType::kNVTENumTypes (C++ enumerator)
NVTEShape (C++ struct)
NVTEShape::data (C++ member)
NVTEShape::ndim (C++ member)
NVTETensor (C++ type)
R
RelativePositionBiases (class in transformer_engine.jax)
S
ShardingResource (class in transformer_engine.jax)
ShardingType (class in transformer_engine.jax)
T
transformer_engine (C++ type)
transformer_engine::DType (C++ enum)
transformer_engine::DType::kBFloat16 (C++ enumerator)
transformer_engine::DType::kByte (C++ enumerator)
transformer_engine::DType::kFloat16 (C++ enumerator)
transformer_engine::DType::kFloat32 (C++ enumerator)
transformer_engine::DType::kFloat8E4M3 (C++ enumerator)
transformer_engine::DType::kFloat8E5M2 (C++ enumerator)
transformer_engine::DType::kInt32 (C++ enumerator)
transformer_engine::DType::kNumTypes (C++ enumerator)
transformer_engine::TensorWrapper (C++ struct)
transformer_engine::TensorWrapper::amax (C++ function)
transformer_engine::TensorWrapper::data (C++ function)
transformer_engine::TensorWrapper::dptr (C++ function)
transformer_engine::TensorWrapper::dtype (C++ function)
transformer_engine::TensorWrapper::operator= (C++ function)
,
[1]
transformer_engine::TensorWrapper::scale (C++ function)
transformer_engine::TensorWrapper::scale_inv (C++ function)
transformer_engine::TensorWrapper::shape (C++ function)
transformer_engine::TensorWrapper::tensor_ (C++ member)
transformer_engine::TensorWrapper::TensorWrapper (C++ function)
,
[1]
,
[2]
,
[3]
,
[4]
transformer_engine::TensorWrapper::~TensorWrapper (C++ function)
TransformerLayer (class in transformer_engine.jax)
(class in transformer_engine.pytorch)
TransformerLayerType (class in transformer_engine.jax)
U
update_collections() (in module transformer_engine.jax)
update_fp8_metas() (in module transformer_engine.jax)