Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default v5 flags #6168

Merged
merged 3 commits into from
Dec 18, 2023
Merged

Add default v5 flags #6168

merged 3 commits into from
Dec 18, 2023

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Dec 14, 2023

Adds default values for TPU-specific XLA flags on v5e and v5p.

Verifying performance with a v5e run of Llama2 70B.

@JackCaoG
Copy link
Collaborator

do we need these for 2.2 release?

@jonb377
Copy link
Collaborator Author

jonb377 commented Dec 14, 2023

do we need these for 2.2 release?

It would be nice to include so v5 is performant out-of-the-box on 2.2

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jon for making this change. I have a few questions regarding to some of the flags.

'xla_enable_async_all_gather': 'true',
'xla_enable_async_collective_permute': 'true',
# Limit compiler-injected rematerialization
'xla_jf_rematerialization_percent_shared_memory_limit': '10000',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? I don't think MaxText adds this one as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried a run without all three flags mentioned, and the absolute MFU reduction was 1.5%. Including this one increases MFU by 1%. I have another run to identify which of the other two flags accounts for the other 0.5% gain.

'xla_tpu_enable_async_collective_fusion_fuse_all_gather': 'true',
'xla_tpu_enable_async_collective_fusion_multiple_steps': 'true',
# Disable net router
'xla_tpu_enable_net_router_in_all_gather': 'false',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we drop this one?

# Disable net router
'xla_tpu_enable_net_router_in_all_gather': 'false',
# Disable experimental Reduce+Broadcast->ReduceWindow-Conv fusion
'xla_tpu_rwb_fusion': 'false',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really give us any performance boost?

@alanwaketan
Copy link
Collaborator

BTW, do we need flags for MultiSlice?

@jonb377
Copy link
Collaborator Author

jonb377 commented Dec 15, 2023

Thanks @alanwaketan - let me try matching the MaxText flags exactly. A lot of these were recommended by the XLA team prior to your optimization barrier change, which may make them unnecessary.

@jonb377
Copy link
Collaborator Author

jonb377 commented Dec 15, 2023

I saw ~1.5% performance degradation using the MaxText flags and when omitting the three specified. I think the main driver of this is the rematerialization flag. Trying another run with the remat flag set.

@jonb377
Copy link
Collaborator Author

jonb377 commented Dec 16, 2023

BTW, do we need flags for MultiSlice?

Comparing to the MaxText flags, the only MultiSlice-related flag without the default value only relates to scanned compilation. I don't think we need to include any in the defaults here, but there can still be gains from tuning flags for the specific MultiSlice environment for a given workload.

@alanwaketan
Copy link
Collaborator

I saw ~1.5% performance degradation using the MaxText flags and when omitting the three specified. I think the main driver of this is the rematerialization flag. Trying another run with the remat flag set.

@jonb377 Have you figured out where the ~1.5% comes from?

@jonb377
Copy link
Collaborator Author

jonb377 commented Dec 18, 2023

I found the following:

  • Adding only xla_jf_rematerialization_percent_shared_memory_limit back yields 0.5% degredation over baseline
  • Removing only xla_tpu_enable_net_router_in_all_gather from the original set yields ~1.5% improvement over baseline
  • xla_tpu_rwb_fusion is inconclusive from the tests I've run so far, but doesn't seem to do much.

I'll do one more test removing only xla_tpu_rwb_fusion.

@JackCaoG
Copy link
Collaborator

This is not really a bug fix, if we want to get this in for 2.2, I would like to get it merge ASAP so we have enough time for testing.

@jonb377
Copy link
Collaborator Author

jonb377 commented Dec 18, 2023

cc @alanwaketan, let's be conservative on the flags then and land today - I'll update this PR to only enable async collectives. There's also still an issue with async AG fusion in the current libtpu pin for v5p. Users can always specify LIBTPU_INIT_ARGS for the best perf.

@jonb377 jonb377 merged commit 2e6e183 into master Dec 18, 2023
@jonb377 jonb377 deleted the jonbolin/v5-flags branch December 18, 2023 23:43
jonb377 added a commit that referenced this pull request Dec 18, 2023
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Jan 3, 2024
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants