Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fMHA: Make dropout/bias configurable at build-time
# Performance on A100 (FW) ``` ------------------------------ attention (attn_bias=<class 'NoneType'>) -----------------------------] | disable_bias_dropout | main | eager | pr587_d1b0fa 1 threads: -------------------------------------------------------------------------------------------- f16 B=1, M=4096, H=160, K=128 | 14929.6 | 14853.7 | 21510.6 | 15245.8 f32 B=1, M=4096, H=160, K=128 | 56159.3 | 56691.3 | 91407.5 | 57963.3 f16 B=2, M=4096, H=160, K=128 | 29784.5 | 29633.5 | 43562.2 | 30427.9 f32 B=2, M=4096, H=160, K=128 | 112125.7 | 113237.0 | | 115784.8 f16 B=1, M=8192, H=160, K=128 | 59600.3 | 59240.6 | | 60820.3 f32 B=1, M=8192, H=160, K=128 | 224734.5 | 226792.5 | | 232806.2 f16 B=2, M=8192, H=160, K=128 | 119145.7 | 118414.6 | | 121599.1 f32 B=2, M=8192, H=160, K=128 | 449501.3 | 453321.1 | | 466094.3 f16 B=1, M=4096, H=16, K=40 | 863.6 | 870.6 | 1958.3 | 857.0 f32 B=1, M=4096, H=16, K=40 | 2771.4 | 2765.6 | 6914.2 | 2866.1 f16 B=1, M=16384, H=16, K=40 | 12251.9 | 12346.3 | 30460.6 | 12159.2 f32 B=1, M=16384, H=16, K=40 | 40643.3 | 40533.2 | 123282.2 | 41947.4 f16 B=256, M=4096, H=16, K=64 | 182110.1 | 183488.9 | | 181204.2 f32 B=256, M=4096, H=16, K=64 | 644969.5 | 642960.3 | | 665520.1 f16 B=64, M=1024, H=16, K=16 | 2577.2 | 2593.0 | 7293.8 | 2611.3 f32 B=64, M=1024, H=16, K=16 | 8440.7 | 8424.4 | 24237.0 | 8722.5 f16 B=64, M=1024, H=16, K=32 | 2596.5 | 2612.0 | 7675.3 | 2634.5 f32 B=64, M=1024, H=16, K=32 | 8513.5 | 8494.5 | 26228.7 | 8811.5 f16 B=64, M=1024, H=16, K=64 | 2972.8 | 2991.8 | 8669.3 | 2993.2 f32 B=64, M=1024, H=16, K=64 | 10304.0 | 10271.7 | 30612.5 | 10667.6 f16 B=64, M=1024, H=16, K=128 | 5282.0 | 5266.6 | 10313.9 | 5398.9 f32 B=64, M=1024, H=16, K=128 | 22730.2 | 22940.7 | 39341.6 | 23517.9 f16 B=64, M=1024, H=16, K=256 | 12337.4 | 12354.8 | 13558.9 | 12272.2 f32 B=64, M=1024, H=16, K=256 | 45885.8 | 46488.2 | 71460.8 | 48531.7 ``` ghstack-source-id: dd55058ad3e1715515475029c8cd360be43d374d Pull Request resolved: https://github.com/fairinternal/xformers/pull/441 __original_commit__ = fairinternal/xformers@a264f71c0de5c5474e6ec63b7a3b5d9e4a2ec98c
- Loading branch information