Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

infer_auto_device_map inefficiently allocates GPU memory for models with imbalanced module sizes #3041

Open
2 of 4 tasks
Nech-C opened this issue Aug 25, 2024 · 7 comments
Open
2 of 4 tasks
Labels
wip Work in progress

Comments

@Nech-C
Copy link

Nech-C commented Aug 25, 2024

System Info

- `Accelerate` version: 0.33.0
- Platform: Windows-10-10.0.22631-SP0
- `accelerate` bash location: C:\Users\Nech\anaconda3\envs\transformer-multi-device\Scripts\accelerate.exe
- Python version: 3.11.9
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.4.0+cu118 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 15.86 GB
- GPU type: NVIDIA GeForce RTX 3060 Laptop GPU
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Steps to Reproduce

  1. Create a Segformer model from the Transformers library or any model with a disproportionally largest layer.
  2. Compute the model's module sizes.
  3. Set max_memory for CUDA device(s) to less than the size of the largest layer, but more than some smaller layers combined.
  4. Run infer_auto_device_map with `no_split_module_classes=[] and the defined max_memory.
  5. Observe the resulting device map.

Code example

from transformers import SegformerConfig, SegformerModel, SegformerForSemanticSegmentation
from accelerate import infer_auto_device_map
from accelerate.utils.modeling import compute_module_sizes

config = SegformerConfig(
    image_size=64,
    num_channels=3,
    num_encoder_blocks=4,
    depths=[1, 1, 1, 1],
    sr_ratios=[8, 4, 2, 1],
    hidden_sizes=[8, 8, 16, 16],
    num_attention_heads=[1, 1, 2, 2],
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    initializer_range=0.02,
)

# Segformer has a huge single layer that cannot be split
model = SegformerForSemanticSegmentation(config)
model_sizes = compute_module_sizes(model)
model_size = model_sizes[""]

# 0.7 is one of the split ratios defined for device offload tests in the Transformers library
split_ratio = 0.7
max_memory = {0: int(split_ratio  * model_size), "cpu": model_size * 2}

print(f"model size: {model_size}, max memory: {max_memory}")
print(
    infer_auto_device_map(
        model,
        max_memory=max_memory,
        no_split_module_classes=[]
    )
)

Output:

model size: 1195632, max memory: {0: 836942, 'cpu': 2391264}
OrderedDict([('', 'cpu')])

As shown above, no module is allocated to the GPU while there is enough space on it, and there are modules that can be allocated to it.

Expected behavior

The function should allocate modules to the GPU if possible. When it does allocate modules to the GPU, it should efficiently utilize the space in the GPU for a model with imbalanced module sizes.

Here is a simple breakdown of the module sizes of the Segformer model:

: 1195632
segformer: 87648
segformer.encoder: 87648
decode_head: 1107984
decode_head.linear_c: 53248
decode_head.linear_fuse: 1048576
decode_head.batch_norm: 4104
decode_head.classifier: 2056

When we increase the max_memory for the GPU by making the split ratio 0.9 (again, 0.9 is another split ratio used in the tests), some modules are allocated to the GPU in an insufficient way:

Toal model size: 1195632, max memory: {0: 1076068, 'cpu': 2391264}
OrderedDict([('segformer.encoder.patch_embeddings', 0),  
('segformer.encoder.block.0.0.layer_norm_1', 0),  
('segformer.encoder.block.0.0.attention.self.query', 0),  
('segformer.encoder.block.0.0.attention.self.key', 0),  
('segformer.encoder.block.0.0.attention.self.value', 0),  
('segformer.encoder.block.0.0.attention.self.dropout', 0),  
('segformer.encoder.block.0.0.attention.self.sr', 'cpu'), ('segformer.encoder.block.0.0.attention.self.layer_norm', 'cpu'), ('segformer.encoder.block.0.0.attention.output', 'cpu'), ('segformer.encoder.block.0.0.drop_path', 'cpu'), ('segformer.encoder.block.0.0.layer_norm_2', 'cpu'), ('segformer.encoder.block.0.0.mlp', 'cpu'), ('segformer.encoder.block.1', 'cpu'), ('segformer.encoder.block.2', 'cpu'), ('segformer.encoder.block.3', 'cpu'), ('segformer.encoder.layer_norm', 'cpu'), ('decode_head', 'cpu')])

The space allocated to the GPU is significantly less than the defined max_memory for the GPU for both the 0.7 and 0.9 split cases.

After looking into the infer_auto_device_map function, I believe the logic might not be working as intended for models with highly imbalanced module sizes like Segformer:

max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)

while len(modules_to_treat) > 0:
    name, module = modules_to_treat.pop(0)
    module_size = module_sizes[name]
    
    device = devices[current_device]
    current_max_size = max_memory[device] if device != "disk" else None
    current_memory_reserved = 0

    if devices[current_device] in main_devices:
        current_max_size = current_max_size - max_layer_size
        current_memory_reserved = max_layer_size
if current_max_size is not None and current_memory_used + module_size > current_max_size:
    # Split or not split?
    modules_children = (
        []
        if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
        else list(module.named_children())
    )
    if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
        # -> no split, we go to the next device
        device_memory_used[device] = current_memory_used + current_memory_reserved
        current_device += 1
        modules_to_treat = [(name, module)] + modules_to_treat
        current_memory_used = 0
    else:
        # -> split, we replace the module studied by its children + parameters
        modules_children = list(module.named_parameters(recurse=False)) + modules_children
        modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat

While trying to determine whether to allocate a module to the current device, this code reserves space for the largest layer on the current main device. In other words, the current device needs to have more memory than the size of module plus the size of the largest layer just for module to be allocated on it. For Segformer, where the decode_head (1,107,984 bytes) is significantly larger than other layers, this approach may be too conservative, leaving little room for other layers on the GPU.

The module is allocated to the current device when the condition is met. Otherwise, it will try to split the module or move to the next device when the module cannot be split. However, once it moves to the next device (i.e., CPU), it never goes back to the GPU, even if there's available space. This could explain why smaller modules aren't being allocated to the GPU after the decode_head is moved to the CPU.

I encountered this issue while working to enable device_map='auto' for some models in the Transformers library. Offload tests for those models fail because the entire model is allocated to the CPU or disk. I have reported this problem in this issue. Since I am unfamiliar with this library, I don't know if this is the expected behavior of the function. Thank you for reading this!

@Nech-C
Copy link
Author

Nech-C commented Aug 27, 2024

I'm sorry if my original issue was unclear. This is my first bug report. I have modified it; hopefully, it's easier to understand now. I will explain the problem in a few sentences.
When calling infer_auto_device_map, if the max_memory assigns the cuda device/cpu a max memory that is less than the largest layer size of the model, no layer will be allocated to it. The amount of memory used for a main device is <= defined max_memory - largest layer size. While this may not be an issue for most cases, it causes offload/parallelism tests in test_modeling_common.py from the Transformers library to fail since they use the model's size as a reference.
Can you guys tell me whether it's the expected behavior so that we can start enabling device_map="auto" for more models? Thank you so much!

@BenjaminBossan
Copy link
Member

Thanks for reporting this issue. I agree that this looks like there is room for improvement in order to allow as many modules as possible to be loaded on the fastest device. To me, this looks like a Knapsack problem, so finding an optimal solution could become quite interesting. But I'll let @muellerzr and @SunMarc comment on this, who have more background knowledge.

@SunMarc
Copy link
Member

SunMarc commented Aug 28, 2024

Hey @Nech-C, thanks for the detailed report ! You have a very good understanding of the situation. This could indeed be improved as infer_auto_device_map doesn't perform well for unbalanced models. This is something i'm aware of.
As you said, there are two points that can be improved and here are my thoughts:

  1. The current device needs to have more memory than the size of module plus the size of the largest layer just for module to be allocated on it.

This is required in case we perform cpu/disk offloading as we need to bring the largest offloaded layer to the gpu. One way to solve that is to create the device_map without the hypothesis that we will have offloaded layers. If we end up with offloaded layers, we redo the calculation with that hypothesis. Another solution would be check if the memory of the model < memory of the gpus. If that's not the case, we do the calculation without the hypothesis. However, we might still face issues with unbalanced models.

  1. However, once it moves to the next device (i.e., CPU), it never goes back to the GPU, even if there's available space.

We can improve that part indeed. Since most transformers have balanced modules, It was working fine. For example, we can still consider coming back to the previous device if it has at least 10% of space that is available. The reasoning behind moving to the next device was to limit movement across devices as this will make inference slower. 1->2->3 and not 1->2->1->2->3.

If you are up to the challenge, feel free to open a PR to fix those two points ! I can have a look later !

@Nech-C
Copy link
Author

Nech-C commented Aug 29, 2024

@SunMarc Thank you so much for your detailed response! I appreciate your insights into the problem. I'd love to take on this challenge. It might take me a little time to get up to speed with the library, but I'm excited to give it a try. Can I reach out with any questions as I work on this?

@SunMarc
Copy link
Member

SunMarc commented Aug 29, 2024

Nice, thanks for helping ! Yes, feel free to ask any questions !

@Nech-C
Copy link
Author

Nech-C commented Aug 31, 2024

Hi @SunMarc,

I've been digging into the code, and this is more complicated than I first thought. I agree that conditionally calculating the device_map may not fully solve the problem. So, I think we can address this in two separate PRs.

PR no. 1 (quick fix):

  1. Add warnings when no modules are assigned to a main device due to low max_memory.
  2. Report the minimum memory needed for at least one module assignment with the warnings. For example, according to the current logic, this value will be the (first immediate non-splittable module) + (the largest layer) for the first device.
  3. Add a new parameter fallback_allocation. When set to True, it will attempt an alternative assignment if max_memory is sufficient for some (non-splittable module) + (largest layer) but insufficient for the default assignment attempt. This makes sure at least one module is assigned to the potential execution device and likely won't break other code.

PR no. 2 (optimization):
Work on your idea about utilizing space on main devices more efficiently. We'd add a new parameter so users can choose to maximize main device use.

Do you think this approach sounds good to you? If you agree, I can start working on the first pr soon. Let me know if you have any suggestions or concerns about this plan!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@BenjaminBossan BenjaminBossan added the wip Work in progress label Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wip Work in progress
Projects
None yet
Development

No branches or pull requests

3 participants