Skip to content

Commit

Permalink
Add default v5 flags (pytorch#6168)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored and mbzomowski committed Jan 3, 2024
1 parent d44cc2f commit a7c7a5f
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def _setup_libtpu_flags():
# and thus worse performance.
flags = _set_missing_flags(flags,
(('xla_latency_hiding_scheduler_rerun', '1'),))

if tpu.version() == 5:
default_v5_flags = {
# Enable async collectives
'xla_enable_async_all_gather': 'true',
'xla_enable_async_collective_permute': 'true',
}
flags = _set_missing_flags(flags, default_v5_flags.items())

os.environ['LIBTPU_INIT_ARGS'] = ' '.join(flags)


Expand Down

0 comments on commit a7c7a5f

Please sign in to comment.