You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This proposal is to create a version of torch.empty(...) tensor that does not allocate memory on the accelerator, here called a placeholder tensor. If you read the data of this tensor, it will crash. But it should be possible to abstractly stage out an HLO from an IR graph by feeding this tensor as input.
Motivation
When comparing the metrics of a model run using for loop vs run using scan, it appears that the scan version transfers a lot more data to the device. From the metrics, it transfers 23 GiB compared to 182 MiB when using for loops. The metric IrValueTensorToXlaData has a large count (>200) under scan compared to 0 under for loop. The profile timeline also shows additional tensors that are sent to the TPU during LazyTensor tracing.
When I list out the operations that cause a tensor to be sent to the device, specifically via the IrValueTensorToXlaData path, it appears that any graph operation that touches an input tensor (e.g. primals, tangents) cause a tensor to be sent to the device: see 3. For example, the t = torch.ops.aten.t.default(primals_2) operation transposes primal_2, and that likely causes the data backing primal_2 to be sent to the device.
I have narrowed this down to the make_fake_tensor4 function in scan, which creates device data inputs that are used to trace the combine function:
I believe this function associates a CPU tensor with the result as its backing data somehow, but the resulting tensor does not have an IR value. Later, when some operation (such as torch.ops.aten.t) uses the tensor, its IR value is required. PyTorch/XLA then creates the IR value by transferring the data tensor to the device and then minting a DeviceData IR node referencing the data tensor. This behavior is documented in Virtual Devices for GSPMD on PyTorch/XLA 1.
The end result is that even if we later throw away all the outputs of the combine function after extracting its HLO, simply creating the inputs to run the combine function still has the side-effect of sending lots of empty tensors to the device. The size of these tensors do not appear to be zero despite what its name suggests. This slows down training and should be avoided. For now I'll attempt to workaround this by caching the HLO computation corresponding to the combine function. At least this way we only trace and send unnecessary data once, after which the HLO is reused in subsequent scan invocations involving the same combine function.
This is not a problem in Jax since scan in Jax traces the combine function abstractly i.e. abstract tensors are passed to the function instead of concrete tensors holding device data. PyTorch/XLA currently doesn't have the concept of an abstract tensor, so all lazy-tensor based tracing operations are prone to generating side-effects from accidentally sending data to the device to accidentally materializing the tensor.
In Llama 405B, the extra transfers become significant enough to OOM the TPU just from staging out the HLO of the combine function in scan. We'll need to find a representation of these input tensors used during tracing such that they are lowered into HLO parameters but don't transfer data to the device.
Pitch
On Dec 20, 2024, I got some suggestions from jackcao@ where we could create a tensor without a data handle, also called a "placeholder". That tensor can still be lowered and it won't eagerly allocate buffers on the device. When I implemented this suggestion, it looks like LoweringContext is using the data handle to deduplicate parameters for some reason, so it appears that we still need the data handle or at least some kind of unique identifier. For now I worked around this by falling back to the torch::lazy::BackendData pointer when the BackendData object does not hold a handle. Unknown if this approach has holes. It probably does and needs to be properly fixed.
The text was updated successfully, but these errors were encountered:
I won't be able to get to this in the upcoming weeks (at least this and the next) due to other priorities, so if there is another tentative owner in the meantime, feel free to reassign. Otherwise, I'll provide the updated status by then. cc: @tengyifei@miladm
@zpcore let's check out this issue since it unblocks scan and host offloading. In short, scan needs to get the HLO of the input function without allocating memory on the TPU (e.g. torch.empty)
🚀 Feature
This proposal is to create a version of
torch.empty(...)
tensor that does not allocate memory on the accelerator, here called a placeholder tensor. If you read the data of this tensor, it will crash. But it should be possible to abstractly stage out an HLO from an IR graph by feeding this tensor as input.Motivation
When comparing the metrics of a model run using for loop vs run using
scan
, it appears that the scan version transfers a lot more data to the device. From the metrics, it transfers 23 GiB compared to 182 MiB when using for loops. The metricIrValueTensorToXlaData
has a large count (>200) under scan compared to 0 under for loop. The profile timeline also shows additional tensors that are sent to the TPU during LazyTensor tracing.When I list out the operations that cause a tensor to be sent to the device, specifically via the IrValueTensorToXlaData path, it appears that any graph operation that touches an input tensor (e.g. primals, tangents) cause a tensor to be sent to the device: see 3. For example, the
t = torch.ops.aten.t.default(primals_2)
operation transposesprimal_2
, and that likely causes the data backing primal_2 to be sent to the device.I have narrowed this down to the
make_fake_tensor
4 function in scan, which creates device data inputs that are used to trace the combine function:I believe this function associates a CPU tensor with the result as its backing data somehow, but the resulting tensor does not have an IR value. Later, when some operation (such as torch.ops.aten.t) uses the tensor, its IR value is required. PyTorch/XLA then creates the IR value by transferring the data tensor to the device and then minting a DeviceData IR node referencing the data tensor. This behavior is documented in Virtual Devices for GSPMD on PyTorch/XLA 1.
The end result is that even if we later throw away all the outputs of the combine function after extracting its HLO, simply creating the inputs to run the combine function still has the side-effect of sending lots of empty tensors to the device. The size of these tensors do not appear to be zero despite what its name suggests. This slows down training and should be avoided. For now I'll attempt to workaround this by caching the HLO computation corresponding to the combine function. At least this way we only trace and send unnecessary data once, after which the HLO is reused in subsequent scan invocations involving the same combine function.
This is not a problem in Jax since scan in Jax traces the combine function abstractly i.e. abstract tensors are passed to the function instead of concrete tensors holding device data. PyTorch/XLA currently doesn't have the concept of an abstract tensor, so all lazy-tensor based tracing operations are prone to generating side-effects from accidentally sending data to the device to accidentally materializing the tensor.
In Llama 405B, the extra transfers become significant enough to OOM the TPU just from staging out the HLO of the combine function in scan. We'll need to find a representation of these input tensors used during tracing such that they are lowered into HLO parameters but don't transfer data to the device.
Pitch
On Dec 20, 2024, I got some suggestions from jackcao@ where we could create a tensor without a data handle, also called a "placeholder". That tensor can still be lowered and it won't eagerly allocate buffers on the device. When I implemented this suggestion, it looks like LoweringContext is using the data handle to deduplicate parameters for some reason, so it appears that we still need the data handle or at least some kind of unique identifier. For now I worked around this by falling back to the
torch::lazy::BackendData
pointer when the BackendData object does not hold a handle. Unknown if this approach has holes. It probably does and needs to be properly fixed.The text was updated successfully, but these errors were encountered: