Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] cutlass FlashAttention bias+dropout support (facebookresearch#587
) * [feat] cutlass FlashAttention bias+dropout support adds attn bias (including bias grad) and dropout support to CUTLASS flashattn implementation [-------------------------------------------- attn --------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 12.7 | 7.5 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 15.5 | 9.1 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 12.7 | 7.6 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 15.6 | 9.1 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 10.1 | 6.0 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 12.7 | 7.5 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 44.3 | 29.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 55.0 | 35.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 45.1 | 29.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 55.6 | 35.3 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 37.0 | 22.6 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 46.8 | 29.0 Times are in milliseconds (ms). [------------------------------------------ attn-bwd ------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 19.3 | 24.1 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 19.4 | 24.6 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 22.3 | 28.7 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 22.4 | 29.0 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 19.5 | 22.7 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 19.5 | 23.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 62.7 | 91.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 63.4 | 93.7 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 74.8 | 109.8 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 75.1 | 111.1 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 63.2 | 85.5 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 64.0 | 90.1 * benchmark fixes * add more conditions to reduce dOi @ Vj to 2 stages BEFORE [------------------------------------------ attn-bwd ------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0) | 2.8 | 2.4 (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5) | 2.8 | 3.3 (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0) | 3.4 | 3.2 (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5) | 3.4 | 4.2 (8, 512, 64, 64, torch.float16, None, False, 0.0) | 2.8 | 2.0 (8, 512, 64, 64, torch.float16, None, False, 0.5) | 2.8 | 2.9 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 3.6 | 3.9 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 3.6 | 4.8 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 4.2 | 4.8 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 4.2 | 5.6 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 3.6 | 3.4 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 3.6 | 4.4 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0) | 9.7 | 8.8 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5) | 9.7 | 12.6 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0) | 12.0 | 12.1 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5) | 12.1 | 16.1 (8, 1024, 64, 64, torch.float16, None, False, 0.0) | 9.7 | 7.4 (8, 1024, 64, 64, torch.float16, None, False, 0.5) | 9.7 | 10.8 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 11.3 | 14.0 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 11.3 | 17.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 13.6 | 17.8 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 13.6 | 20.9 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 11.3 | 12.1 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 11.3 | 15.8 AFTER [------------------------------------------ attn-bwd ------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0) | 2.8 | 2.4 (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5) | 2.8 | 3.0 (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0) | 3.4 | 3.2 (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5) | 3.4 | 3.8 (8, 512, 64, 64, torch.float16, None, False, 0.0) | 2.8 | 2.0 (8, 512, 64, 64, torch.float16, None, False, 0.5) | 2.8 | 2.6 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 3.6 | 3.9 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 3.6 | 4.8 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 4.2 | 4.8 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 4.2 | 5.6 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 3.6 | 3.4 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 3.6 | 4.4 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0) | 9.7 | 8.8 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5) | 9.7 | 11.4 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0) | 12.0 | 12.1 (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5) | 12.1 | 14.6 (8, 1024, 64, 64, torch.float16, None, False, 0.0) | 9.7 | 7.4 (8, 1024, 64, 64, torch.float16, None, False, 0.5) | 9.7 | 9.6 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 11.3 | 14.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 11.3 | 17.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 13.6 | 17.8 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 13.6 | 20.9 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 11.3 | 12.1 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 11.3 | 15.8 * fix mypy error * fix windows build * rename cutlass rand uniform file name * black reformat
- Loading branch information