Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make env variables optional for FSDP #2998

Merged
merged 10 commits into from
Aug 12, 2024
Merged

Make env variables optional for FSDP #2998

merged 10 commits into from
Aug 12, 2024

Conversation

muellerzr
Copy link
Collaborator

Remove env variable requirement for FSDP

What does this add?

This PR removes the need to lock the user into using accelerate launch when using FSDP + accelerate. Instead, the user can now fully create a FullyShardedDataParallelPlugin manually.

Who is it for?

Users of accelerate who want the flexibility of FSDP without needing accelerate launch

Closes #2973

Why is it needed?

Limiting users to using global variables long-term is not good because it hides much of what's going on in the env, with no way to pull them down easily.

What parts of the API does this impact?

User-facing:

A user can now use FullyShardedDataParallelPlugin manually:

from accelerate import FullyShardedDataParallelPlugin
plugin = FullyShardedDataParallelPlugin(
	strategy="FULL_SHARD"
)

Internal structure:

This was an entire refactor to the internals of the dataclass, allowing for every variable to be passed in, new documentation, and new checks/type hints towards what's happening underneath.

When would I use it, and when wouldn't I?

If a user wants to avoid using accelerate launch, or knows they just want FSDP and manually set it, this avoids needing (too many) hacky env variables

Anticipated maintenance burden? (What will happen in say, 3 months if something changes)

Still todo, we require users to still have the env variable set for doing efficient loading. I need to test it still, but I intend on making a util such as from accelerate.utils import enable_fsdp_low_cpu_mem to trigger this flag, which lets us then set and use it inside of transformers manually if a user passes it in/wants to utilize it without needing to set it first in the env.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactor @muellerzr ! This looks good on my side ! The only nit I have and maybe for future users is that we don't really know the default value for each of these param unless we dig inside the post_init ? Maybe it can make sense to tell the user which will be the default used if the variable is not set + env variable not set.

src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Show resolved Hide resolved
Copy link
Member

@matthewdouglas matthewdouglas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice rewrite, this is more readable and better documented than previously. I didn't check each and every line to verify the logic, but this should be covered by the tests. Still I have a few comments, but nothing major.

Btw. will you rewrite all of this to use pattern matching once Python 3.10 is the min version? ;-)

src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
default=None,
metadata={"help": "A callable specifying a policy to recursively wrap layers with FSDP"},
metadata={
"help": "A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about clarifying what is expected of the callable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best I can really get for this is to have users look at one of the FSDP versions for now (no docs discussing this really well)

src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
Comment on lines 1347 to 1357
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", "NO_PREFETCH")
if isinstance(self.backward_prefetch, str):
if self.backward_prefetch.upper() == "NO_PREFETCH":
self.backward_prefetch = None
else:
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH:
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
else:
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be simplified or am I missing something?

Suggested change
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", "NO_PREFETCH")
if isinstance(self.backward_prefetch, str):
if self.backward_prefetch.upper() == "NO_PREFETCH":
self.backward_prefetch = None
else:
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH:
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
else:
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None)
if isinstance(self.backward_prefetch, str):
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH:
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
else:
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing the fact that NO_PREFETCH is in FSDP_BACKWARD_PREFETCH

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this should work?

            self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None)
                if isinstance(self.backward_prefetch, str) and (self.backward_prefetch != "NO_PREFETCH"):
	            if self.backward_prefetch in FSDP_BACKWARD_PREFETCH:
	                self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
	            if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
	                self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
	            else:
	                self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]

My main concern is to avoid setting self.backward_prefetch from None to "NO_PREFETCH" back to None, which is confusing.

self.state_dict_type = StateDictType(int(self.state_dict_type))
else:
self.state_dict_type = StateDictType[self.state_dict_type.upper()]
self.set_state_dict_type()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't like that we hid the rest of our earlier patterns behind set_state_dict_type because it did a lot more than that, it's more of a set_state_dict_type_and_config. It's far reaching to get rid of it now, so for now I reverted the refactor

src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one nit about potentially simplifying some logic.

Comment on lines 1347 to 1357
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", "NO_PREFETCH")
if isinstance(self.backward_prefetch, str):
if self.backward_prefetch.upper() == "NO_PREFETCH":
self.backward_prefetch = None
else:
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH:
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
else:
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this should work?

            self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None)
                if isinstance(self.backward_prefetch, str) and (self.backward_prefetch != "NO_PREFETCH"):
	            if self.backward_prefetch in FSDP_BACKWARD_PREFETCH:
	                self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
	            if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
	                self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
	            else:
	                self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]

My main concern is to avoid setting self.backward_prefetch from None to "NO_PREFETCH" back to None, which is confusing.

@muellerzr muellerzr merged commit 3bde615 into main Aug 12, 2024
28 checks passed
@muellerzr muellerzr deleted the make-dataclass-optional branch August 12, 2024 15:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
5 participants