Skip to content

torch.compile support, bug fixes & more

Compare
Choose a tag to compare
@lw lw released this 25 Jul 11:59
· 115 commits to main since this release

Pre-built binary wheels require PyTorch 2.4.0

Added

  • fMHA: PagedBlockDiagonalGappyKeysMask
  • fMHA: heterogeneous queries in triton_splitk
  • fMHA: support for paged attention in flash
  • fMHA: Added backwards pass for merge_attentions
  • fMHA: Added torch.compile support for 3 biases (LowerTriangularMask, LowerTriangularMaskWithTensorBias and BlockDiagonalMask) - some might require PyTorch 2.4
  • fMHA: Added torch.compile support in memory_efficient_attention when passing the flash operator explicitely (eg memory_efficient_attention(..., op=(flash.FwOp, flash.BwOp)))
  • fMHA: memory_efficient_attention now expects its attn_bias argument to be on the same device as the other input tensor. Previously, it would convert the bias to the right device.
  • fMHA: AttentionBias subclasses are now constructed by default on the cuda device if available - they used to be created on the CPU device
  • 2:4 sparsity: Added xformers.ops.sp24.sparsify24_ste for Straight Through Estimator (STE) with options to rescale the gradient differently for masked out/kept values

Improved

  • fMHA: Fixed out-of-bounds reading for Split-K triton implementation
  • Profiler: fix bug with modules that take a single tuple as argument
  • Profiler: Added manual trigger for a profiling step, by creating a trigger file in the profiling directory

Removed

  • Removed support for PyTorch version older than 2.2.0