core.tensor_parallel.inference_layers#

Module Contents#

Classes#

InferenceLayerNormColumnParallelLinear

Inference optimized version of TELayerNormColumnParallelLinear.

InferenceRowParallelLinear

Inference optimized version of TERowParallelLinear.

Functions#

API#

core.tensor_parallel.inference_layers._te_rms_norm_kernel(x: torch.Tensor, weight: torch.Tensor, eps: float)#
class core.tensor_parallel.inference_layers.InferenceLayerNormColumnParallelLinear(
input_size: int,
output_size: int,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear

Inference optimized version of TELayerNormColumnParallelLinear.

Initialization

_all_gather(x: torch.Tensor) None#

Attempt an NVLS all-gather into symmetric memory. If not possible, revert to torch dist (NCCL) all-gather.

forward(x: torch.Tensor) torch.Tensor#

Forward pass.

class core.tensor_parallel.inference_layers.InferenceRowParallelLinear(
input_size: int,
output_size: int,
*,
config: megatron.core.model_parallel_config.ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: megatron.core.extensions.transformer_engine.TERowParallelLinear

Inference optimized version of TERowParallelLinear.

Initialization

_matmul_reduce_scatter(x)#

Multiplies x by the weight matrix and performs a reduce-scatter. It will first try to write the matmul output to symmetric memory and perform an NVLS multicast reduce-scatter. If that is not possible, it will revert to torch.dist (NCCL) reduce-scatter.

forward(x: torch.Tensor) torch.Tensor#

Forward pass.