core.tensor_parallel.inference_layers#
Module Contents#
Classes#
Inference optimized version of TELayerNormColumnParallelLinear. |
|
Inference optimized version of TERowParallelLinear. |
Functions#
API#
- core.tensor_parallel.inference_layers._te_rms_norm_kernel(x: torch.Tensor, weight: torch.Tensor, eps: float)#
- class core.tensor_parallel.inference_layers.InferenceLayerNormColumnParallelLinear(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- init_method: Callable,
- gather_output: bool,
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- skip_weight_param_allocation: bool = False,
- tp_comm_buffer_name: Optional[str] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinearInference optimized version of TELayerNormColumnParallelLinear.
Initialization
- _all_gather(x: torch.Tensor) None#
Attempt an NVLS all-gather into symmetric memory. If not possible, revert to torch dist (NCCL) all-gather.
- forward(x: torch.Tensor) torch.Tensor#
Forward pass.
- class core.tensor_parallel.inference_layers.InferenceRowParallelLinear(
- input_size: int,
- output_size: int,
- *,
- config: megatron.core.model_parallel_config.ModelParallelConfig,
- init_method: Callable,
- bias: bool,
- input_is_parallel: bool,
- skip_bias_add: bool,
- is_expert: bool,
- tp_comm_buffer_name: Optional[str] = None,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
megatron.core.extensions.transformer_engine.TERowParallelLinearInference optimized version of TERowParallelLinear.
Initialization
- _matmul_reduce_scatter(x)#
Multiplies x by the weight matrix and performs a reduce-scatter. It will first try to write the matmul output to symmetric memory and perform an NVLS multicast reduce-scatter. If that is not possible, it will revert to torch.dist (NCCL) reduce-scatter.
- forward(x: torch.Tensor) torch.Tensor#
Forward pass.