Skip to content

ethansmith2000/ImprovedTokenMerge

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

title emoji app_file sdk sdk_version
ToDo
⚡️
app.py
gradio
4.19.2

ToDo: Token Downsampling for Efficient Generation of High-Resolution Images


We provide a HuggingFace Spaces demo for our recently proposed method, "ToDo: Token Downsampling for Efficient Generation of High-Resolution Images", and compare it against a popular token merging method called ToMe.

If you consider our research to be helpful, please consider citing us:

@misc{smith2024todo,
      title={ToDo: Token Downsampling for Efficient Generation of High-Resolution Images}, 
      author={Ethan Smith and Nayan Saxena and Aninda Saha},
      year={2024},
      eprint={2402.13573},
      archivePrefix={arXiv}
}

GEuoFn1bMAABQqD

blog post: https://sweet-hall-e72.notion.site/ToDo-Token-Downsampling-for-Efficient-Generation-of-High-Resolution-Images-b41be1ac8ddc46be8cd687e67dee2d84?pvs=4

hf demo: https://huggingface.co/spaces/aningineer/ToDo

heavily inspired by https://github.com/dbolya/tomesd by @dbolya, a big thanks to the original authors.

This project aims to adress some of the shortcomings of Token Merging for Stable Diffusion. Namely consistenly faster inference without quality loss. I found with the original that you would have to use a high merging ratio to get really any speedups at all, and by then quality was tarnished. Benchmarks here: dbolya/tomesd#19 (comment)

I propose two changes to the original to solve this.

  1. Merging Method
    • the original calculates a similarity matrix of the input tokens and merges those with highest similarity
    • an issue here is that similarity calculation is O(n2) time, for ViT where token merging was proposed, you only had to do this a few times so it was quite efficient
    • here it needs to be done at every step, and the computation ends up being nearly as costly as attention itself
    • We can leverage a simple obsevation that nearby tokens tend to be similar to each other.
    • therefore we can merge tokens via downsampling which is very cheap and seems to be a good approximation
    • this can be analogized to grid-based subsampling of an image when using a nearest-neighbor downsample method, this is similar to what DiNAT (dilated neigborhood attention) does except for the fact we are still making use of global context
  2. Merge Targets
    • the original merges the input tokens to attention, and then "unmerges" the resulting tokens to the original size
    • this operation seems to be quite lossy
    • instead i propose simply downsampling keys/values of the attention operation. both the QK calculation and QK * V can still drastically be reduced from the typical O(n2) scaling of attention, without needing to unmerge anything
    • queries are left fully intact, they just attend more sparsely to the image
    • attention for images, especially at larger resolutions, seems to be very sparse in general (QK matrix is low rank) so it does not appear that we lose too much from this

putting this altogether we can get tangible speedups of ~1.5x at typical sizes like 768-1024 and up to 3x and beyond at 1536 to 2048 range, in combination with flash attention

Setup 🛠

pip install -r requirements.txt

Inference 🚀

See the provided notebook, or gradio demo which you can run with python app.py

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published