Important

You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.

Context Parallelism#

Context Parallelism (CP) is a method for parallelizing the processing of neural network activations across multiple GPUs by partitioning the input tensors along the sequence dimension. Unlike Sequence Parallelism (SP) that partitions the activations of specific layers, CP divides the activations of all layers.

CP is critical for training long context models, as it allows the model to handle longer sequences by distributing the sequence activations across multiple GPUs. This method reduces the memory footprint and computational cost of processing long sequences.

Enable Context Parallelism#

To activate CP in the NeMo Framework, set the context_parallel_size parameter in the model configuration. This parameter specifies the number of GPUs across which the model’s sequence activations are distributed.

Set context_parallel_size to a value greater than 1 to enable sequence-wide model parallelism inside Megatron Strategy.

strategy = nl.MegatronStrategy(
  tensor_model_parallel_size=2,
  pipeline_model_parallel_size=2,
  virtual_pipeline_model_parallel_size=None,
  context_parallel_size=1, # Example to enable Context Parallelism
  sequence_parallel=True,
  expert_model_parallel_size=1,
)

More details on the model training settings can be found and modified here: NeMo 2.0 Pretraining.

Implement Context Parallelism#

NeMo Framework leverages functionalities from both Megatron Core and Transformer Engine to implement CP efficiently. During forward propagation, each GPU handles a segment of the sequence, storing only the necessary Key and Value (KV) pairs. In the backward pass, these KV pairs are reassembled across GPUs using advanced communication schemes like all-gather and reduce-scatter transformed into point-to-point communications in a ring topology. This method reduces the memory footprint significantly while maintaining computational efficiency.

Visit our source code for more insights into the implementation: