Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warnings and fallback for unassigned devices in infer_auto_device_map #3066

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

Nech-C
Copy link

@Nech-C Nech-C commented Sep 1, 2024

What does this PR do?

This PR is proposed changes to the infer_auto_device_map function from #3041. It will make the following improvements:

  1. Add warnings when no modules are assigned to a main device due to low max_memory.
  2. Report the minimum memory needed for at least one module assignment with the warnings. For example, according to the current logic, this value will be the (first immediate non-splittable module) + (the largest layer) for the first device.
  3. Add a new parameter fallback_allocation. When set to True, it will attempt an alternative assignment if max_memory is sufficient for some (non-splittable module) + (largest layer) but insufficient for the default assignment attempt. This makes sure at least one module is assigned to the potential execution device and likely won't break other code.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

The fallback allocation will be reintroduced once the branching logic is fully refactored. This commit prepares the function infer_auto_device_map for further refactoring.
Implemented fallback allocation to allow modules to be allocated to devices using BFS when regular allocation fails. This enhancement improves the allocation process by ensuring that at least one module is assigned to the device, even under tight memory constraints.
@Nech-C Nech-C marked this pull request as ready for review October 14, 2024 01:40
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates and the style fix. I'm not very knowledgeable about the whole logic being applied here, so I won't comment on that.

Personally, I find the use of continue in addition to many nested conditionals makes the logic super hard to follow. Usually, I would try to stick to either if + continue or if ... elif ... without continue. Not sure if the code could be simplified here.

One thing I believe we should ensure is that the new logic does not add any unnecessary warnings. Right now, we have some unit tests to ensure that specific warnings are there, but AFAICT we don't have tests to ensure that for other cases, there are no warnings. Maybe it would be good to add tests for the "happy path" and show that there is no warning. Potentially, we can even use existing tests and just add a check there is no warning. WDYT?

src/accelerate/utils/modeling.py Outdated Show resolved Hide resolved
tests/test_modeling_utils.py Outdated Show resolved Hide resolved
test_infer_auto_device_map and test_infer_auto_device_map_with_fallback_allocation now each have a no-warning test case.

Simplified and rewrote code sections that were made unreadable by the linter.
Added complete return type hinting for _init_infer_auto_device_map
@Nech-C
Copy link
Author

Nech-C commented Oct 14, 2024

Thanks for the updates and the style fix. I'm not very knowledgeable about the whole logic being applied here, so I won't comment on that.

Personally, I find the use of continue in addition to many nested conditionals makes the logic super hard to follow. Usually, I would try to stick to either if + continue or if ... elif ... without continue. Not sure if the code could be simplified here.

One thing I believe we should ensure is that the new logic does not add any unnecessary warnings. Right now, we have some unit tests to ensure that specific warnings are there, but AFAICT we don't have tests to ensure that for other cases, there are no warnings. Maybe it would be good to add tests for the "happy path" and show that there is no warning. Potentially, we can even use existing tests and just add a check there is no warning. WDYT?

Hey @BenjaminBossan, I appreciate your feedback!

Regarding the use of continue and nested conditionals, I've tried simplifying the logic where possible. Now the while loop in infer_auto_device_map uses if + continue for branching. However, there are more continue statements in the code, and many of them come from the original implementation. If you think it's necessary to address those, I will take a look at those and see what I can do.

I completely agree with your point about avoiding unnecessary warnings. I've added checks in both test_infer_auto_device_map_with_fallback_allocation and test_infer_auto_device_map to verify that no unexpected warnings are raised in the 'happy path' cases.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning the code up and extending the tests.

I agree that the logic was already complex beforehand so it's not just because of this PR. But I think your recent changes helped a little bit to make it easier to understand, even if the overall complexity is still high and I can't say I understand all that's going on.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! This will be very handy. cc @SunMarc for a final look since it's big model inference :)

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @Nech-C ! Really appreciate that you are putting a lot of effort into this PR ! I will review it soon but first I have a question: could you explain a bit with an example of what this fallback allocation will do ? From our conversation last time, the biggest issue with infer_auto_device_map is that we are saving memory for the largest layer in case we need to offload it to the cpu. I think that in your case, you are trying to find a module that fits the device memory - largest layer ?

@Nech-C
Copy link
Author

Nech-C commented Oct 15, 2024

Thanks for the PR @Nech-C ! Really appreciate that you are putting a lot of effort into this PR ! I will review it soon but first I have a question: could you explain a bit with an example of what this fallback allocation will do ? From our conversation last time, the biggest issue with infer_auto_device_map is that we are saving memory for the largest layer in case we need to offload it to the cpu. I think that in your case, you are trying to find a module that fits the device memory - largest layer ?

