Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Placeholder tensors #8612

Open
tengyifei opened this issue Jan 22, 2025 · 3 comments
Open

Placeholder tensors #8612

tengyifei opened this issue Jan 22, 2025 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@tengyifei
Copy link
Collaborator

🚀 Feature

This proposal is to create a version of torch.empty(...) tensor that does not allocate memory on the accelerator, here called a placeholder tensor. If you read the data of this tensor, it will crash. But it should be possible to abstractly stage out an HLO from an IR graph by feeding this tensor as input.

Motivation

When comparing the metrics of a model run using for loop vs run using scan, it appears that the scan version transfers a lot more data to the device. From the metrics, it transfers 23 GiB compared to 182 MiB when using for loops. The metric IrValueTensorToXlaData has a large count (>200) under scan compared to 0 under for loop. The profile timeline also shows additional tensors that are sent to the TPU during LazyTensor tracing.

When I list out the operations that cause a tensor to be sent to the device, specifically via the IrValueTensorToXlaData path, it appears that any graph operation that touches an input tensor (e.g. primals, tangents) cause a tensor to be sent to the device: see 3. For example, the t = torch.ops.aten.t.default(primals_2) operation transposes primal_2, and that likely causes the data backing primal_2 to be sent to the device.

I have narrowed this down to the make_fake_tensor 4 function in scan, which creates device data inputs that are used to trace the combine function:

  def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
    return torch.empty(
        v.size(), dtype=v.dtype).to(device).requires_grad_(v.requires_grad)

I believe this function associates a CPU tensor with the result as its backing data somehow, but the resulting tensor does not have an IR value. Later, when some operation (such as torch.ops.aten.t) uses the tensor, its IR value is required. PyTorch/XLA then creates the IR value by transferring the data tensor to the device and then minting a DeviceData IR node referencing the data tensor. This behavior is documented in Virtual Devices for GSPMD on PyTorch/XLA 1.

The end result is that even if we later throw away all the outputs of the combine function after extracting its HLO, simply creating the inputs to run the combine function still has the side-effect of sending lots of empty tensors to the device. The size of these tensors do not appear to be zero despite what its name suggests. This slows down training and should be avoided. For now I'll attempt to workaround this by caching the HLO computation corresponding to the combine function. At least this way we only trace and send unnecessary data once, after which the HLO is reused in subsequent scan invocations involving the same combine function.

This is not a problem in Jax since scan in Jax traces the combine function abstractly i.e. abstract tensors are passed to the function instead of concrete tensors holding device data. PyTorch/XLA currently doesn't have the concept of an abstract tensor, so all lazy-tensor based tracing operations are prone to generating side-effects from accidentally sending data to the device to accidentally materializing the tensor.

In Llama 405B, the extra transfers become significant enough to OOM the TPU just from staging out the HLO of the combine function in scan. We'll need to find a representation of these input tensors used during tracing such that they are lowered into HLO parameters but don't transfer data to the device.

Pitch

On Dec 20, 2024, I got some suggestions from jackcao@ where we could create a tensor without a data handle, also called a "placeholder". That tensor can still be lowered and it won't eagerly allocate buffers on the device. When I implemented this suggestion, it looks like LoweringContext is using the data handle to deduplicate parameters for some reason, so it appears that we still need the data handle or at least some kind of unique identifier. For now I worked around this by falling back to the torch::lazy::BackendData pointer when the BackendData object does not hold a handle. Unknown if this approach has holes. It probably does and needs to be properly fixed.

@tengyifei
Copy link
Collaborator Author

@rpsilva-aws is interested in taking a stab

@rpsilva-aws
Copy link
Collaborator

I won't be able to get to this in the upcoming weeks (at least this and the next) due to other priorities, so if there is another tentative owner in the meantime, feel free to reassign. Otherwise, I'll provide the updated status by then. cc: @tengyifei @miladm

@tengyifei
Copy link
Collaborator Author

@zpcore let's check out this issue since it unblocks scan and host offloading. In short, scan needs to get the HLO of the input function without allocating memory on the TPU (e.g. torch.empty)

@zpcore zpcore assigned zpcore and unassigned rpsilva-aws Feb 4, 2025
@ysiraichi ysiraichi added the enhancement New feature or request label Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants