fusions package
This package provides modules that provide commonly fused operations. Fusing operations improves compute efficiency by increasing the amount of work done each time a tensor is read from memory. To perform the fusion, modules in this either rely on PyTorch functionality for doing just-in-time compilation (i.e. torch.jit.script in older PyTorch versions of torch.compile in recent versions), or call into custom kernels in external libraries such as Apex or TransformerEngine.
This module uses PyTorch JIT to fuse the bias add and dropout operations. Since dropout is not used during inference, different functions are used when in train mode and when in inference mode.
- core.fusions.fused_bias_dropout.bias_dropout_add_fused_inference(x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float) → torch.Tensor
- core.fusions.fused_bias_dropout.bias_dropout_add_fused_train(x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float) → torch.Tensor
- core.fusions.fused_bias_dropout.bias_dropout_add_unfused(training)
- core.fusions.fused_bias_dropout.get_bias_dropout_add(training, fused)
This module uses PyTorch JIT to fuse the bias add and GeLU nonlinearity operations.
This module provides a wrapper around various fused LayerNorm implementation in Apex.
- class core.fusions.fused_layer_norm.FusedLayerNorm(*args: Any, **kwargs: Any)
Bases:
torch.nn.Module
Layer Norm, fused into a single CUDA kernel.
- Parameters
hidden_size (int) – Transformer hidden dimension.
eps (float) – Epsilon added to denominator, for numerical stability.
persist_layer_norm (bool) – Use persistent fused layer norm kernel.
Please (This kernel supports only a set of hidden sizes.) –
supported. (check persist_ln_hidden_sizes if your hidden size is) –
parallel (sequence) – Apply sequence parallelism optimization.
zero_centered_gamma (bool) – Adjust LayerNorm weights such that they are
stability. (centered around zero. This improves numerical) –
config (TransformerConfig) – Transformer config. Include to match custom
interfaces. (layer norm) –
normalization (str) – Normalization type, used for Transformer Engine.
here. (Must equal 'LayerNorm') –
- forward(input: torch.Tensor) → torch.Tensor
- reset_parameters()
This module provides wrappers around variations of Softmax in Apex.
- class core.fusions.fused_softmax.FusedScaleMaskSoftmax(*args: Any, **kwargs: Any)
Bases:
torch.nn.Module
fused operation: scaling + mask + softmax
- Parameters
input_in_fp16 – flag to indicate if input in fp16 data format.
input_in_bf16 – flag to indicate if input in bf16 data format.
attn_mask_type – attention mask type (pad or causal)
scaled_masked_softmax_fusion – flag to indicate user want to use softmax fusion
mask_func – mask function to be applied.
softmax_in_fp32 – if true, softmax in performed at fp32 precision.
scale – scaling factor used in input tensor scaling.
- forward(input: torch.Tensor, mask: Optional[torch.Tensor])
Forward pass of softmax with masked input.
In case attn_mask_type is causal the mask is generated and None can be passed. A user-defined mask is only needed when attn_mask_type is not causal.
- forward_fused_softmax(input, mask)
- forward_torch_softmax(input, mask)
- static get_batch_per_block(sq, sk, b, np)
- is_kernel_available(mask, b, np, sq, sk)
- class core.fusions.fused_softmax.ScaledMaskedSoftmax(*args: Any, **kwargs: Any)
Bases:
torch.autograd.Function
Fused operation which performs following three operations in sequence 1. Scale the tensor. 2. Apply the mask. 3. Perform softmax.
- static backward(ctx, output_grads)
- static forward(ctx, inputs, mask, scale)
- class core.fusions.fused_softmax.ScaledSoftmax(*args: Any, **kwargs: Any)
Bases:
torch.autograd.Function
Fused operation which performs following two operations in sequence 1. Scale the tensor. 2. Perform softmax.
- static backward(ctx, output_grads)
- static forward(ctx, inputs, scale)
- class core.fusions.fused_softmax.ScaledUpperTriangMaskedSoftmax(*args: Any, **kwargs: Any)
Bases:
torch.autograd.Function
Fused operation which performs following three operations in sequence 1. Scale the tensor. 2. Apply upper triangular mask (typically used in gpt models). 3. Perform softmax.
- static backward(ctx, output_grads)
- static forward(ctx, inputs, scale)