Skip to content

Commit

Permalink
Trace user_fn with placeholder tensors to reduce device memory usage
Browse files Browse the repository at this point in the history
This way we don't allocate those input tensors on the device just to
stage out an HLO.
  • Loading branch information
tengyifei committed Dec 23, 2024
1 parent 0219443 commit 8158d94
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 14 deletions.
19 changes: 19 additions & 0 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ class Type:
Type.PRED: torch.bool,
}

_PT_XLA_TYPE_MAP = {
torch.float32: Type.F32,
torch.float64: Type.F64,
torch.bfloat16: Type.BF16,
torch.float16: Type.F16,
torch.uint8: Type.U8,
torch.int8: Type.S8,
torch.int16: Type.S16,
torch.int32: Type.S32,
torch.int64: Type.S64,
torch.complex64: Type.C64,
torch.complex128: Type.C128,
torch.bool: Type.PRED,
}


class Shape(object):
"""Wraps a core XLA shape object to provide a more friendly API."""
Expand Down Expand Up @@ -749,6 +764,10 @@ def map(cls, ops, computation, dimensions, static_operands=(), builder=None):
def to_torch_type(cls, dtype):
return _XLA_PT_TYPE_MAP[dtype] if dtype else torch.float32

@classmethod
def from_torch_type(cls, dtype):
return _PT_XLA_TYPE_MAP[dtype]


def create_builder(name):
return torch_xla._XLAC._xla_op_create_builder(name)
Expand Down
28 changes: 25 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1148,13 +1148,19 @@ class PyLoweringContext {

// Convert lazy node data into opaque handle id
torch::lazy::BackendDataPtr data = DeviceData::Cast(node)->data();
torch::lazy::BackendData::Handle handle = data->GetHandle();
torch::lazy::BackendData::Handle handle =
data->HasValue() ? data->GetHandle()
: reinterpret_cast<std::uintptr_t>(&*data);

// Linearly search parameters and compare opaque handles
const std::vector<torch::lazy::BackendDataPtr>& device_data =
lowering_ctx.GetParametersData();
for (int i = 0; i < device_data.size(); ++i) {
if (device_data[i]->GetHandle() == handle) {
torch::lazy::BackendData::Handle device_handle =
device_data[i]->HasValue()
? device_data[i]->GetHandle()
: reinterpret_cast<std::uintptr_t>(&*device_data[i]);
if (device_handle == handle) {
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]);
XLA_CHECK(param_id.has_value());
return param_id.value();
Expand Down Expand Up @@ -1843,6 +1849,19 @@ void InitXlaModuleBindings(py::module m) {
return bridge::AtenFromXlaTensor(output);
});

// Creates a placeholder tensor that does not hold any device buffer.
// This is primarily useful for staging out the HLO of a user computation.
// Accessing the value of the tensor will panic.
//
// TODO: write tests.
m.def("_xla_create_placeholder_tensor", [](py::object py_shape) {
xla::Shape shape = op_builder::PyShapeToShape(py_shape);
auto xla_tensor = XLATensor::Create(
torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder(
bridge::GetCurrentDevice().toString(), std::move(shape)));
return bridge::AtenFromXlaTensor(xla_tensor);
});

m.def("_xla_set_default_device", [](const std::string& device) {
return SetCurrentThreadDevice(device);
});
Expand Down Expand Up @@ -2884,7 +2903,10 @@ void InitXlaModuleBindings(py::module m) {
}

// Dedup by handle
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
torch::lazy::BackendData::Handle handle =
backend_data->HasValue()
? backend_data->GetHandle()
: reinterpret_cast<std::uintptr_t>(&*backend_data);
if (!data_handles.insert(handle).second) {
continue;
}
Expand Down
10 changes: 8 additions & 2 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();
xla::XlaOp LoweringContext::GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& backend_data,
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
torch::lazy::BackendData::Handle handle =
backend_data->HasValue()
? backend_data->GetHandle()
: reinterpret_cast<std::uintptr_t>(&*backend_data);
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
Expand Down Expand Up @@ -147,7 +150,10 @@ xla::XlaOp LoweringContext::GetParameter(

std::optional<size_t> LoweringContext::GetParameterId(
const std::shared_ptr<torch::lazy::BackendData>& backend_data) const {
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
torch::lazy::BackendData::Handle handle =
backend_data->HasValue()
? backend_data->GetHandle()
: reinterpret_cast<std::uintptr_t>(&*backend_data);
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
return std::nullopt;
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class PjRtComputationClient : public ComputationClient {
PjRtComputationClient();
~PjRtComputationClient();

// TODO: buffer ptr is null.

DataPtr CreateDataPlaceholder(
std::string device, xla::Shape shape,
std::optional<xla::OpSharding> sharding = std::nullopt) override;
Expand Down
11 changes: 4 additions & 7 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,13 +672,10 @@ def fn(carry, x):
# Abstractly trace and lower `fn`.
# Later we will include `fn_computation` within the while loop body.
def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
# TODO: this will transfer data to the device when `fn` is traced with
# such a tensor. However, if we tried to avoid that via
# `torch.empty(v.size(), dtype=v.dtype, device=v.device)`, then that lowers
# into a constant IR node instead of a device data IR node.
return torch.empty(
v.size(),
dtype=v.dtype).to(device=v.device).requires_grad_(v.requires_grad)
dtype = xb.Op.from_torch_type(v.dtype)
shape = xb.mkshape(dtype, v.shape)
t = torch_xla._XLAC._xla_create_placeholder_tensor(shape.shape)
return t.requires_grad_(v.requires_grad)

device = torch_xla.device()
fake_carry = tree_map(make_fake_tensor, init)
Expand Down

0 comments on commit 8158d94

Please sign in to comment.