-
Notifications
You must be signed in to change notification settings - Fork 955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make env variables optional for FSDP #2998
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice refactor @muellerzr ! This looks good on my side ! The only nit I have and maybe for future users is that we don't really know the default value for each of these param unless we dig inside the post_init
? Maybe it can make sense to tell the user which will be the default used if the variable is not set + env variable not set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice rewrite, this is more readable and better documented than previously. I didn't check each and every line to verify the logic, but this should be covered by the tests. Still I have a few comments, but nothing major.
Btw. will you rewrite all of this to use pattern matching once Python 3.10 is the min version? ;-)
src/accelerate/utils/dataclasses.py
Outdated
default=None, | ||
metadata={"help": "A callable specifying a policy to recursively wrap layers with FSDP"}, | ||
metadata={ | ||
"help": "A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about clarifying what is expected of the callable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Best I can really get for this is to have users look at one of the FSDP versions for now (no docs discussing this really well)
src/accelerate/utils/dataclasses.py
Outdated
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", "NO_PREFETCH") | ||
if isinstance(self.backward_prefetch, str): | ||
if self.backward_prefetch.upper() == "NO_PREFETCH": | ||
self.backward_prefetch = None | ||
else: | ||
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH: | ||
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 | ||
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): | ||
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) | ||
else: | ||
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be simplified or am I missing something?
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", "NO_PREFETCH") | |
if isinstance(self.backward_prefetch, str): | |
if self.backward_prefetch.upper() == "NO_PREFETCH": | |
self.backward_prefetch = None | |
else: | |
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH: | |
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 | |
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): | |
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) | |
else: | |
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] | |
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None) | |
if isinstance(self.backward_prefetch, str): | |
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH: | |
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 | |
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): | |
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) | |
else: | |
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're missing the fact that NO_PREFETCH
is in FSDP_BACKWARD_PREFETCH
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this should work?
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None)
if isinstance(self.backward_prefetch, str) and (self.backward_prefetch != "NO_PREFETCH"):
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH:
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
else:
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]
My main concern is to avoid setting self.backward_prefetch
from None
to "NO_PREFETCH"
back to None
, which is confusing.
self.state_dict_type = StateDictType(int(self.state_dict_type)) | ||
else: | ||
self.state_dict_type = StateDictType[self.state_dict_type.upper()] | ||
self.set_state_dict_type() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't like that we hid the rest of our earlier patterns behind set_state_dict_type
because it did a lot more than that, it's more of a set_state_dict_type_and_config
. It's far reaching to get rid of it now, so for now I reverted the refactor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one nit about potentially simplifying some logic.
src/accelerate/utils/dataclasses.py
Outdated
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", "NO_PREFETCH") | ||
if isinstance(self.backward_prefetch, str): | ||
if self.backward_prefetch.upper() == "NO_PREFETCH": | ||
self.backward_prefetch = None | ||
else: | ||
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH: | ||
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 | ||
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): | ||
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) | ||
else: | ||
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this should work?
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None)
if isinstance(self.backward_prefetch, str) and (self.backward_prefetch != "NO_PREFETCH"):
if self.backward_prefetch in FSDP_BACKWARD_PREFETCH:
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
else:
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]
My main concern is to avoid setting self.backward_prefetch
from None
to "NO_PREFETCH"
back to None
, which is confusing.
Remove env variable requirement for FSDP
What does this add?
This PR removes the need to lock the user into using
accelerate launch
when using FSDP + accelerate. Instead, the user can now fully create aFullyShardedDataParallelPlugin
manually.Who is it for?
Users of accelerate who want the flexibility of FSDP without needing
accelerate launch
Closes #2973
Why is it needed?
Limiting users to using global variables long-term is not good because it hides much of what's going on in the env, with no way to pull them down easily.
What parts of the API does this impact?
User-facing:
A user can now use
FullyShardedDataParallelPlugin
manually:Internal structure:
This was an entire refactor to the internals of the dataclass, allowing for every variable to be passed in, new documentation, and new checks/type hints towards what's happening underneath.
When would I use it, and when wouldn't I?
If a user wants to avoid using
accelerate launch
, or knows they just want FSDP and manually set it, this avoids needing (too many) hacky env variablesAnticipated maintenance burden? (What will happen in say, 3 months if something changes)
Still todo, we require users to still have the env variable set for doing efficient loading. I need to test it still, but I intend on making a util such as
from accelerate.utils import enable_fsdp_low_cpu_mem
to trigger this flag, which lets us then set and use it inside oftransformers
manually if a user passes it in/wants to utilize it without needing to set it first in the env.