Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce virtual device #4091

Merged
merged 1 commit into from
Nov 19, 2022
Merged

Introduce virtual device #4091

merged 1 commit into from
Nov 19, 2022

Conversation

steventk-g
Copy link
Collaborator

@steventk-g steventk-g commented Oct 12, 2022

Changes in this PR

  • Expose virtual device via a flag XLA_USE_SPMD
  • Use virtual device to conditionally delay the data transfer of a tensor. This is accomplished by setting the device on the backend based on the flag, so the PJRT computation client can check for the virtual device before transferring data (NOTE: We need to use the device on the backend rather than checking for the flag directly, so that we still have a way to transfer sharded data later on).
  • When the flag is enabled, transfer sharded data without redownloading from an xla device. This is done in _xla_mark_sharding. The re-downloading path is preserved so that XLA_USE_SPMD=0 still works as well.
  • When the flag is enabled, ensure that the user gets xla:0 from xm.xla_device(). At this point, users should expect all tensors to be treated as if they are on the virtual device when SPMD is enabled.

@yeounoh
Copy link
Contributor

yeounoh commented Oct 12, 2022

Let's make sure that we cover the explict sharded cases, where we want to avoid the initial unpartitioned data transfer. We will have to double-check, but Modify XLATensor::Compile to begin data transfer on implicitly sharded tensors. this may not be needed.

@steventk-g
Copy link
Collaborator Author

Notes after chat with Yeounoh:

  • We need to locate the place where data transfer is initiated to backend device. This is probably in upstream code. This is where we can check the device type and potentially skip the data transfer.
  • We need to determine how to check the device type of an at::Tensor or XLATensor. The XlaDeviceType of tensors to shard will be "SPMD", and the device type will be XLA (like physical XLA devices: TPU, CPU, GPU).
  • Explicitly sharded tensors on an SPMD device will be transferred to the backend device by a call to CreateTensorData in _xla_mark_sharding.
  • We need to decide when to transfer data for implicitly sharded tensors, if not in XLATensor::Compile

@JackCaoG
Copy link
Collaborator

We need to locate the place where data transfer is initiated to backend device add a log to

std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(

This is the only entry for the transfer data to device.

@steventk-g steventk-g force-pushed the virtual-device branch 14 times, most recently from 6a326ec to 32ff407 Compare October 18, 2022 06:18
@steventk-g
Copy link
Collaborator Author

Remaining implementation details before I can start testing:

  • Determine how to filter tensors and devices in CreateTensorsData methods. We want all devices passed into the sharded method to be SPMD, and we want to stop data transfer to backend devices when a tensor with an SPMD device is passed into the non-sharded method. Can we simply remove the non-SPMD tensors in the first case, and remove the SPMD tensors in the second case?
  • Figure out what to return from TensorToXlaData when we don't transfer data to a real backend.

@steventk-g steventk-g force-pushed the virtual-device branch 8 times, most recently from f64c1a5 to 01fa14d Compare October 26, 2022 19:08
@steventk-g steventk-g force-pushed the virtual-device branch 14 times, most recently from df56ba8 to ede9395 Compare November 16, 2022 21:21
@steventk-g steventk-g force-pushed the virtual-device branch 3 times, most recently from 17cf0a0 to 5b7ac6f Compare November 18, 2022 19:43
@steventk-g steventk-g requested a review from yeounoh November 18, 2022 19:47
Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @steventk-g 👍

@steventk-g steventk-g merged commit b2bd721 into master Nov 19, 2022
@yeounoh yeounoh added the arm label Dec 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants