-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add default v5 flags #6168
Add default v5 flags #6168
Conversation
do we need these for 2.2 release? |
It would be nice to include so v5 is performant out-of-the-box on 2.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Jon for making this change. I have a few questions regarding to some of the flags.
torch_xla/__init__.py
Outdated
'xla_enable_async_all_gather': 'true', | ||
'xla_enable_async_collective_permute': 'true', | ||
# Limit compiler-injected rematerialization | ||
'xla_jf_rematerialization_percent_shared_memory_limit': '10000', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? I don't think MaxText adds this one as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried a run without all three flags mentioned, and the absolute MFU reduction was 1.5%. Including this one increases MFU by 1%. I have another run to identify which of the other two flags accounts for the other 0.5% gain.
torch_xla/__init__.py
Outdated
'xla_tpu_enable_async_collective_fusion_fuse_all_gather': 'true', | ||
'xla_tpu_enable_async_collective_fusion_multiple_steps': 'true', | ||
# Disable net router | ||
'xla_tpu_enable_net_router_in_all_gather': 'false', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we drop this one?
torch_xla/__init__.py
Outdated
# Disable net router | ||
'xla_tpu_enable_net_router_in_all_gather': 'false', | ||
# Disable experimental Reduce+Broadcast->ReduceWindow-Conv fusion | ||
'xla_tpu_rwb_fusion': 'false', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this really give us any performance boost?
BTW, do we need flags for MultiSlice? |
Thanks @alanwaketan - let me try matching the MaxText flags exactly. A lot of these were recommended by the XLA team prior to your optimization barrier change, which may make them unnecessary. |
I saw ~1.5% performance degradation using the MaxText flags and when omitting the three specified. I think the main driver of this is the rematerialization flag. Trying another run with the remat flag set. |
Comparing to the MaxText flags, the only MultiSlice-related flag without the default value only relates to scanned compilation. I don't think we need to include any in the defaults here, but there can still be gains from tuning flags for the specific MultiSlice environment for a given workload. |
@jonb377 Have you figured out where the ~1.5% comes from? |
I found the following:
I'll do one more test removing only |
This is not really a bug fix, if we want to get this in for 2.2, I would like to get it merge ASAP so we have enough time for testing. |
cc @alanwaketan, let's be conservative on the flags then and land today - I'll update this PR to only enable async collectives. There's also still an issue with async AG fusion in the current libtpu pin for v5p. Users can always specify |
Adds default values for TPU-specific XLA flags on v5e and v5p.
Verifying performance with a v5e run of Llama2 70B.