diff --git a/OP_LOWERING_GUIDE.md b/OP_LOWERING_GUIDE.md index b8ef3ca4804..ac13d8c75e2 100644 --- a/OP_LOWERING_GUIDE.md +++ b/OP_LOWERING_GUIDE.md @@ -4,7 +4,11 @@ PyTorch wraps the C++ ATen tensor library that offers a wide range of operations implemented on GPU and CPU. Pytorch/XLA is a PyTorch extension; one of its purposes is to convert PyTorch operations to XLA operations. Lowering defines a process of converting a higher-level representation to a lower-level representation. In this document, I will refer to the process of converting PyTorch operation to XLA operation as the lowering. XLA Compiler will also lower XlaOp to HLO, but that’s beyond the scope of this documentation. We will forward operations that we haven’t provided an XLA lowering yet to CPU and call ATen implementations. Operations that are forwarded to the CPU will cause a significant slowdown. We must lower all operations used in the model to achieve the best performance. ## Before you start -You should follow the instructions in [here](https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md) to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU. +You should follow the instructions in [here](https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md) to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU. You can configure Pytorch/XLA to use XLA:CPU by running + +``` +export XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localservice:0;grpc://localhost:51011" +``` ## Understanding the operation You can find the definition of the C++ ATen operations in [native_functions.yaml](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml). After you build Pytorch/XLA from source, you will also find our default implementation (forward to PyTorch native CPU) in `xla/torch_xla/csrc/aten_xla_type_default.h/cpp`. Pytorch operations can usually be mapped to [PyTorch tensor api](https://pytorch.org/docs/stable/index.html) easily. If that is not the case searching the PyTorch native implementation under [PyTorch repo](https://github.com/pytorch/pytorch) is recommended. The goal is to lower the PyTorch operations into a sequence of XLA operations defined in [here](https://www.tensorflow.org/xla/operation_semantics). @@ -24,4 +28,55 @@ All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the e Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to `xla/test/test_operations.py` if it is required. We also need to add CPP tests in `xla/test/cpp/test_aten_xla_tensor.cpp`. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the `aten::op` and `xla::op` counters. ## Tips -The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/2969). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/1940). +The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in [this pr](https://github.com/pytorch/xla/pull/2969). You can also find a slightly more complicated example with backward lowering in [this pr](https://github.com/pytorch/xla/pull/2972). + +We have auto-generated wrapper implementations of `out=` and `inplace` operators for some operators in `RegisterXLA.cpp`. We only need to lower the vanilla op in this case. An example would be `lerp` operator which has 6 variants in `native_functions.yaml`, they are + +``` + - lerp_.Scalar + - lerp_.Tensor + - lerp.Scalar_out + - lerp.Tensor_out + - lerp.Scalar + - lerp.Tensor +``` + +and will generate function prototypes + +``` +at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight); +at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight); +at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight); +at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out); +at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight); +at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out); +``` + +in `XLANativeFunctions.h` if we add all of them to the `xla_native_functions.yaml`. However if we only lower `lerp.Scalar` and `lerp.Tensor` and check `RegisterXLA.cpp`, we will see + +``` +namespace { + +at::Tensor wrapper_Scalar_lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + // No device check + + + // DeviceGuard omitted + return torch_xla::lerp(self, end, weight); +} + +} // anonymous namespace + +at::Tensor & wrapper_Scalar_lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + auto wrapper_Scalar_lerp__tmp = wrapper_Scalar_lerp(self, end, weight); + at::_copy_from(wrapper_Scalar_lerp__tmp, self); + return self; +} + +... + m.impl("lerp_.Scalar", + TORCH_FN(wrapper_Scalar_lerp_)); + +``` + +`lerp_.Scalar` will use our `lerp.Scalar` implementation without us providing explictly lowering. diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 8f9f39a78d2..f6c5e4e2555 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -10285,5 +10285,114 @@ TEST_F(AtenXlaTensorTest, TestEarlySyncLiveTensors) { cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestLerp) { + torch::Tensor start = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor weight = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor res = torch::lerp(start, end, weight); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_start = CopyToDevice(start, device); + torch::Tensor xla_end = CopyToDevice(end, device); + torch::Tensor xla_weight = CopyToDevice(weight, device); + torch::Tensor xla_res = torch::lerp(xla_start, xla_end, xla_weight); + AllClose(res, xla_res); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestLerpScalar) { + torch::Tensor start = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Scalar weight = torch::Scalar(3.0); + torch::Tensor res = torch::lerp(start, end, weight); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_start = CopyToDevice(start, device); + torch::Tensor xla_end = CopyToDevice(end, device); + torch::Tensor xla_res = torch::lerp(xla_start, xla_end, weight); + AllClose(res, xla_res); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestLerpInplace) { + torch::Tensor input = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor weight = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor input_copy = input.clone(); + input.lerp_(end, weight); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input_copy, device); + torch::Tensor xla_end = CopyToDevice(end, device); + torch::Tensor xla_weight = CopyToDevice(weight, device); + xla_input.lerp_(xla_end, xla_weight); + AllClose(xla_input, input); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestLerpScalarInplace) { + torch::Tensor input = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Scalar weight = torch::Scalar(3.0); + torch::Tensor input_copy = input.clone(); + input.lerp_(end, weight); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input_copy, device); + torch::Tensor xla_end = CopyToDevice(end, device); + xla_input.lerp_(xla_end, weight); + AllClose(xla_input, input); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestLerpOut) { + torch::Tensor start = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor weight = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat)); + ; + torch::lerp_out(res, start, end, weight); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_start = CopyToDevice(start, device); + torch::Tensor xla_end = CopyToDevice(end, device); + torch::Tensor xla_weight = CopyToDevice(weight, device); + torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options()); + torch::lerp_out(xla_res, xla_start, xla_end, xla_weight); + AllClose(res, xla_res); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestLerpScalarOut) { + torch::Tensor start = + torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor end = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Scalar weight = torch::Scalar(3.0); + torch::Tensor res = torch::empty({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::lerp_out(res, start, end, weight); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_start = CopyToDevice(start, device); + torch::Tensor xla_end = CopyToDevice(end, device); + torch::Tensor xla_res = torch::empty({3, 4}, xla_start.options()); + torch::lerp_out(xla_res, xla_start, xla_end, weight); + AllClose(res, xla_res); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::lerp", cpp_test::GetIgnoredCounters()); +} + } // namespace cpp_test } // namespace torch_xla diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 4652c26628e..c28e39d2111 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1519,6 +1519,30 @@ at::Tensor leaky_relu_backward(const at::Tensor& grad_output, negative_slope.to())); } +at::Tensor lerp(const at::Tensor& self, const at::Tensor& end, + const at::Tensor& weight) { + XLA_FN_COUNTER("xla::"); + XLA_CHECK_EQ(self.dtype(), end.dtype()) + << "expected dtype " << self.dtype() << " for `end` but got dtype " + << end.dtype(); + XLA_CHECK_EQ(self.dtype(), weight.dtype()) + << "expected dtype " << self.dtype() << " for `weight` but got dtype " + << weight.dtype(); + return bridge::AtenFromXlaTensor( + XLATensor::lerp(bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), + bridge::GetXlaTensor(weight))); +} + +at::Tensor lerp(const at::Tensor& self, const at::Tensor& end, + const at::Scalar& weight) { + XLA_FN_COUNTER("xla::"); + XLA_CHECK_EQ(self.dtype(), end.dtype()) + << "expected dtype " << self.dtype() << " for `end` but got dtype " + << end.dtype(); + return bridge::AtenFromXlaTensor(XLATensor::lerp( + bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), weight)); +} + at::Tensor log(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::log(bridge::GetXlaTensor(self))); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 08d1d4c7399..9649cbd0be4 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -738,6 +738,11 @@ NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias, std::move(lower_fn)); } +NodePtr Lerp(const Value& start, const Value& end, const Value& weight) { + ScopePusher ir_scope(at::aten::lerp.toQualString()); + return start + weight * (end - start); +} + } // namespace ops } // namespace ir } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 6d5add9a0fa..53c476e0153 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -212,6 +212,8 @@ NodePtr IsNan(const Value& input); NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias, const Value& product_multiplier, const Value& bias_multiplier); +NodePtr Lerp(const Value& start, const Value& end, const Value& weight); + } // namespace ops } // namespace ir } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index abd9dd3d616..16747e8c8df 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -651,6 +651,11 @@ class XLATensor { const XLATensor& input, double negative_slope); + static XLATensor lerp(const XLATensor& input, const XLATensor& end, + const XLATensor& weight); + static XLATensor lerp(const XLATensor& input, const XLATensor& end, + const at::Scalar& weight); + static XLATensor log(const XLATensor& input); static XLATensor log_base(const XLATensor& input, ir::OpKind op, double base); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 20bb8ab9ce0..c62dea6672a 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1458,6 +1458,20 @@ XLATensor XLATensor::leaky_relu_backward(const XLATensor& grad_output, grad_output.GetIrValue(), input.GetIrValue(), negative_slope)); } +XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end, + const XLATensor& weight) { + return input.CreateFrom( + ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight.GetIrValue())); +} + +XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end, + const at::Scalar& weight) { + ir::Value weight_val = GetIrValueForScalar( + weight, input.shape().get().element_type(), input.GetDevice()); + return input.CreateFrom( + ir::ops::Lerp(input.GetIrValue(), end.GetIrValue(), weight_val)); +} + XLATensor XLATensor::log(const XLATensor& input) { return input.CreateFrom(ir::ops::Log(input.GetIrValue())); } diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 6d85368030b..7e5b35004be 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -307,6 +307,8 @@ supported: - sigmoid_backward - tanh_backward - ger + - lerp.Scalar + - lerp.Tensor autograd: - max_pool2d - max_pool3d