Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate SDXL VAE using NATTEN local neighbourhood attention #3

Merged
merged 4 commits into from
Oct 30, 2023

Conversation

Birch-san
Copy link
Owner

@Birch-san Birch-san commented Oct 28, 2023

Accelerate SDXL VAE using NATTEN local neighbourhood attention

Thanks @crowsonkb for the idea!

Install natten from source like so:

git clone https://github.com/SHI-Labs/NATTEN.git
cd NATTEN
pip install cmake==3.20.3
CUDACXX=/usr/local/cuda/bin/nvcc make CUDA_ARCH="8.9" WORKERS=2

or get latest stable from pip:

pip install natten -f https://shi-labs.com/natten/wheels/cu121/torch2.1.0/index.html

Input image:
in

Output image after VAE round-trip (global self-attention):
out expected

Output image after VAE round-trip (local neighbourhood attention, kernel size 17):
out outproj17fused
This looks identical to global self-attention, whilst requiring far less memory and compute.

Output image after VAE round-trip (local neighbourhood attention, kernel size 3):
out natten3
This looks nearly identical to global self-attention, requiring even less memory and compute.

Null attention

It looks like there's not actually much similarity-based mixing of information between tokens. so what happens if we just drop scaled dot product attention altogether?

so instead of:
softmax(Q•K.mT*scale)•V•O

we just do:
V•O

Output image after VAE round-trip (null attention):
out null

It's not identical to global self-attention, but it's pretty close. and far cheaper & more scalable than global self-attention (in compute and in memory).

@Birch-san Birch-san changed the title Neighbourhood vae Accelerate SDXL VAE using NATTEN local neighbourhood attention Oct 28, 2023
…vs/p311-sdxl/bin/python /home/birch/.vscode-server/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher 58517 -- -m scripts.vae_roundtrip
@Birch-san Birch-san merged commit 5ad5bdd into main Oct 30, 2023
@Birch-san
Copy link
Owner Author

Birch-san commented Aug 5, 2024

okay so installing this again was a wild ride. maybe it got harder in latest version of cmake (v3.30.2)? maybe specifically because FindCUDA got removed?

CMake Warning (dev) at CMakeLists.txt:11 (find_package):
  Policy CMP0146 is not set: The FindCUDA module is removed.  Run "cmake
  --help-policy CMP0146" for policy details.  Use the cmake_policy command to
  set the policy and suppress this warning.

This warning is for project developers.  Use -Wno-dev to suppress it.

Anyway, cmake was failing build on pretty much the very first line, project(natten LANGUAGES CXX CUDA). it would say CMAKE_CUDA_ARCHITECTURES must be non-empty if set.
This error message is misleading, and probably just means nvcc compilation failed.

To get a look at the true error, I modified setup.py to pass in a CMAKE_CUDA_ARCHITECTURES option:

  f"-DNATTEN_CUDA_ARCH_LIST={cuda_arch_list_str}",
+ f"-DCMAKE_CUDA_ARCHITECTURES={cuda_arch_list_str}",

This got it to stop complaining about an empty CMAKE_CUDA_ARCHITECTURES variable, and attempt an nvcc build. Which failed with:

  /usr/local/cuda/bin/../targets/x86_64-linux/include/crt/host_config.h:136:2:
  error: #error -- unsupported GNU version! gcc versions later than 12 are
  not supported! The nvcc flag '-allow-unsupported-compiler' can be used to
  override this version check; however, using an unsupported host compiler
  may cause compilation failure or incorrect run time execution.  Use at your
  own risk.

nvcc compilation failed because my gcc and g++ didn't point anywhere. So I used update-alternatives to point to GCC/G++ version 12:
https://askubuntu.com/questions/26498/how-to-choose-the-default-gcc-and-g-version

sudo update-alternatives --config gcc
sudo update-alternatives --config g++

each of these commands told me that it'd been broken until now:

update-alternatives: warning: forcing reinstallation of alternative /usr/bin/gcc-12 because link group gcc is broken
update-alternatives: warning: forcing reinstallation of alternative /usr/bin/g++-12 because link group g++ is broken

On the next build attempt, it complained of:

CMake Error at /home/birch/git/sdxl-play/venv-311/lib/python3.11/site-packages/cmake/data/share/cmake-3.30/Modules/FindPackageHandleStandardArgs.cmake:233 (message):
  Could NOT find CUDA (missing: CUDA_CUDART_LIBRARY) (found suitable version
  "12.2", minimum required is "11.0")

So I made a different change to setup.py (I didn't need to keep the -DCMAKE_CUDA_ARCHITECTURES change now that I'd gotten past that error):

  f"-DNATTEN_CUDA_ARCH_LIST={cuda_arch_list_str}",
+ f"-DCUDA_CUDART_LIBRARY=/usr/local/cuda/lib64/libcudart.so",

This finally got it compiling!
Oh. Linking failed.

/home/birch/git/sdxl-play/venv-311/lib/python3.11/site-packages/cmake/data/bin/cmake -E cmake_link_script CMakeFiles/natten.dir/link.txt --verbose=1
/usr/bin/c++ -fPIC  -std=c++17 -shared -Wl,-soname,natten/libnatten.cpython-311-x86_64-linux-gnu.so -o natten/libnatten.cpython-311-x86_64-linux-gnu.so … -lcudart /usr/local/cuda/lib64/libcudart.so /usr/local/cuda/lib64/libnvToolsExt.so -lcudadevrt -lcudart_static -lrt -lpthread -ldl
/usr/bin/ld: cannot find -lcudart: No such file or directory
/usr/bin/ld: cannot find -lcudadevrt: No such file or directory
/usr/bin/ld: cannot find -lcudart_static: No such file or directory

=====

to get a better look at the command that python setup.py, I built it the way they build official distributions. first needed wheel:

pip install wheel

Then I could use python setup.py bdist_wheel:

NATTEN_CUDA_ARCH=8.9 NATTEN_VERBOSE=0 NATTEN_IS_BUILDING_DIST=1 NATTEN_WITH_CUDA=1 NATTEN_N_WORKERS=4 python setup.py bdist_wheel -d out/wheels/cu121/torch/240

The CUDA and torch versions were determined like so:

from torch import version, __version__
from sys import version_info

# https://github.com/SHI-Labs/NATTEN/blob/7bc099de2a307e23903bb4f8ca1ca36c9df54cef/setup.py#L121
cuda_tag = "".join(version.cuda.split(".")[:2])
torch_tag = "".join(__version__.split("+", maxsplit=1)[0].split("."))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant