Release Notes – Release 2.10¶
Key Features and Enhancements¶
[PyTorch] Added support for the NVFP4 training recipe for the GroupedLinear module.
[PyTorch] Added support for CUDA graphs when using quantized weights with Tensor Parallelism.
[PyTorch] Added support for CUDA graphs when using delay_wgrad_compute.
[PyTorch] Expanded debug tools to support more statistics.
[PyTorch] Reduced the overhead of using debug tools.
[PyTorch] Added support for clamped SwiGLU in the TransformerLayer module.
[PyTorch] Added backwards compatibility for older Megatron-Core versions by introducing a keep_columnwise parameter to cast_master_weights_to_fp8 and related helper functions.
[PyTorch] Added a reset interface to make_graphed_callables that clears internal CUDA graphs before distributed process group cleanup, preventing hangs.
[PyTorch] Added support for FSDP2 with quantized weights.
[PyTorch] Added support for Sliding Window Attention (SWA) with Context Parallelism with THD input format.
[PyTorch] Integrated Flash Attention’s num_splits parameters into the attention backend.
[PyTorch] Made various improvements to mitigate CPU overhead, especially for the GroupedLinear module.
[C][PyTorch] Enabled RoPE (Rotary Position Embedding) application with position offsets during training, removing the previous restriction that start_positions could only be used with cp_size=1 (context parallelism disabled).
[Jax] Added options to disable Stochastic Rounding, Randomized Hadamard Transform, and 2D weight quantization in the NVFP4 training recipe.
[Jax] Improved performance by using Transformer Engine quantization when fused normalization or fused activation are disabled.
[Jax] Performance Improvement for NVFP4 via TE kernels for scaling factor swizzles.
[Jax] Added support for checkpointing quantization operations in JAX.
[Jax] Added support for sink attention.
[Jax] Added support for concurrent use of Data Parallelism (DP) and Fully-Sharded Data Parallelism (FSDP).
Fixed Issues¶
Fixed an occasional crash when loading cuDNN library during runtime.
[C] Fixed an out of bounds access in the NVFP4 dequantization kernel.
[C] Fixed a numerical error in the amax computation in normalization kernels.
[PyTorch] Fixed a crash in the permute kernel when using triton v3.5.
[PyTorch] Fixed a numerical issue when using gradient accumulation fusion with FSDP.
[PyTorch] Fixed a crash when exporting modules via ONNX when using RMSNorm.
[Jax] Fixed a partitioning issue for the NVFP4 training recipe with 1D Mesh.
[Jax] Fixed a bug where the bias parameter could be added twice when using unfused attention backend.
[Jax] Fixed a sharding bug in ring attention primitives when using packed sequences where segment position tensors were not properly sharded to match their corresponding segment ID tensors.
[PyTorch][Jax] Fixed various logical issues in the backend selection process for attention.
Known Issues in This Release¶
There are no known issues in this release.
Breaking Changes in This Release¶
[Jax] Default value for intermediate_dropout changed from 0.1 to 0.0.
[Jax] Default value for return_layernorm_output changed from True to False.
[Jax] Default activation changed from ReLU to GeLU.
[Jax] Default input type for DotProductAttention is changed to BSHD.
Deprecated Features¶
There are no deprecated features in this release.