This repository includes:
- Usage and benchmarks of
SymmetricMemory
-based multi-GPU algorithms in PyTorch. - Examples and benchmarks of multi-GPU algorithms built with
SymmetricMemory
+ Triton.
This script demonstrates the usage of SymmetricMemory
-based NVLink all-reduce implementations and benchmarks their performance. The available variants are:
multimem_all_reduce
(PyTorch op available in nightly)one_shot_all_reduce
(PyTorch op available in nightly)two_shot_all_reduce
(PyTorch op available in nightly)triton_multimem_all_reduce
(Triton kernel defined in this repo)triton_one_shot_all_reduce
(Triton kernel defined in this repo)
Usage:
torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 symm_mem_all_reduce.py --impl multimem_all_reduce
Some benchmarks on 8xH100 with NVSwitch:
![](https://private-user-images.githubusercontent.com/4156752/398271636-5de69841-7683-4b7a-9a38-f1aac3785060.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwODM0NzIsIm5iZiI6MTczOTA4MzE3MiwicGF0aCI6Ii80MTU2NzUyLzM5ODI3MTYzNi01ZGU2OTg0MS03NjgzLTRiN2EtOWEzOC1mMWFhYzM3ODUwNjAucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI1MDIwOSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNTAyMDlUMDYzOTMyWiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9MTU1MTg1NDMwZjFjZDQ3MzkxNTk5OWVmZWQwYWNkMmI5NDU0MzJmMmEwYTExNWUyMWI4ZjYwNjJhZjAxYjhkMCZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QifQ.Y8m1z3-6rGhWLZNtJD8i32GaTo1VHJAVRj6SK132TiU)
![](https://private-user-images.githubusercontent.com/4156752/398271645-c666cd6c-3f70-4380-9fa1-0d8e953cb382.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwODM0NzIsIm5iZiI6MTczOTA4MzE3MiwicGF0aCI6Ii80MTU2NzUyLzM5ODI3MTY0NS1jNjY2Y2Q2Yy0zZjcwLTQzODAtOWZhMS0wZDhlOTUzY2IzODIucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI1MDIwOSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNTAyMDlUMDYzOTMyWiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9ZThkYzM3MTNkNDZmZTFiZWRkMTFjZjFiYzRmNTQ4ODU2N2U2MzQ1ZmJhNjgyNjc1ZTMzODgzM2FkZTZkZmYzYyZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QifQ.f2iMRKeLnU2qHXEojhJQLd7lV61GiJUWpbyuWbs51zA)
![](https://private-user-images.githubusercontent.com/4156752/398271655-597e12d8-37ed-4776-aca8-2b12bba58bff.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwODM0NzIsIm5iZiI6MTczOTA4MzE3MiwicGF0aCI6Ii80MTU2NzUyLzM5ODI3MTY1NS01OTdlMTJkOC0zN2VkLTQ3NzYtYWNhOC0yYjEyYmJhNThiZmYucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI1MDIwOSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNTAyMDlUMDYzOTMyWiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9NzlkMzk1MjM3MDcxMDkzMTAyMWQ4NGE5M2E1NjAwMGMxNmY0YzFmYjM2MDM3NTFlZjlkYzVjMzU2NjhjZjc3YiZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QifQ.P2wxhE6tFRrGd3hgzhuiExaLiSZTFUlMU0h96o_9RKQ)
![](https://private-user-images.githubusercontent.com/4156752/398271662-1cfa320d-589f-466f-a54f-7fa45e6f132e.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkwODM0NzIsIm5iZiI6MTczOTA4MzE3MiwicGF0aCI6Ii80MTU2NzUyLzM5ODI3MTY2Mi0xY2ZhMzIwZC01ODlmLTQ2NmYtYTU0Zi03ZmE0NWU2ZjEzMmUucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI1MDIwOSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNTAyMDlUMDYzOTMyWiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9NzZmZTgxY2E4NzdkNzJjOWU4ZWQ3Zjk5NTk5ZmViZWRlMTRiYTU1ZDczOTdmYzM2MjUyYTM2MmExM2QxNTZkOSZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QifQ.glrtXsFs9Lm5QY-7MWEWFrAwzIYORQzLy8DuxJhJby4)
This is a fused all-gather matmul example using Triton + SymmetricMemory
, based on the tma_persistent
Triton tutorial with slight modifications.
This example requires PyTorch Nightly and Triton 3.0.0+ to run.
Usage:
torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 triton_all_gather_matmul.py \
--M 16384 --N 6656 --K 16384 --BLOCK_SIZE_M 128 --BLOCK_SIZE_N 256 --BLOCK_SIZE_K 64
Some benchmarks on 8xH100 (special version with HBM2e, at 650W) with NVSwitch:
Problem Size (M) |
Config1 | cuBLAS MM Only (µs) |
Triton MM Only (µs) |
cuBLAS + NCCL (µs) |
Triton Fused (µs) |
Speedup |
---|---|---|---|---|---|---|
4096 | 64,128,128,4 | 100 | 142 | 223 | 211 | 1.05x2 |
8192 | 128,128,64,6 | 186 | 198 | 393 | 293 | 1.34x |
16384 | 128,256,64,3 | 363 | 363 | 748 | 485 | 1.54x |
Problem Size (M) |
Config1 | cuBLAS MM Only (µs) |
Triton MM Only (µs) |
cuBLAS + NCCL (µs) |
Triton Fused (µs) |
Speedup |
---|---|---|---|---|---|---|
4096 | 128,128,64,6 | 376 | 392 | 587 | 453 | 1.29x |
8192 | 128,256,64,3 | 746 | 706 | 1168 | 821 | 1.42x |
16384 | 128,256,64,3 | 1502 | 1403 | 2306 | 1566 | 1.47x |
Problem Size (M) |
Config1 | cuBLAS MM Only (µs) |
Triton MM Only (µs) |
cuBLAS + NCCL (µs) |
Triton Fused (µs) |
Speedup |
---|---|---|---|---|---|---|
4096 | 128,256,64,3 | 1358 | 1425 | 1858 | 1615 | 1.15x |
8192 | 128,256,64,3 | 2567 | 2656 | 3533 | 2907 | 1.22x |
16384 | 128,256,64,3 | 5249 | 5375 | 6982 | 5814 | 1.20x |
1 Config refers to BLOCK_SIZE_M
, BLOCK_SIZE_N
, BLOCK_SIZE_K
, and num_stages
.
2 For this problem size, using multicast all-gather would be a more suitable optimization.