core.tensor_parallel.inference_layers#

Module Contents#

Classes#

InferenceLayerNormColumnParallelLinear

Inference optimized version of TELayerNormColumnParallelLinear.

InferenceRowParallelLinear

Inference optimized version of TERowParallelLinear.

Functions#

API#

core.tensor_parallel.inference_layers._te_rms_norm_kernel(x: torch.Tensor, weight: torch.Tensor, eps: float)#
class core.tensor_parallel.inference_layers.InferenceLayerNormColumnParallelLinear(
input_size: int,
output_size: int,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
stride: int = 1,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear

Inference optimized version of TELayerNormColumnParallelLinear.

Initialization

_maybe_allocate_symmetric_buffer(x: torch.Tensor)#

Attempt to allocate symmetric memory buffer for all-gather.

_all_gather(x: torch.Tensor, symm_mem_buffer: dict) None#

Attempt an NVLS all-gather into symmetric memory. If not possible, revert to torch dist (NCCL) all-gather.

forward(x: torch.Tensor) torch.Tensor#

Forward pass.

class core.tensor_parallel.inference_layers.InferenceRowParallelLinear(
input_size: int,
output_size: int,
*,
config: megatron.core.model_parallel_config.ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: megatron.core.extensions.transformer_engine.TERowParallelLinear

Inference optimized version of TERowParallelLinear.

Initialization

_matmul_reduce_scatter(x, residual=None)#

Multiplies x by the weight matrix and performs a reduce-scatter. It will first try to write the matmul output to symmetric memory and perform an NVLS multicast reduce-scatter. If that is not possible, it will revert to torch.dist (NCCL) reduce-scatter.

_set_next_layer_norm_weights(weights: torch.Tensor)#

Set next layer norm weights for fused reduce-scatter + add + rms-norm + all-gather.

_set_residual(residual: torch.Tensor)#

Set residual for fused reduce-scatter + add + rms-norm + all-gather.

forward(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) torch.Tensor#

Forward pass.