Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Add API to disable global SPMD config #8717

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Feb 18, 2025

Before this PR, xr.use_spmd() set a un-revertible global SPMD state in torch_xla. (e.g. User cannot access devices after use_spmd, which will set a "Virtual Device" for the SPMD code path). This one-time SPMD setting limits the flexibility of having code regions running under SPMD mode and other running with non-SPMD mode.

This PR relaxes the above constraint by:

  • Introducing a new API disable_spmd() to revert the global SPMD setting from use_spmd.
  • Add an option in the current use_spmd() logic to not replicate all the non-SPMD live tensors on virtual device. This keeps normal tensors on its designated device after use_spmd() is called.

Implementation notes:

  • In the current implementation, the device information is stored in static variable, the value depends on UseVirtualDevice() querying function. Since the global device state will change as user switches between SPMD and non-SPMD mode, those values need to change accordingly. example1, example2
  • In the current implementation, _xla_force_spmd_device does 2 things: 1) Move all live non-SPMD tensors onto virtual device. 2) Set global SPMD config. Splitting the logic of 2) into a new API _set_spmd_mode(bool use_spmd) to manage the global SPMD config.

Test:

  • Add a test to switch between SPMD and non-SPMD mode, checking no unexpected data transfers in between.

@lsy323 lsy323 force-pushed the lsiyuan/disable-spmd branch from 359bbab to de27158 Compare February 19, 2025 00:28
@lsy323 lsy323 marked this pull request as ready for review February 19, 2025 00:54
met.metric_data('TransferToDeviceTime')[0],
expected_transfer_to_device_counter[i])
spmd_output = self._run_spmd(spmd_model, spmd_input_shape, mesh)
spmd_outputs.append(spmd_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any interaction between the SPMD and non-SPMD parts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I cannot do that since we cannot create global tensor from device shards until #8716 is merged. But I do tested them together locally and it works.

Seems we should land #8716 first.

# TODO(yeounoh) introduce SPMD configuration.
def use_spmd(auto: Optional[bool] = False):
def use_spmd(auto: Optional[bool] = False,
force_tensors_on_spmd_device: Optional[bool] = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the word "force" is meaningless here IMO. I think a more descriptive name could be replicate_existing_tensors.

# TODO(yeounoh) introduce SPMD configuration.
def use_spmd(auto: Optional[bool] = False):
def use_spmd(auto: Optional[bool] = False,
force_tensors_on_spmd_device: Optional[bool] = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If force_tensors_on_spmd_device defaults to False, would this end up being a backward compatibility breaking change? IIUC we used to always replicate existing tensors unconditionally.

# TODO(yeounoh) introduce SPMD configuration.
def use_spmd(auto: Optional[bool] = False):
def use_spmd(auto: Optional[bool] = False,
force_tensors_on_spmd_device: Optional[bool] = False):
"""API to enable SPMD mode. This is a recommended way to enable SPMD.

This forces SPMD mode if some tensors are already initialized on non-SPMD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is probably out of date. Also I think whoever wrote this comment had in mind an idea of forcefully enabling SPMD mode, like that's an aggressive act. But with your PR there's nothing forceful about this anymore. IIUC people now have a clear option of replicating the existing tensors, or not replicating them (and keeping them on non-SPMD devices).

if (!g_current_device) {
g_current_device = *GetDefaultDevice();
}
g_current_device = *GetDefaultDevice();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're always overwriting g_current_device, wondering is it possible to rid of this global variable?

@@ -78,19 +78,19 @@ torch::lazy::BackendDevice GetVirtualDevice() {
}

bool ShouldUseVirtualDevice() {
bool use_virtual_device =
bool g_use_virtual_device =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This g_use_virtual_device variable shadows a global variable of the same name.

@lsy323
Copy link
Collaborator Author

lsy323 commented Feb 24, 2025

test_runtime_spmd_api is failing because, I did some investigation and here are the findings:

AtenXlaDeviceMapper has 2 states: a) Contains all local devices under non-spmd mode (e.g. TPU:0, TPU1...). b) Contains SPMD:0 under SPMD mode.

In this PR, we switch the 2 states as use_spmd and disable_spmd are called. However, in the current use_spmd logic, the existing live tensors will be moved onto SPMD virtual device, the BackendDataHandle is moved to SPMD virtual device, but the underlying backend device in LazyTensor cannot be updated due to constness ref.

In a concrete example:

  1. Create some tensors under non-spmd mode, has device (XLA:0), AtenXlaDeviceMapper is initialized with devices XLA:0, XLA:1..
  2. use_spmd is called, AtenXlaDeviceMapper is switch to state b); The Lazy Tensor state of non-SPMD tensors are still on XLA:0, the underlying BackendDataHandle is on SPMD:0.

Currently, it works by having physical devices in AtenXlaDeviceMapper even if SPMD is turned on after AtenXlaDeviceMapper is initialized. This doesn't seem to be expected but it works now.

My conclusion from the above:

Our current implementation is based on the fact that the state of AtenXlaDeviceMapper won't change once it's initialized, this contradicts with scenario of switching between spmd and non-spmd mode. To move forward, I think we need to make use_spmd working properly with a stateful AtenXlaDeviceMapper.

Comment on lines +11 to +12
import torch_xla.runtime as xr
local_bs = 4096 * 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: It is clearer if there is a space between imports, and global variables

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants