Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

--upcast-sampling support for CUDA #8782

Merged
merged 7 commits into from
Mar 25, 2023
Merged

Conversation

FNSpd
Copy link
Contributor

@FNSpd FNSpd commented Mar 21, 2023

Should speed up image generation on GTX 10XX and 16XX when using --upcast-sampling and fix error when using "Upcast cross attention layer to float32" with xformers

@FNSpd FNSpd requested a review from AUTOMATIC1111 as a code owner March 21, 2023 11:02
@fractal-fumbler
Copy link

fractal-fumbler commented Mar 21, 2023

not an 10xx or 16xx nvidia, but getting error mixed dtype (CPU): expect parameter to have scalar type of Float with --xformers --xformers-flash-attention --opt-channelslast --listen --upcast-sampling --no-half-vae --api --enable-insecure-extension-access

Traceback (most recent call last):                                                                                      
  File "/tmp/stable-diffusion-webui/modules/call_queue.py", line 56, in f                                               
    res = list(func(*args, **kwargs))                                                                                   
  File "/tmp/stable-diffusion-webui/modules/call_queue.py", line 37, in f                                               
    res = func(*args, **kwargs)                                                                                         
  File "/tmp/stable-diffusion-webui/modules/txt2img.py", line 59, in txt2img                                            
    processed = process_images(p)                                                                                       
  File "/tmp/stable-diffusion-webui/modules/processing.py", line 486, in process_images                                 
    res = process_images_inner(p)                                                                                       
  File "/tmp/stable-diffusion-webui/modules/processing.py", line 678, in process_images_inner                           
    samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_stren
gth=p.subseed_strength, prompts=prompts)                                                                                
  File "/tmp/stable-diffusion-webui/modules/processing.py", line 892, in sample                                         
    samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_ima
ge_conditioning(x))                                                                                                     
  File "/tmp/stable-diffusion-webui/modules/sd_samplers_kdiffusion.py", line 353, in sample                             
    samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={                        
  File "/tmp/stable-diffusion-webui/modules/sd_samplers_kdiffusion.py", line 229, in launch_sampling                    
    return func()                                                                                                       
  File "/tmp/stable-diffusion-webui/modules/sd_samplers_kdiffusion.py", line 353, in <lambda>                           
    samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={                        
  File "/usr/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context                    
    return func(*args, **kwargs)                                                                                        
  File "/tmp/stable-diffusion-webui/repositories/k-diffusion/k_diffusion/sampling.py", line 145, in sample_euler_ancestr
al                                                                                                                      
    denoised = model(x, sigmas[i] * s_in, **extra_args)                                                                 
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                         
    return forward_call(*args, **kwargs)                                                                                
  File "/tmp/stable-diffusion-webui/modules/sd_samplers_kdiffusion.py", line 121, in forward                            
    x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})              
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                         
    return forward_call(*args, **kwargs)                                                                                
  File "/tmp/stable-diffusion-webui/repositories/k-diffusion/k_diffusion/external.py", line 167, in forward
    return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
  File "/tmp/stable-diffusion-webui/repositories/k-diffusion/k_diffusion/external.py", line 177, in get_v
    return self.inner_model.apply_model(x, t, cond)
  File "/tmp/stable-diffusion-webui/modules/sd_hijack_utils.py", line 17, in <lambda>
    setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
  File "/tmp/stable-diffusion-webui/modules/sd_hijack_utils.py", line 26, in __call__
    return self.__sub_func(self.__orig_func, *args, **kwargs)
  File "/tmp/stable-diffusion-webui/modules/sd_hijack_unet.py", line 45, in apply_model
    return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
  File "/tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py", line 858, 