Hi @SunMarc, sure thing!
Let's consider a model on which we run infer_auto_device_map. The model has three non-splittable modules: A, B, and C, with sizes 2, 1, and 3, respectively. max_memory is defined as {0: 4, "cpu": 6}. Without fallback allocation, the entire model will be allocated to the CPU (main memory) because the combined size of A and C (2+3 = 5) exceeds the memory limit of device 0 (=4), so device 0 gets skipped. When fallback_allocation is set to True, it uses DFS to find a module in the model that satisfies the size constraint (module size + largest layer size < max memory). The resultant device map will look like this: {A: 0, B: "cpu", C: "cpu"}. This way, the faster hardware can be used during the inference without causing OOM.

You are right. My code doesn't directly address the issue that the function may reserve space on a device for a module that won't be loaded onto it during inference when there are multiple execution devices. I have tried to come up with new allocation strategies, but the task is really complex. If possible, I would like to open a separate PR to address this issue when I come up with a reasonable solution.

While this PR doesn't solve the most significant concern, it does alleviate the problem. The constraint for allocating a module to a device is roughly module size + max layer size <= device memory. The aforementioned issue focuses on lowering max layer size, and this PR focuses on lowering module size by looking for a smaller module in the module list. It tries to assign a module to a device that receives no assignment when the regular allocation logic fails. In theory, a device can be used during execution if it has more memory than the largest layer, even with no module assigned to it. Thus, we can achieve the same result without going through such a roundabout approach. However, I believe this would be a breaking change, as we need the returned value device_map to include this information, and it also requires considerable changes in other code, such as the dispatch_model function.

Thanks for your feedback. I'm open to further suggestions or clarifications if needed.

@SunMarc
Copy link
Member

SunMarc commented Oct 16, 2024

Nice explanation @Nech-C ! Thanks for confirming !

The constraint for allocating a module to a device is roughly module size + max layer size <= device memory. The aforementioned issue focuses on lowering max layer size, and this PR focuses on lowering module size by looking for a smaller module in the module list.

I think that a quick solution to the max_layer size issue would be the following algorithm

    1. run the infer_auto_device_map with max layer size = 0
    1. Check if we have offloaded layers.
      a) If not, we have our final device_map
      b) Else, we redo the computation without removing max layer size

We can add your fallback option each time we run infer_auto_device_map if wanted.

This will help fixing this following issue I saw a couple of time:
The model has three non-splittable modules: A, B, and C, with sizes 2, 1, and 3, respectively. max_memory is defined as {0: 4, 1:10, "cpu": 6}. With the current flow, the device_map will be {A: 1, B: 1, C: 1} since A+C = 5 > 4 whereas with the above algorithm, we will have {A: 0, B: 0, C: 1}. Your fallback option could help in the first iteration if A size is 5. In this case, the end device_map would be {A: 1, B: 0, C: 0} instead of {A: 1, B: 1, C: 1}.

Let me know what you think !

Nevertheless, I think it will be nice to first merge this PR before moving the max_layer size fix.

@Nech-C
Copy link
Author

Nech-C commented Oct 16, 2024

Nice explanation @Nech-C ! Thanks for confirming !

The constraint for allocating a module to a device is roughly module size + max layer size <= device memory. The aforementioned issue focuses on lowering max layer size, and this PR focuses on lowering module size by looking for a smaller module in the module list.

I think that a quick solution to the max_layer size issue would be the following algorithm

    1. run the infer_auto_device_map with max layer size = 0
    1. Check if we have offloaded layers.
      a) If not, we have our final device_map
      b) Else, we redo the computation without removing max layer size

We can add your fallback option each time we run infer_auto_device_map if wanted.

When running the infer_auto_device_map, we can add the fallback option of course.

This will help fixing this following issue I saw a couple of time: The model has three non-splittable modules: A, B, and C, with sizes 2, 1, and 3, respectively. max_memory is defined as {0: 4, 1:10, "cpu": 6}. With the current flow, the device_map will be {A: 1, B: 1, C: 1} since A+C = 5 > 4 whereas with the above algorithm, we will have {A: 0, B: 0, C: 1}. Your fallback option could help in the first iteration if A size is 5. In this case, the end device_map would be {A: 1, B: 0, C: 0} instead of {A: 1, B: 1, C: 1}.

Let me know what you think !

Nevertheless, I think it will be nice to first merge this PR before moving the max_layer size fix.

Ohhh, now I get it @SunMarc . Thanks for breaking it down. Working on the code really helped me understand your idea. TBH, I didn't fully understand it when you first mentioned it in the issue 😅. Your algorithm idea sounds solid. I'm on board with merging this PR first, then tackling the max_layer size fix.

And how should I proceed with the max_layer fix? Do I just open a new PR referencing the original issue, or do we need a new issue for this?

Also, just a heads up, I've got a couple of busy weeks coming up, so I may not be able to start working on this right away. But I'll definitely get to it as soon as I can.

Any tweaks you want me to make to this PR before we move on?

@SunMarc
Copy link
Member

SunMarc commented Oct 17, 2024

Also, just a heads up, I've got a couple of busy weeks coming up, so I may not be able to start working on this right away. But I'll definitely get to it as soon as I can.

Any tweaks you want me to make to this PR before we move on?

Sounds good ! I'll try to review this today !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants