core.tensor_parallel.inference_layers#
Module Contents#
Classes#
Inference optimized version of TELayerNormColumnParallelLinear. |
|
Inference optimized version of TERowParallelLinear. |
Functions#
API#
- core.tensor_parallel.inference_layers._te_rms_norm_kernel(x: torch.Tensor, weight: torch.Tensor, eps: float)#
- class core.tensor_parallel.inference_layers.InferenceLayerNormColumnParallelLinear(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- init_method: Callable,
- gather_output: bool,
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- stride: int = 1,
- skip_weight_param_allocation: bool = False,
- tp_comm_buffer_name: Optional[str] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinearInference optimized version of TELayerNormColumnParallelLinear.
Initialization
- _maybe_allocate_symmetric_buffer(x: torch.Tensor)#
Attempt to allocate symmetric memory buffer for all-gather.
- _all_gather(x: torch.Tensor, symm_mem_buffer: dict) None#
Attempt an NVLS all-gather into symmetric memory. If not possible, revert to torch dist (NCCL) all-gather.
- forward(x: torch.Tensor) torch.Tensor#
Forward pass.
- class core.tensor_parallel.inference_layers.InferenceRowParallelLinear(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- init_method: Callable,
- bias: bool,
- input_is_parallel: bool,
- skip_bias_add: bool,
- is_expert: bool,
- tp_comm_buffer_name: Optional[str] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
megatron.core.extensions.transformer_engine.TERowParallelLinearInference optimized version of TERowParallelLinear.
Initialization
- _matmul_reduce_scatter(x, residual=None)#
Multiplies x by the weight matrix and performs a reduce-scatter. It will first try to write the matmul output to symmetric memory and perform an NVLS multicast reduce-scatter. If that is not possible, it will revert to torch.dist (NCCL) reduce-scatter.
- _set_next_layer_norm_weights(weights: torch.Tensor)#
Set next layer norm weights for fused reduce-scatter + add + rms-norm + all-gather.
- _set_residual(residual: torch.Tensor)#
Set residual for fused reduce-scatter + add + rms-norm + all-gather.
- forward(
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
Forward pass.