in apply_model
    x_recon = self.model(x_noisy, t, **cond)
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    out = self.diffusion_model(x, t, context=cc)
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py", line 778, in forward
    h = module(h, emb, context)
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py", line 82, in forward
    x = layer(x, emb)
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py", line 249, in forward
    return checkpoint(
  File "/tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/util.py", line 114, in checkpoint
    return CheckpointFunction.apply(func, len(inputs), *args)
  File "/usr/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/util.py", line 129, in forward
    output_tensors = ctx.run_function(*ctx.input_tensors)
  File "/tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py", line 262, in _forward
    h = self.in_layers(x)
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/util.py", line 219, in forward
    return super().forward(x.float()).type(x.dtype)
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 273, in forward
    return F.group_norm(
  File "/usr/lib/python3.10/site-packages/torch/nn/functional.py", line 2530, in group_norm
    return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: mixed dtype (CPU): expect parameter to have scalar type of Float

@FNSpd
Copy link
Contributor Author

FNSpd commented Mar 22, 2023

Is this with my changes or with current version using --upcast-sampling? Also, which PyTorch version do you have? I didn't test on 2.0. Only 1.13.1

@Shadowghost
Copy link

I can confirm the error on current master with pytorch 2.0 and --upcast-sampling on a GTX 1070.

@FNSpd
Copy link
Contributor Author

FNSpd commented Mar 22, 2023

I'm a few commits behind master. I'll test with PyTorch 2 to see if it'll work on mine version. Maybe, something breaks it in recent updates

@brkirch
Copy link
Collaborator

brkirch commented Mar 24, 2023

Looks good, except fully disabling autocast shouldn't be necessary. If there is some problematic section of code it can be disabled for that specifically, but otherwise fully disabling autocast should be optional (with --precision full). I use MPS so I can't test CUDA autocast and verify everything is working as intended, but with --upcast-sampling autocast is supposed to be disabled outside the unet (for sampling) and enabled inside the unet.

@FNSpd
Copy link
Contributor Author

FNSpd commented Mar 24, 2023

Looks good, except fully disabling autocast shouldn't be necessary. If there is some problematic section of code it can be disabled for that specifically, but otherwise fully disabling autocast should be optional (with --precision full). I use MPS so I can't test CUDA autocast and verify everything is working as intended, but with --upcast-sampling autocast is supposed to be disabled outside the unet (for sampling) and enabled inside the unet.

I changed it to not confuse people and make one argument dependant on the other but yeah, --precision full works too. Without it GTX 1650 used to emulate FP16 (2 times slower than with --precision full)

@FNSpd
Copy link
Contributor Author

FNSpd commented Mar 24, 2023

So, looks like I found fix for PyTorch 2 issue but speed is the same as without upcast. I'll include it just in case but I would suggest to use Torch 1.13.1 to GTX 10XX and 16XX users

@brkirch
Copy link
Collaborator

brkirch commented Mar 24, 2023

I changed it to not confuse people and make one argument dependant on the other but yeah, --precision full works too. Without it GTX 1650 used to emulate FP16 (2 times slower than with --precision full)

It’s more an issue of if this breaks ROCm support for some users; extensive testing was done in #6510 and it is important that we don’t break support for other users if all that is otherwise needed is a note in the documentation that says “try with --precision full for better performance on some GPUs”.

@FNSpd
Copy link
Contributor Author

FNSpd commented Mar 24, 2023

I changed it to not confuse people and make one argument dependant on the other but yeah, --precision full works too. Without it GTX 1650 used to emulate FP16 (2 times slower than with --precision full)

It’s more an issue of if this breaks ROCm support for some users; extensive testing was done in #6510 and it is important that we don’t break support for other users if all that is otherwise needed is a note in the documentation that says “try with --precision full for better performance on some GPUs”.

Oh, I didn't know that. Then I'll revert it, yeah

@Shadowghost
Copy link

I tested the current code with 1.13 instead of 2.0 and basic generation works but LoRAs do crash the pipeline with them providing half but float is expected (I'll post the specific message later)

@FNSpd
Copy link
Contributor Author

FNSpd commented Mar 24, 2023

I tested the current code with 1.13 instead of 2.0 and basic generation works but LoRAs do crash the pipeline with them providing half but float is expected (I'll post the specific message later)

Did you add my changes to lora.py?

@Shadowghost
Copy link

Seems like something went wrong while merging your changes. On a clean clone it works as intended, sorry for that.

@AUTOMATIC1111 AUTOMATIC1111 merged commit 03c8eef into AUTOMATIC1111:master Mar 25, 2023
brkirch pushed a commit to brkirch/stable-diffusion-webui that referenced this pull request Apr 5, 2023
--upcast-sampling support for CUDA
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants