-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Add API to disable global SPMD config #8717
base: master
Are you sure you want to change the base?
Conversation
359bbab
to
de27158
Compare
met.metric_data('TransferToDeviceTime')[0], | ||
expected_transfer_to_device_counter[i]) | ||
spmd_output = self._run_spmd(spmd_model, spmd_input_shape, mesh) | ||
spmd_outputs.append(spmd_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any interaction between the SPMD and non-SPMD parts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# TODO(yeounoh) introduce SPMD configuration. | ||
def use_spmd(auto: Optional[bool] = False): | ||
def use_spmd(auto: Optional[bool] = False, | ||
force_tensors_on_spmd_device: Optional[bool] = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the word "force" is meaningless here IMO. I think a more descriptive name could be replicate_existing_tensors
.
# TODO(yeounoh) introduce SPMD configuration. | ||
def use_spmd(auto: Optional[bool] = False): | ||
def use_spmd(auto: Optional[bool] = False, | ||
force_tensors_on_spmd_device: Optional[bool] = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If force_tensors_on_spmd_device
defaults to False, would this end up being a backward compatibility breaking change? IIUC we used to always replicate existing tensors unconditionally.
# TODO(yeounoh) introduce SPMD configuration. | ||
def use_spmd(auto: Optional[bool] = False): | ||
def use_spmd(auto: Optional[bool] = False, | ||
force_tensors_on_spmd_device: Optional[bool] = False): | ||
"""API to enable SPMD mode. This is a recommended way to enable SPMD. | ||
|
||
This forces SPMD mode if some tensors are already initialized on non-SPMD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is probably out of date. Also I think whoever wrote this comment had in mind an idea of forcefully enabling SPMD mode, like that's an aggressive act. But with your PR there's nothing forceful about this anymore. IIUC people now have a clear option of replicating the existing tensors, or not replicating them (and keeping them on non-SPMD devices).
if (!g_current_device) { | ||
g_current_device = *GetDefaultDevice(); | ||
} | ||
g_current_device = *GetDefaultDevice(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're always overwriting g_current_device
, wondering is it possible to rid of this global variable?
@@ -78,19 +78,19 @@ torch::lazy::BackendDevice GetVirtualDevice() { | |||
} | |||
|
|||
bool ShouldUseVirtualDevice() { | |||
bool use_virtual_device = | |||
bool g_use_virtual_device = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This g_use_virtual_device
variable shadows a global variable of the same name.
In this PR, we switch the 2 states as In a concrete example:
Currently, it works by having physical devices in My conclusion from the above: Our current implementation is based on the fact that the state of |
import torch_xla.runtime as xr | ||
local_bs = 4096 * 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: It is clearer if there is a space between imports, and global variables
Before this PR,
xr.use_spmd()
set a un-revertible global SPMD state in torch_xla. (e.g. User cannot access devices afteruse_spmd
, which will set a "Virtual Device" for the SPMD code path). This one-time SPMD setting limits the flexibility of having code regions running under SPMD mode and other running with non-SPMD mode.This PR relaxes the above constraint by:
disable_spmd()
to revert the global SPMD setting fromuse_spmd
.use_spmd()
logic to not replicate all the non-SPMD live tensors on virtual device. This keeps normal tensors on its designated device afteruse_spmd()
is called.Implementation notes:
UseVirtualDevice()
querying function. Since the global device state will change as user switches betweenSPMD
andnon-SPMD
mode, those values need to change accordingly. example1, example2_set_spmd_mode(bool use_spmd)
to manage the global SPMD config.Test: