Ring Attention Benchmark#
Benchmarks scaled_dot_product_attention on a single GPU vs. distributed ring attention via ShardTensor.
Quick start#
Single GPU:
python benchmark_sharded_attention.py \
--seq_len 4096 --num_heads 16 --head_dim 64
Distributed (ring attention):
torchrun --nproc-per-node 4 benchmark_sharded_attention.py \
--seq_len 4096 --num_heads 16 --head_dim 64
Key options#
Flag |
Default |
Description |
|---|---|---|
|
4096 |
Sequence length (world-size-divisible; chunk multiple of 32) |
|
16 |
Number of attention heads |
|
64 |
Dimension per head |
|
1 |
Batch size |
|
|
|
|
|
|
|
5 |
Warmup iterations |
|
10 |
Timed iterations |
|
— |
Path to write JSON results |
Plotting results#
After collecting JSON results in results/, generate scaling plots:
python plot_scaling_results.py
This reads all results/*.json files and writes per-mode latency plots (e.g. ring_attention_shard_tensor_inference.png).
The module also exposes helpers for custom analysis:
import plot_scaling_results as psr
df = psr.load_results() # DataFrame, one row per JSON file
train = psr.filter(df, mode="train", gpus=4) # filter by mode / GPUs / seq_len
df = psr.add_efficiency(df) # adds speedup & parallel_efficiency columns
print(psr.summary_table(df)) # pivot table of mean latency (ms)