Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XFormers massively increases performance and improves memory usage #80

Closed
hafriedlander opened this issue Dec 24, 2022 · 15 comments
Closed

Comments

@hafriedlander
Copy link
Collaborator

hafriedlander commented Dec 24, 2022

Without xformers, I can only do a batch of 1 (and no prior preservation) when training a unet + text encoder on my 24GB 4090, even with all other tricks (fp16, 8 bit adam, can't use gradient checkpointing because it currently doesn't work correctly).

With xformers enabled, I can do a batch of 2 w/ prior preservation (so 4 images total per batch), and the performance is still the same as the batch of 1 without it.

So, the bad news: backwards support in xformers is bad. It mostly only works on A100s, on other cards it's variable. On my 4090 it works fine for SD2.1 models, but not on SD1.5 models (although I've got it working via a nasty hack
facebookresearch/xformers@27dadbf) - it also doesn't work on Colab.

Anyway, I'll add the option to turn it on in the dreambooth scripts after christmas, but there'll be plenty of issues to sort out. Just wanted to document this in the meantime.

@hafriedlander hafriedlander changed the title XFormers massively increases performance and memory usage XFormers massively increases performance and improves memory usage Dec 24, 2022
@brian6091
Copy link
Collaborator

I'm running with xfomers and SD 1.5 on Colab (or at least I think so). What is the problem you're seeing?

@hafriedlander
Copy link
Collaborator Author

Without the hack, CUDA throws an "Operator doesn't exist" error. Interesting it works on Colab, what's the GPU there?

Are you turning it on manually in a custom dreambooth script @brian6091 or using one of the Diffusers versions that turn it on automatically?

@hafriedlander
Copy link
Collaborator Author

It's definitely an xformers issue BTW - facebookresearch/xformers#563. (A lot of users of my SD server run into it, because CLIP Guidance runs backwards to collect a guidance gradient)

@brian6091
Copy link
Collaborator

I've been running on T4 and A100. I've added the following:

if is_xformers_available():
    try:
        unet.enable_xformers_memory_efficient_attention()
    except Exception as e:
        logger.warning(
            "Could not enable memory efficient attention. Make sure xformers is installed"
            f" correctly and a GPU is available: {e}"
        )

which I believe was added to the official diffusers examples.

@brian6091
Copy link
Collaborator

Are you building the xformers wheel yourself?

@hafriedlander
Copy link
Collaborator Author

Yes, Windows and Linux. A100 you shouldn't have a problem (that's the only card that's officially supported as "this will work"). Interesting the T4 works too.

@brian6091
Copy link
Collaborator

Hmm, it's been awhile since I've tried with a T4. I'll give it another shot tonight.

@Thomas-MMJ
Copy link

Thomas-MMJ commented Dec 24, 2022

hafriedlander - xformers backwards with the 'hack' (force I and J to 64, Mmas to false) to enable cutlass_backwards works here on 3060 for 1.5 using LoRA in sd_dreambooth_extension. Thus should work on all sm86 devices since they have the same shared memory sizes.

kBlockSizeJ = 64;
kBlockSizeI = 64;
kPreloadMmas = false;

Haven't figured out how to get the cudaDeviceProp information properly during compile time, neither the access method for C nor C++ see to be working for me (though my C is extremally rusty, and my C++ is almost non-existent, and my CUDA knowledge is non-existent).

Once I figure that out I can do,

  static constexpr bool kIsAmpere = p->major * 10 + p->minor == 86;
  static constexpr bool kIsLoveLace = p->major * 10 + p->minor == 89;

then use those in the checks to set the J, K, and Mmas

Another way I considered is getting the shared memory availability directly.

@78Alpha
Copy link

78Alpha commented Dec 25, 2022

Using the optimizations from https://github.com/d8ahazard/sd_dreambooth_extension I tend to run LORA with about 6 - 6.4 GB VRAM. So there are a lot of savings available.

@hafriedlander
Copy link
Collaborator Author

hafriedlander commented Dec 25, 2022

@Thomas-MMJ I actually ran into a problem doing it that way - the code "worked" (stopped throwing CUDA errors), but the result was bad. The eventual result was just latent noise.

I figured out another way to get most of the benefit, which is to only enable xformers on CrossAttention modules where K <= 64. This is basically the top & bottom of the unet (about 1/3rd of the CrossAttention modules in total) but enough that I can run batch_size 4 on unet+text_encoder in 24GB without gradient_checkpointing.

@Thomas-MMJ
Copy link

Thomas-MMJ commented Dec 25, 2022

@Thomas-MMJ I actually ran into a problem doing it that way - the code "worked" (stopped throwing CUDA errors), but the result was bad. The eventual result was just latent noise.

Did you run the pytests? It should detect if there is a significant difference between the xformer result and without xformers result,

pytest -k "test_backward[cutlass" ./tests/test_mem_eff_attention.py

Also did you try with kPreloadMmas = false;

Also did you do python setup.py clean and remove all pycaches?

@hafriedlander
Copy link
Collaborator Author

No, didn't run tests. I'll do it later, and try with kPreloadMmas = false too. I'm doing everything in fresh dockers, so no pycaches / old builds.

Lower priority for me now, since this is working nicely for me (more than 4x max images per batch), and doesn't need a recompile.

def set_use_memory_efficient_attention_xformers(
    module: torch.nn.Module, valid: bool, dim_head_max: int = 0
) -> None:
    def fn_set_mem_eff_lim(module: torch.nn.Module):
        if isinstance(module, BasicTransformerBlock):
            # dim_head isn't stored anywhere, so back-calculate
            source = module.attn1.to_v
            if isinstance(source, LoraInjectedLinear):
                source = source.linear

            dim_head = source.out_features / module.attn1.heads

            # If dim_head > dim_head_max, turn xformers off
            if dim_head > dim_head_max:
                module.set_use_memory_efficient_attention_xformers(False)

        for child in module.children():
            fn_set_mem_eff_lim(child)

    module.set_use_memory_efficient_attention_xformers(valid)
    if dim_head_max:
        fn_set_mem_eff_lim(module)

(and setting dim_head_max to 64)

@brian6091
Copy link
Collaborator

Yes, Windows and Linux. A100 you shouldn't have a problem (that's the only card that's officially supported as "this will work"). Interesting the T4 works too.

Just following up to confirm that enabling xformers works fine on the T4 in Colab.

@hafriedlander
Copy link
Collaborator Author

Added #103 which tests the head dimension for CrossAttention blocks and only enables xformers for that block if it works backwards at that size.

@hafriedlander
Copy link
Collaborator Author

Closing now #103 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants