From 0a7cb901b432395154eb3595f7928c9142f268ac Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Mon, 10 Jan 2022 23:15:09 +0800 Subject: [PATCH 01/15] add retry on pull dense sync (#38793) --- paddle/fluid/framework/fleet/fleet_wrapper.cc | 56 ++++++++++++++++++- paddle/fluid/framework/fleet/heter_context.h | 1 - 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 225c2656fbfd1..f90027556342d 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -632,6 +632,7 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, if (ret != 0) { LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]"; sleep(sleep_seconds_before_fail_exit_); + exit(-1); } #else for (size_t index = 0; index < inputs->size(); ++index) { @@ -685,9 +686,36 @@ void FleetWrapper::PullDenseVarsSync( paddle::ps::Region reg(w, tensor->numel()); regions.emplace_back(std::move(reg)); } - auto status = - pslib_ptr_->_worker_ptr->pull_dense(regions.data(), regions.size(), tid); - status.wait(); + int32_t status = -1; + int32_t cnt = 0; + while (true) { + auto tt = pslib_ptr_->_worker_ptr->pull_dense(regions.data(), + regions.size(), tid); + bool flag = true; + + tt.wait(); + + try { + status = tt.get(); + } catch (const std::future_error& e) { + VLOG(0) << "Caught a future_error with code" << e.code() + << ", Message:" << e.what(); + } + if (status != 0) { + VLOG(0) << "fleet pull dense sync failed, status[" << status << "]"; + sleep(sleep_seconds_before_fail_exit_); + flag = false; + cnt++; + } + if (cnt > 3) { + VLOG(0) << "fleet pull dense sync failed, retry 3 times"; + exit(-1); + } + + if (flag) { + break; + } + } #endif } @@ -1248,6 +1276,7 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id, if (ret.get() != 0) { LOG(ERROR) << "load model of table id: " << table_id << ", from path: " << path << " failed"; + exit(-1); } #else VLOG(0) << "FleetWrapper::LoadModel does nothing when no pslib"; @@ -1263,6 +1292,7 @@ void FleetWrapper::LoadWithWhitelist(const uint64_t table_id, if (ret.get() != 0) { LOG(ERROR) << "load model of table id: " << table_id << ", from path: " << path << " failed"; + exit(-1); } #else VLOG(0) << "FleetWrapper::LoadWhitelist does nothing when no pslib"; @@ -1311,6 +1341,7 @@ void FleetWrapper::SaveModelOneTable(const uint64_t table_id, if (ret.get() != 0) { LOG(ERROR) << "save model of table id: " << table_id << ", to path: " << path << " failed"; + exit(-1); } #else VLOG(0) << "FleetWrapper::SaveModelOneTable does nothing when no pslib"; @@ -1328,6 +1359,7 @@ void FleetWrapper::SaveModelOneTablePrefix(const uint64_t table_id, if (ret.get() != 0) { LOG(ERROR) << "save model (with prefix) of table id: " << table_id << ", to path: " << path << " failed"; + exit(-1); } #else VLOG(0) << "FleetWrapper::SaveModelOneTablePrefix does nothing when no pslib"; @@ -1351,6 +1383,7 @@ void FleetWrapper::SetDate(const uint64_t table_id, const std::string& date) { ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "setdate : " << date << " failed"; + exit(-1); } #else VLOG(0) << "FleetWrapper::SetDate does nothing when no pslib-gpu"; @@ -1463,6 +1496,11 @@ void FleetWrapper::ShrinkSparseTable(int table_id) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->shrink(table_id); ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "Shrink Sparse Table failed"; + exit(-1); + } #else VLOG(0) << "FleetWrapper::ShrinkSparseTable does nothing when no pslib"; #endif @@ -1472,6 +1510,10 @@ void FleetWrapper::ClearModel() { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->clear(); ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "Clear Model failed"; + } #else VLOG(0) << "FleetWrapper::ClearModel does nothing when no pslib"; #endif @@ -1481,6 +1523,10 @@ void FleetWrapper::ClearOneTable(const uint64_t table_id) { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->clear(table_id); ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "Clear One Table failed table_id: " << table_id; + } #else VLOG(0) << "FleetWrapper::ClearOneTable does nothing when no pslib"; #endif @@ -1541,6 +1587,10 @@ void FleetWrapper::ClientFlush() { #ifdef PADDLE_WITH_PSLIB auto ret = pslib_ptr_->_worker_ptr->flush(); ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "Client Flush failed"; + } #else VLOG(0) << "FleetWrapper::ServerFlush does nothing when no pslib"; #endif diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 45f9b04383944..3e8b0cfbc31f3 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -235,7 +235,6 @@ class HeterContext { } VLOG(3) << "heter_context unique keys with dynamic mf dimention"; } - for (std::thread& t : threads) { t.join(); } From ffbc2122afb24f5ec0a173283c78e11ad8cd9966 Mon Sep 17 00:00:00 2001 From: fengkuangxiaxia Date: Tue, 11 Jan 2022 10:43:01 +0800 Subject: [PATCH 02/15] roi_align fix (#38788) --- paddle/fluid/inference/tensorrt/op_teller.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 8504474168d53..878eef016e7d1 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/inference/tensorrt/op_teller.h" + #include + #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/data_layout.h" @@ -1283,7 +1285,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, return false; } std::vector attrs{"pooled_height", "pooled_width", - "spatial_scale", "sampling_ratio"}; + "spatial_scale", "sampling_ratio", + "aligned"}; for (auto const attr : attrs) { if (!desc.HasAttr(attr)) return false; } From d368647112e3194298f521196a2ff3df453ec6be Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 11 Jan 2022 12:30:07 +0800 Subject: [PATCH 03/15] [Eager] fix some eager logic (#38576) * Rearranged Eager AutoCodeGen directory structure * Removed USE_OP in Eager AutoCodeGen * Enabled generation for Operators without Grad/Inputs/Outputs * Resolved operators without input * Fixed merge conflicts * Enabled Eager AutoCodeGen for 10+ more operators * Refactored Eager AutoCodeGen with more organized helper objects * Enabled Eager AutoCodeGen for operators with multiple OpBases * Adjusted Eager AutoCodeGen to Enable Passing Output Tensor as Input Argument * Handled Dispensable Inputs/Outputs in Eager AutoCodeGen * Adjusted function generation/call between Python-C API & Dygraph API * Synchronized auto-generated Python-C API with Dygraph Forward Functions * support more eager tensor api * fix merge compile error * fix compile error and fit develop code * support pure CPU * fix some logic error in eager_mode * support _varbase_creator in eager mode * Added safe_initialized interface to EagerTensor for use in processing dispensable inputs * for eager mode * refine * support multiple constructor for eager tensor * add place related code * polish code * specific randint with dtype of int64 * Support pure cpu test * eager logic * refine test in pure cpu * eager logic * eager logic * eager logic, test=develop * skip core.eager when in inference, test=develop * refine, test=develop * refine, test=develop * call RetainGrad after run forward kernel, test=develop * refine, test=develop * support dygraph util, meta, guard test * eager test case * support inference test * refine test and fix initializer failed * modify eagertensor patch method * add eagertensor.clear_grandint, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * support create varbase and fix retain grad error * call monkey_patch_varbase in _test_eager_guard, test=develop * fix windows error * split clear_gradient to clear_gradient and zero_grads, test=develop * refine, test=develop * refine, test=develop * support test_imperative_basic test in eager mode * remove additional log in variable.h * remove additional log in variable.h * remove additional code create in merge * eager * fix some eager logic, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop Co-authored-by: jim19930609 Co-authored-by: JiabinYang <360788950@qq.com> --- .../eager/accumulation/accumulation_node.cc | 3 + .../eager/accumulation/accumulation_node.h | 2 +- paddle/fluid/eager/eager_tensor.h | 50 ++++++++-------- .../data_structure_tests/eager_tensor_test.cc | 2 +- .../performance_tests/benchmark_utils.cc | 8 +-- .../eager/tests/task_tests/generated_test.cc | 10 ++-- paddle/fluid/pybind/eager_method.cc | 58 +++++++++++++++---- paddle/fluid/pybind/eager_properties.cc | 28 +++++++-- .../fluid/dygraph/varbase_patch_methods.py | 32 +++++++--- 9 files changed, 134 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/eager/accumulation/accumulation_node.cc b/paddle/fluid/eager/accumulation/accumulation_node.cc index 69628d9b40021..ed1146eed0fb0 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.cc +++ b/paddle/fluid/eager/accumulation/accumulation_node.cc @@ -28,6 +28,9 @@ static void CopyOrAddTensor(egr::EagerTensor* tensor, const egr::EagerTensor& t) { + if (t.Var().IsInitialized()) { + const_cast(&t)->SyncToTensor(); + } if (!tensor->defined() || !tensor->initialized()) { // Simply copy tensor->impl *tensor = t; diff --git a/paddle/fluid/eager/accumulation/accumulation_node.h b/paddle/fluid/eager/accumulation/accumulation_node.h index a2683db75e92c..9578924b783f5 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.h +++ b/paddle/fluid/eager/accumulation/accumulation_node.h @@ -32,7 +32,7 @@ class GradNodeAccumulation : public GradNodeBase { void RetainGrad( const std::function& hook); - egr::EagerTensor Grad() { return accumulated_grad; } + egr::EagerTensor* Grad() { return &accumulated_grad; } private: egr::EagerTensor accumulated_grad; diff --git a/paddle/fluid/eager/eager_tensor.h b/paddle/fluid/eager/eager_tensor.h index 0bcef2253f993..72fe5732e9620 100644 --- a/paddle/fluid/eager/eager_tensor.h +++ b/paddle/fluid/eager/eager_tensor.h @@ -239,8 +239,8 @@ class EagerTensor final { auto tensor_dense = std::dynamic_pointer_cast(tensor_->impl()); if (tensor_dense) { - paddle::experimental::MovesStorage(tensor_dense.get(), - framework_tensor); + paddle::experimental::SharesStorage(tensor_dense.get(), + framework_tensor); } else { PADDLE_THROW(paddle::platform::errors::Fatal( "Unrecognized egr::EagerTensor type, only " @@ -258,27 +258,23 @@ class EagerTensor final { /** Part 11: Sync paddle::framework::Variable with pten::Tensor **/ void SyncToTensor() { // Synchronize allocation only once. - if (!this->defined() || !this->initialized()) { - // TODO(jiabin): Support selected rows later. - if (var_.IsInitialized()) { - if (var_.IsType()) { - SetImplWithLegacyTensor(); - } else if (var_.IsType()) { - SetImplWithLegacyTensor(); - } else { - PADDLE_THROW(paddle::platform::errors::Fatal( - "Unable to fetch underlying tensor " - "from VarBase, only LoDTensor and " - "Tensor are supported for now")); - } + if (var_.IsInitialized()) { + if (var_.IsType()) { + SetImplWithLegacyTensor(); + } else if (var_.IsType()) { + SetImplWithLegacyTensor(); } else { - PADDLE_THROW(paddle::platform::errors::Fatal( - "Can not Sync EagerTensor %s whose paddle::framework::Variable is " - "not initialized!", - name())); + PADDLE_THROW( + paddle::platform::errors::Fatal("Unable to fetch underlying tensor " + "from VarBase, only LoDTensor and " + "Tensor are supported for now")); } + } else { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Can not Sync EagerTensor %s whose paddle::framework::Variable is " + "not initialized!", + name())); } } @@ -296,8 +292,16 @@ class EagerTensor final { template void SetImplWithLegacyTensor() { const auto& framework_tensor = var_.Get(); - this->set_impl( - std::move(paddle::experimental::MakePtenDenseTensor(framework_tensor))); + if (this->initialized()) { + VLOG(8) << "Sync Var to initialized tensor for: " << name(); + paddle::experimental::ReMakePtenDenseTensor( + framework_tensor, + static_cast(this->impl().get())); + } else { + VLOG(8) << "Sync Var to uninitialized tensor for: " << name(); + this->set_impl(std::move( + paddle::experimental::MakePtenDenseTensor(framework_tensor))); + } var_.Clear(); } diff --git a/paddle/fluid/eager/tests/data_structure_tests/eager_tensor_test.cc b/paddle/fluid/eager/tests/data_structure_tests/eager_tensor_test.cc index a02f0bec456bf..84daf4eac4ce6 100644 --- a/paddle/fluid/eager/tests/data_structure_tests/eager_tensor_test.cc +++ b/paddle/fluid/eager/tests/data_structure_tests/eager_tensor_test.cc @@ -118,7 +118,7 @@ TEST(EagerTensor, MemberFunction) { CHECK_EQ(et3.Var().Get().data()[1], 10.0f); VLOG(6) << "SyncToTensor"; - CHECK(et3.initialized() == false); + CHECK(et3.initialized() == true); et3.SyncToTensor(); CHECK(et3.initialized() == true); VLOG(6) << "Check Tensor"; diff --git a/paddle/fluid/eager/tests/performance_tests/benchmark_utils.cc b/paddle/fluid/eager/tests/performance_tests/benchmark_utils.cc index baa99dc93c2dd..e05a63a69d002 100644 --- a/paddle/fluid/eager/tests/performance_tests/benchmark_utils.cc +++ b/paddle/fluid/eager/tests/performance_tests/benchmark_utils.cc @@ -87,8 +87,8 @@ void benchmark_eager_intermediate_matmul(const EagerTensor& X, // Examine Forward Grad (w.r.t max_num_runs = 2) eager_test::CompareVariableWithValue(input_tensor0, 16); // Examine Backward Grad (w.r.t max_num_runs = 2) - eager_test::CompareGradVariableWithValue(X, 16); - eager_test::CompareGradVariableWithValue(Y, 16); + eager_test::CompareGradTensorWithValue(X, 16); + eager_test::CompareGradTensorWithValue(Y, 16); } } @@ -121,8 +121,8 @@ void benchmark_eager_intermediate_mlp(const EagerTensor& X, eager_test::CompareVariableWithValue(Out, result["Out"]); // Examine Backward Grad (w.r.t max_num_runs = 2) - eager_test::CompareGradVariableWithValue(X, result["GradX"]); - eager_test::CompareGradVariableWithValue(Ws[0], result["GradW"]); + eager_test::CompareGradTensorWithValue(X, result["GradX"]); + eager_test::CompareGradTensorWithValue(Ws[0], result["GradW"]); } } diff --git a/paddle/fluid/eager/tests/task_tests/generated_test.cc b/paddle/fluid/eager/tests/task_tests/generated_test.cc index a06091247bf7a..b5ce9223f6c97 100644 --- a/paddle/fluid/eager/tests/task_tests/generated_test.cc +++ b/paddle/fluid/eager/tests/task_tests/generated_test.cc @@ -54,7 +54,7 @@ TEST(Generated, Sigmoid) { RunBackward(target_tensors, {}); VLOG(6) << "Finish Backward"; - eager_test::CompareGradVariableWithValue(tensor, 0.25); + eager_test::CompareGradTensorWithValue(tensor, 0.25); } TEST(Generated, Matmul_v2) { @@ -85,8 +85,8 @@ TEST(Generated, Matmul_v2) { std::vector target_tensors = {output_tensor}; RunBackward(target_tensors, {}); - eager_test::CompareGradVariableWithValue(X, 2.0 * 20); - eager_test::CompareGradVariableWithValue(Y, 3.0 * 4); + eager_test::CompareGradTensorWithValue(X, 2.0 * 20); + eager_test::CompareGradTensorWithValue(Y, 3.0 * 4); } TEST(Generated, ElementwiseAdd) { @@ -116,8 +116,8 @@ TEST(Generated, ElementwiseAdd) { std::vector target_tensors = {output_tensor}; RunBackward(target_tensors, {}); - eager_test::CompareGradVariableWithValue(X, 1.0); - eager_test::CompareGradVariableWithValue(Y, 1.0); + eager_test::CompareGradTensorWithValue(X, 1.0); + eager_test::CompareGradTensorWithValue(Y, 1.0); } } // namespace egr diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 7f131f9ccd742..c56fe5be4da69 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -35,7 +35,7 @@ limitations under the License. */ namespace paddle { namespace pybind { -extern PyTypeObject* pEagerTensorType; +extern PyTypeObject* p_eager_tensor_type; static PyObject* eager_tensor_method_numpy(EagerTensorObject* self, PyObject* args, PyObject* kwargs) { @@ -167,7 +167,7 @@ static PyObject* eager_tensor__clear_gradient(EagerTensorObject* self, EAGER_SYNC_TRY VLOG(4) << "ClearGradient " << self->eager_tensor.name(); - egr::EagerTensor grad; + egr::EagerTensor* grad; if (egr::egr_utils_api::IsLeafTensor(self->eager_tensor)) { // Add RetainGrad as PostHook to AccumulationNode std::shared_ptr grad_node = @@ -182,14 +182,14 @@ static PyObject* eager_tensor__clear_gradient(EagerTensorObject* self, grad = accumulation_grad_node->Grad(); } else { auto meta = egr::EagerUtils::unsafe_autograd_meta(self->eager_tensor); - grad = meta->Grad(); + grad = meta->MutableGrad(); } - if (grad.initialized()) { + if (grad->initialized()) { VLOG(4) << "Gradient of " << self->eager_tensor.name() << " is initialized, will be released."; auto dense_tensor = - std::dynamic_pointer_cast(grad.impl()); + std::dynamic_pointer_cast(grad->impl()); dense_tensor->release(); } Py_INCREF(Py_None); @@ -202,7 +202,6 @@ static PyObject* eager_tensor__zero_grads(EagerTensorObject* self, EAGER_TRY VLOG(4) << "ZeroGrads " << self->eager_tensor.name(); - egr::EagerTensor grad; if (egr::egr_utils_api::IsLeafTensor(self->eager_tensor)) { // Add RetainGrad as PostHook to AccumulationNode std::shared_ptr grad_node = @@ -214,21 +213,54 @@ static PyObject* eager_tensor__zero_grads(EagerTensorObject* self, "with type: GradNodeAccumulation")); auto accumulation_grad_node = std::dynamic_pointer_cast(grad_node); - grad = accumulation_grad_node->Grad(); + if (accumulation_grad_node->Grad()->initialized()) { + accumulation_grad_node->Grad()->set_tensor( + std::make_shared( + paddle::experimental::zeros_like( + *(accumulation_grad_node->Grad()->Tensor().get())))); + } } else { auto meta = egr::EagerUtils::unsafe_autograd_meta(self->eager_tensor); - grad = meta->Grad(); + if (meta->MutableGrad()->initialized()) { + meta->MutableGrad()->set_tensor( + std::make_shared( + paddle::experimental::zeros_like( + *(meta->MutableGrad()->Tensor().get())))); + } } - if (grad.initialized()) { - grad.set_tensor(std::make_shared( - paddle::experimental::zeros_like(*(grad.Tensor().get())))); - } Py_INCREF(Py_None); return Py_None; EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* eager_tensor_method_detach(EagerTensorObject* self, + PyObject* args, PyObject* kwargs) { + EAGER_SYNC_TRY + PADDLE_ENFORCE_EQ( + self->eager_tensor.initialized(), true, + platform::errors::InvalidArgument("Tensor %s has not been initialized!", + self->eager_tensor.name())); + + PyObject* obj = p_eager_tensor_type->tp_alloc(p_eager_tensor_type, 0); + if (obj) { + auto v = reinterpret_cast(obj); + new (&(v->eager_tensor)) egr::EagerTensor(); + v->eager_tensor.set_impl(self->eager_tensor.impl()); + v->eager_tensor.set_name(egr::Controller::Instance().GenerateUniqueName()); + auto autograd_meta_src = + egr::EagerUtils::autograd_meta(&(self->eager_tensor)); + auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->eager_tensor)); + autograd_meta->SetPersistable(autograd_meta_src->Persistable()); + } else { + PADDLE_THROW(platform::errors::Fatal( + "tp_alloc return null, can not new a PyObject.")); + } + + return obj; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef variable_methods[] = { {"numpy", (PyCFunction)(void (*)(void))eager_tensor_method_numpy, METH_VARARGS | METH_KEYWORDS, NULL}, @@ -246,6 +278,8 @@ PyMethodDef variable_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"_zero_grads", (PyCFunction)(void (*)(void))eager_tensor__zero_grads, METH_VARARGS | METH_KEYWORDS, NULL}, + {"detach", (PyCFunction)(void (*)(void))eager_tensor_method_detach, + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL}}; } // namespace pybind diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index b147d5fbad0ed..71b8bbbb1a283 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -63,7 +63,6 @@ PyObject* eager_tensor_properties_get_grad(EagerTensorObject* self, void* closure) { EAGER_SYNC_TRY if (egr::egr_utils_api::IsLeafTensor(self->eager_tensor)) { - // Add RetainGrad as PostHook to AccumulationNode std::shared_ptr grad_node = egr::EagerUtils::grad_node(self->eager_tensor); PADDLE_ENFORCE( @@ -73,7 +72,7 @@ PyObject* eager_tensor_properties_get_grad(EagerTensorObject* self, "with type: GradNodeAccumulation")); auto accumulation_grad_node = std::dynamic_pointer_cast(grad_node); - return ToPyObject(accumulation_grad_node->Grad()); + return ToPyObject(*accumulation_grad_node->Grad()); } else { VLOG(6) << "Get grad for tensor: " << self->eager_tensor.name(); auto meta = egr::EagerUtils::unsafe_autograd_meta(self->eager_tensor); @@ -82,6 +81,27 @@ PyObject* eager_tensor_properties_get_grad(EagerTensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +int eager_tensor_properties_set_grad(EagerTensorObject* self, PyObject* value, + void* closure) { + EAGER_SYNC_TRY + auto src = CastPyArg2EagerTensor(value, 0); + PADDLE_ENFORCE( + egr::egr_utils_api::IsLeafTensor(self->eager_tensor), + paddle::platform::errors::Fatal("Only leaf Tensor can be set grad.")); + std::shared_ptr grad_node = + egr::EagerUtils::grad_node(self->eager_tensor); + PADDLE_ENFORCE( + grad_node.get() != nullptr, + paddle::platform::errors::Fatal("Detected NULL grad_node" + "Leaf tensor should have had grad_node " + "with type: GradNodeAccumulation")); + auto accumulation_grad_node = + std::dynamic_pointer_cast(grad_node); + accumulation_grad_node->Grad()->copy_(src, true); + return 0; + EAGER_CATCH_AND_THROW_RETURN_ZERO +} + int eager_tensor_properties_set_stop_gradient(EagerTensorObject* self, PyObject* value, void* closure) { EAGER_SYNC_TRY @@ -147,8 +167,8 @@ PyObject* eager_tensor_properties_get_dtype(EagerTensorObject* self, } struct PyGetSetDef variable_properties[] = { - {"grad", (getter)eager_tensor_properties_get_grad, nullptr, nullptr, - nullptr}, + {"grad", (getter)eager_tensor_properties_get_grad, + (setter)eager_tensor_properties_set_grad, nullptr, nullptr}, {"name", (getter)eager_tensor_properties_get_name, (setter)eager_tensor_properties_set_name, nullptr, nullptr}, {"stop_gradient", (getter)eager_tensor_properties_get_stop_gradient, diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index c61f87ccf9089..e06e7f52dd671 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -22,7 +22,7 @@ from .. import framework from .. import core from .. import unique_name -from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, _in_eager_mode +from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, _in_eager_mode, EagerParamBase from .base import switch_to_static_graph from .math_op_patch import monkey_patch_math_varbase from .parallel import scale_loss @@ -149,7 +149,7 @@ def set_value(self, value): out = linear(t) # call with different weight """ - if _in_eager_mode(): + if core._in_eager_mode(): base_tensor = core.eager.EagerTensor else: base_tensor = core.VarBase @@ -238,7 +238,7 @@ def backward(self, grad_tensor=None, retain_graph=False): """ if framework.in_dygraph_mode(): if grad_tensor is not None: - if _in_eager_mode(): + if core._in_eager_mode(): assert isinstance( grad_tensor, core.eager.EagerTensor ), "The type of grad_tensor must be paddle.Tensor" @@ -250,7 +250,7 @@ def backward(self, grad_tensor=None, retain_graph=False): "Tensor shape not match, Tensor of grad_tensor [ {} ] with shape {} mismatch Tensor [ {} ] with shape {}".format( grad_tensor.name, grad_tensor.shape, self.name, self.shape) - if _in_eager_mode(): + if core._in_eager_mode(): if grad_tensor is None: grad_tensor = [] else: @@ -258,7 +258,7 @@ def backward(self, grad_tensor=None, retain_graph=False): if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu(): # TODO(liuyuhui): Currently only for xpu. Will be removed in the future. scaled_loss = scale_loss(self) - if _in_eager_mode(): + if core._in_eager_mode(): core.eager.run_backward([scaled_loss], grad_tensor, retain_graph) else: @@ -266,7 +266,7 @@ def backward(self, grad_tensor=None, retain_graph=False): retain_graph, framework._dygraph_tracer()) else: - if _in_eager_mode(): + if core._in_eager_mode(): core.eager.run_backward([self], grad_tensor, retain_graph) else: core.dygraph_run_backward([self], [grad_tensor], @@ -305,7 +305,7 @@ def gradient(self): # [500.] """ - if _in_eager_mode(): + if core._in_eager_mode(): if not self.grad._is_initialized(): return None # TODO(wanghuancoder) support SELECTED_ROWS @@ -587,7 +587,7 @@ def __str__(self): # [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436], # [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]]) """ - if _in_eager_mode(): + if core._in_eager_mode(): from paddle.tensor.to_string import eager_tensor_to_string return eager_tensor_to_string(self) else: @@ -619,7 +619,7 @@ def __deepcopy__(self, memo): raise RuntimeError( "Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy" ) - if _in_eager_mode(): + if core._in_eager_mode(): new_varbase = core.eager.EagerTensor() else: new_varbase = core.VarBase() @@ -763,6 +763,14 @@ def _grad_ivar(self): else: return None + @framework.dygraph_only + def _set_grad_ivar(self, value): + if isinstance(self, EagerParamBase): + self.grad = value + else: + raise TypeError( + "_set_grad_ivar is only supported for Parameter Tensor") + @framework.dygraph_only def clear_gradient(self, set_to_zero=True): if set_to_zero: @@ -770,6 +778,10 @@ def clear_gradient(self, set_to_zero=True): else: self._clear_gradient() + @framework.dygraph_only + def clone(self): + return _C_ops_.assign(self) + if core._in_eager_mode() and not hasattr(core, "eager"): return @@ -790,7 +802,9 @@ def clear_gradient(self, set_to_zero=True): if core._in_eager_mode(): setattr(core.eager.EagerTensor, "_grad_ivar", _grad_ivar) + setattr(core.eager.EagerTensor, "_set_grad_ivar", _set_grad_ivar) setattr(core.eager.EagerTensor, "clear_gradient", clear_gradient) + setattr(core.eager.EagerTensor, "clone", clone) else: setattr(core.VarBase, "__name__", "Tensor") setattr(core.VarBase, "grad", grad) From e91f7c02b61017486e2c24f023165d92e1988a8f Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 11 Jan 2022 14:02:45 +0800 Subject: [PATCH 04/15] Jit pre save hook (#38186) * Pre-save hooks of jit.save 1. Added pre_save_hooks features to jit.save. 2. Added related unittests * Added jit pre_save_hooks functions's alias to paddle.jit and copyright. * Make jit.save_pre_hook style be consisent with Paddle's rule. * Fixed arguments passing bug in run_save_pre_hooks * Added API Documents * Move clear and run_pre_save_hooks as internal methonds only. * Made register_save_pre_hook as an internal function. --- python/paddle/fluid/dygraph/jit.py | 101 ++++++++++++++++++ .../unittests/test_jit_pre_save_hooks.py | 59 ++++++++++ python/paddle/jit/__init__.py | 3 +- 3 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 2db9fb5d76a58..4bfdc3c27fad6 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -1,4 +1,5 @@ # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +21,7 @@ import functools from collections import OrderedDict import inspect +import threading import six import paddle @@ -525,6 +527,105 @@ def _build_load_path_and_config(path, config): return model_path, config +_save_pre_hooks_lock = threading.Lock() +_save_pre_hooks = [] + + +class HookRemoveHelper(object): + """ A HookRemoveHelper that can be used to remove hook. """ + + def __init__(self, hook): + self._hook = hook + + def remove(self): + _remove_save_pre_hook(self._hook) + + +def _register_save_pre_hook(hook): + """ + Register a save pre-hook for `paddle.jit.save`. + This hook will be executed before `save` function has been invoked. + + hook(layer, input_spec, configs) -> None + - layer (Layer|function): This argument is corresponding to `layer` in `paddle.jit.save`. + - input_spec (list or tuple[InputSpec|Tensor|Python built-in variable]): This argument is corresponding to `input_spec` in `paddle.jit.save`. + - configs (dict): This argument is corresponding to `configs` in `paddle.jit.save`. + + Args: + hook(function): a function registered as a save pre-hook + + Returns: + HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()`. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + IMAGE_SIZE = 256 + CLASS_NUM = 10 + + class LinearNet(paddle.nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = paddle.nn.Linear(IMAGE_SIZE, CLASS_NUM) + + def forward(self, x): + return self._linear(x) + + saving_count = 0 + def save_pre_hook(layer, input_spec, configs): + global saving_count + saving_count += 1 + + remove_handler = paddle.jit.register_save_pre_hook(save_pre_hook) + + layer = LinearNet() + paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])]) + # saving_count == 1 + + remove_handler.remove() + paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])]) + # saving_count == 1 + """ + global _save_pre_hooks_lock + global _save_pre_hooks + _save_pre_hooks_lock.acquire() + if hook not in _save_pre_hooks: + _save_pre_hooks.append(hook) + _save_pre_hooks_lock.release() + return HookRemoveHelper(hook) + + +def _clear_save_pre_hooks(): + global _save_pre_hooks_lock + global _save_pre_hooks + _save_pre_hooks_lock.acquire() + _save_pre_hooks.clear() + _save_pre_hooks_lock.release() + + +def _remove_save_pre_hook(hook): + global _save_pre_hooks_lock + global _save_pre_hooks + _save_pre_hooks_lock.acquire() + if hook in _save_pre_hooks: + _save_pre_hooks.remove(hook) + _save_pre_hooks_lock.release() + + +def _run_save_pre_hooks(func): + def wrapper(layer, path, input_spec=None, **configs): + global _save_pre_hooks + for hook in _save_pre_hooks: + hook(layer, input_spec, configs) + func(layer, path, input_spec, **configs) + + return wrapper + + +@_run_save_pre_hooks @switch_to_static_graph def save(layer, path, input_spec=None, **configs): """ diff --git a/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py b/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py new file mode 100644 index 0000000000000..a938024e3c9b4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +from paddle.fluid.dygraph.jit import _run_save_pre_hooks, _clear_save_pre_hooks, _register_save_pre_hook + +_counter = 0 + + +class TestPreSaveHooks(unittest.TestCase): + def test_pre_save_hook_functions(self): + def fake_func(*args, **kwgs): + global _counter + _counter += 1 + + remove_handler = _register_save_pre_hook(fake_func) + self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1) + self.assertTrue( + paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func) + + # Test of avoiding redundancy hanging + remove_handler = _register_save_pre_hook(fake_func) + self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1) + self.assertTrue( + paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func) + + remove_handler.remove() + self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0) + + remove_handler = _register_save_pre_hook(fake_func) + _clear_save_pre_hooks() + self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0) + + global _counter + _counter = 0 + remove_handler = _register_save_pre_hook(fake_func) + func_with_hook = _run_save_pre_hooks(fake_func) + func_with_hook(None, None) + self.assertEqual(_counter, 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/jit/__init__.py b/python/paddle/jit/__init__.py index 576989e8e0d2a..a2af493faca11 100644 --- a/python/paddle/jit/__init__.py +++ b/python/paddle/jit/__init__.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 0ad363b1527461bf2a1c6c674f6202b8b6c0a48c Mon Sep 17 00:00:00 2001 From: Sing_chan <51314274+betterpig@users.noreply.github.com> Date: Tue, 11 Jan 2022 14:27:51 +0800 Subject: [PATCH 05/15] support vs2019 compilation in windows (#38719) * support vs2019 compilation in windows * not modify pow_op's original compute logic --- cmake/external/protobuf.cmake | 4 ++ .../elementwise/elementwise_functor.h | 42 +++++++++++++++++++ .../elementwise/elementwise_pow_op.cu | 3 +- .../elementwise/elementwise_pow_op.h | 17 +++++++- paddle/fluid/operators/svd_helper.h | 6 +-- paddle/scripts/paddle_build.bat | 7 +++- paddle/utils/small_vector.h | 1 + 7 files changed, 73 insertions(+), 7 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/elementwise/elementwise_pow_op.h diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 2a028b8dc7e7f..f7cb7716969f5 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -207,6 +207,10 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) elseif(WITH_IPU) SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git) SET(PROTOBUF_TAG d750fbf648256c7c631f51ffdbf67d7c18b0114e) + elseif(WIN32) + SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git) + # Change the tag to support building with vs2019 + SET(PROTOBUF_TAG 01a05a53f40ca2ac5f0af10c6cc0810bee39b792) else() SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git) SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546) diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index a62c531ff0733..0a6866f578d01 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -174,6 +174,27 @@ struct FMaxFunctor { } }; +template <> +struct FMaxFunctor { + inline HOSTDEVICE int operator()(const int& a, const int& b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmax(float_a, float_b); + return std::lrint(result); + } +}; + +template <> +struct FMaxFunctor { + inline HOSTDEVICE int64_t operator()(const int64_t& a, + const int64_t& b) const { + double double_a = static_cast(a); + double double_b = static_cast(b); + auto result = std::fmax(double_a, double_b); + return std::llrint(result); + } +}; + // Fmin template struct FMinFunctor { @@ -194,6 +215,27 @@ struct FMinFunctor { } }; +template <> +struct FMinFunctor { + inline HOSTDEVICE int operator()(const int& a, const int& b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmin(float_a, float_b); + return std::lrint(result); + } +}; + +template <> +struct FMinFunctor { + inline HOSTDEVICE int64_t operator()(const int64_t& a, + const int64_t& b) const { + double double_a = static_cast(a); + double double_b = static_cast(b); + auto result = std::fmin(double_a, double_b); + return std::llrint(result); + } +}; + template struct MulGradFunctor { inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu index 5335f274ef126..a5570f2cb85d5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu @@ -31,7 +31,8 @@ struct CudaPowFunctor< // when cast to int by default and it is wrong. // Use llrint to cast it to the nearest integer, which is 3. inline HOSTDEVICE T operator()(const T args[]) const { - return std::llrint(std::pow(args[0], args[1])); + return std::llrint( + std::pow(static_cast(args[0]), static_cast(args[1]))); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.h b/paddle/fluid/operators/elementwise/elementwise_pow_op.h old mode 100755 new mode 100644 index ee718a3ecd1ec..256ab31ead69c --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.h @@ -31,7 +31,8 @@ struct PowFunctor { // when cast to int by default and it is wrong. // Use llrint to cast it to the nearest integer, which is 3. if (std::is_integral::value) { - return std::llrint(std::pow(a, b)); + return std::llrint( + std::pow(static_cast(a), static_cast(b))); } #endif return std::pow(a, b); @@ -60,13 +61,25 @@ class ElementwisePowKernel : public framework::OpKernel { template struct PowGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + if (std::is_integral::value) { + return dout * y * + std::pow(static_cast(x), static_cast(y - 1)); + } +#endif return dout * y * std::pow(x, y - 1); } }; -template +template struct PowGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + if (std::is_integral::value) { + return dout * std::log(static_cast(x)) * + std::pow(static_cast(x), static_cast(y)); + } +#endif return dout * std::log(x) * std::pow(x, y); } }; diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 8d17ddec6fbb4..8a3622a6b1b5e 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -84,7 +84,7 @@ void BatchSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, int batches, template struct PowFunctor { - PowFunctor(const T* input, T* output, int64_t numel, float exp) + PowFunctor(const T* input, T* output, int64_t numel, T exp) : input_(input), output_(output), numel_(numel), exp_(exp) {} HOSTDEVICE void operator()(int64_t idx) const { @@ -93,7 +93,7 @@ struct PowFunctor { const T* input_; T* output_; int64_t numel_; - float exp_; + T exp_; }; template @@ -297,7 +297,7 @@ struct DeviceIndependenceTensorOperations { const framework::ExecutionContext& context) : context(context) {} - framework::Tensor Pow(const framework::Tensor& x, float exp) { + framework::Tensor Pow(const framework::Tensor& x, T exp) { framework::Tensor out; auto for_range = GetForRange(x.numel()); int numel = x.numel(); diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index 8bb21fa4ef2e1..f64acbeb72307 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -261,6 +261,7 @@ set ON_INFER=ON set WITH_TESTING=ON set WITH_TENSORRT=ON set WITH_INFERENCE_API_TEST=ON +set WITH_TPCACHE=OFF call :cmake || goto cmake_error call :build || goto build_error @@ -325,7 +326,11 @@ echo ======================================== rem set vs language to english to block showIncludes, this need vs has installed English language package. set VSLANG=1033 rem Configure the environment for 64-bit builds. 'DISTUTILS_USE_SDK' indicates that the user has selected the compiler. -call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat" +echo %task_name%|findstr wincheck_inference >nul && ( + call "D:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat" +) || ( + call "C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat" +) set DISTUTILS_USE_SDK=1 rem Windows 10 Kit bin dir set PATH=C:\Program Files (x86)\Windows Kits\10\bin\10.0.17763.0\x64;%PATH% diff --git a/paddle/utils/small_vector.h b/paddle/utils/small_vector.h index e9e7996babcf7..48af2491b89f8 100644 --- a/paddle/utils/small_vector.h +++ b/paddle/utils/small_vector.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include From 9f34a0702213ada872c04ddbc367db2ceedfc697 Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Tue, 11 Jan 2022 14:38:25 +0800 Subject: [PATCH 06/15] Remove useless headers for some grad ops (#38823) * fix the wrong filename * first commit * first commit * remove rest useless headers * for ci approval --- .../fluid/operators/elementwise/elementwise_add_op.cu | 10 ---------- .../fluid/operators/elementwise/elementwise_add_op.h | 6 ------ .../fluid/operators/elementwise/elementwise_div_op.h | 11 ----------- .../operators/elementwise/elementwise_floordiv_op.cu | 1 - .../operators/elementwise/elementwise_floordiv_op.h | 3 --- .../fluid/operators/elementwise/elementwise_functor.h | 3 --- .../fluid/operators/elementwise/elementwise_max_op.cu | 1 - .../fluid/operators/elementwise/elementwise_max_op.h | 3 --- .../fluid/operators/elementwise/elementwise_min_op.cu | 1 - .../fluid/operators/elementwise/elementwise_min_op.h | 4 ---- .../fluid/operators/elementwise/elementwise_mod_op.cu | 3 +-- .../fluid/operators/elementwise/elementwise_mod_op.h | 2 -- .../fluid/operators/elementwise/elementwise_mul_op.cu | 8 -------- .../fluid/operators/elementwise/elementwise_mul_op.h | 6 +----- paddle/fluid/operators/elementwise/elementwise_op.h | 2 -- .../fluid/operators/elementwise/elementwise_pow_op.cu | 2 +- .../fluid/operators/elementwise/elementwise_pow_op.h | 1 - .../fluid/operators/elementwise/elementwise_sub_op.cu | 4 ---- .../fluid/operators/elementwise/elementwise_sub_op.h | 5 ----- 19 files changed, 3 insertions(+), 73 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 7b153a4bce86a..b5c19a3edb818 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -12,17 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" - -// only can include the headers in paddle/top/api dirs -#include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/include/core.h" -#include "paddle/pten/include/math.h" namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index d6d79d166d00a..35807d7c57d47 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -17,14 +17,8 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/math_function.h" - -#include "paddle/fluid/framework/pten_utils.h" // only can include the headers in paddle/pten/include dirs -#include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/include/core.h" #include "paddle/pten/kernels/math_kernel.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index b13a0539ec6ad..d9f7bbc56a902 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -14,21 +14,10 @@ limitations under the License. */ #pragma once -#include #include #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/framework/pten_utils.h" - -// only can include the headers in paddle/pten/include dirs -#include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/include/core.h" -#include "paddle/pten/kernels/math_kernel.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu index 41a0ae00f270d..3202b0a7d254b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h index ae8d2d8625c58..fc8f18161990d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h @@ -14,10 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 0a6866f578d01..e2689cefd43a7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -16,9 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/array.h" #include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/hostdevice.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index eb6f78bf270ad..760429200889b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_max_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index acb212e992a1d..a7a49fed87151 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -15,10 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/eigen_ext.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index 59f1c51bce266..b51dbcd883608 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_min_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index ebd8f4477d8cf..ffb8c965357a3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -15,11 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/eigen_ext.h" -#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.cu b/paddle/fluid/operators/elementwise/elementwise_mod_op.cu index bb49fdbf12dfa..d2106645a4727 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.cu @@ -11,9 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_mod_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/elementwise/elementwise_mod_op.h b/paddle/fluid/operators/elementwise/elementwise_mod_op.h index 03884f2a45883..66c3e553c141f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mod_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mod_op.h @@ -14,9 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index cdf376fd6a8cc..a8b6c2abe3bf9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -13,15 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" -// only can include the headers in paddle/top/api dirs -#include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/include/core.h" -#include "paddle/pten/include/math.h" namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 5cff3173e8115..385c7549e07f2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -15,16 +15,12 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/cpu_info.h" // only can include the headers in paddle/pten/include dirs -#include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/include/core.h" #include "paddle/pten/kernels/math_kernel.h" + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index b7df9bb864db1..e1d9655e293a3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -21,9 +21,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_layout.h" -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu index a5570f2cb85d5..0f3aa8c3e1b9b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu @@ -8,7 +8,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" + #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.h b/paddle/fluid/operators/elementwise/elementwise_pow_op.h index 256ab31ead69c..c1fecab8aba1c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index cba261a394732..2ff4033ffe194 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -12,11 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 6a51d7c2a45ad..09818380d8ea7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -14,14 +14,9 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/math/blas.h" // only can include the headers in paddle/pten/include dirs -#include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/include/core.h" #include "paddle/pten/kernels/math_kernel.h" namespace paddle { namespace operators { From fbb4028148cf3a87f4fd464b452597c94e321374 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Tue, 11 Jan 2022 15:14:26 +0800 Subject: [PATCH 07/15] [AMP] Check call order of paddle.amp.decorate and paddle.DataParallel (#38785) * check amp.decorate and DataParallel * refine coverage * fix layer dtype * refine code --- python/paddle/fluid/dygraph/amp/auto_cast.py | 4 ++++ python/paddle/fluid/dygraph/layers.py | 2 ++ .../unittests/test_imperative_auto_mixed_precision.py | 8 ++++++++ 3 files changed, 14 insertions(+) diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 15adf4cb6faaf..f09e210c3c161 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -145,6 +145,10 @@ def check_models(models): raise RuntimeError( "Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.". format(type(model))) + if isinstance(model, paddle.DataParallel): + raise RuntimeError( + "For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model." + ) def check_optimizers(optimizers): diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 4c37a378e0aae..6a65b3bd9c684 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1569,6 +1569,8 @@ def _apply(self, func, device, dtype, blocking, include_sublayers=True): for key, buf in self._buffers.items(): self._buffers[key] = func(buf, device, dtype, blocking) + self._dtype = dtype + def _to_impl(self, device=None, dtype=None, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index a8ed23f5938c0..62b40f88571d4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -536,6 +536,14 @@ def __init__(self): self.assertRaises(TypeError, test_error_model) + def test_error_distributed_model(): + model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + model = paddle.DataParallel(model) + with fluid.dygraph.guard(): + model = paddle.amp.decorate(models=model, level='O2') + + self.assertRaises(RuntimeError, test_error_distributed_model) + def test_error_optimizer(): class MyOptimizer(object): def __init__(self): From d3ba189548b8e5ca01da310e2945fe9ee4d53b63 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Tue, 11 Jan 2022 16:49:42 +0800 Subject: [PATCH 08/15] =?UTF-8?q?=E3=80=90Auto=20Parallel=E3=80=91New=20lo?= =?UTF-8?q?cal=20tensor=20(#38747)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update dist tensor * add unitest * update unitest * refactor dist tensor * update dist tensor and unitest --- .../distributed/auto_parallel/dist_context.py | 18 +- .../distributed/auto_parallel/dist_tensor.py | 283 +++++++++++++++++- .../fluid/tests/unittests/CMakeLists.txt | 3 + .../test_auto_parallel_dist_tensor.py | 222 ++++++++++++++ 4 files changed, 523 insertions(+), 3 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 12bf14fcce5bd..b194bcc3de6b5 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -62,6 +62,10 @@ def __init__(self, program=None): self._dist_op_context = DistributedOperatorContext() self._process_meshes = [] + # Distributed programs + self._dist_main_programs = {} + self._dist_startup_programs = {} + @property def serial_program(self): return self._serial_program @@ -84,6 +88,14 @@ def process_meshes(self): def dist_op_context(self): return self._dist_op_context + @property + def dist_main_programs(self): + return self._dist_main_programs + + @property + def dist_startup_programs(self): + return self._dist_startup_programs + def add_process_mesh(self, process_mesh): assert isinstance(process_mesh, ProcessMesh), \ 'The type of dim_mapping must be ProcessMesh.' @@ -371,10 +383,14 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_serial_program" or k == "_serial_graph": + if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs": setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) + + # update dist tensor's dist_context + for key in result._dist_tensors_for_program.keys(): + result._dist_tensors_for_program[key]._dist_context = result return result diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py index f46c6e86d6870..5e3c852699ab6 100644 --- a/python/paddle/distributed/auto_parallel/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -13,18 +13,155 @@ # limitations under the License import copy +import inspect + +import paddle from paddle.fluid import core +from paddle.fluid.framework import Parameter, Block, Variable from .dist_attribute import TensorDistributedAttribute from .dist_attribute import get_tensor_dist_attr_field_keys +from .utils import _linear_idx2coordinate class DistributedTensor: - def __init__(self, serial_tensor, dist_attr=None): + """ + DistributedTensor represents the distribution of tensor on the process group and + local tensors can be created by DistributedTensor. + Only support even sharding now and uneven sharding will be supported in the future. + Local tensor information can be obtained from the DistributedTensor instance object, + or obtained by the static methods provided by DistributedTensor, + including shard (i.e. the index in the serial tensor), offsets, and sizes. + """ + + @staticmethod + def _validate_sizes_and_dist_attr(sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + if not (isinstance(sizes, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x > 0, sizes))): + raise ValueError( + "The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}". + format(sizes)) + if not (isinstance(dims_mapping, (list, tuple)) and all( + map(lambda x: isinstance(x, int) and x >= -1, dims_mapping))): + raise ValueError( + "The dims_mapping must be list or tuple and item in dims_mapping must >= -1, but got {}". + format(dims_mapping)) + if not (isinstance(processes, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x >= 0, processes))): + raise ValueError( + "The processes must be list or tuple and item in processes must be integer, but got {}". + format(processes)) + if not (isinstance(topology, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x > 0, topology))): + raise ValueError( + "The topology must be list or tuple and item in topology must be non-negative integer, but got {}". + format(topology)) + if rank is not None and not (isinstance(rank, int) and rank >= 0): + raise ValueError("The rank must >= 0, but got {}".format(rank)) + + # NOTE: Only support even sharding now + if shard_sizes is not None: + raise ValueError("Only support even sharding now.") + + @staticmethod + def get_local_sizes(global_sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + DistributedTensor._validate_sizes_and_dist_attr( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + + local_sizes = [] + # for even sharding, the local sizes of every rank are equal + for idx, item in enumerate(global_sizes): + if dims_mapping[idx] == -1: + local_sizes.append(item) + else: + local_sizes.append(item // topology[dims_mapping[idx]]) + + return local_sizes + + @staticmethod + def get_local_offsets(global_sizes, + dims_mapping, + topology, + processes, + rank, + shard_sizes=None): + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + local_offsets = [] + rank_relatvie = processes.index(rank) + coordinate = _linear_idx2coordinate(topology, rank_relatvie) + + for i in range(len(global_sizes)): + if dims_mapping[i] == -1: + local_offsets.append(0) + else: + local_offsets.append(coordinate[dims_mapping[i]] * + local_sizes[i]) + return local_offsets + + @staticmethod + def get_global_sizes(local_sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + DistributedTensor._validate_sizes_and_dist_attr( + local_sizes, dims_mapping, topology, processes, rank, shard_sizes) + global_sizes = [] + for idx, item in enumerate(local_sizes): + if dims_mapping[idx] == -1: + global_sizes.append(item) + else: + global_sizes.append(item * topology[dims_mapping[idx]]) + return global_sizes + + @staticmethod + def get_local_shard(global_sizes, + dims_mapping, + topology, + processes, + rank, + shard_sizes=None): + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + assert len(local_sizes) == len( + local_offsets + ), "The length of local_sizes must be equal to local_offsets, but got {} and {}.".format( + len(local_sizes), len(local_offsets)) + + local_end_offsets = list( + map(lambda x: x[0] + x[1], zip(local_offsets, local_sizes))) + local_shard = list(zip(local_offsets, local_end_offsets)) + return local_shard + + def __init__(self, serial_tensor, dist_attr=None, dist_context=None): self._serial_tensor = serial_tensor self._dist_attr = None self._batch_dim = 0 # Reuse the dist_attr setter to initialize _dist_attr self.dist_attr = dist_attr + self._local_sizes_map = {} + self._local_offsets_map = {} + self._local_shard_map = {} + self._local_tensor_map = {} + + from .dist_context import get_default_distributed_context + self._dist_context = dist_context if dist_context is not None else get_default_distributed_context( + ) + # TODO: Add Automatically to dist_context after initialized and it will be adapted in the future. + # self._dist_context.add_dist_tensor_for_program(self) @property def serial_tensor(self): @@ -34,6 +171,10 @@ def serial_tensor(self): def dist_attr(self): return self._dist_attr + @property + def dist_context(self): + return self._dist_context + @dist_attr.setter def dist_attr(self, dist_attr): if self._dist_attr is None: @@ -66,12 +207,150 @@ def validate_dist_attr(self): return False return True + def local_sizes(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_sizes = None + if rank in self._local_sizes_map.keys(): + local_sizes = self._local_sizes_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_sizes_map[rank] = local_sizes + + return local_sizes + + def local_offsets(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_offsets = None + if rank in self._local_offsets_map.keys(): + local_offsets = self._local_offsets_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_offsets_map[rank] = local_offsets + + return local_offsets + + def global_sizes(self): + return self.serial_tensor.shape + + def local_shard(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_shard = None + if rank in self._local_shard_map.keys(): + local_shard = self._local_shard_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_shard_map[rank] = local_shard + + return local_shard + + def new_local_tensor(self, block=None, rank=None, name=None): + """ + Create a new local tensor of serial tensor corresponding to rank. + + Args: + block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None. + rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None. + """ + + def _copy_kwargs(serial_tensor): + kwargs = {} + no_need_copy_args = ["self", "block", "shape", "name"] + arg_spec = inspect.getargspec(Variable.__init__) + + for key in arg_spec.args: + # TODO: Check the copied attribute from serial tensor whether valid + if key in no_need_copy_args: + continue + elif key not in kwargs: + if key == "type": + kwargs[key] = serial_tensor.desc.type() + elif key == "dtype": + kwargs[key] = serial_tensor.desc.dtype() + elif key == "lod_level": + kwargs[key] = serial_tensor.desc.lod_level() + elif key == "persistable": + kwargs[key] = serial_tensor.desc.persistable() + elif key == "stop_gradient": + kwargs[key] = serial_tensor.desc.stop_gradient() + elif key == "need_check_feed": + kwargs[key] = serial_tensor.desc.need_check_feed() + # TODO: Get capacity by framework + elif key == "capacity": + continue + else: + kwargs[key] = self.serial_tensor.__dict__[key] + + if isinstance(serial_tensor, Parameter): + kwargs["trainable"] = serial_tensor.trainable + kwargs["optimize_attr"] = serial_tensor.trainable + kwargs["regularizer"] = serial_tensor.regularizer + kwargs["do_model_average"] = serial_tensor.do_model_average + kwargs["need_clip"] = serial_tensor.need_clip + kwargs["is_distributed"] = serial_tensor.is_distributed + kwargs["is_parameter"] = serial_tensor.is_parameter + + return kwargs + + if rank is not None and not (isinstance(rank, int) and rank >= 0): + raise ValueError("The rank must >= 0, but got {}".format(rank)) + if block is not None and not isinstance(block, Block): + raise TypeError("The block must be Block, but got {}.".format( + type(block))) + rank = paddle.distributed.get_rank() if rank is None else rank + + if block is None: + block_id = self.serial_tensor.block.idx + block = self.dist_context.dist_main_programs[rank].block(block_id) + + # copy serial tensor attribute + kwargs = _copy_kwargs(self.serial_tensor) + kwargs["name"] = name + kwargs["shape"] = self.local_sizes(rank) + + if isinstance(self.serial_tensor, Parameter): + kwargs.pop("persistable") + local_tensor = Parameter(block=block, **kwargs) + else: + local_tensor = block.create_var(**kwargs) + + # TODO: Set original id when set original_id is approved + local_tensor.desc.set_original_id(self.serial_tensor.desc.id()) + self._local_tensor_map[rank] = local_tensor + return local_tensor + + def local_tensor(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + assert rank in self._local_tensor_map, "The rank {} local tensor has not been created.".format( + rank) + return self._local_tensor_map[rank] + def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_serial_tensor": + if k == "_serial_tensor" or k == "_local_tensor_map": setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 64c247e56d1d3..b46a10c8c79d8 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -94,6 +94,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_searcher) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_dist_tensor) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp) @@ -262,6 +263,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_searcher) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_dist_tensor) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) @@ -649,6 +651,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_searcher MODULES test_auto_parallel_searcher ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_dist_tensor MODULES test_auto_parallel_dist_tensor ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py new file mode 100644 index 0000000000000..b21cbb5ae78bc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import paddle +from paddle.fluid import core +import paddle.distributed.auto_parallel as auto +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor +from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute +import test_auto_parallel_reshard +from test_auto_parallel_reshard import mlp_forward + + +def get_dist_prog(train_program, + startup_program, + dist_context, + rank_id, + complete_train_program=None): + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context + + # serial forward & backward completion + complete_train_program = auto.complete_annotation( + train_program, dist_context + ) if complete_train_program is None else complete_train_program + + # parallelizer._apply_serial_forward_pass(complete_train_program, + # startup_program) + + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + + # logical partition + partitioner = Partitioner(dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) + + partitioned_optimize_ops = parallelizer._apply_optimize( + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) + + return auto_parallel_main_prog, auto_parallel_startup_prog, complete_train_program + + +class TestDistributedTensor(unittest.TestCase): + def test_new_local_tensor(self): + test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh( + mesh=[0, 1]) + test_auto_parallel_reshard._global_parallel_strategy = "dp" + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 0 + dist_main_prog, dist_startup_prog, complete_train_program = get_dist_prog( + train_program, startup_program, dist_context, rank_id) + dist_context.dist_main_programs[rank_id] = dist_main_prog + dist_context.dist_startup_programs[rank_id] = dist_startup_prog + name = "layer_norm_1.tmp_2" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_0 = dist_tensor.new_local_tensor( + name="intermediate_var_0") + self.assertEqual(intermediate_var_0.shape, (2, 1024)) + self.assertEqual(intermediate_var_0.name, "intermediate_var_0") + + rank_id = 1 + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_main_prog, dist_startup_prog, _ = get_dist_prog( + train_program, startup_program, dist_context, rank_id, + complete_train_program) + dist_context.dist_main_programs[rank_id] = dist_main_prog + dist_context.dist_startup_programs[rank_id] = dist_startup_prog + name = "layer_norm_1.tmp_2" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_1 = dist_tensor.new_local_tensor( + rank=rank_id, name="intermediate_var_1") + self.assertEqual(intermediate_var_0.shape, (2, 1024)) + self.assertEqual(intermediate_var_1.name, "intermediate_var_1") + + name = "linear_0.w_0" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_1 = dist_tensor.new_local_tensor( + rank=rank_id, name="linear_0.w_0_intermediate") + self.assertEqual(intermediate_var_1.shape, (1024, 4096)) + self.assertEqual(intermediate_var_1.name, "linear_0.w_0_intermediate") + + copied_dist_context = copy.deepcopy(dist_context) + self.assertIsNotNone(copied_dist_context) + self.assertEqual( + id(copied_dist_context), + id( + copied_dist_context.get_dist_tensor_for_program( + dist_tensor.serial_tensor).dist_context)) + + def test_static_method(self): + dims_mapping = [1, 0] + processes = [0, 1, 2, 3, 4, 5, 6] + topology = [2, 3] + global_sizes = [6, 6] + + # rank 0 [(0, 2), (0, 3)] + # rank 1 [(2, 4), (0, 3)] + # rank 4 [(2, 4), (3, 6)] + rank = 0 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [0, 0]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(0, 2), (0, 3)]) + + rank = 1 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [2, 0]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(2, 4), (0, 3)]) + + rank = 4 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [2, 3]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(2, 4), (3, 6)]) + + # global sizes + local_sizes = [2, 3] + global_sizes = DistributedTensor.get_global_sizes( + local_sizes, dims_mapping, topology, processes) + self.assertEqual(global_sizes, [6, 6]) + + def test_instance_method(self): + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = [1, 0] + tensor_dist_attr.process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2], [3, 4, 5]]) + serial_tensor = paddle.static.data( + name="data", shape=[6, 6], dtype='float32') + dist_tensor = DistributedTensor(serial_tensor, tensor_dist_attr) + + # rank 0 [(0, 2), (0, 3)] + # rank 1 [(2, 4), (0, 3)] + # rank 4 [(2, 4), (3, 6)] + rank = 0 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [0, 0]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(0, 2), (0, 3)]) + self.assertEqual(local_sizes, dist_tensor.local_sizes(rank)) + self.assertEqual(local_offsets, dist_tensor.local_offsets(rank)) + self.assertEqual(local_shard, dist_tensor.local_shard(rank)) + self.assertEqual(local_sizes, dist_tensor.local_sizes()) + self.assertEqual(local_offsets, dist_tensor.local_offsets()) + self.assertEqual(local_shard, dist_tensor.local_shard()) + + rank = 1 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [2, 0]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(2, 4), (0, 3)]) + + rank = 4 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [2, 3]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(2, 4), (3, 6)]) + + global_sizes = dist_tensor.global_sizes() + self.assertEqual(global_sizes, (6, 6)) + + +if __name__ == "__main__": + unittest.main() From 29c211ee079c03b14929f9354002ade6752e2238 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Tue, 11 Jan 2022 17:43:51 +0800 Subject: [PATCH 09/15] Support test_numpy_bridge and thread_local_has_grad (#38835) --- .../unittests/test_imperative_numpy_bridge.py | 14 ++++++++++++-- .../test_imperative_thread_local_has_grad.py | 8 +++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py b/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py index 772dd913e4d20..4f3089baffdd3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py @@ -16,10 +16,11 @@ import numpy as np import paddle.fluid as fluid import warnings +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode class TestImperativeNumpyBridge(unittest.TestCase): - def test_tensor_from_numpy(self): + def func_tensor_from_numpy(self): data_np = np.array([[2, 3, 1]]).astype('float32') with fluid.dygraph.guard(fluid.CPUPlace()): with warnings.catch_warnings(record=True) as w: @@ -39,9 +40,18 @@ def test_tensor_from_numpy(self): self.assertTrue(np.array_equal(var2.numpy(), data_np)) data_np[0][0] = -1 self.assertEqual(data_np[0][0], -1) - self.assertNotEqual(var2[0][0].numpy()[0], -1) + if _in_eager_mode(): + # eager_mode, var2 is EagerTensor, is not subscriptable + self.assertNotEqual(var2.numpy()[0][0], -1) + else: + self.assertNotEqual(var2[0][0].numpy()[0], -1) self.assertFalse(np.array_equal(var2.numpy(), data_np)) + def test_func_tensor_from_numpy(self): + with _test_eager_guard(): + self.func_tensor_from_numpy() + self.func_tensor_from_numpy() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py index d81849725d75a..f54e50953f131 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py @@ -18,6 +18,7 @@ import paddle.nn as nn import numpy as np import threading +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode class SimpleNet(nn.Layer): @@ -44,7 +45,7 @@ def thread_2_main(self): x = net(x) self.assertFalse(x.stop_gradient) - def test_main(self): + def func_main(self): threads = [] for _ in range(10): threads.append(threading.Thread(target=self.thread_1_main)) @@ -54,6 +55,11 @@ def test_main(self): for t in threads: t.join() + def test_main(self): + with _test_eager_guard(): + self.func_main() + self.func_main() + if __name__ == "__main__": unittest.main() From 2bed9b9c5970497cfbbff197d6eb7a4b87680dd2 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 11 Jan 2022 19:16:15 +0800 Subject: [PATCH 10/15] [PTEN] Add pten::Place data structure. (#38844) * add pten::Place data structure. * update ci problem * fix ci problem * update --- paddle/pten/api/lib/utils/CMakeLists.txt | 4 +- paddle/pten/api/lib/utils/place_utils.cc | 62 ------------ paddle/pten/api/lib/utils/place_utils.h | 28 ------ paddle/pten/common/CMakeLists.txt | 2 +- paddle/pten/common/device.cc | 65 ------------- paddle/pten/common/device.h | 70 -------------- paddle/pten/common/place.cc | 57 +++++++++-- paddle/pten/common/place.h | 109 +++++++++++++++------- paddle/pten/tests/api/CMakeLists.txt | 1 - paddle/pten/tests/api/test_place_utils.cc | 77 --------------- paddle/pten/tests/common/CMakeLists.txt | 1 + paddle/pten/tests/common/test_place.cc | 53 +++++++++++ 12 files changed, 184 insertions(+), 345 deletions(-) delete mode 100644 paddle/pten/api/lib/utils/place_utils.cc delete mode 100644 paddle/pten/api/lib/utils/place_utils.h delete mode 100644 paddle/pten/common/device.cc delete mode 100644 paddle/pten/common/device.h delete mode 100644 paddle/pten/tests/api/test_place_utils.cc create mode 100644 paddle/pten/tests/common/test_place.cc diff --git a/paddle/pten/api/lib/utils/CMakeLists.txt b/paddle/pten/api/lib/utils/CMakeLists.txt index 06178dad43767..4a44ad7758b56 100644 --- a/paddle/pten/api/lib/utils/CMakeLists.txt +++ b/paddle/pten/api/lib/utils/CMakeLists.txt @@ -1,2 +1,2 @@ -cc_library(pten_api_utils SRCS allocator.cc storage.cc tensor_utils.cc place_utils.cc DEPS -tensor_base convert_utils dense_tensor lod_tensor selected_rows place var_type_traits pten_common) +cc_library(pten_api_utils SRCS allocator.cc storage.cc tensor_utils.cc DEPS +tensor_base convert_utils dense_tensor lod_tensor selected_rows place var_type_traits) diff --git a/paddle/pten/api/lib/utils/place_utils.cc b/paddle/pten/api/lib/utils/place_utils.cc deleted file mode 100644 index af4f84b1ad836..0000000000000 --- a/paddle/pten/api/lib/utils/place_utils.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/pten/api/lib/utils/place_utils.h" -#include "paddle/pten/api/ext/exception.h" - -namespace paddle { -namespace experimental { - -Place ConvertToPtenPlace(const platform::Place& src) { - Place place; - if (platform::is_cpu_place(src)) { - place.Reset(Device(DeviceType::kHost, 0)); - } else if (platform::is_gpu_place(src)) { - place.Reset( - Device(DeviceType::kCuda, - BOOST_GET_CONST(platform::CUDAPlace, src).GetDeviceId())); - } else if (platform::is_cuda_pinned_place(src)) { - place.Reset(Device(DeviceType::kCuda, 0), true); - } else if (platform::is_xpu_place(src)) { - place.Reset(Device(DeviceType::kXpu, - BOOST_GET_CONST(platform::XPUPlace, src).GetDeviceId())); - } else { - PD_THROW("Invalid platform place type."); - } - return place; -} - -platform::Place ConvertToPlatformPlace(const Place& src) { - switch (src.device().type()) { - case DeviceType::kHost: { - return platform::CPUPlace(); - } - case DeviceType::kCuda: { - if (src.is_pinned()) { - return platform::CUDAPinnedPlace(); - } else { - return platform::CUDAPlace(src.device().id()); - } - } - case DeviceType::kXpu: { - return platform::XPUPlace(src.device().id()); - } - default: - PD_THROW("Invalid pten place type."); - } - return {}; -} - -} // namespace experimental -} // namespace paddle diff --git a/paddle/pten/api/lib/utils/place_utils.h b/paddle/pten/api/lib/utils/place_utils.h deleted file mode 100644 index 9ac10158040b2..0000000000000 --- a/paddle/pten/api/lib/utils/place_utils.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/platform/place.h" -#include "paddle/pten/common/place.h" - -namespace paddle { -namespace experimental { - -Place ConvertToPtenPlace(const platform::Place& src); - -platform::Place ConvertToPlatformPlace(const Place& src); - -} // namespace experimental -} // namespace paddle diff --git a/paddle/pten/common/CMakeLists.txt b/paddle/pten/common/CMakeLists.txt index c4083d7f0d756..feaf0e12bdb16 100644 --- a/paddle/pten/common/CMakeLists.txt +++ b/paddle/pten/common/CMakeLists.txt @@ -1 +1 @@ -cc_library(pten_common SRCS device.cc place.cc DEPS enforce) +cc_library(pten_place SRCS place.cc) diff --git a/paddle/pten/common/device.cc b/paddle/pten/common/device.cc deleted file mode 100644 index 55130067ae200..0000000000000 --- a/paddle/pten/common/device.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/pten/common/device.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/pten/api/ext/exception.h" - -namespace paddle { -namespace experimental { - -const char* DeviceTypeStr(DeviceType type) { - switch (type) { - case DeviceType::kUndef: - return "kUndef"; - case DeviceType::kHost: - return "kHost"; - case DeviceType::kXpu: - return "kXpu"; - case DeviceType::kCuda: - return "kCuda"; - case DeviceType::kHip: - return "kHip"; - case DeviceType::kNpu: - return "kNpu"; - default: - PD_THROW("Invalid pten device type."); - } - return {}; -} - -Device::Device(DeviceType type, int8_t id) : type_(type), id_(id) { - PADDLE_ENFORCE_GE( - id, - 0, - platform::errors::InvalidArgument( - "The device id needs to start from zero, but you passed in %d.", id)); -} - -Device::Device(DeviceType type) : type_(type), id_(0) { - PADDLE_ENFORCE_EQ( - type, - DeviceType::kHost, - platform::errors::InvalidArgument( - "The device id needs to start from zero, but you passed in %s.", - DeviceTypeStr(type))); -} - -std::string Device::DebugString() const { - std::string str{"DeviceType:"}; - return str + DeviceTypeStr(type_) + ", id: " + std::to_string(id_); -} - -} // namespace experimental -} // namespace paddle diff --git a/paddle/pten/common/device.h b/paddle/pten/common/device.h deleted file mode 100644 index eddb71bce16da..0000000000000 --- a/paddle/pten/common/device.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -namespace paddle { -namespace experimental { - -enum class DeviceType : int8_t { - kUndef = 0, - kHost = 1, - kXpu = 2, - kCuda = 3, - kHip = 4, - kNpu = 5, -}; - -const char* DeviceTypeStr(DeviceType type); - -/// \brief The device is used to store hardware information. It has not yet -/// stored information related to the math acceleration library. -struct Device final { - public: - Device() = default; - - Device(DeviceType type, int8_t id); - - Device(DeviceType type); - - DeviceType type() const noexcept { return type_; } - - /// \brief Returns the index of the device. Here, -1 is used to indicate an - /// invalid value, and 0 to indicate a default value. - /// \return The index of the device. - int8_t id() const noexcept { return id_; } - - void set_type(DeviceType type) noexcept { type_ = type; } - - void set_id(int8_t id) noexcept { id_ = id; } - - std::string DebugString() const; - - private: - friend bool operator==(const Device&, const Device&) noexcept; - - private: - DeviceType type_{DeviceType::kUndef}; - int8_t id_{-1}; -}; - -inline bool operator==(const Device& lhs, const Device& rhs) noexcept { - return (lhs.type_ == rhs.type_) && (lhs.id_ == rhs.id_); -} - -} // namespace experimental -} // namespace paddle diff --git a/paddle/pten/common/place.cc b/paddle/pten/common/place.cc index ba34c5d0f9222..2d33bb508af44 100644 --- a/paddle/pten/common/place.cc +++ b/paddle/pten/common/place.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,57 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/pten/common/place.h" -#include "paddle/fluid/platform/enforce.h" -namespace paddle { -namespace experimental { +#include +#include + +#include "paddle/pten/api/ext/exception.h" + +namespace pten { + +const char *AllocationTypeStr(AllocationType type) { + switch (type) { + case AllocationType::UNDEF: + return "undef"; + case AllocationType::CPU: + return "cpu"; + case AllocationType::GPU: + return "gpu"; + case AllocationType::GPUPINNED: + return "gpu pinned"; + case AllocationType::XPU: + return "xpu"; + case AllocationType::NPU: + return "npu"; + case AllocationType::NPUPINNED: + return "npu pinned"; + case AllocationType::IPU: + return "ipu"; + case AllocationType::MLU: + return "mlu"; + default: + PD_THROW("Invalid pten device type."); + return {}; + } +} std::string Place::DebugString() const { - return device_.DebugString() + ", is_pinned: " + std::to_string(is_pinned_); + std::ostringstream os; + os << "Place("; + os << AllocationTypeStr(alloc_type_); + if (alloc_type_ == AllocationType::GPUPINNED || + alloc_type_ == AllocationType::NPUPINNED || + alloc_type_ == AllocationType::CPU) { + os << ")"; + } else { + os << ":" << std::to_string(device) << ")"; + } + return os.str(); +} + +std::ostream &operator<<(std::ostream &os, const Place &p) { + os << p.DebugString(); + return os; } -} // namespace experimental -} // namespace paddle +} // namespace pten diff --git a/paddle/pten/common/place.h b/paddle/pten/common/place.h index fdc948734934b..24d24305202cf 100644 --- a/paddle/pten/common/place.h +++ b/paddle/pten/common/place.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,52 +16,97 @@ limitations under the License. */ #include -#include "paddle/pten/common/device.h" +namespace pten { + +enum class AllocationType : int8_t { + UNDEF = 0, + CPU = 1, + GPU = 2, + GPUPINNED = 3, + XPU = 4, + NPU = 5, + NPUPINNED = 6, + IPU = 7, + MLU = 8, +}; -namespace paddle { -namespace experimental { +const char *AllocationTypeStr(AllocationType type); /// \brief The place is used to specify where the data is stored. -class Place final { +class Place { public: - Place() = default; - - explicit Place(const Device& device) : device_(device) {} + Place() : device(0), alloc_type_(AllocationType::UNDEF) {} - Place(DeviceType type, int8_t id) : device_(type, id) {} + explicit Place(AllocationType type, int8_t id) + : device(id), alloc_type_(type) {} - Place(DeviceType type) : device_(type) {} + explicit Place(AllocationType type) : device(0), alloc_type_(type) {} - Place(const Device& device, bool is_pinned) noexcept : device_(device), - is_pinned_(is_pinned) { + void Reset(AllocationType type, int8_t device_id = 0) noexcept { + alloc_type_ = type; + device = device_id; } - const Device& device() const noexcept { return device_; } + AllocationType GetType() const { return alloc_type_; } - /// \brief Returns whether the memory is a locked page. The page lock - /// memory is actually located in the host memory, but it can only be - /// used by certain devices and can be directly transferred by DMA. - /// \return Whether the memory is a locked page. - bool is_pinned() const noexcept { return is_pinned_; } - - void Reset(const Device& device, bool is_pinned = false) noexcept { - device_ = device; - is_pinned_ = is_pinned; - } + int8_t GetDeviceId() const { return device; } std::string DebugString() const; - private: - friend bool operator==(const Place&, const Place&) noexcept; + public: + // TODO(wilber): Just because of backward compatibility, it needs to be + // changed to private in the future. + int8_t device; private: - Device device_; - bool is_pinned_{false}; + AllocationType alloc_type_; +}; + +class CPUPlace : public Place { + public: + CPUPlace() : Place(AllocationType::CPU, 0) {} +}; + +class GPUPlace : public Place { + public: + GPUPlace() : Place(AllocationType::GPU, 0) {} + explicit GPUPlace(int device_id) : Place(AllocationType::GPU, device_id) {} +}; + +class GPUPinnedPlace : public Place { + public: + GPUPinnedPlace() : Place(AllocationType::GPUPINNED) {} +}; + +class XPUPlace : public Place { + public: + XPUPlace() : Place(AllocationType::XPU, 0) {} + explicit XPUPlace(int device_id) : Place(AllocationType::XPU, device_id) {} +}; + +class NPUPlace : public Place { + public: + NPUPlace() : Place(AllocationType::NPU, 0) {} + explicit NPUPlace(int device_id) : Place(AllocationType::XPU, device_id) {} +}; + +class NPUPinnedPlace : public Place { + public: + NPUPinnedPlace() : Place(AllocationType::NPUPINNED) {} +}; + +class IPUPlace : public Place { + public: + IPUPlace() : Place(AllocationType::XPU, 0) {} + explicit IPUPlace(int device_id) : Place(AllocationType::XPU, device_id) {} +}; + +class MLUPlace : public Place { + public: + MLUPlace() : Place(AllocationType::MLU, 0) {} + explicit MLUPlace(int device_id) : Place(AllocationType::MLU, device_id) {} }; -inline bool operator==(const Place& lhs, const Place& rhs) noexcept { - return (lhs.device_ == rhs.device_) && (lhs.is_pinned_ == rhs.is_pinned_); -} +std::ostream &operator<<(std::ostream &, const Place &); -} // namespace experimental -} // namespace paddle +} // namespace pten diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index bb1eab2c09551..ffbc551843148 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -7,7 +7,6 @@ endif() cc_test(test_pten_exception SRCS test_pten_exception.cc DEPS gtest) cc_test(test_framework_storage SRCS test_storage.cc DEPS pten_api_utils) cc_test(test_framework_tensor_utils SRCS test_tensor_utils.cc DEPS pten_api_utils) -cc_test(test_framework_place_utils storage SRCS test_place_utils.cc DEPS pten_api_utils) cc_test(test_mean_api SRCS test_mean_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_dot_api SRCS test_dot_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_place_utils.cc b/paddle/pten/tests/api/test_place_utils.cc deleted file mode 100644 index 4db1f59d83786..0000000000000 --- a/paddle/pten/tests/api/test_place_utils.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "gtest/gtest.h" - -#include "paddle/pten/api/lib/utils/place_utils.h" - -namespace paddle { -namespace experimental { -namespace tests { - -TEST(place_utils, cpu_place) { - auto pd_place = platform::CPUPlace(); - Place pten_place = ConvertToPtenPlace(pd_place); - CHECK_EQ(pten_place.device().id(), 0); - CHECK(pten_place.device().type() == DeviceType::kHost); - CHECK(pten_place.is_pinned() == false); - - auto pd_place_1 = ConvertToPlatformPlace(pten_place); - CHECK(platform::is_cpu_place(pd_place_1)); - CHECK(pd_place == BOOST_GET_CONST(platform::CPUPlace, pd_place_1)); - CHECK(pten_place == ConvertToPtenPlace(pd_place_1)); -} - -TEST(place_utils, cuda_place) { - auto pd_place = platform::CUDAPlace(1); - Place pten_place = ConvertToPtenPlace(pd_place); - CHECK_EQ(pten_place.device().id(), 1); - CHECK(pten_place.device().type() == DeviceType::kCuda); - CHECK(pten_place.is_pinned() == false); - - auto pd_place_1 = ConvertToPlatformPlace(pten_place); - CHECK(platform::is_gpu_place(pd_place_1)); - CHECK(pd_place == BOOST_GET_CONST(platform::CUDAPlace, pd_place_1)); - CHECK(pten_place == ConvertToPtenPlace(pd_place_1)); -} - -TEST(place_utils, cuda_pinned_place) { - auto pd_place = platform::CUDAPinnedPlace(); - Place pten_place = ConvertToPtenPlace(pd_place); - CHECK_EQ(pten_place.device().id(), 0); - CHECK(pten_place.device().type() == DeviceType::kCuda); - CHECK(pten_place.is_pinned() == true); - - auto pd_place_1 = ConvertToPlatformPlace(pten_place); - CHECK(platform::is_cuda_pinned_place(pd_place_1)); - CHECK(pd_place == BOOST_GET_CONST(platform::CUDAPinnedPlace, pd_place_1)); - CHECK(pten_place == ConvertToPtenPlace(pd_place_1)); -} - -TEST(place_utils, xpu_place) { - auto pd_place = platform::XPUPlace(1); - Place pten_place = ConvertToPtenPlace(pd_place); - CHECK_EQ(pten_place.device().id(), 1); - CHECK(pten_place.device().type() == DeviceType::kXpu); - CHECK(pten_place.is_pinned() == false); - - auto pd_place_1 = ConvertToPlatformPlace(pten_place); - CHECK(platform::is_xpu_place(pd_place_1)); - CHECK(pd_place == BOOST_GET_CONST(platform::XPUPlace, pd_place_1)); - CHECK(pten_place == ConvertToPtenPlace(pd_place_1)); -} - -} // namespace tests -} // namespace experimental -} // namespace paddle diff --git a/paddle/pten/tests/common/CMakeLists.txt b/paddle/pten/tests/common/CMakeLists.txt index c0a5414d53e47..f54b37cb976c5 100644 --- a/paddle/pten/tests/common/CMakeLists.txt +++ b/paddle/pten/tests/common/CMakeLists.txt @@ -1,3 +1,4 @@ cc_test(pten_test_backend SRCS test_backend.cc DEPS gtest) cc_test(pten_test_data_layout SRCS test_data_layout.cc DEPS gtest) cc_test(pten_test_data_type SRCS test_data_type.cc DEPS gtest) +cc_test(pten_test_place SRCS test_place.cc DEPS pten_place) diff --git a/paddle/pten/tests/common/test_place.cc b/paddle/pten/tests/common/test_place.cc new file mode 100644 index 0000000000000..0bbd8f1d42273 --- /dev/null +++ b/paddle/pten/tests/common/test_place.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/common/place.h" + +#include "gtest/gtest.h" + +namespace pten { +namespace tests { + +TEST(PtenPlace, place) { + pten::Place place; + EXPECT_EQ(place.GetType(), pten::AllocationType::UNDEF); + + place.Reset(pten::AllocationType::GPU, 1); + EXPECT_EQ(place.GetType(), pten::AllocationType::GPU); + EXPECT_EQ(place.GetDeviceId(), 1); +} + +TEST(Place, cpu_place) { + pten::CPUPlace place; + EXPECT_EQ(place.GetType(), pten::AllocationType::CPU); + std::cout << "cpu place repr: " << place << std::endl; +} + +TEST(Place, gpu_place) { + pten::GPUPlace place; + EXPECT_EQ(place.GetType(), pten::AllocationType::GPU); + EXPECT_EQ(place.GetDeviceId(), 0); + + pten::GPUPlace place1(2); + EXPECT_EQ(place1.GetType(), pten::AllocationType::GPU); + EXPECT_EQ(place1.GetDeviceId(), 2); + std::cout << "gpu place repr: " << place1 << std::endl; + + pten::GPUPinnedPlace place2; + EXPECT_EQ(place2.GetType(), pten::AllocationType::GPUPINNED); + std::cout << "gpu pinned place repr: " << place2 << std::endl; +} + +} // namespace tests +} // namespace pten From 3eaf8d2cead9fc3d7b82c5c928c331917ea687b6 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Tue, 11 Jan 2022 19:49:01 +0800 Subject: [PATCH 11/15] Modified Kernel Primitive API and elementwise for xpu2 #38688 --- .../elementwise/elementwise_op_broadcast.cu.h | 8 +- .../elementwise/elementwise_op_impl.cu.h | 3 +- .../datamover_primitives_xpu2.h | 172 +++++++++--------- .../kernel_primitives/kernel_primitives.h | 15 +- paddle/fluid/platform/hostdevice.h | 9 +- paddle/pten/kernels/gpu/elementwise.h | 104 +++++------ 6 files changed, 164 insertions(+), 147 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 25c983566b371..e3d4607b7130c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -25,8 +25,7 @@ namespace kps = paddle::operators::kernel_primitives; template void LaunchBroadcastElementwiseCudaKernel( - const platform::CUDADeviceContext &ctx, - const std::vector &ins, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { std::vector pt_inputs; std::vector pt_outputs; @@ -58,8 +57,7 @@ void LaunchBroadcastElementwiseCudaKernel( template void LaunchElementwiseCudaKernel( - const platform::CUDADeviceContext &cuda_ctx, - const std::vector &ins, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { std::vector pt_inputs; std::vector pt_outputs; @@ -85,7 +83,7 @@ void LaunchElementwiseCudaKernel( pt_outputs.push_back(pt_outputs_tmp[i].get()); } pten::LaunchElementwiseCudaKernel( - cuda_ctx, pt_inputs, &pt_outputs, axis, func); + ctx, pt_inputs, &pt_outputs, axis, func); } } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 1d8acd5eca5d9..36ff1ae254d20 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -35,8 +35,7 @@ using ElementwiseType = pten::ElementwiseType; template void LaunchSameDimsElementwiseCudaKernel( - const platform::CUDADeviceContext &ctx, - const std::vector &ins, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { std::vector pt_inputs; std::vector pt_outputs; diff --git a/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h b/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h index b27ba27b3c6f1..333899535894e 100644 --- a/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h +++ b/paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h @@ -32,42 +32,50 @@ struct alignas(sizeof(T) * VecSize) VectorType { * index of the output data. if input or output shape is [dim0, dim1] then dims * must be [dim1, dim0]. */ +#pragma pack(4) template struct BroadcastConfig { - uint32_t stride_in[framework::DDim::kMaxRank]; - uint32_t stride_out[framework::DDim::kMaxRank]; - uint32_t shape_in[framework::DDim::kMaxRank]; + int strides_in[framework::DDim::kMaxRank]; + int strides_out[framework::DDim::kMaxRank]; + int in_dim[framework::DDim::kMaxRank]; HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig(const std::vector& out_dims, const std::vector& in_dims, int dim_size) { - std::vector strides_in; - std::vector strides_out; - std::vector shapes_in; - - strides_out.resize(dim_size, 1); - strides_in.resize(dim_size, 1); - shapes_in.resize(dim_size, 1); - - for (int i = 0; i < dim_size; ++i) { - shape_in[i] = in_dims[dim_size - i - 1]; + std::vector strides_in_tmp; + std::vector strides_out_tmp; + std::vector dim_tmp; + strides_in_tmp.resize(dim_size, 1); + strides_out_tmp.resize(dim_size, 1); + dim_tmp.resize(dim_size, 1); + for (int i = 1; i < dim_size; i++) { + strides_in_tmp[i] = strides_in_tmp[i - 1] * in_dims[i - 1]; + strides_out_tmp[i] = strides_out_tmp[i - 1] * out_dims[i - 1]; } - for (int i = 1; i < dim_size - 1; ++i) { - strides_out[dim_size - i - 1] = std::accumulate( - out_dims.begin(), out_dims.begin() + i, 1, std::multiplies()) - strides_in[dim_size - i - 1] = - std::accumulate(in_dims.begin(), in_dims.begin() + i, 1, - std::multiplies()) + for (int i = 0; i < dim_size; i++) { + dim_tmp[i] = in_dims[i]; } - memcpy(stride_in, strides_in.data(), kDims * sizeof(uint32_t)); - memcpy(stride_out, strides_out.data(), kDims * sizeof(uint32_t)); - memcpy(shape_in, shapes_in.data(), kDims * sizeof(uint32_t)); + memcpy(strides_in, strides_in_tmp.data(), kDims * sizeof(int)); + memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); + memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); + } + + __device__ inline int operator()(int index_output) const { + int index_src = 0; +#pragma unroll + for (int i = kDims - 1; i >= 0; --i) { + int tmp_index = (index_output / strides_out[i]); + index_output = index_output - tmp_index * strides_out[i]; + index_src += (tmp_index % in_dim[i]) * strides_in[i]; + } + return index_src; } }; +#pragma pack() } // namespace details @@ -99,12 +107,12 @@ struct BroadcastConfig { */ template -__device__ __forceinline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, - int size_nx, int size_ny, - int stride_nx, int stride_ny) { +__device__ __inline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, + int size_nx, int size_ny, int stride_nx, + int stride_ny) { int thread_offset = core_id(); int left_size_nx = size_nx - thread_offset; - __local__ T in_temp[1]; + __local__ Tx in_temp[1]; // Each branch is added for better performance if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1 if (IsBoundary) { @@ -168,7 +176,7 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx _global_ptr_* src, * init_data: Initial value. */ template -__device__ __forceinline__ void Init(T* dst, T init_data) { +__device__ __inline__ void Init(T* dst, T init_data) { #pragma unroll for (int i = 0; i < NX; i++) { dst[i] = init_data; @@ -197,8 +205,8 @@ __device__ __forceinline__ void Init(T* dst, T init_data) { * size: The current block needs to load size data continuously. */ template -__device__ __forceinline__ void ReadData(T* dst, const T _global_ptr_* src, - int num) { +__device__ __inline__ void ReadData(T* dst, const T _global_ptr_* src, + int num) { int thread_offset = core_id() * NX; __local__ T in_temp[1]; if (IsBoundary) { // core_num() * NX > num @@ -241,10 +249,11 @@ __device__ __forceinline__ void ReadData(T* dst, const T _global_ptr_* src, */ template -__device__ __forceinline__ void ReadDataBc( - T* dst, const T _global_ptr_* src, uint32_t block_offset, - details::BroadcastConfig config, int total_num_output, int stride_nx, - int stride_ny) { +__device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, + uint32_t block_offset, + details::BroadcastConfig config, + int total_num_output, int stride_nx, + int stride_ny) { uint32_t thread_offset = block_offset + core_id(); uint32_t index_src = 0; __local__ T in_temp[1]; @@ -256,16 +265,11 @@ __device__ __forceinline__ void ReadDataBc( uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx; index_src = 0; if (IsBoundary) { - if (index_output >= total_num_output) { + if (index_output >= (uint32_t)total_num_output) { break; } } -#pragma unroll - for (int i = 0; i < Rank; ++i) { - uint32_t tmp = index_output / config.stride_out[i]; - index_output = index_output - tmp * config.stride_out[i]; - index_src += (tmp % config.shape_in[i]) * config.stride_in[i]; - } + index_src = config(index_output); GM2LM(src + index_src, in_temp, sizeof(T)); dst[nx + ny * NX] = in_temp[0]; } @@ -305,33 +309,34 @@ __device__ __forceinline__ void ReadDataBc( */ template -__device__ __forceinline__ void ReadDataReduce( - T* dst, const T _global_ptr_* src, int block_offset, - const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx, - int stride_ny, bool reduce_last_dim) { - __local__ T in_temp[1]; +__device__ __inline__ void ReadDataReduce(T* dst, const T _global_ptr_* src, + int block_offset, + const IndexCal& index_cal, + int size_nx, int size_ny, + int stride_nx, int stride_ny, + bool reduce_last_dim) { + __local__ Tx in_temp[1]; int thread_offset = 0; - int left_size_nx = size_nx; - int left_size_ny = size_ny; + int left_idx = 0; if (reduce_last_dim) { - thread_offset = block_offset + core_id(); - left_size_nx -= thread_offset; + thread_offset = core_id(); + left_idx = 0; } else { - thread_offset = block_offset + core_id(); - left_size_ny -= thread_offset; + thread_offset = 0; + left_idx = 0; } if (NX == 1) { #pragma unroll for (int ny = 0; ny < NY; ++ny) { if (IsBoundary) { - if (ny * stride_ny >= left_size_ny) { + if (thread_offset >= size_ny) { break; } } - uint32_t index_src = index_cal(thread_offset); - GM2LM(src + index_src, in_temp, sizeof(T)); - dst[ny] = in_temp[0]; + uint32_t index_src = index_cal(thread_offset + block_offset); + GM2LM(src + index_src, in_temp, sizeof(Tx)); + dst[ny] = static_cast(func(in_temp[0])); thread_offset += stride_ny; } } else { @@ -340,17 +345,16 @@ __device__ __forceinline__ void ReadDataReduce( #pragma unroll for (int ny = 0; ny < NY; ++ny) { if (IsBoundary) { - if ((ny * stride_ny >= left_size_ny) || - (nx * stride_nx >= left_size_nx)) { + if ((thread_offset >= size_ny) || + (left_idx + nx * stride_nx >= size_nx)) { break; } } - uint32_t index_src = index_cal(thread_offset); - GM2LM(src + index_src, in_temp, sizeof(T)); - dst[nx + ny * NX] = in_temp[0]; + uint32_t index_src = index_cal(thread_offset + block_offset); + GM2LM(src + index_src, in_temp, sizeof(Tx)); + dst[nx + ny * NX] = static_cast(func(in_temp[0])); thread_offset += stride_ny; } - thread_offset += stride_nx; } } } @@ -421,9 +425,9 @@ __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { */ template -__device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, - int size_nx, int size_ny, - int stride_nx, int stride_ny) { +__device__ __inline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, + int size_nx, int size_ny, int stride_nx, + int stride_ny) { int thread_offset = core_id(); int left_size_nx = size_nx - thread_offset; __local__ Ty in_temp[1]; @@ -433,11 +437,11 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, if (IsBoundary) { if (left_size_nx > 0) { in_temp[0] = static_cast(src[0]); - LM2GM(in_temp, dst + thread_offset, sizeof(T)); + LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); } } else { in_temp[0] = static_cast(src[0]); - LM2GM(in_temp, dst + thread_offset, sizeof(T)); + LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); } } else if (NX == 1) { #pragma unroll @@ -449,7 +453,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, } in_temp[0] = static_cast(src[idy]); - LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(T)); + LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty)); } } else if (NY == 1) { // for NY == 1 and NX != 1 #pragma unroll @@ -461,7 +465,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, } in_temp[0] = static_cast(src[idx]); - LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(T)); + LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty)); } } else { // for NX != 1 and NY != 1 #pragma unroll @@ -480,7 +484,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, } in_temp[0] = static_cast(src[idx + idy * NX]); LM2GM(in_temp, dst + thread_offset + idx * stride_nx + idy * stride_ny, - sizeof(T)); + sizeof(Ty)); } } } @@ -498,7 +502,7 @@ __device__ __forceinline__ void WriteData(Ty _global_ptr_* dst, const Tx* src, * init_data: The register pointer of init data, the size is NX. */ template -__device__ __forceinline__ void Init(T* dst, T* init_data, int num) { +__device__ __inline__ void Init(T* dst, T* init_data, int num) { #pragma unroll for (int i = 0; i < NX; i++) { if (IsBoundary) { @@ -535,30 +539,26 @@ __device__ __forceinline__ void Init(T* dst, T* init_data, int num) { */ template -__device__ __forceinline__ void ReadDataBc( - T* dst, const T _global_ptr_* src, uint32_t block_offset, - details::BroadcastConfig config, int total_num_output) { - uint32_t thread_offset = block_offset + core_id() * NX; - uint32_t index_src = 0; - __local__ T in_temp[1]; +__device__ __inline__ void ReadDataBc(T* dst, const T _global_ptr_* src, + uint32_t block_offset, + details::BroadcastConfig config, + int total_num_output) { + int thread_offset = block_offset + core_id() * NX; + int index_src = 0; + __local__ T in_temp; #pragma unroll - for (uint32_t nx = 0; nx < NX; ++nx) { - uint32_t index_output = thread_offset + nx; + for (int nx = 0; nx < NX; ++nx) { + int index_output = thread_offset + nx; index_src = 0; if (IsBoundary) { if (index_output >= total_num_output) { break; } } -#pragma unroll - for (int i = 0; i < Rank; ++i) { - uint32_t tmp = index_output / config.stride_out[i]; - index_output = index_output - tmp * config.stride_out[i]; - index_src += (tmp % config.shape_in[i]) * config.stride_in[i]; - } - GM2LM(src + index_src, in_temp, sizeof(T)); - dst[nx + ny * NX] = in_temp[0]; + index_src = config(index_output); + GM2LM(src + index_src, &in_temp, sizeof(T)); + dst[nx] = in_temp; } } diff --git a/paddle/fluid/operators/kernel_primitives/kernel_primitives.h b/paddle/fluid/operators/kernel_primitives/kernel_primitives.h index e20e77ae26a71..558f8c81c6642 100644 --- a/paddle/fluid/operators/kernel_primitives/kernel_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/kernel_primitives.h @@ -13,11 +13,18 @@ // limitations under the License. #pragma once -#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" #include "paddle/fluid/operators/kernel_primitives/helper_primitives.h" #ifdef PADDLE_WITH_XPU2 #include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h" #include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives_xpu2.h" + +#define KPStream XPUStream +#define KPDevice paddle::platform::XPUDeviceContext +#define _ptr_ _global_ptr_ +#define __forceinline__ __inline__ +#define __restrict__ + #define THREAD_ID_X core_id() #define THREAD_ID_Y 0 #define THREAD_ID_Z 0 @@ -36,6 +43,12 @@ #else #include "paddle/fluid/operators/kernel_primitives/compute_primitives.h" #include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" + +#define KPStream gpuStream_t +#define KPDevice paddle::platform::CUDADeviceContext +#define _ptr_ + #define THREAD_ID_X threadIdx.x #define THREAD_ID_Y threadIdx.y #define THREAD_ID_Z threadIdx.z diff --git a/paddle/fluid/platform/hostdevice.h b/paddle/fluid/platform/hostdevice.h index 1ffbbc217e254..65005a5adbb1d 100644 --- a/paddle/fluid/platform/hostdevice.h +++ b/paddle/fluid/platform/hostdevice.h @@ -17,7 +17,14 @@ #include #endif -#if (defined(__CUDACC__) || defined(__HIPCC__)) +#ifdef __xpu_kp__ +#include +#include "xpu/kernel/cluster_header.h" +#include "xpu/kernel/debug.h" +#include "xpu/kernel/math.h" +#endif + +#if (defined(__CUDACC__) || defined(__HIPCC__) || defined(__xpu_kp__)) #define HOSTDEVICE __host__ __device__ #define DEVICE __device__ #define HOST __host__ diff --git a/paddle/pten/kernels/gpu/elementwise.h b/paddle/pten/kernels/gpu/elementwise.h index f78328c01a30d..e4cc894e48354 100644 --- a/paddle/pten/kernels/gpu/elementwise.h +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -86,7 +86,7 @@ struct ElementwisePrimitiveCaller { template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( - paddle::framework::Array outs, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, ConditionalT src[VecSize], int block_offset, int num) { @@ -109,7 +109,7 @@ struct ElementwiseWriteDataCaller { template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( - paddle::framework::Array outs, + paddle::framework::Array<_ptr_ OutT *, 1> outs, OutT src[VecSize], int block_offset, int num) { @@ -126,8 +126,8 @@ template __device__ void VectorizedElementwiseKernelImpl( - const paddle::framework::Array &in, - paddle::framework::Array outs, + const paddle::framework::Array &in, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, int num, int data_offset, Functor func) { @@ -161,8 +161,8 @@ template __global__ void VectorizedElementwiseKernel( - paddle::framework::Array ins, - paddle::framework::Array outs, + paddle::framework::Array ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, int size, int main_offset, Functor func) { @@ -212,17 +212,13 @@ template -void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, +void ElementwiseCudaKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { auto numel = ins[0]->numel(); - int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize); - int grid_size = - ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; - auto stream = ctx.stream(); - paddle::framework::Array ins_data; - paddle::framework::Array outs_data; + paddle::framework::Array ins_data; + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < Arity; ++i) { ins_data[i] = ins[i]->data(); @@ -231,8 +227,9 @@ void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, outs_data[i] = (*outs)[i]->mutable_data(); } #ifdef PADDLE_WITH_XPU2 - block_size = 128; - grid_size = 8; + int block_size = 64; + int grid_size = 8; + auto stream = ctx.x_context()->xpu_stream; int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; VectorizedElementwiseKernel<<>>( ins_data, outs_data, numel, main_offset, func); #else + int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize); + int grid_size = + ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + auto stream = ctx.stream(); VectorizedElementwiseKernel void LaunchSameDimsElementwiseCudaKernel( - const paddle::platform::CUDADeviceContext &ctx, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { @@ -471,12 +472,12 @@ struct DimensionsTransform { template __device__ __forceinline__ void LoadData( T *dst, - const T *__restrict__ src, + const _ptr_ T *src, uint32_t block_offset, const kps::details::BroadcastConfig &config, int numel, int num, - bool need_broadcast) { + int need_broadcast) { // numel : whole num of output // num: how many data will be deal with in this time if (need_broadcast) { @@ -496,9 +497,9 @@ template __device__ void ElementwiseBroadcastKernelImpl( - const paddle::framework::Array &ins, - paddle::framework::Array outs, - const paddle::framework::Array &use_broadcast, + const paddle::framework::Array &ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + const paddle::framework::Array &use_broadcast, uint32_t numel, const paddle::framework::Array, Arity> &configs, @@ -540,9 +541,9 @@ template __global__ void ElementwiseBroadcastKernel( - paddle::framework::Array ins, - paddle::framework::Array outs, - paddle::framework::Array use_broadcast, + paddle::framework::Array ins, + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, + paddle::framework::Array use_broadcast, uint32_t numel, paddle::framework::Array, Arity> configs, @@ -570,7 +571,8 @@ __global__ void ElementwiseBroadcastKernel( block_offset, func); } - if (block_offset < numel) { + int num = numel - block_offset; + if (num > 0) { ElementwiseBroadcastKernelImpl( - ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); + ins, outs, use_broadcast, numel, configs, num, block_offset, func); } #else if (block_offset < main_offset) { @@ -619,23 +621,16 @@ template -void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, +void LaunchKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func, DimensionsTransform merge_dims) { int numel = (*outs)[0]->numel(); - const int threads = 256; - int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; - - int main_offset = (numel / (VecSize * threads)) * VecSize * threads; - int tail_tid = numel % (VecSize * threads); - auto stream = ctx.stream(); - paddle::framework::Array, Arity> configs; - paddle::framework::Array use_broadcast; - paddle::framework::Array ins_data; - paddle::framework::Array outs_data; + paddle::framework::Array use_broadcast; + paddle::framework::Array ins_data; + paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < NumOuts; ++i) { outs_data[i] = (*outs)[i]->mutable_data(); @@ -643,7 +638,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, for (int i = 0; i < Arity; i++) { use_broadcast[i] = (ins[i]->numel() != numel); - ins_data[i] = ins[i]->data(); + ins_data[i] = (_ptr_ InT *)(ins[i]->data()); if (use_broadcast[i]) { // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} @@ -654,10 +649,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx, } #ifdef PADDLE_WITH_XPU2 - threads = 128; - blocks = 8; - main_offset = (numel / (VecSize * threads)) * VecSize * threads; - tail_tid = numel % (VecSize * threads); + const int threads = 64; + const int blocks = 8; + int main_offset = (numel / (VecSize * threads)) * VecSize * threads; + int tail_tid = numel % (VecSize * threads); + auto stream = ctx.x_context()->xpu_stream; ElementwiseBroadcastKernel void LaunchBroadcastKernelForDifferentVecSize( - const paddle::platform::CUDADeviceContext &ctx, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, @@ -737,7 +738,7 @@ template void LaunchBroadcastElementwiseCudaKernel( - const paddle::platform::CUDADeviceContext &ctx, + const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, @@ -835,12 +836,11 @@ template -void LaunchElementwiseCudaKernel( - const paddle::platform::CUDADeviceContext &cuda_ctx, - const std::vector &ins, - std::vector *outs, - int axis, - Functor func) { +void LaunchElementwiseCudaKernel(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor func) { std::vector dims_size; bool no_broadcast_flag = true; for (auto *in : ins) { @@ -849,14 +849,14 @@ void LaunchElementwiseCudaKernel( } if (no_broadcast_flag) { LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, outs, func); + ctx, ins, outs, func); } else { axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) - *std::min_element(dims_size.begin(), dims_size.end()) : axis; LaunchBroadcastElementwiseCudaKernel( - cuda_ctx, ins, outs, axis, func); + ctx, ins, outs, axis, func); } } From 7915d18056d4f4284f5f415d5f9111c157b782c7 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Tue, 11 Jan 2022 20:00:12 +0800 Subject: [PATCH 12/15] Fix bug in elementwise_mul/div_grad when inplace strategy (#38840) * fix bug when inplace strategy * fix * fix * fix * fix * fix --- .../operators/elementwise/elementwise_div_op.cu | 10 ---------- .../operators/elementwise/elementwise_mul_op.cu | 12 +----------- .../operators/elementwise/elementwise_op_function.h | 1 + 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 7a25f65366901..06f9107db27b4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -31,20 +31,10 @@ ElementwiseDivGrad(const framework::ExecutionContext& ctx, const auto& dev_ctx = ctx.template device_context(); const auto place = ctx.GetPlace(); if (dx != nullptr && dy != nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, out, y}; GetGradXAndYOut( dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y}; GetGradXOrYOut(dev_ctx, place, axis, ins, dout, dx, DivGradXFunctor()); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index a8b6c2abe3bf9..5ece5cadc603f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -74,20 +74,10 @@ ElementwiseMulGrad(const framework::ExecutionContext& ctx, const auto place = ctx.GetPlace(); if (dx != nullptr && dy != nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y, x}; - GetGradXAndYOut( + GetGradXAndYOut( dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor()); } else if (dx != nullptr && dy == nullptr) { - dx->mutable_data(place); - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), place); - } std::vector ins = {dout, y}; GetGradXOrYOut(dev_ctx, place, axis, ins, dout, dx, MulGradFunctor()); diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 3929699955a17..41cb2696f5492 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -2575,6 +2575,7 @@ void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx, framework::Tensor *dy, Functor func) { framework::Tensor tmp_dx; framework::Tensor tmp_dy; + dx->mutable_data(place); dy->mutable_data(place); std::vector outs; if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { From 5b940c44fd5e755e08573bac6fe3af5ed8ef3c83 Mon Sep 17 00:00:00 2001 From: Sing_chan <51314274+betterpig@users.noreply.github.com> Date: Tue, 11 Jan 2022 20:50:29 +0800 Subject: [PATCH 13/15] oepn third_party cache in wincheck_inference (#38877) --- paddle/scripts/paddle_build.bat | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index f64acbeb72307..ca34b12b5d4f8 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -261,7 +261,6 @@ set ON_INFER=ON set WITH_TESTING=ON set WITH_TENSORRT=ON set WITH_INFERENCE_API_TEST=ON -set WITH_TPCACHE=OFF call :cmake || goto cmake_error call :build || goto build_error From be817719982f1821ab0519ceab85ec238bf99d43 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 11 Jan 2022 20:52:35 +0800 Subject: [PATCH 14/15] =?UTF-8?q?=E3=80=90PTen=E3=80=91Add=20dot=20and=20m?= =?UTF-8?q?atmul=20grad=20kernel=20in=20pten=20(#38713)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor matmul directory in pten * fix merge conflict * add dot_grad kernel * add dot_grad kernel in pten * add matmul_grad kernel * update the code * delete useless code in fluid * fix some bug of running matmul grad kernel * fix merge conflict * refactor some code * refactor code --- cmake/pten_kernel.cmake | 3 + paddle/fluid/framework/operator.cc | 26 +- paddle/fluid/framework/pten_utils.cc | 2 +- paddle/fluid/imperative/prepared_operator.cc | 66 +- paddle/fluid/operators/conj_op.h | 2 +- paddle/fluid/operators/dot_op.cc | 7 + paddle/fluid/operators/dot_op.h | 222 +- paddle/fluid/operators/math/blas.h | 4 + paddle/fluid/operators/math/blas_impl.h | 16 +- paddle/fluid/operators/matmul_v2_op.cc | 24 + paddle/fluid/operators/matmul_v2_op.h | 2222 +---------------- paddle/pten/core/dense_tensor.cc | 6 + paddle/pten/core/dense_tensor.h | 2 + paddle/pten/core/kernel_alias_name.h | 5 + paddle/pten/core/kernel_context.cc | 11 +- paddle/pten/core/kernel_context.h | 10 + paddle/pten/core/kernel_registry.h | 4 + paddle/pten/core/kernel_utils.h | 22 + paddle/pten/include/linalg.h | 2 +- paddle/pten/include/math.h | 11 - paddle/pten/kernels/complex_kernel.h | 13 +- paddle/pten/kernels/cpu/complex_kernel.cc | 2 +- paddle/pten/kernels/cpu/dot_grad_kernel.cc | 32 + paddle/pten/kernels/cpu/dot_kernel.cc | 10 +- paddle/pten/kernels/cpu/matmul_grad_kernel.cc | 47 + paddle/pten/kernels/dot_grad_kernel.h | 56 + paddle/pten/kernels/dot_kernel.h | 8 +- paddle/pten/kernels/empty_kernel.cc | 79 +- paddle/pten/kernels/empty_kernel.h | 8 + paddle/pten/kernels/gpu/complex_kernel.cu | 3 +- paddle/pten/kernels/gpu/dot_grad_kernel.cu | 32 + paddle/pten/kernels/gpu/dot_kernel.cu | 10 +- paddle/pten/kernels/gpu/matmul_grad_kernel.cu | 50 + paddle/pten/kernels/hybird/transpose.h | 28 + .../pten/kernels/impl/complex_kernel_impl.h | 6 +- .../pten/kernels/impl/dot_grad_kernel_impl.h | 919 +++++++ .../kernels/impl/matmul_grad_kernel_impl.h | 1742 +++++++++++++ paddle/pten/kernels/impl/matmul_kernel_impl.h | 14 +- paddle/pten/kernels/matmul_grad_kernel.h | 63 + paddle/pten/kernels/matmul_kernel.h | 14 +- 40 files changed, 3336 insertions(+), 2467 deletions(-) create mode 100644 paddle/pten/kernels/cpu/dot_grad_kernel.cc create mode 100644 paddle/pten/kernels/cpu/matmul_grad_kernel.cc create mode 100644 paddle/pten/kernels/dot_grad_kernel.h create mode 100644 paddle/pten/kernels/gpu/dot_grad_kernel.cu create mode 100644 paddle/pten/kernels/gpu/matmul_grad_kernel.cu create mode 100644 paddle/pten/kernels/impl/dot_grad_kernel_impl.h create mode 100644 paddle/pten/kernels/impl/matmul_grad_kernel_impl.h create mode 100644 paddle/pten/kernels/matmul_grad_kernel.h diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index 947defcea4a61..f962c1332093a 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -79,6 +79,9 @@ function(kernel_library TARGET) endif() list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h) + list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h) + endif() list(APPEND all_srcs ${common_srcs}) list(APPEND all_srcs ${cpu_srcs}) list(APPEND all_srcs ${gpu_srcs}) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index c3e54290fd3da..dc4d1365093aa 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1880,16 +1880,32 @@ void OperatorWithKernel::BuildPtenKernelContext( // Otherwise,we will create new storage. for (size_t offset = 0; offset < outs_vector.size(); ++offset) { if (current_vector_size > start_idx + offset) { - experimental::ReMakePtenDenseTensorFromVar( - outs_vector[offset], out_def, + auto* buffer_tensor = pt_kernel_context_->MutableOutputAt(start_idx + - offset)); + offset); + if (buffer_tensor) { + experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset], + out_def, buffer_tensor); + } } else { pt_kernel_context_->EmplaceBackOutputWithoutSetRange( experimental::MakePtenTensorBaseFromVar(outs_vector[offset], out_def)); } } + + // Deal with the case that some outputs are NULL when run the kernel. + // For example : the outputs of matmul_grad are dx and dy, + // sometimes dx or dy may be NULL. + if (outs_vector.empty()) { + if (current_vector_size > start_idx) { + pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr}); + } else { + pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr}); + } + end_idx = start_idx + 1; + } + pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx), i); } @@ -2002,7 +2018,9 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const { range_pair.first, range_pair.second); for (size_t j = 0; j < pten_outs.size(); ++j) { - experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]); + if (pten_outs[j]) { + experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]); + } } } } diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index 9831c2628dc95..dddcd914ed28a 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -99,7 +99,7 @@ KernelSignatureMap& KernelSignatureMap::Instance() { const auto& op_type = pair.first; const auto* op_proto = pair.second.proto_; if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) && - op_proto != nullptr) { + op_proto) { KernelArgsNameMakerByOpProto maker(op_proto); VLOG(10) << "Register kernel signature for " << op_type; auto success = kernel_signature_map_->map_ diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index c355ace528d42..1d12ecf30ede5 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext( for (size_t i = 0; i < output_names.size(); ++i) { auto& out_def = output_defs.at(i); - auto& outs_vector = outs.at(output_names[i]); size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); - size_t end_idx = start_idx + outs_vector.size(); auto current_vector_size = kernel_ctx->OutputsSize(); + + auto iter = outs.find(output_names[i]); + if (iter == outs.end()) { + if (current_vector_size > start_idx) { + kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr}); + } else { + kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); + } + kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1), + i); + continue; + } + + auto& outs_vector = iter->second; + size_t end_idx = start_idx + outs_vector.size(); + // If the memory needed is less than the current memory allocated, we will // reuse the current memory by using ReMakePtenDenseTensorFromVar. // Otherwise,we will create new storage. for (size_t offset = 0; offset < outs_vector.size(); ++offset) { if (current_vector_size > start_idx + offset) { - experimental::ReMakePtenDenseTensorFromVar( - outs_vector[offset]->MutableVar(), out_def, - kernel_ctx->MutableOutputAt(start_idx + offset)); + auto* buffer_tensor = + kernel_ctx->MutableOutputAt(start_idx + offset); + if (buffer_tensor) { + experimental::ReMakePtenDenseTensorFromVar( + outs_vector[offset]->MutableVar(), out_def, buffer_tensor); + } else { + kernel_ctx->SetOutputWithoutSetRange( + start_idx + offset, + experimental::MakePtenTensorBaseFromVar( + outs_vector[offset]->MutableVar(), out_def)); + } } else { kernel_ctx->EmplaceBackOutputWithoutSetRange( experimental::MakePtenTensorBaseFromVar( @@ -465,15 +487,18 @@ static void WriteBackToOutputs( auto& output_names = std::get<2>(pt_kernel_signature.args); for (size_t i = 0; i < output_names.size(); ++i) { - auto& outs_vector = outs.at(output_names[i]); + auto iter = outs.find(output_names[i]); + if (iter != outs.end()) { + auto& outs_vector = iter->second; - auto& range_pair = kernel_ctx->OutputRangeAt(i); - auto pten_outs = kernel_ctx->MutableOutputBetween( - range_pair.first, range_pair.second); + auto& range_pair = kernel_ctx->OutputRangeAt(i); + auto pten_outs = kernel_ctx->MutableOutputBetween( + range_pair.first, range_pair.second); - for (size_t j = 0; j < pten_outs.size(); ++j) { - experimental::MakeVariableFromPtenTensor(pten_outs[j], - outs_vector[j]->MutableVar()); + for (size_t j = 0; j < pten_outs.size(); ++j) { + experimental::MakeVariableFromPtenTensor(pten_outs[j], + outs_vector[j]->MutableVar()); + } } } } @@ -529,6 +554,7 @@ static void PreparedOpRunImpl( template static void PreparedOpRunPtImpl( const framework::OperatorBase& op, + const framework::OpKernelType& kernel_type, const framework::KernelSignature& pt_kernel_signature, const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context, platform::DeviceContext* dev_ctx, const NameVarMap& ins, @@ -558,7 +584,9 @@ static void PreparedOpRunPtImpl( pt_kernel_context->ClearData(); // TODO(chenweihang): add debug flags later - // TODO(chenweihang): deal with complex cases later + if (framework::IsComplexType(kernel_type.data_type_)) { + HandleComplexGradToRealGrad(outs); + } } void PreparedOp::Run(const NameVarMap& ins, @@ -566,9 +594,9 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_pten_kernel_) { - PreparedOpRunPtImpl(op_, pt_kernel_signature_, pt_kernel_, - pt_kernel_context_, dev_ctx_, ins, outs, attrs, - default_attrs); + PreparedOpRunPtImpl(op_, kernel_type_, pt_kernel_signature_, + pt_kernel_, pt_kernel_context_, dev_ctx_, ins, + outs, attrs, default_attrs); } else { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, default_attrs); @@ -580,9 +608,9 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_pten_kernel_) { - PreparedOpRunPtImpl(op_, pt_kernel_signature_, pt_kernel_, - pt_kernel_context_, dev_ctx_, ins, - outs, attrs, default_attrs); + PreparedOpRunPtImpl( + op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_, + dev_ctx_, ins, outs, attrs, default_attrs); } else { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, default_attrs); diff --git a/paddle/fluid/operators/conj_op.h b/paddle/fluid/operators/conj_op.h index 1012e9383f607..381f4cb66b3cd 100644 --- a/paddle/fluid/operators/conj_op.h +++ b/paddle/fluid/operators/conj_op.h @@ -39,7 +39,7 @@ class ConjKernel : public framework::OpKernel { auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); // call new kernel - pten::Conj(dev_ctx, *pt_x.get(), pt_out.get()); + pten::ConjKernel(dev_ctx, *pt_x.get(), pt_out.get()); } }; diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index 31acd9718115c..e1463c8ccb58e 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.GetPlace()); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "dot_grad", {"X", "Y", framework::GradVarName("Out")}, {}, + {framework::GradVarName("X"), framework::GradVarName("Y")}); + } }; template diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index f6877c57a5c18..02ba57ef8d495 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -22,217 +22,14 @@ // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/linalg.h" +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/dot_kernel.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -struct P { - void operator()(T a, R b); -}; - -template -struct DotGradFunction { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx); -}; - -template -struct DotGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); - - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), - tensor_dx->data()); - for_range(functor); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - - dx.device(dev) = dx * dout.broadcast(size); - } - - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), - tensor_dy->data()); - for_range(functor); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - - dy.device(dev) = dy * dout.broadcast(size); - } - } else { - auto dout = framework::EigenMatrix::From(*tensor_dout); - - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = framework::EigenMatrix::From(*tensor_y); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), - tensor_dx->data()); - for_range(functor); - auto dx = framework::EigenMatrix::From(*tensor_dx); - - dx.device(dev) = dx * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = framework::EigenMatrix::From(*tensor_x); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - - paddle::platform::ForRange for_range(dev_raw, - tensor_x->numel()); - math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), - tensor_dy->data()); - for_range(functor); - - auto dy = framework::EigenMatrix::From(*tensor_dy); - - dy.device(dev) = dy * dout.broadcast(size); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_y = tensor_y->data(); - const framework::DDim& dim = tensor_x->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_x = tensor_x->data(); - const framework::DDim& dim = tensor_y->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s]; - } - } -#endif - } -}; - -template -struct DotGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); - - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - dx.device(dev) = y * dout.broadcast(size); - } - - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - dy.device(dev) = x * dout.broadcast(size); - } - } else { - auto dout = framework::EigenMatrix::From(*tensor_dout); - - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = framework::EigenMatrix::From(*tensor_y); - auto dx = framework::EigenMatrix::From(*tensor_dx); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - dx.device(dev) = y * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = framework::EigenMatrix::From(*tensor_x); - auto dy = framework::EigenMatrix::From(*tensor_dy); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - dy.device(dev) = x * dout.broadcast(size); - } - } -#else - auto const *x = tensor_x->data(), *y = tensor_y->data(), - *dz = tensor_dout->data(); - auto&& d = tensor_x->dims(); - auto const N = tensor_x->numel(); - auto const B = d[d.size() - 1]; - - if (tensor_dx) { - auto* dx = tensor_dx->mutable_data(ctx.GetPlace()); - for (auto j = 0; j < N / B; ++j) { - auto const ss = dz[j]; - for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss; - } - } - - if (tensor_dy) { - auto* dy = tensor_dy->mutable_data(ctx.GetPlace()); - for (auto j = 0; j < N / B; ++j) { - auto const ss = dz[j]; - for (auto i = 0; i < B; i++) *dy++ = *x++ * ss; - } - } -#endif - } -}; - // See Note [ Why still keep the original kernel implementation? ] template class DotKernel : public framework::OpKernel { @@ -249,7 +46,7 @@ class DotKernel : public framework::OpKernel { auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); // call new kernel - pten::Dot(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get()); + pten::DotKernel(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get()); } }; @@ -266,8 +63,17 @@ class DotGradKernel : public framework::OpKernel { if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); - DotGradFunction()(tensor_x, tensor_y, tensor_dout, - tensor_dx, tensor_dy, ctx); + auto pt_x = paddle::experimental::MakePtenDenseTensor(*tensor_x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*tensor_y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*tensor_dout); + auto pt_dx = paddle::experimental::MakePtenDenseTensor(*tensor_dx); + auto pt_dy = paddle::experimental::MakePtenDenseTensor(*tensor_dy); + + auto& dev_ctx = ctx.device_context(); + + // call new kernel + pten::DotGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, pt_dx.get(), + pt_dy.get()); } }; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index f245bad01aa4c..2be7695e6a8c4 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -225,6 +225,10 @@ class Blas { const framework::Tensor& mat_b, const MatDescriptor& dim_b, T alpha, framework::Tensor* mat_out, T beta) const; + template + void MatMul(const T* mat_a, const MatDescriptor& dim_a, const T* mat_b, + const MatDescriptor& dim_b, T alpha, T* mat_out, T beta) const; + template void VINV(int n, const T* a, T* y) const; diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 4bcf3baa64932..be9cf1e3448b6 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -1249,6 +1249,15 @@ void Blas::MatMul(const framework::Tensor &mat_a, const framework::Tensor &mat_b, const MatDescriptor &dim_b, T alpha, framework::Tensor *mat_out, T beta) const { + MatMul(mat_a.data(), dim_a, mat_b.data(), dim_b, alpha, + mat_out->data(), beta); +} + +template +template +void Blas::MatMul(const T *mat_a, const MatDescriptor &dim_a, + const T *mat_b, const MatDescriptor &dim_b, + T alpha, T *mat_out, T beta) const { PADDLE_ENFORCE_EQ( dim_a.width_, dim_b.height_, platform::errors::InvalidArgument( @@ -1261,8 +1270,7 @@ void Blas::MatMul(const framework::Tensor &mat_a, CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, - dim_a.width_, alpha, mat_a.data(), - mat_b.data(), beta, mat_out->data()); + dim_a.width_, alpha, mat_a, mat_b, beta, mat_out); } else { PADDLE_ENFORCE_EQ( dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || @@ -1273,8 +1281,8 @@ void Blas::MatMul(const framework::Tensor &mat_a, "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", dim_a.batch_size_, dim_b.batch_size_)); this->template BatchedGEMM( - transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, - mat_a.data(), mat_b.data(), beta, mat_out->data(), + transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a, + mat_b, beta, mat_out, dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, dim_a.stride_, dim_b.stride_); } diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 5add86f5b3c74..a5eca7b225558 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -389,6 +389,14 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_grad", {"X", "Y", framework::GradVarName("Out")}, + {"trans_x", "trans_y"}, + {framework::GradVarName("X"), framework::GradVarName("Y")}); + } }; template @@ -431,6 +439,13 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel { context->ShareDim("DOut", "DDOut"); } } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"}, + {"trans_x", "trans_y"}, {"DX", "DY", "DDOut"}); + } }; template @@ -500,6 +515,15 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel { context->ShareDim("Y", "D_DDY_out"); } } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_triple_grad", + {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, + {"trans_x", "trans_y"}, + {"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"}); + } }; template diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index b257f345eaf36..e93bd212868fd 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -28,6 +28,7 @@ limitations under the License. */ // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/matmul_grad_kernel.h" #include "paddle/pten/kernels/matmul_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) @@ -39,333 +40,6 @@ namespace operators { using framework::Tensor; -template -void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, - const std::vector& reduce_dims, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - auto stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl>( - *input, output, kps::IdentityFunctor(), reduce_dims, stream); -#else - ReduceKernelFunctor( - input, output, reduce_dims, true, false, ctx) - .template apply(); -#endif -} - -static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims, - const int y_ndim, const std::int64_t* y_dims, - std::int64_t* x_bd_dims, - std::int64_t* y_bd_dims, - std::int64_t* out_bd_dims) { - const int ndim = (std::max)(x_ndim, y_ndim); - std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); - std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); - std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); - std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); - - for (int i = 0; i < ndim; ++i) { - PADDLE_ENFORCE_EQ( - x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, - true, - platform::errors::InvalidArgument( - "Input(X) and Input(Y) has error dim." - "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s]," - "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1," - "But received X_broadcast's shape[%s] = [%s]" - "received Y_broadcast's shape[%s] = [%s]", - i, i, i, i, i, x_bd_dims[i], i, y_bd_dims[i])); - if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { - out_bd_dims[i] = 0; - } else { - out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); - } - } -} - -static int64_t GetIndexMessage(const int n, const int64_t* dims, - const int64_t* index) { - int64_t sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } - } - return sum; -} - -static void IndexIncreaseFromDims(const int ndim, const int64_t* dims, - int64_t* index) { - for (int i = ndim - 1; i >= 0; --i) { - ++index[i]; - if (index[i] >= dims[i]) { - index[i] -= dims[i]; - } else { - break; - } - } -} - -template -void MatMulFunction(const Tensor* X, const Tensor* Y, - const std::vector& x_dims, - const std::vector& y_dims, Tensor* Out, - bool trans_x, bool trans_y, - const paddle::framework::ExecutionContext& ctx, - bool flag = false) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - - // Get data ptr - const T* x_data = X->data(); - const T* y_data = Y->data(); - - if (x_ndim == 1 && y_ndim == 1) { - PADDLE_ENFORCE_EQ( - X->numel(), Y->numel(), - platform::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - X->numel(), Y->numel())); - VLOG(3) << "MatMul's case 1"; - Out->Resize({1}); - Out->mutable_data(ctx.GetPlace()); - auto out_eigen = framework::EigenScalar::From(*Out); - auto x_eigen = framework::EigenVector::Flatten(*X); - auto y_eigen = framework::EigenVector::Flatten(*Y); - - auto& dev = *ctx.template device_context().eigen_device(); - if (flag) { - out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen; - } else { - out_eigen.device(dev) = (x_eigen * y_eigen).sum(); - } - return; - } - - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); - - if (x_ndim == 1) { - const int N = X->numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], N, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 1, N, y_ndim - 1, y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], N, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 2, N, y_ndim - 2, y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->Resize(framework::make_ddim(out_dims)); - Out->mutable_data(ctx.GetPlace()); - if (trans_y) { - const int M = Y->numel() / N; - VLOG(3) << "MatMul's case 2"; - blas.GEMV(false, M, N, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y->numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 3"; - blas.GEMV(true, N, M, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 4"; - blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), - y_data, x_data, static_cast(flag), Out->data(), - batch_size, M * N, 0); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y->numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], N, - platform::errors::InvalidArgument( - "Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, N, x_ndim - 2, x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], N, - platform::errors::InvalidArgument( - "Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, N, x_ndim - 1, x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->Resize(framework::make_ddim(out_dims)); - Out->mutable_data(ctx.GetPlace()); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X->numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 5"; - blas.GEMV(true, N, M, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 6"; - blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), - x_data, y_data, static_cast(flag), Out->data(), - batch_size, M * N, 0); - } - } else { - const int M = X->numel() / N; - VLOG(3) << "MatMul's case 7"; - blas.GEMV(false, M, N, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, x_dims.data(), y_ndim - 2, y_dims.data(), - x_broadcast_dims.data(), y_broadcast_dims.data(), - out_broadcast_dims.data()); - - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->Resize(framework::make_ddim(out_broadcast_dims)); - Out->mutable_data(ctx.GetPlace()); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = !std::equal( - x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = std::accumulate( - x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - const std::int64_t y_batch_size = std::accumulate( - y_broadcast_dims.cbegin(), y_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - const std::int64_t out_batch_size = std::accumulate( - out_broadcast_dims.cbegin(), out_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul's case 8"; - blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast(1), - x_data, y_data, static_cast(flag), Out->data()); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul's case 9"; - blas.GEMV(false, y_batch_size * N, K, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 10"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, 0, K * N); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul's case 11"; - blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, - x_batch_size * M, N, K, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 12"; - blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, M * K, 0); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul's case 13"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, M * K, K * N); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = Out->data() + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul's case 14"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_ptr.data(), y_ptr.data(), - static_cast(flag), out_ptr.data(), out_batch_size); - } -} - -template -void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, - bool trans_y, - const paddle::framework::ExecutionContext& ctx, - bool flag = false) { - const std::vector x_dims = vectorize(X->dims()); - const std::vector y_dims = vectorize(Y->dims()); - MatMulFunction(X, Y, x_dims, y_dims, Out, trans_x, trans_y, - ctx, flag); -} - template class MatMulV2Kernel : public framework::OpKernel { public: @@ -400,26 +74,6 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) { return output; } -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, - const framework::Tensor& input) { - auto in_dims = input.dims(); - if (in_dims.size() != 3) { - return input; - } - framework::Tensor output; - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - output.mutable_data(context.GetPlace()); - std::vector axis = {1, 0, 2}; - math::Transpose trans; - trans(context, input, &output, axis); - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - return output; -} - /** * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the * original x_dim is returned. @@ -482,1000 +136,45 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } -template -struct ConjHelper { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - dst.set_layout(src.layout()); - dst.ShareDataWith(src); - return; - } - - const framework::ExecutionContext& ctx_; -}; - -template -struct ConjHelper> { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - auto* src_data = src.data>(); - auto* dst_data = dst.mutable_data>( - ctx_.GetPlace(), - size_t(src.numel() * sizeof(paddle::platform::complex))); - - platform::ForRange for_range( - ctx_.template device_context(), src.numel()); - math::ConjFunctor> functor( - src_data, src.numel(), dst_data); - for_range(functor); - return; - } - const framework::ExecutionContext& ctx_; -}; - -template -struct ConjHelper> { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - auto* src_data = src.data>(); - auto* dst_data = dst.mutable_data>( - ctx_.GetPlace(), - size_t(src.numel() * sizeof(paddle::platform::complex))); - - platform::ForRange for_range( - ctx_.template device_context(), src.numel()); - math::ConjFunctor> functor( - src_data, src.numel(), dst_data); - for_range(functor); - return; - } - const framework::ExecutionContext& ctx_; -}; - -template -struct DotDoubleGradFunction { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx); -}; - -template -struct DotDoubleGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - framework::Tensor tensor_dout_help; - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - if (tensor_dx || tensor_dy) { - tensor_dout_help.Resize(tensor_dout->dims()); - tensor_dout_help.mutable_data(ctx.GetPlace()); - paddle::platform::ForRange for_range( - dev_raw, tensor_dout->numel()); - math::ConjFunctor functor(tensor_dout->data(), - tensor_dout->numel(), - tensor_dout_help.data()); - for_range(functor); - } - if (tensor_dx) { - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - Eigen::DSizes size(tensor_ddy->numel()); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto dout = framework::EigenVector::Flatten(tensor_dout_help); - dx.device(dev) = ddy * dout.broadcast(size); - } - - if (tensor_dy) { - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - Eigen::DSizes size(tensor_ddx->numel()); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto dout = framework::EigenVector::Flatten(tensor_dout_help); - dy.device(dev) = ddx * dout.broadcast(size); - } - - if (tensor_ddout) { - framework::Tensor tensor_x_help, tensor_y_help; - tensor_x_help.Resize(tensor_x->dims()); - tensor_x_help.mutable_data(ctx.GetPlace()); - tensor_y_help.Resize(tensor_y->dims()); - tensor_y_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range(dev_raw, - tensor_x->numel()); - math::ConjFunctor functor_x(tensor_x->data(), tensor_x->numel(), - tensor_x_help.data()); - for_range(functor_x); - math::ConjFunctor functor_y(tensor_y->data(), tensor_y->numel(), - tensor_y_help.data()); - for_range(functor_y); - auto x = framework::EigenVector::Flatten(tensor_x_help); - auto y = framework::EigenVector::Flatten(tensor_y_help); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - auto ddout = framework::EigenVector::Flatten(*tensor_ddout); - ddout.device(dev) = (x * ddy + y * ddx).sum(); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_ddy = tensor_ddy->data(); - const framework::DDim& dim = tensor_dx->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_ddx = tensor_ddx->data(); - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i]; - } - } - - if (tensor_ddout) { - auto* data_ddout = tensor_ddout->mutable_data(ctx.GetPlace()); - auto* data_x = tensor_x->data(); - auto* data_y = tensor_y->data(); - auto* data_ddx = tensor_ddx->data(); - auto* data_ddy = tensor_ddy->data(); - - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + - T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; - } else { - data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + - T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; - } - new_s = false; - } - } -#endif - } -}; - -template -struct DotDoubleGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - auto dout = framework::EigenVector::Flatten(*tensor_dout); - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - Eigen::DSizes size(tensor_ddy->numel()); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - dx.device(dev) = ddy * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - Eigen::DSizes size(tensor_ddx->numel()); - - auto dy = framework::EigenVector::Flatten(*tensor_dy); - dy.device(dev) = ddx * dout.broadcast(size); - } - - if (tensor_ddout) { - tensor_ddout->mutable_data(ctx.GetPlace()); - auto x = framework::EigenVector::Flatten(*tensor_x); - auto y = framework::EigenVector::Flatten(*tensor_y); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - auto ddout = framework::EigenVector::Flatten(*tensor_ddout); - ddout.device(dev) = (x * ddy + y * ddx).sum(); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_ddy = tensor_ddy->data(); - const framework::DDim& dim = tensor_dx->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = data_dout[s] * data_ddy[i]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_ddx = tensor_ddx->data(); - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = data_dout[s] * data_ddx[i]; - } - } - - if (tensor_ddout) { - auto* data_ddout = tensor_ddout->mutable_data(ctx.GetPlace()); - auto* data_x = tensor_x->data(); - auto* data_y = tensor_y->data(); - auto* data_ddx = tensor_ddx->data(); - auto* data_ddy = tensor_ddy->data(); - - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; - } else { - data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; - } - new_s = false; - } - } -#endif - } -}; - -template -struct DotTripleGradFunction { - void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, - const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, - const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, - const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, - Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, - Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, - Tensor* out_tensor_d_ddy, - const paddle::framework::ExecutionContext& ctx); -}; - -// TODO(wuweilong): enable this function when the unittests framewark for multi -// grad is ok (dtype: complex64 or complex128). -template -struct DotTripleGradFunction> { - void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, - const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, - const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, - const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, - Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, - Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, - Tensor* out_tensor_d_ddy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == in_tensor_d_ddout->dims().size()) { - framework::Tensor in_tensor_d_ddout_help; - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - if (out_tensor_d_x || out_tensor_d_y) { - in_tensor_d_ddout_help.Resize(in_tensor_d_ddout->dims()); - in_tensor_d_ddout_help.mutable_data(ctx.GetPlace()); - paddle::platform::ForRange for_range( - dev_raw, in_tensor_d_ddout->numel()); - math::ConjFunctor functor(in_tensor_d_ddout->data(), - in_tensor_d_ddout->numel(), - in_tensor_d_ddout_help.data()); - for_range(functor); - } - if (out_tensor_d_x) { - auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); - Eigen::DSizes size(in_tensor_ddy->numel()); - auto d_x = framework::EigenVector::Flatten(*out_tensor_d_x); - auto d_ddout = - framework::EigenVector::Flatten(in_tensor_d_ddout_help); - d_x.device(dev) = ddy * d_ddout.broadcast(size); - } - - if (out_tensor_d_y) { - auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); - Eigen::DSizes size(in_tensor_ddx->numel()); - auto d_y = framework::EigenVector::Flatten(*out_tensor_d_y); - auto d_ddout = - framework::EigenVector::Flatten(in_tensor_d_ddout_help); - d_y.device(dev) = ddx * d_ddout.broadcast(size); - } - - if (out_tensor_d_dout) { - framework::Tensor in_tensor_ddx_help, in_tensor_ddy_help; - in_tensor_ddx_help.Resize(in_tensor_ddx->dims()); - in_tensor_ddx_help.mutable_data(ctx.GetPlace()); - in_tensor_ddy_help.Resize(in_tensor_ddy->dims()); - in_tensor_ddy_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range( - dev_raw, in_tensor_ddx->numel()); - math::ConjFunctor functor_ddx(in_tensor_ddx->data(), - in_tensor_ddx->numel(), - in_tensor_ddx_help.data()); - for_range(functor_ddx); - math::ConjFunctor functor_ddy(in_tensor_ddy->data(), - in_tensor_ddy->numel(), - in_tensor_ddy_help.data()); - for_range(functor_ddy); - auto ddx = framework::EigenVector::Flatten(in_tensor_ddx_help); - auto ddy = framework::EigenVector::Flatten(in_tensor_ddy_help); - auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); - auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); - auto d_dout = framework::EigenVector::Flatten(*out_tensor_d_dout); - d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); - } - if (out_tensor_d_ddx) { - framework::Tensor in_tensor_dout_help, in_tensor_y_help; - in_tensor_dout_help.Resize(in_tensor_dout->dims()); - in_tensor_dout_help.mutable_data(ctx.GetPlace()); - in_tensor_y_help.Resize(in_tensor_y->dims()); - in_tensor_y_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range( - dev_raw, in_tensor_dout->numel()); - math::ConjFunctor functor_dout(in_tensor_dout->data(), - in_tensor_dout->numel(), - in_tensor_dout_help.data()); - for_range(functor_dout); - math::ConjFunctor functor_y(in_tensor_y->data(), - in_tensor_y->numel(), - in_tensor_y_help.data()); - for_range(functor_y); - auto dout = framework::EigenVector::Flatten(in_tensor_dout_help); - auto y = framework::EigenVector::Flatten(in_tensor_y_help); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); - auto d_ddx = framework::EigenVector::Flatten(*out_tensor_d_ddx); - Eigen::DSizes size(in_tensor_y->numel()); - d_ddx.device(dev) = - (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); - } - if (out_tensor_d_ddy) { - framework::Tensor in_tensor_dout_help, in_tensor_x_help; - in_tensor_dout_help.Resize(in_tensor_dout->dims()); - in_tensor_dout_help.mutable_data(ctx.GetPlace()); - in_tensor_x_help.Resize(in_tensor_x->dims()); - in_tensor_x_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range( - dev_raw, in_tensor_dout->numel()); - math::ConjFunctor functor_dout(in_tensor_dout->data(), - in_tensor_dout->numel(), - in_tensor_dout_help.data()); - for_range(functor_dout); - math::ConjFunctor functor_x(in_tensor_x->data(), - in_tensor_x->numel(), - in_tensor_x_help.data()); - for_range(functor_x); - auto dout = framework::EigenVector::Flatten(in_tensor_dout_help); - auto x = framework::EigenVector::Flatten(in_tensor_x_help); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); - auto d_ddy = framework::EigenVector::Flatten(*out_tensor_d_ddy); - Eigen::DSizes size(in_tensor_x->numel()); - d_ddy.device(dev) = - (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); - } - } -#else - const auto* data_d_ddout = in_tensor_d_ddout->data(); - - if (out_tensor_d_x) { - auto* data_d_x = out_tensor_d_x->mutable_data(ctx.GetPlace()); - const auto* data_ddy = in_tensor_ddy->data(); - - const framework::DDim& dim = out_tensor_d_x->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s]; - } - } - - if (out_tensor_d_y) { - auto* data_d_y = out_tensor_d_y->mutable_data(ctx.GetPlace()); - const auto* data_ddx = in_tensor_ddx->data(); - - const framework::DDim& dim = out_tensor_d_y->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s]; - } - } - - if (out_tensor_d_dout) { - auto* data_d_dout = out_tensor_d_dout->mutable_data(ctx.GetPlace()); - auto* data_ddx = in_tensor_ddx->data(); - auto* data_ddy = in_tensor_ddy->data(); - auto* data_d_dx = in_tensor_d_dx->data(); - auto* data_d_dy = in_tensor_d_dy->data(); - - const framework::DDim& dim = out_tensor_d_dout->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_d_dout[s] = - T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + - T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; - } else { - data_d_dout[s] += - T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + - T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; - } - new_s = false; - } - } - - if (out_tensor_d_ddx) { - auto* data_d_ddx = out_tensor_d_ddx->mutable_data(ctx.GetPlace()); - auto* data_dout = in_tensor_dout->data(); - auto* data_d_dy = in_tensor_d_dy->data(); - auto* data_y = in_tensor_y->data(); - auto* data_d_ddout = in_tensor_d_ddout->data(); - - const framework::DDim& dim = out_tensor_d_ddx->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_ddx[i] = - T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] + - T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s]; - } - } - - if (out_tensor_d_ddy) { - auto* data_d_ddy = out_tensor_d_ddy->mutable_data(ctx.GetPlace()); - auto* data_dout = in_tensor_dout->data(); - auto* data_d_dx = in_tensor_d_dx->data(); - auto* data_x = in_tensor_x->data(); - auto* data_d_ddout = in_tensor_d_ddout->data(); - - const framework::DDim& dim = out_tensor_d_ddy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_ddy[i] = - T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] + - T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s]; - } - } -#endif - } -}; - -template -struct DotTripleGradFunction> { - void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y, - const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy, - const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy, - const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout, - Tensor* out_tensor_d_x, Tensor* out_tensor_d_y, - Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx, - Tensor* out_tensor_d_ddy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == in_tensor_d_ddout->dims().size()) { - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - if (out_tensor_d_x) { - out_tensor_d_x->mutable_data(ctx.GetPlace()); - auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); - Eigen::DSizes size(in_tensor_ddy->numel()); - auto d_x = framework::EigenVector::Flatten(*out_tensor_d_x); - d_x.device(dev) = ddy * d_ddout.broadcast(size); - } - - if (out_tensor_d_y) { - out_tensor_d_y->mutable_data(ctx.GetPlace()); - auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); - Eigen::DSizes size(in_tensor_ddx->numel()); - - auto d_y = framework::EigenVector::Flatten(*out_tensor_d_y); - d_y.device(dev) = ddx * d_ddout.broadcast(size); - } - - if (out_tensor_d_dout) { - out_tensor_d_dout->mutable_data(ctx.GetPlace()); - auto ddx = framework::EigenVector::Flatten(*in_tensor_ddx); - auto ddy = framework::EigenVector::Flatten(*in_tensor_ddy); - auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); - auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); - auto d_dout = framework::EigenVector::Flatten(*out_tensor_d_dout); - d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); - } - - if (out_tensor_d_ddx) { - out_tensor_d_ddx->mutable_data(ctx.GetPlace()); - auto dout = framework::EigenVector::Flatten(*in_tensor_dout); - auto y = framework::EigenVector::Flatten(*in_tensor_y); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - auto d_dy = framework::EigenVector::Flatten(*in_tensor_d_dy); - auto d_ddx = framework::EigenVector::Flatten(*out_tensor_d_ddx); - Eigen::DSizes size(in_tensor_y->numel()); - d_ddx.device(dev) = - (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); - } - - if (out_tensor_d_ddy) { - out_tensor_d_ddy->mutable_data(ctx.GetPlace()); - auto dout = framework::EigenVector::Flatten(*in_tensor_dout); - auto x = framework::EigenVector::Flatten(*in_tensor_x); - auto d_ddout = framework::EigenVector::Flatten(*in_tensor_d_ddout); - auto d_dx = framework::EigenVector::Flatten(*in_tensor_d_dx); - auto d_ddy = framework::EigenVector::Flatten(*out_tensor_d_ddy); - Eigen::DSizes size(in_tensor_x->numel()); - d_ddy.device(dev) = - (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); - } - } -#else - const auto* data_d_ddout = in_tensor_d_ddout->data(); - - if (out_tensor_d_x) { - auto* data_d_x = out_tensor_d_x->mutable_data(ctx.GetPlace()); - const auto* data_ddy = in_tensor_ddy->data(); - - const framework::DDim& dim = out_tensor_d_x->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_x[i] = data_ddy[i] * data_d_ddout[s]; - } - } - - if (out_tensor_d_y) { - auto* data_d_y = out_tensor_d_y->mutable_data(ctx.GetPlace()); - const auto* data_ddx = in_tensor_ddx->data(); - - const framework::DDim& dim = out_tensor_d_y->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_y[i] = data_ddx[i] * data_d_ddout[s]; - } - } - - if (out_tensor_d_dout) { - auto* data_d_dout = out_tensor_d_dout->mutable_data(ctx.GetPlace()); - auto* data_ddx = in_tensor_ddx->data(); - auto* data_ddy = in_tensor_ddy->data(); - auto* data_d_dx = in_tensor_d_dx->data(); - auto* data_d_dy = in_tensor_d_dy->data(); - - const framework::DDim& dim = in_tensor_ddx->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_d_dout[s] = - data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; - } else { - data_d_dout[s] += - data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; - } - new_s = false; - } - } - - if (out_tensor_d_ddx) { - auto* data_d_ddx = out_tensor_d_ddx->mutable_data(ctx.GetPlace()); - auto* data_dout = in_tensor_dout->data(); - auto* data_d_dy = in_tensor_d_dy->data(); - auto* data_y = in_tensor_y->data(); - auto* data_d_ddout = in_tensor_d_ddout->data(); - - const framework::DDim& dim = out_tensor_d_ddx->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_ddx[i] = - data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s]; - } - } - - if (out_tensor_d_ddy) { - auto* data_d_ddy = out_tensor_d_ddy->mutable_data(ctx.GetPlace()); - auto* data_dout = in_tensor_dout->data(); - auto* data_d_dx = in_tensor_d_dx->data(); - auto* data_x = in_tensor_x->data(); - auto* data_d_ddout = in_tensor_d_ddout->data(); - - const framework::DDim& dim = out_tensor_d_ddy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_d_ddy[i] = - data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s]; - } - } -#endif - } -}; - template class MatMulV2GradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, - framework::Tensor* out) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, - static_cast(0)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, - bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out); - } else { - auto& ctx = context.template device_context(); - MatMul(context, is_fold_init_dims_a - ? FoldInitDims(a) - : FoldHeadAndLastDims(ctx, a), - trans_a, is_fold_init_dims_b - ? FoldInitDims(b) - : FoldHeadAndLastDims(ctx, b), - trans_b, out); - } - } - void Compute(const framework::ExecutionContext& ctx) const override { bool transpose_x = ctx.Attr("trans_x"); bool transpose_y = ctx.Attr("trans_y"); - auto x = *ctx.Input("X"); - auto y = *ctx.Input("Y"); - auto dout = *ctx.Input(framework::GradVarName("Out")); - - framework::Tensor y_conj(y.type()); - framework::Tensor x_conj(y.type()); - - // get dims - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector dout_dims = vectorize(dout.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - if (dx) dx->mutable_data(ctx.GetPlace()); - if (dy) dy->mutable_data(ctx.GetPlace()); - if (dout.numel() == 1) { - DotGradFunction()(&x, &y, &dout, dx, dy, ctx); - return; - } - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, - y_dims.cbegin()); - } - - // Case2: no broadcast or no batch size, it aims to speed and it is same as - // matmul in old version. - if (!is_broadcast) { - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - framework::DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } + if (dx) dx->mutable_data(ctx.GetPlace()); + if (dy) dy->mutable_data(ctx.GetPlace()); - // for complex - ConjHelper conj_helper(ctx); - conj_helper(y, y_conj); - } + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_dx = dx ? paddle::experimental::MakePtenDenseTensor(*dx) + : std::unique_ptr(nullptr); + auto pt_dy = dy ? paddle::experimental::MakePtenDenseTensor(*dy) + : std::unique_ptr(nullptr); - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - - // for complex - ConjHelper conj_helper(ctx); - conj_helper(x, x_conj); - } - if (transpose_x && transpose_y) { - CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx); - CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy); - } else if (transpose_x) { - CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx); - CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy); - } else if (transpose_y) { - CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx); - CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy); - } else { - CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx); - CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy); - } - - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - Tensor dx_help, dy_help; - - ConjHelper conj_helper(ctx); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - if (transpose_x) { - if (transpose_y) { - // X'Y': dA = Y'G', dB = G'X' - if (dx) - MatMulFunction(&y_conj, &dout, y_dims, dout_dims, - &dx_help, true, true, ctx); - if (dy) - MatMulFunction(&dout, &x_conj, dout_dims, x_dims, - &dy_help, true, true, ctx); - } else { - // X'Y: dX = YG', dY = XG - if (dx) - MatMulFunction(&y_conj, &dout, y_dims, dout_dims, - &dx_help, false, true, ctx); - if (dy) - MatMulFunction(&x_conj, &dout, x_dims, dout_dims, - &dy_help, false, false, ctx); - } - } else { - if (transpose_y) { - // XY': dX = GY, dY = G'X - if (dx) - MatMulFunction(&dout, &y_conj, dout_dims, y_dims, - &dx_help, false, false, ctx); - if (dy) - MatMulFunction(&dout, &x_conj, dout_dims, x_dims, - &dy_help, true, false, ctx); - } else { - // XY: dX = GY', dY = X'G - if (dx) - MatMulFunction(&dout, &y_conj, dout_dims, y_dims, - &dx_help, false, true, ctx); - if (dy) - MatMulFunction(&x_conj, &dout, x_dims, dout_dims, - &dy_help, true, false, ctx); - } - } - - // get help dims - const std::vector dx_help_dims = vectorize(dx_help.dims()); - const std::vector dy_help_dims = vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // reduce sum to get grad by ReduceSum - if (dx) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, - ctx); - } - dx->Resize(x.dims()); - } - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, - ctx); - } - dy->Resize(y.dims()); - } + auto& dev_ctx = ctx.device_context(); - // Get the OutputGrad(out) - } + // call new kernel + pten::MatmulGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, transpose_x, + transpose_y, pt_dx.get(), pt_dy.get()); } }; template class MatMulV2DoubleGradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, framework::Tensor* out, - bool flag) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, - static_cast(flag)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, - bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out, bool flag) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out, flag); - } else { - auto& ctx = context.template device_context(); - MatMul(context, is_fold_init_dims_a - ? FoldInitDims(a) - : FoldHeadAndLastDims(ctx, a), - trans_a, is_fold_init_dims_b - ? FoldInitDims(b) - : FoldHeadAndLastDims(ctx, b), - trans_b, out, flag); - } - } - void Compute(const framework::ExecutionContext& context) const override { - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = *context.Input("DOut"); + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* dout = context.Input("DOut"); auto* ddx = context.Input("DDX"); auto* ddy = context.Input("DDY"); @@ -1486,316 +185,38 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel { bool transpose_x = context.Attr("trans_x"); bool transpose_y = context.Attr("trans_y"); - // Get dims from the input x, y, output_grad - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector dout_dims = vectorize(dout.dims()); - framework::Tensor x_conj(x.type()); - framework::Tensor y_conj(y.type()); - framework::Tensor dout_conj(dout.type()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - DotDoubleGradFunction()(&x, &y, dx, dy, &dout, ddx, ddy, - ddout, context); - return; - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, - y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - framework::DDim dx_dims; - - ConjHelper conj_helper(context); - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - } - - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - framework::DDim ddout_dims; - if (ddout) { - ddout_dims = ddout->dims(); - if (ddout_dims != dout.dims()) { - ddout->Resize(dout.dims()); - } - } - - if (ddx || ddy) { - ConjHelper conj_helper(context); - conj_helper(dout, dout_conj); - } - if (ddout) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - } - bool ddout_flag = false; - if (ddx) { - auto ddx_mat = *ddx; - if (ddx_mat.dims() != x.dims()) { - ddx_mat.Resize(x.dims()); - } - if (dy) { - if (transpose_x && transpose_y) { - // dy = dout' * ddx' - CalcInputGrad(context, dout_conj, true, true, ddx_mat, true, false, - dy, false); - } else if (transpose_x) { - // dy = ddx * dout - CalcInputGrad(context, ddx_mat, false, false, dout_conj, false, - true, dy, false); - } else if (transpose_y) { - // dy = dout' * ddx - CalcInputGrad(context, dout_conj, true, true, ddx_mat, false, true, - dy, false); - } else { - // dy = ddx' * dout - CalcInputGrad(context, ddx_mat, true, true, dout_conj, false, true, - dy, false); - } - } - - if (ddout) { - CalcInputGrad(context, ddx_mat, transpose_x, true, y_conj, - transpose_y, false, ddout, ddout_flag); - ddout_flag = true; - } - } - - if (ddy) { - auto ddy_mat = *ddy; - if (ddy_mat.dims() != y.dims()) { - ddy_mat.Resize(y.dims()); - } - if (dx) { - if (transpose_x && transpose_y) { - // dx = ddy' * dout' - CalcInputGrad(context, ddy_mat, true, true, dout_conj, true, false, - dx, false); - } else if (transpose_x) { - // dx = ddy * dout' - CalcInputGrad(context, ddy_mat, false, false, dout_conj, true, - false, dx, false); - } else if (transpose_y) { - // dx = dout * ddy - CalcInputGrad(context, dout_conj, false, false, ddy_mat, false, - true, dx, false); - } else { - // dx = dout * ddy' - CalcInputGrad(context, dout_conj, false, false, ddy_mat, true, - false, dx, false); - } - } + if (dx) dx->mutable_data(context.GetPlace()); + if (dy) dy->mutable_data(context.GetPlace()); + if (ddout) ddout->mutable_data(context.GetPlace()); - if (ddout) { - CalcInputGrad(context, x_conj, transpose_x, true, ddy_mat, - transpose_y, false, ddout, ddout_flag); - } - } + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx); + auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy); + auto pt_dx = paddle::experimental::MakePtenDenseTensor(*dx); + auto pt_dy = paddle::experimental::MakePtenDenseTensor(*dy); + auto pt_ddout = paddle::experimental::MakePtenDenseTensor(*ddout); - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } + auto& dev_ctx = context.device_context(); - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - - if (ddout) { - if (ddout_dims != dout.dims()) { - ddout->Resize(ddout_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - framework::Tensor ddy_conj(ddx->type()); - framework::Tensor ddx_conj(ddy->type()); - - Tensor dx_help, dy_help; - if (dx || dy) { - ConjHelper conj_helper(context); - conj_helper(dout, dout_conj); - } - if (ddout) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - } - if (transpose_x) { - if (transpose_y) { - if (dx) - MatMulFunction(ddy, &dout_conj, y_dims, dout_dims, - &dx_help, true, true, context); - if (dy) - MatMulFunction(&dout_conj, ddx, dout_dims, x_dims, - &dy_help, true, true, context); - } else { - if (dx) - MatMulFunction(ddy, &dout_conj, y_dims, dout_dims, - &dx_help, false, true, context); - if (dy) - MatMulFunction(ddx, &dout_conj, x_dims, dout_dims, - &dy_help, false, false, context); - } - } else { - if (transpose_y) { - if (dx) - MatMulFunction(&dout_conj, ddy, dout_dims, y_dims, - &dx_help, false, false, context); - if (dy) - MatMulFunction(&dout_conj, ddx, dout_dims, x_dims, - &dy_help, true, false, context); - } else { - if (dx) - MatMulFunction(&dout_conj, ddy, dout_dims, y_dims, - &dx_help, false, true, context); - if (dy) - MatMulFunction(ddx, &dout_conj, x_dims, dout_dims, - &dy_help, true, false, context); - } - } - - // get help dims - const std::vector dx_help_dims = vectorize(dx_help.dims()); - const std::vector dy_help_dims = vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // Reduce sum to get grad by ReduceSum - if (dx) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, - context); - } - dx->Resize(x.dims()); - } - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, - context); - } - dy->Resize(y.dims()); - } - - if (ddout) { - // Calculate the gradient of OutputGrad(Out) - MatMulFunction(ddx, &y_conj, x_dims, y_dims, ddout, - transpose_x, transpose_y, context); - MatMulFunction(&x_conj, ddy, x_dims, y_dims, ddout, - transpose_x, transpose_y, context, - true); - } - } + // call new kernel + pten::MatmulDoubleGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx, + *pt_ddy, transpose_x, transpose_y, + pt_dx.get(), pt_dy.get(), pt_ddout.get()); } }; template class MatMulV2TripleGradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, framework::Tensor* out, - bool flag) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, - static_cast(flag)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, - bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out, bool flag) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out, flag); - } else { - auto& ctx = context.template device_context(); - MatMul(context, is_fold_init_dims_a - ? FoldInitDims(a) - : FoldHeadAndLastDims(ctx, a), - trans_a, is_fold_init_dims_b - ? FoldInitDims(b) - : FoldHeadAndLastDims(ctx, b), - trans_b, out, flag); - } - } - void Compute(const framework::ExecutionContext& context) const override { // get input - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = *context.Input("DOut"); - auto ddx = *context.Input("DDX"); - auto ddy = *context.Input("DDY"); + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* dout = context.Input("DOut"); + auto* ddx = context.Input("DDX"); + auto* ddy = context.Input("DDY"); auto* d_dx = context.Input("D_DX"); auto* d_dy = context.Input("D_DY"); @@ -1812,539 +233,34 @@ class MatMulV2TripleGradKernel : public framework::OpKernel { bool transpose_x = context.Attr("trans_x"); bool transpose_y = context.Attr("trans_y"); - // Get dims from the input x, y, output_grad - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector dout_dims = vectorize(dout.dims()); - framework::Tensor x_conj(x.type()); - framework::Tensor y_conj(y.type()); - framework::Tensor dout_conj(dout.type()); - framework::Tensor ddx_conj(ddx.type()); - framework::Tensor ddy_conj(ddy.type()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's and y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1"; - - DotTripleGradFunction()( - &x, &y, &ddx, &ddy, d_dx, d_dy, &dout, d_ddout, out_d_x, out_d_y, - out_d_dout, out_d_ddx, out_d_ddy, context); - return; - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, - y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2"; - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - - if (ddx.dims() != x.dims()) { - ddx.Resize(x.dims()); - } - - if (ddy.dims() != y.dims()) { - ddy.Resize(y.dims()); - } - - ConjHelper conj_helper(context); - - framework::DDim out_dx_dims; - if (out_d_x) { - out_dx_dims = out_d_x->dims(); - if (out_dx_dims != x.dims()) { - out_d_x->Resize(x.dims()); - } - } - - framework::DDim out_dy_dims; - if (out_d_y) { - out_dy_dims = out_d_y->dims(); - if (out_dy_dims != y.dims()) { - out_d_y->Resize(y.dims()); - } - } - - framework::DDim out_d_dout_dims; - if (out_d_dout) { - out_d_dout_dims = out_d_dout->dims(); - if (out_d_dout_dims != dout.dims()) { - out_d_dout->Resize(dout.dims()); - } - } - - framework::DDim out_d_ddx_dims; - if (out_d_ddx) { - out_d_ddx_dims = out_d_ddx->dims(); - if (out_d_ddx_dims != x.dims()) { - out_d_ddx->Resize(x.dims()); - } - } - - framework::DDim out_d_ddy_dims; - if (out_d_ddy) { - out_d_ddy_dims = out_d_ddy->dims(); - if (out_d_ddy_dims != y.dims()) { - out_d_ddy->Resize(y.dims()); - } - } - - if (out_d_dout) { - ConjHelper conj_helper(context); - conj_helper(ddx, ddx_conj); - conj_helper(ddy, ddy_conj); - } - - if (out_d_ddx || out_d_ddy) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - conj_helper(dout, dout_conj); - } - - bool d_dout_flag = false; - bool d_ddx_flag = false; - bool d_ddy_flag = false; - - if (d_ddout) { - auto d_ddout_mat = *d_ddout; - if (d_ddout_mat.dims() != dout.dims()) { - d_ddout_mat.Resize(dout.dims()); - } - - if (out_d_y) { - if (transpose_x && transpose_y) { - // out_d_y = d_ddout' * ddx' - CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, true, - false, out_d_y, false); - } else if (transpose_x) { - // out_d_y = ddx * d_ddout - CalcInputGrad(context, ddx_conj, false, false, d_ddout_mat, false, - true, out_d_y, false); - } else if (transpose_y) { - // out_d_y = d_ddout' * ddx - CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, false, - true, out_d_y, false); - } else { - // out_d_y = ddx' * d_ddout - CalcInputGrad(context, ddx_conj, true, true, d_ddout_mat, false, - true, out_d_y, false); - } - } - - if (out_d_x) { - if (transpose_x && transpose_y) { - // out_d_x = ddy' * d_ddout' - CalcInputGrad(context, ddy_conj, true, true, d_ddout_mat, true, - false, out_d_x, false); - } else if (transpose_x) { - // out_d_x = ddy * d_ddout' - CalcInputGrad(context, ddy_conj, false, false, d_ddout_mat, true, - false, out_d_x, false); - } else if (transpose_y) { - // out_d_x = d_ddout * ddy - CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, false, - true, out_d_x, false); - } else { - // out_d_x = d_ddout * ddy' - CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, true, - false, out_d_x, false); - } - } - - // equations: - // d_ddx = DOut * D_DY + Y * D_DDOut - // Let: d_ddx1 = Y * D_DDOut - // Let: d_ddx2 = DOut * D_DY - - // d_ddy = DOut * D_DX + X * D_DDOut - // Let: d_ddy1 = X * D_DDOut - // Let: d_ddy2 = DOut * D_DX - - // d_dout = DDY * D_DX + DDX * D_DY - // Let: d_dout1 = DDX * D_DY - // Let: d_dout2 = DDY * D_DX - - // compute d_ddx1 - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' - CalcInputGrad(context, y_conj, true, true, d_ddout_mat, true, false, - out_d_ddx, d_ddx_flag); - } else if (transpose_x) { - // out_d_ddx1 = y * d_ddout' - CalcInputGrad(context, y_conj, false, false, d_ddout_mat, true, - false, out_d_ddx, d_ddx_flag); - } else if (transpose_y) { - // out_d_ddx1 = d_ddout * y - CalcInputGrad(context, d_ddout_mat, false, false, y_conj, false, - true, out_d_ddx, d_ddx_flag); - } else { - // out_d_ddx1 = d_ddout * y' - CalcInputGrad(context, d_ddout_mat, false, false, y_conj, true, - false, out_d_ddx, d_ddx_flag); - } - d_ddx_flag = true; - } - - // compute d_ddy1 - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy1 = d_ddout' * x' - CalcInputGrad(context, d_ddout_mat, true, true, x_conj, true, false, - out_d_ddy, false); - } else if (transpose_x) { - // out_d_ddy1 = x * d_ddout - CalcInputGrad(context, x_conj, false, false, d_ddout_mat, false, - true, out_d_ddy, false); - } else if (transpose_y) { - // out_d_ddy1 = d_ddout' * x - CalcInputGrad(context, d_ddout_mat, true, true, x_conj, false, true, - out_d_ddy, false); - } else { - // out_d_ddy1 = x' * d_ddout - CalcInputGrad(context, x_conj, true, true, d_ddout_mat, false, true, - out_d_ddy, false); - } - d_ddy_flag = true; - } - } - - if (d_dy) { - auto d_dy_mat = *d_dy; - if (d_dy_mat.dims() != y.dims()) { - d_dy_mat.Resize(y.dims()); - } - - // compute d_dout1 - if (out_d_dout) { - CalcInputGrad(context, ddx_conj, transpose_x, true, d_dy_mat, - transpose_y, false, out_d_dout, d_dout_flag); - d_dout_flag = true; - } - - // compute d_ddx2 - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx2 = D_DY' * DOut' - CalcInputGrad(context, d_dy_mat, true, true, dout_conj, true, false, - out_d_ddx, d_ddx_flag); - } else if (transpose_x) { - // out_d_ddx2 = D_DY * Dout' - CalcInputGrad(context, d_dy_mat, false, false, dout_conj, true, - false, out_d_ddx, d_ddx_flag); - } else if (transpose_y) { - // out_d_ddx2 = Dout * D_DY - CalcInputGrad(context, dout_conj, false, false, d_dy_mat, false, - true, out_d_ddx, d_ddx_flag); - } else { - // out_d_ddx2 = Dout * D_DY' - CalcInputGrad(context, dout_conj, false, false, d_dy_mat, true, - false, out_d_ddx, d_ddx_flag); - } - } - } - - if (d_dx) { - auto d_dx_mat = *d_dx; - if (d_dx_mat.dims() != x.dims()) { - d_dx_mat.Resize(x.dims()); - } - - // compute d_dout2 - if (out_d_dout) { - CalcInputGrad(context, d_dx_mat, transpose_x, true, ddy_conj, - transpose_y, false, out_d_dout, d_dout_flag); - } - - // compute d_ddy2 - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy2 = dout' * d_dx' - CalcInputGrad(context, dout_conj, true, true, d_dx_mat, true, false, - out_d_ddy, d_ddy_flag); - } else if (transpose_x) { - // out_d_ddy2 = d_dx * dout - CalcInputGrad(context, d_dx_mat, false, false, dout_conj, false, - true, out_d_ddy, d_ddy_flag); - } else if (transpose_y) { - // out_d_ddy2 = dout' * d_dx - CalcInputGrad(context, dout_conj, true, true, d_dx_mat, false, true, - out_d_ddy, d_ddy_flag); - } else { - // out_d_ddy2 = d_dx' * dout - CalcInputGrad(context, d_dx_mat, true, true, dout_conj, false, true, - out_d_ddy, d_ddy_flag); - } - } - } - - if (out_d_x) { - if (out_dx_dims != x.dims()) { - out_d_x->Resize(out_dx_dims); - } - } - - if (out_d_y) { - if (out_dy_dims != y.dims()) { - out_d_y->Resize(out_dy_dims); - } - } - - if (out_d_dout) { - if (out_d_dout_dims != dout.dims()) { - out_d_dout->Resize(out_d_dout_dims); - } - } - - if (out_d_ddx) { - if (out_d_ddx_dims != x.dims()) { - out_d_ddx->Resize(out_d_ddx_dims); - } - } - - if (out_d_ddy) { - if (out_d_ddy_dims != x.dims()) { - out_d_ddy->Resize(out_d_ddy_dims); - } - } - - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3"; - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - - Tensor out_dx_help, out_dy_help; - Tensor out_d_ddx_help, out_d_ddy_help; - if (out_d_dout) { - ConjHelper conj_helper(context); - conj_helper(ddx, ddx_conj); - conj_helper(ddy, ddy_conj); - } - if (out_d_ddx || out_d_ddy) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - conj_helper(dout, dout_conj); - } - - if (transpose_x) { - if (transpose_y) { - // dX = ddY' d_ddout’, dY = d_ddout’ ddX' - if (out_d_x) - MatMulFunction(&ddy_conj, d_ddout, y_dims, - dout_dims, &out_dx_help, true, - true, context); - if (out_d_y) - MatMulFunction(d_ddout, &ddx_conj, dout_dims, - x_dims, &out_dy_help, true, true, - context); - } else { - // dX = ddY d_ddout', dY = ddX d_ddout - if (out_d_x) - MatMulFunction(&ddy_conj, d_ddout, y_dims, - dout_dims, &out_dx_help, false, - true, context); - if (out_d_y) - MatMulFunction(&ddx_conj, d_ddout, x_dims, - dout_dims, &out_dy_help, false, - false, context); - } - } else { - if (transpose_y) { - // dX = d_ddout ddY, dY = d_ddout’ ddX - if (out_d_x) - MatMulFunction(d_ddout, &ddy_conj, dout_dims, - y_dims, &out_dx_help, false, false, - context); - if (out_d_y) - MatMulFunction(d_ddout, &ddx_conj, dout_dims, - x_dims, &out_dy_help, true, false, - context); - } else { - // dX = d_ddout ddY', dY = ddX' d_ddout - if (out_d_x) - MatMulFunction(d_ddout, &ddy_conj, dout_dims, - y_dims, &out_dx_help, false, true, - context); - if (out_d_y) - MatMulFunction(&ddx_conj, d_ddout, x_dims, - dout_dims, &out_dy_help, true, - false, context); - } - } - - // get help dims - const std::vector dx_help_dims = - vectorize(out_dx_help.dims()); - const std::vector dy_help_dims = - vectorize(out_dx_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); - - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); - - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // Reduce sum to get grad by ReduceSum - if (out_d_x) { - if (dx_reduce_dims.empty()) { - *out_d_x = std::move(out_dx_help); - } else { - ReduceSumForMatmulGrad(&out_dx_help, out_d_x, - dx_reduce_dims, context); - } - out_d_x->Resize(x.dims()); - } - - if (out_d_y) { - if (dy_reduce_dims.empty()) { - *out_d_y = std::move(out_dy_help); - } else { - ReduceSumForMatmulGrad(&out_dy_help, out_d_y, - dy_reduce_dims, context); - } - out_d_y->Resize(y.dims()); - } - - // compute d_dout - if (out_d_dout) { - MatMulFunction(d_dx, &ddy_conj, x_dims, y_dims, - out_d_dout, transpose_x, transpose_y, - context); - MatMulFunction(&ddx_conj, d_dy, x_dims, y_dims, - out_d_dout, transpose_x, transpose_y, - context, true); - } - - // compute d_ddx - if (out_d_ddx) { - if (transpose_x && transpose_y) { - // out_d_ddx1 = y' * d_ddout' - MatMulFunction(&y_conj, d_ddout, y_dims, dout_dims, - &out_d_ddx_help, true, true, - context); - // out_d_ddx2 = D_DY' * DOut' - MatMulFunction(d_dy, &dout_conj, y_dims, dout_dims, - &out_d_ddx_help, true, true, context, - true); - } else if (transpose_x) { - // out_d_ddx1 = y * d_ddout' - MatMulFunction(&y_conj, d_ddout, y_dims, dout_dims, - &out_d_ddx_help, false, true, - context); - // out_d_ddx2 = D_DY * Dout' - MatMulFunction(d_dy, &dout_conj, y_dims, dout_dims, - &out_d_ddx_help, false, true, - context, true); - } else if (transpose_y) { - // out_d_ddx1 = d_ddout * y - MatMulFunction(d_ddout, &y_conj, dout_dims, y_dims, - &out_d_ddx_help, false, false, - context); - // out_d_ddx2 = Dout * D_DY - MatMulFunction(&dout_conj, d_dy, dout_dims, y_dims, - &out_d_ddx_help, false, false, - context, true); - } else { - // out_d_ddx1 = d_ddout * y' - MatMulFunction(d_ddout, &y_conj, dout_dims, y_dims, - &out_d_ddx_help, false, true, - context); - // out_d_ddx2 = Dout * D_DY' - MatMulFunction(&dout_conj, d_dy, dout_dims, y_dims, - &out_d_ddx_help, false, true, - context, true); - } - if (dx_reduce_dims.empty()) { - *out_d_ddx = std::move(out_d_ddx_help); - } else { - ReduceSumForMatmulGrad(&out_d_ddx_help, out_d_ddx, - dx_reduce_dims, context); - } - out_d_ddx->Resize(x.dims()); - } - - // compute d_ddy - if (out_d_ddy) { - if (transpose_x && transpose_y) { - // out_d_ddy1 = d_ddout' * x' - MatMulFunction(d_ddout, &x_conj, dout_dims, x_dims, - &out_d_ddy_help, true, true, - context); - // out_d_ddy2 = dout' * d_dx' - MatMulFunction(&dout_conj, d_dx, dout_dims, x_dims, - &out_d_ddy_help, true, true, context, - true); - } else if (transpose_x) { - // out_d_ddy1 = x * d_ddout - MatMulFunction(&x_conj, d_ddout, x_dims, dout_dims, - &out_d_ddy_help, false, false, - context); - // out_d_ddy2 = d_dx * dout - MatMulFunction(d_dx, &dout_conj, x_dims, dout_dims, - &out_d_ddy_help, false, false, - context, true); - } else if (transpose_y) { - // out_d_ddy1 = d_ddout' * x - MatMulFunction(d_ddout, &x_conj, dout_dims, x_dims, - &out_d_ddy_help, true, false, - context); - // out_d_ddy2 = dout' * d_dx - MatMulFunction(&dout_conj, d_dx, dout_dims, x_dims, - &out_d_ddy_help, true, false, - context, true); - } else { - // out_d_ddy1 = x' * d_ddout - MatMulFunction(&x_conj, d_ddout, x_dims, dout_dims, - &out_d_ddy_help, true, false, - context); - // out_d_ddy2 = d_dx' * dout - MatMulFunction(d_dx, &dout_conj, x_dims, dout_dims, - &out_d_ddy_help, true, false, - context, true); - } - - if (dy_reduce_dims.empty()) { - *out_d_ddy = std::move(out_d_ddy_help); - } else { - ReduceSumForMatmulGrad(&out_d_ddy_help, out_d_ddy, - dy_reduce_dims, context); - } - out_d_ddy->Resize(y.dims()); - } - } + if (out_d_x) out_d_x->mutable_data(context.GetPlace()); + if (out_d_y) out_d_y->mutable_data(context.GetPlace()); + if (out_d_dout) out_d_dout->mutable_data(context.GetPlace()); + if (out_d_ddx) out_d_ddx->mutable_data(context.GetPlace()); + if (out_d_ddy) out_d_ddy->mutable_data(context.GetPlace()); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx); + auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy); + auto pt_d_dx = paddle::experimental::MakePtenDenseTensor(*d_dx); + auto pt_d_dy = paddle::experimental::MakePtenDenseTensor(*d_dy); + auto pt_d_ddout = paddle::experimental::MakePtenDenseTensor(*d_ddout); + + auto pt_out_d_x = paddle::experimental::MakePtenDenseTensor(*out_d_x); + auto pt_out_d_y = paddle::experimental::MakePtenDenseTensor(*out_d_y); + auto pt_out_d_dout = paddle::experimental::MakePtenDenseTensor(*out_d_dout); + auto pt_out_d_ddx = paddle::experimental::MakePtenDenseTensor(*out_d_ddx); + auto pt_out_d_ddy = paddle::experimental::MakePtenDenseTensor(*out_d_ddy); + + auto& dev_ctx = context.device_context(); + // call new kernel + pten::MatmulTripleGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx, + *pt_ddy, *pt_d_dx, *pt_d_dy, *pt_d_ddout, + transpose_x, transpose_y, pt_out_d_x.get(), + pt_out_d_y.get(), pt_out_d_dout.get(), + pt_out_d_ddx.get(), pt_out_d_ddy.get()); } }; diff --git a/paddle/pten/core/dense_tensor.cc b/paddle/pten/core/dense_tensor.cc index 1b4254ad2c103..0b5f5cb18e13d 100644 --- a/paddle/pten/core/dense_tensor.cc +++ b/paddle/pten/core/dense_tensor.cc @@ -70,6 +70,12 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { return *this; } +DenseTensor& DenseTensor::operator=(DenseTensor&& other) { + meta_ = std::move(other.meta_); + storage_.swap(other.storage_); + return *this; +} + int64_t DenseTensor::numel() const { if (meta_.is_scalar) { return 1; diff --git a/paddle/pten/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h index fc92e84f52cea..1502accd197be 100644 --- a/paddle/pten/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -97,6 +97,8 @@ class DenseTensor : public TensorBase, /// \brief DenseTensor shallow copy assignment. DenseTensor& operator=(const DenseTensor& other); + DenseTensor& operator=(DenseTensor&& other); + /// \brief Destroy the tensor object and release exclusive resources. virtual ~DenseTensor() = default; diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 56f7eea7ea802..46fa6dd376ee3 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -29,6 +29,9 @@ const std::unordered_map kernel_alias_name_map = { {"flatten_contiguous_range", "flatten"}, {"flatten_contiguous_range_grad", "flatten_grad"}, {"matmul_v2", "matmul"}, + {"matmul_v2_grad", "matmul_grad"}, + {"matmul_v2_grad_grad", "matmul_double_grad"}, + {"matmul_v2_triple_grad", "matmul_triple_grad"}, {"reduce_mean", "mean"}, {"reduce_sum", "sum"}, {"reshape2", "reshape"}, @@ -36,6 +39,8 @@ const std::unordered_map kernel_alias_name_map = { {"flatten", "deprecated"}, {"flatten_grad", "deprecated"}, {"matmul", "deprecated"}, + {"matmul_grad", "deprecated"}, + {"matmul_grad_grad", "deprecated"}, {"mean", "deprecated"}, {"reshape", "deprecated"}, {"sum", "deprecated"}}; diff --git a/paddle/pten/core/kernel_context.cc b/paddle/pten/core/kernel_context.cc index b2c84807951a5..74bd6d17f066a 100644 --- a/paddle/pten/core/kernel_context.cc +++ b/paddle/pten/core/kernel_context.cc @@ -50,6 +50,11 @@ void KernelContext::EmplaceBackOutputWithoutSetRange( outputs_.emplace_back(std::move(output)); } +void KernelContext::SetOutputWithoutSetRange( + int index, std::shared_ptr output) { + outputs_.at(index) = std::move(output); +} + void KernelContext::EmplaceBackOutputs( paddle::SmallVector> outputs) { int index = outputs_.size(); @@ -119,8 +124,10 @@ void KernelContext::ClearData() { } } for (auto& out : outputs_) { - CompatibleDenseTensorUtils::ClearStorage( - static_cast(out.get())); + if (out) { + CompatibleDenseTensorUtils::ClearStorage( + static_cast(out.get())); + } } attrs_.clear(); } diff --git a/paddle/pten/core/kernel_context.h b/paddle/pten/core/kernel_context.h index 6c695987096cb..b6cc15c084ac0 100644 --- a/paddle/pten/core/kernel_context.h +++ b/paddle/pten/core/kernel_context.h @@ -62,6 +62,8 @@ class KernelContext { void EmplaceBackOutputWithoutSetRange(std::shared_ptr output); + void SetOutputWithoutSetRange(int index, std::shared_ptr output); + void EmplaceBackOutputs( paddle::SmallVector> outputs); @@ -80,6 +82,14 @@ class KernelContext { return static_cast(*(inputs_.at(idx))); } + template + paddle::optional OptionalInputAt(size_t idx) const { + const auto& input = inputs_.at(idx); + return input ? paddle::optional{static_cast< + const TensorType&>(*input)} + : paddle::optional{paddle::none}; + } + std::shared_ptr& MutableInputPtrAt(size_t idx) { return inputs_.at(idx); } diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index bd4687c6e7f4e..f08ef4acfd9ce 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -65,6 +65,10 @@ struct KernelArgsParseFunctor { } else if (arg_type == std::type_index(typeid(const DenseTensor&))) { args_def->AppendInput( default_key.backend(), default_tensor_layout, default_key.dtype()); + } else if (arg_type == std::type_index(typeid( + paddle::optional))) { + args_def->AppendInput( + default_key.backend(), default_tensor_layout, default_key.dtype()); } else if (arg_type == std::type_index(typeid(const std::vector&))) { args_def->AppendInput( diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index 5087d912ed525..60201151c62a2 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -77,6 +77,27 @@ namespace pten { } \ } +#define PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(tensor_type) \ + template \ + struct KernelCallHelper, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(attr_idx == 0, \ + "Kernel's Input should appear before Attributes."); \ + static_assert(out_idx == 0, \ + "Kernel's Input should appear before Outputs."); \ + const std::pair range = ctx->InputRangeAt(in_idx); \ + auto arg = ctx->OptionalInputAt(range.first); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + #define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ template \ struct KernelCallHelper&, Tail...> { \ @@ -190,6 +211,7 @@ struct KernelImpl { /* Input Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); // TODO(chenweihang): adapt SelectedRows // PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor); diff --git a/paddle/pten/include/linalg.h b/paddle/pten/include/linalg.h index 22f287468e673..71bc518aa89f8 100644 --- a/paddle/pten/include/linalg.h +++ b/paddle/pten/include/linalg.h @@ -30,7 +30,7 @@ DenseTensor Dot(const ContextT& dev_ctx, pten::make_intrusive( dev_ctx.GetPlace()), std::move(out_meta)); - Dot(dev_ctx, x, y, &dense_out); + DotKernel(dev_ctx, x, y, &dense_out); return dense_out; } diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index faa4c8db8dac3..5070d0d4e0e5a 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -48,15 +48,4 @@ DenseTensor Scale(const ContextT& dev_ctx, return dense_out; } -template -DenseTensor Conj(const ContextT& dev_ctx, const DenseTensor& x) { - auto out_meta = UnchangedInferMeta(x.meta()); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - Conj(dev_ctx, x, &dense_out); - return dense_out; -} - } // namespace pten diff --git a/paddle/pten/kernels/complex_kernel.h b/paddle/pten/kernels/complex_kernel.h index dfe8fff43e6ef..e9f717152a458 100644 --- a/paddle/pten/kernels/complex_kernel.h +++ b/paddle/pten/kernels/complex_kernel.h @@ -16,9 +16,20 @@ limitations under the License. */ #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/infermeta/unary.h" +#include "paddle/pten/kernels/empty_kernel.h" + namespace pten { template -void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); + +template +DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { + auto out_meta = UnchangedInferMeta(x.meta()); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + ConjKernel(dev_ctx, x, &dense_out); + return dense_out; +} } // namespace pten diff --git a/paddle/pten/kernels/cpu/complex_kernel.cc b/paddle/pten/kernels/cpu/complex_kernel.cc index 9bf27ef22dcd7..10e7e684db3c1 100644 --- a/paddle/pten/kernels/cpu/complex_kernel.cc +++ b/paddle/pten/kernels/cpu/complex_kernel.cc @@ -24,7 +24,7 @@ PT_REGISTER_CTX_KERNEL(conj, CPU, ALL_LAYOUT, - pten::Conj, + pten::ConjKernel, paddle::platform::complex, paddle::platform::complex, float, diff --git a/paddle/pten/kernels/cpu/dot_grad_kernel.cc b/paddle/pten/kernels/cpu/dot_grad_kernel.cc new file mode 100644 index 0000000000000..c9d5c35e134c8 --- /dev/null +++ b/paddle/pten/kernels/cpu/dot_grad_kernel.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" + +PT_REGISTER_CTX_KERNEL(dot_grad, + CPU, + ALL_LAYOUT, + pten::DotGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/dot_kernel.cc b/paddle/pten/kernels/cpu/dot_kernel.cc index 247ad1216a266..72e9e28907f90 100644 --- a/paddle/pten/kernels/cpu/dot_kernel.cc +++ b/paddle/pten/kernels/cpu/dot_kernel.cc @@ -23,10 +23,10 @@ namespace pten { template -void Dot(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { auto const *x_ptr = x.data(), *x_ptr_ = &x_ptr[0]; auto const *y_ptr = y.data(), *y_ptr_ = &y_ptr[0]; auto* z = out->mutable_data(); @@ -52,7 +52,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_CTX_KERNEL(dot, CPU, ALL_LAYOUT, - pten::Dot, + pten::DotKernel, float, double, int, diff --git a/paddle/pten/kernels/cpu/matmul_grad_kernel.cc b/paddle/pten/kernels/cpu/matmul_grad_kernel.cc new file mode 100644 index 0000000000000..5a8abb6701b0e --- /dev/null +++ b/paddle/pten/kernels/cpu/matmul_grad_kernel.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/matmul_grad_kernel.h" + +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(matmul_grad, + CPU, + ALL_LAYOUT, + pten::MatmulGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_double_grad, + CPU, + ALL_LAYOUT, + pten::MatmulDoubleGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_triple_grad, + CPU, + ALL_LAYOUT, + pten::MatmulTripleGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/dot_grad_kernel.h b/paddle/pten/kernels/dot_grad_kernel.h new file mode 100644 index 0000000000000..b0940e5b16a33 --- /dev/null +++ b/paddle/pten/kernels/dot_grad_kernel.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void DotGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy); + +template +void DotDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout); + +template +void DotTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + const DenseTensor& dout, + const DenseTensor& d_ddout, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_ddx, + DenseTensor* d_ddy, + DenseTensor* d_dout); + +} // namespace pten diff --git a/paddle/pten/kernels/dot_kernel.h b/paddle/pten/kernels/dot_kernel.h index 9924749cd2141..5ef660265333e 100644 --- a/paddle/pten/kernels/dot_kernel.h +++ b/paddle/pten/kernels/dot_kernel.h @@ -19,9 +19,9 @@ namespace pten { template -void Dot(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out); +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); } // namespace pten diff --git a/paddle/pten/kernels/empty_kernel.cc b/paddle/pten/kernels/empty_kernel.cc index 94886806bccf3..2dd55a13e38e5 100644 --- a/paddle/pten/kernels/empty_kernel.cc +++ b/paddle/pten/kernels/empty_kernel.cc @@ -1,33 +1,34 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #include "paddle/pten/kernels/empty_kernel.h" #include "paddle/pten/backends/all_context.h" #include "paddle/pten/core/kernel_registry.h" +#include "paddle/fluid/platform/complex.h" + namespace pten { -template -void EmptyKernel(const ContextT& dev_ctx, +template +void EmptyKernel(const Context& dev_ctx, const ScalarArray& shape, DenseTensor* out) { out->Resize(paddle::framework::make_ddim(shape.GetData())); } -template -void EmptyLikeKernel(const ContextT& dev_ctx, DenseTensor* out) { +template +void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) { out->mutable_data(); } @@ -37,44 +38,62 @@ PT_REGISTER_CTX_KERNEL(empty, CPU, ALL_LAYOUT, pten::EmptyKernel, - bool, - int, - int64_t, float, double, - paddle::platform::float16) {} + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} PT_REGISTER_CTX_KERNEL(empty_like, CPU, ALL_LAYOUT, pten::EmptyLikeKernel, - bool, - int, - int64_t, float, double, - paddle::platform::float16) {} + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PT_REGISTER_CTX_KERNEL(empty, GPU, ALL_LAYOUT, pten::EmptyKernel, - bool, - int, - int64_t, float, double, - paddle::platform::float16) {} + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} PT_REGISTER_CTX_KERNEL(empty_like, GPU, ALL_LAYOUT, pten::EmptyLikeKernel, - bool, - int, - int64_t, float, double, - paddle::platform::float16) {} + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} #endif diff --git a/paddle/pten/kernels/empty_kernel.h b/paddle/pten/kernels/empty_kernel.h index d71ee0b1266f2..d283ef5c1e41e 100644 --- a/paddle/pten/kernels/empty_kernel.h +++ b/paddle/pten/kernels/empty_kernel.h @@ -41,6 +41,14 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) { return dense_out; } +template +DenseTensor Empty(const Context& dev_ctx) { + return Empty(dev_ctx, + {paddle::experimental::CppTypeToDataType::Type(), + {-1}, + DataLayout::NCHW}); +} + template DenseTensor Empty(const Context& dev_ctx, const ScalarArray& shape, diff --git a/paddle/pten/kernels/gpu/complex_kernel.cu b/paddle/pten/kernels/gpu/complex_kernel.cu index 5a3c14de4036a..02f050f5bc838 100644 --- a/paddle/pten/kernels/gpu/complex_kernel.cu +++ b/paddle/pten/kernels/gpu/complex_kernel.cu @@ -24,7 +24,8 @@ PT_REGISTER_CTX_KERNEL(conj, GPU, ALL_LAYOUT, - pten::Conj, + pten::ConjKernel, + paddle::platform::float16, paddle::platform::complex, paddle::platform::complex, float, diff --git a/paddle/pten/kernels/gpu/dot_grad_kernel.cu b/paddle/pten/kernels/gpu/dot_grad_kernel.cu new file mode 100644 index 0000000000000..42af96f7c7265 --- /dev/null +++ b/paddle/pten/kernels/gpu/dot_grad_kernel.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" + +PT_REGISTER_CTX_KERNEL(dot_grad, + GPU, + ALL_LAYOUT, + pten::DotGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/dot_kernel.cu b/paddle/pten/kernels/gpu/dot_kernel.cu index 6b66d45b7dd48..1f9e7aa3f1cfd 100644 --- a/paddle/pten/kernels/gpu/dot_kernel.cu +++ b/paddle/pten/kernels/gpu/dot_kernel.cu @@ -25,10 +25,10 @@ namespace pten { template -void Dot(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { out->mutable_data(); if (1 == out->dims().size()) { auto eigen_out = pten::EigenScalar::From(*out); @@ -55,7 +55,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_CTX_KERNEL(dot, GPU, ALL_LAYOUT, - pten::Dot, + pten::DotKernel, float, double, int, diff --git a/paddle/pten/kernels/gpu/matmul_grad_kernel.cu b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu new file mode 100644 index 0000000000000..f20c3f82c9262 --- /dev/null +++ b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu @@ -0,0 +1,50 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/matmul_grad_kernel.h" + +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(matmul_grad, + GPU, + ALL_LAYOUT, + pten::MatmulGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_double_grad, + GPU, + ALL_LAYOUT, + pten::MatmulDoubleGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_triple_grad, + GPU, + ALL_LAYOUT, + pten::MatmulTripleGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/hybird/transpose.h b/paddle/pten/kernels/hybird/transpose.h index 459fed6b9fa04..17f52c74a1344 100644 --- a/paddle/pten/kernels/hybird/transpose.h +++ b/paddle/pten/kernels/hybird/transpose.h @@ -17,6 +17,9 @@ #include "paddle/fluid/framework/ddim.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" + namespace pten { namespace math { @@ -30,5 +33,30 @@ struct TransposeNormal { const std::vector& axis); }; +template +struct Transpose { + void operator()(const DeviceContext& dev_ctx, + const DenseTensor& in, + DenseTensor* out, + const std::vector& axis) { + Eigen::array permute; + for (int i = 0; i < Rank; i++) { + permute[i] = axis[i]; + } + auto eigen_in = pten::EigenTensor::From(in); + auto eigen_out = pten::EigenTensor::From(*out); + auto* dev = dev_ctx.eigen_device(); + // use 32bit index to speed up computation + bool use_32bit_index = eigen_out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); + if (use_32bit_index && is_gpu_place) { + To32BitIndex(eigen_out).device(*dev) = + To32BitIndex(eigen_in).shuffle(permute); + } else { + eigen_out.device(*dev) = eigen_in.shuffle(permute); + } + } +}; + } // namespace math } // namespace pten diff --git a/paddle/pten/kernels/impl/complex_kernel_impl.h b/paddle/pten/kernels/impl/complex_kernel_impl.h index 6f3a6049faa9a..e0c6825a78a53 100644 --- a/paddle/pten/kernels/impl/complex_kernel_impl.h +++ b/paddle/pten/kernels/impl/complex_kernel_impl.h @@ -21,12 +21,14 @@ namespace pten { template -void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { +void ConjKernel(const Context& context, + const DenseTensor& x, + DenseTensor* out) { auto numel = x.numel(); auto* x_data = x.data(); auto* out_data = out->mutable_data(); - paddle::platform::ForRange for_range(dev_ctx, numel); + paddle::platform::ForRange for_range(context, numel); paddle::operators::math::ConjFunctor functor(x_data, numel, out_data); for_range(functor); } diff --git a/paddle/pten/kernels/impl/dot_grad_kernel_impl.h b/paddle/pten/kernels/impl/dot_grad_kernel_impl.h new file mode 100644 index 0000000000000..16c87bbab474a --- /dev/null +++ b/paddle/pten/kernels/impl/dot_grad_kernel_impl.h @@ -0,0 +1,919 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" + +#include "paddle/pten/kernels/complex_kernel.h" + +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/operators/math/complex_functors.h" + +namespace pten { + +template +struct DotGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy); +}; + +template +struct DotGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto dout = EigenVector::Flatten(*tensor_dout); + + if (tensor_dx) { + auto y = EigenVector::Flatten(*tensor_y); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + + ConjKernel(ctx, *tensor_y, tensor_dx); + + auto dx = EigenVector::Flatten(*tensor_dx); + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = EigenVector::Flatten(*tensor_x); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + + ConjKernel(ctx, *tensor_x, tensor_dy); + + auto dy = EigenVector::Flatten(*tensor_dy); + dy.device(dev) = dy * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(); + auto y = EigenMatrix::From(*tensor_y); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + + ConjKernel(ctx, *tensor_y, tensor_dx); + + auto dx = EigenMatrix::From(*tensor_dx); + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto x = EigenMatrix::From(*tensor_x); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + + ConjKernel(ctx, *tensor_x, tensor_dy); + + auto dy = EigenMatrix::From(*tensor_dy); + dy.device(dev) = dy * dout.broadcast(size); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_y = tensor_y->data(); + const DDim& dim = tensor_x->dims(); + size_t N = static_cast(paddle::framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_x = tensor_x->data(); + const DDim& dim = tensor_y->dims(); + size_t N = static_cast(paddle::framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s]; + } + } +#endif + } +}; + +template +struct DotGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto dout = EigenVector::Flatten(*tensor_dout); + if (tensor_dx) { + auto y = EigenVector::Flatten(*tensor_y); + auto dx = EigenVector::Flatten(*tensor_dx); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = EigenVector::Flatten(*tensor_x); + auto dy = EigenVector::Flatten(*tensor_dy); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + dy.device(dev) = x * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(); + auto y = EigenMatrix::From(*tensor_y); + auto dx = EigenMatrix::From(*tensor_dx); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto x = EigenMatrix::From(*tensor_x); + auto dy = EigenMatrix::From(*tensor_dy); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + dy.device(dev) = x * dout.broadcast(size); + } + } +#else + auto const *x = tensor_x->data(), *y = tensor_y->data(), + *dz = tensor_dout->data(); + auto&& d = tensor_x->dims(); + auto const N = tensor_x->numel(); + auto const B = d[d.size() - 1]; + + if (tensor_dx) { + auto* dx = tensor_dx->mutable_data(); + for (auto j = 0; j < N / B; ++j) { + auto const ss = dz[j]; + for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss; + } + } + + if (tensor_dy) { + auto* dy = tensor_dy->mutable_data(); + for (auto j = 0; j < N / B; ++j) { + auto const ss = dz[j]; + for (auto i = 0; i < B; i++) *dy++ = *x++ * ss; + } + } +#endif + } +}; + +template +struct DotDoubleGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout); +}; + +template +struct DotDoubleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + DenseTensor tensor_dout_help; + auto& dev = *ctx.eigen_device(); + if (tensor_dx || tensor_dy) { + tensor_dout_help = Conj(ctx, *tensor_dout); + } + if (tensor_dx) { + auto ddy = EigenVector::Flatten(*tensor_ddy); + Eigen::DSizes size(tensor_ddy->numel()); + auto dx = EigenVector::Flatten(*tensor_dx); + auto dout = EigenVector::Flatten(tensor_dout_help); + dx.device(dev) = ddy * dout.broadcast(size); + } + + if (tensor_dy) { + auto ddx = EigenVector::Flatten(*tensor_ddx); + Eigen::DSizes size(tensor_ddx->numel()); + auto dy = EigenVector::Flatten(*tensor_dy); + auto dout = EigenVector::Flatten(tensor_dout_help); + dy.device(dev) = ddx * dout.broadcast(size); + } + + if (tensor_ddout) { + DenseTensor tensor_x_help = Conj(ctx, *tensor_x); + DenseTensor tensor_y_help = Conj(ctx, *tensor_y); + + auto x = EigenVector::Flatten(tensor_x_help); + auto y = EigenVector::Flatten(tensor_y_help); + auto ddx = EigenVector::Flatten(*tensor_ddx); + auto ddy = EigenVector::Flatten(*tensor_ddy); + auto ddout = EigenVector::Flatten(*tensor_ddout); + ddout.device(dev) = (x * ddy + y * ddx).sum(); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_ddy = tensor_ddy->data(); + const DDim& dim = tensor_dx->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_ddx = tensor_ddx->data(); + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i]; + } + } + + if (tensor_ddout) { + auto* data_ddout = tensor_ddout->mutable_data(); + auto* data_x = tensor_x->data(); + auto* data_y = tensor_y->data(); + auto* data_ddx = tensor_ddx->data(); + auto* data_ddy = tensor_ddy->data(); + + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + + T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; + } else { + data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + + T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; + } + new_s = false; + } + } +#endif + } +}; + +template +struct DotDoubleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto& dev = *ctx.eigen_device(); + auto dout = EigenVector::Flatten(*tensor_dout); + if (tensor_dx) { + tensor_dx->mutable_data(); + auto ddy = EigenVector::Flatten(*tensor_ddy); + Eigen::DSizes size(tensor_ddy->numel()); + auto dx = EigenVector::Flatten(*tensor_dx); + dx.device(dev) = ddy * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto ddx = EigenVector::Flatten(*tensor_ddx); + Eigen::DSizes size(tensor_ddx->numel()); + + auto dy = EigenVector::Flatten(*tensor_dy); + dy.device(dev) = ddx * dout.broadcast(size); + } + + if (tensor_ddout) { + tensor_ddout->mutable_data(); + auto x = EigenVector::Flatten(*tensor_x); + auto y = EigenVector::Flatten(*tensor_y); + auto ddx = EigenVector::Flatten(*tensor_ddx); + auto ddy = EigenVector::Flatten(*tensor_ddy); + auto ddout = EigenVector::Flatten(*tensor_ddout); + ddout.device(dev) = (x * ddy + y * ddx).sum(); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_ddy = tensor_ddy->data(); + const DDim& dim = tensor_dx->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = data_dout[s] * data_ddy[i]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_ddx = tensor_ddx->data(); + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = data_dout[s] * data_ddx[i]; + } + } + + if (tensor_ddout) { + auto* data_ddout = tensor_ddout->mutable_data(); + auto* data_x = tensor_x->data(); + auto* data_y = tensor_y->data(); + auto* data_ddx = tensor_ddx->data(); + auto* data_ddy = tensor_ddy->data(); + + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; + } else { + data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; + } + new_s = false; + } + } +#endif + } +}; + +template +struct DotTripleGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy); +}; + +// TODO(wuweilong): enable this function when the unittests framewark for multi +// grad is ok (dtype: complex64 or complex128). +template +struct DotTripleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == in_tensor_d_ddout->dims().size()) { + DenseTensor in_tensor_d_ddout_help; + auto& dev = *ctx.eigen_device(); + if (out_tensor_d_x || out_tensor_d_y) { + in_tensor_d_ddout_help = + Conj(ctx, *in_tensor_d_ddout); + } + if (out_tensor_d_x) { + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + Eigen::DSizes size(in_tensor_ddy->numel()); + auto d_x = EigenVector::Flatten(*out_tensor_d_x); + auto d_ddout = EigenVector::Flatten(in_tensor_d_ddout_help); + d_x.device(dev) = ddy * d_ddout.broadcast(size); + } + + if (out_tensor_d_y) { + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + Eigen::DSizes size(in_tensor_ddx->numel()); + auto d_y = EigenVector::Flatten(*out_tensor_d_y); + auto d_ddout = EigenVector::Flatten(in_tensor_d_ddout_help); + d_y.device(dev) = ddx * d_ddout.broadcast(size); + } + + if (out_tensor_d_dout) { + DenseTensor in_tensor_ddx_help = + Conj(ctx, *in_tensor_ddx); + DenseTensor in_tensor_ddy_help = + Conj(ctx, *in_tensor_ddy); + + auto ddx = EigenVector::Flatten(in_tensor_ddx_help); + auto ddy = EigenVector::Flatten(in_tensor_ddy_help); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_dout = EigenVector::Flatten(*out_tensor_d_dout); + d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); + } + + if (out_tensor_d_ddx) { + DenseTensor in_tensor_dout_help = + Conj(ctx, *in_tensor_dout); + DenseTensor in_tensor_y_help = + Conj(ctx, *in_tensor_y); + + auto dout = EigenVector::Flatten(in_tensor_dout_help); + auto y = EigenVector::Flatten(in_tensor_y_help); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_ddx = EigenVector::Flatten(*out_tensor_d_ddx); + Eigen::DSizes size(in_tensor_y->numel()); + d_ddx.device(dev) = + (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); + } + + if (out_tensor_d_ddy) { + DenseTensor in_tensor_dout_help = + Conj(ctx, *in_tensor_dout); + DenseTensor in_tensor_x_help = + Conj(ctx, *in_tensor_x); + + auto dout = EigenVector::Flatten(in_tensor_dout_help); + auto x = EigenVector::Flatten(in_tensor_x_help); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_ddy = EigenVector::Flatten(*out_tensor_d_ddy); + Eigen::DSizes size(in_tensor_x->numel()); + d_ddy.device(dev) = + (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); + } + } +#else + const auto* data_d_ddout = in_tensor_d_ddout->data(); + + if (out_tensor_d_x) { + auto* data_d_x = out_tensor_d_x->mutable_data(); + const auto* data_ddy = in_tensor_ddy->data(); + + const DDim& dim = out_tensor_d_x->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_y) { + auto* data_d_y = out_tensor_d_y->mutable_data(); + const auto* data_ddx = in_tensor_ddx->data(); + + const DDim& dim = out_tensor_d_y->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_dout) { + auto* data_d_dout = out_tensor_d_dout->mutable_data(); + auto* data_ddx = in_tensor_ddx->data(); + auto* data_ddy = in_tensor_ddy->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + + const DDim& dim = out_tensor_d_dout->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_d_dout[s] = + T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + + T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; + } else { + data_d_dout[s] += + T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + + T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; + } + new_s = false; + } + } + + if (out_tensor_d_ddx) { + auto* data_d_ddx = out_tensor_d_ddx->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + auto* data_y = in_tensor_y->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddx[i] = + T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] + + T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_ddy) { + auto* data_d_ddy = out_tensor_d_ddy->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_x = in_tensor_x->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddy[i] = + T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] + + T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s]; + } + } +#endif + } +}; + +template +struct DotTripleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == in_tensor_d_ddout->dims().size()) { + auto& dev = *ctx.eigen_device(); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + if (out_tensor_d_x) { + out_tensor_d_x->mutable_data(); + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + Eigen::DSizes size(in_tensor_ddy->numel()); + auto d_x = EigenVector::Flatten(*out_tensor_d_x); + d_x.device(dev) = ddy * d_ddout.broadcast(size); + } + + if (out_tensor_d_y) { + out_tensor_d_y->mutable_data(); + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + Eigen::DSizes size(in_tensor_ddx->numel()); + + auto d_y = EigenVector::Flatten(*out_tensor_d_y); + d_y.device(dev) = ddx * d_ddout.broadcast(size); + } + + if (out_tensor_d_dout) { + out_tensor_d_dout->mutable_data(); + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_dout = EigenVector::Flatten(*out_tensor_d_dout); + d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); + } + + if (out_tensor_d_ddx) { + out_tensor_d_ddx->mutable_data(); + auto dout = EigenVector::Flatten(*in_tensor_dout); + auto y = EigenVector::Flatten(*in_tensor_y); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_ddx = EigenVector::Flatten(*out_tensor_d_ddx); + Eigen::DSizes size(in_tensor_y->numel()); + d_ddx.device(dev) = + (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); + } + + if (out_tensor_d_ddy) { + out_tensor_d_ddy->mutable_data(); + auto dout = EigenVector::Flatten(*in_tensor_dout); + auto x = EigenVector::Flatten(*in_tensor_x); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_ddy = EigenVector::Flatten(*out_tensor_d_ddy); + Eigen::DSizes size(in_tensor_x->numel()); + d_ddy.device(dev) = + (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); + } + } +#else + const auto* data_d_ddout = in_tensor_d_ddout->data(); + + if (out_tensor_d_x) { + auto* data_d_x = out_tensor_d_x->mutable_data(); + const auto* data_ddy = in_tensor_ddy->data(); + + const DDim& dim = out_tensor_d_x->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_x[i] = data_ddy[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_y) { + auto* data_d_y = out_tensor_d_y->mutable_data(); + const auto* data_ddx = in_tensor_ddx->data(); + + const DDim& dim = out_tensor_d_y->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_y[i] = data_ddx[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_dout) { + auto* data_d_dout = out_tensor_d_dout->mutable_data(); + auto* data_ddx = in_tensor_ddx->data(); + auto* data_ddy = in_tensor_ddy->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + + const DDim& dim = in_tensor_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_d_dout[s] = + data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; + } else { + data_d_dout[s] += + data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; + } + new_s = false; + } + } + + if (out_tensor_d_ddx) { + auto* data_d_ddx = out_tensor_d_ddx->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + auto* data_y = in_tensor_y->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddx[i] = + data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_ddy) { + auto* data_d_ddy = out_tensor_d_ddy->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_x = in_tensor_x->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddy[i] = + data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s]; + } + } +#endif + } +}; + +template +void DotGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { + if (dx) { + dx->mutable_data(); + } + if (dy) { + dy->mutable_data(); + } + DotGradFunction()(dev_ctx, &x, &y, &dout, dx, dy); +} + +template +void DotDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout) { + if (dx) { + dx->mutable_data(); + } + if (dy) { + dy->mutable_data(); + } + if (ddout) { + ddout->mutable_data(); + } + DotDoubleGradFunction()( + dev_ctx, &x, &y, &dout, ddx, ddy, dx, dy, ddout); +} + +template +void DotTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + const DenseTensor& dout, + const DenseTensor& d_ddout, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_ddx, + DenseTensor* d_ddy, + DenseTensor* d_dout) { + if (d_x) { + d_x->mutable_data(); + } + if (d_y) { + d_y->mutable_data(); + } + if (d_ddx) { + d_ddx->mutable_data(); + } + if (d_ddy) { + d_ddy->mutable_data(); + } + if (d_dout) { + d_dout->mutable_data(); + } + + DotTripleGradFunction()(dev_ctx, + &x, + &y, + ddx, + ddy, + d_dx, + d_dy, + dout, + d_ddout, + d_x, + d_y, + d_dout, + d_ddx, + d_ddy); +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h b/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h new file mode 100644 index 0000000000000..802cc019d78c5 --- /dev/null +++ b/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h @@ -0,0 +1,1742 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// #include "paddle/pten/kernels/complex_kernel.h" +#include "paddle/pten/include/math.h" +#include "paddle/pten/kernels/empty_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" +#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" + +#include "paddle/pten/kernels/cpu/reduce.h" +#include "paddle/pten/kernels/funcs/reduce_functor.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/backends/gpu/gpu_context.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/pten/kernels/gpu/reduce.h" +#endif + +namespace pten { + +template +struct ReduceSumForMatmulGrad { + void operator()(const Context& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims); +}; + +template +struct ReduceSumForMatmulGrad { + void operator()(const CPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims) { + std::vector reduce_dims_tmp(reduce_dims.begin(), + reduce_dims.end()); + ReduceKernelImpl( + dev_ctx, input, output, reduce_dims_tmp, true, false); + } +}; + +#if defined(__NVCC__) || defined(__HIPCC__) +template +struct ReduceSumForMatmulGrad { + void operator()(const GPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims) { + auto stream = dev_ctx.stream(); + kernels:: + TensorReduceFunctorImpl>( + input, output, kps::IdentityFunctor(), reduce_dims, stream); + } +}; +#endif + +// Reshape a rank-3 tensor from P x M x N to (P * M) x N. +// Identity op if the tensor is not of rank 3. +static DenseTensor FoldInitDims(const DenseTensor& input) { + DenseTensor output = input; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + } + return output; +} + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx, + const DenseTensor& input) { + auto in_dims = input.dims(); + if (in_dims.size() != 3) { + return input; + } + DenseTensor output = EmptyLike(dev_ctx, input); + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(dev_ctx, input, &output, axis); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); + return output; +} + +template +void MatMul(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out, + bool flag = false) { + out->mutable_data(); + auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto mat_dim_a = + paddle::operators::math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = + paddle::operators::math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + if (a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a.data(), + mat_dim_a, + b.data(), + mat_dim_b, + static_cast(1), + out->mutable_data(), + static_cast(flag)); +} + +/** + * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the + * original x_dim is returned. + */ +static DDim RowMatrixFromVector(const DDim& x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return paddle::framework::make_ddim({1, x_dim[0]}); +} + +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static DDim ColumnMatrixFromVector(const DDim& y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return paddle::framework::make_ddim({y_dim[0], 1}); +} + +/** + * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. + * + * The shape would be [BatchSize, H, W] or [H, W]. + * If transposed, `H,W` will be swapped. + */ +static void ReshapeTensorIntoMatrixSequence( + DenseTensor* x, const paddle::operators::math::MatDescriptor& descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +static void ReshapeXYOutIntoMatrixSequence(DenseTensor* x, + DenseTensor* y, + DenseTensor* out, + bool trans_x, + bool trans_y) { + auto x_dim = RowMatrixFromVector(x->dims()); + auto y_dim = ColumnMatrixFromVector(y->dims()); + auto mat_dim_x = + paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = + paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, + mat_dim_y.width_}); + } + + ReshapeTensorIntoMatrixSequence(x, mat_dim_x); + ReshapeTensorIntoMatrixSequence(y, mat_dim_y); +} + +template +void CalcInputGrad(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + bool is_fold_init_dims_a, + const DenseTensor& b, + bool trans_b, + bool is_fold_init_dims_b, + DenseTensor* out, + bool flag = false) { + if (out == nullptr) return; + bool need_combine = + (a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2; + if (!need_combine) { + MatMul(dev_ctx, a, trans_a, b, trans_b, out, flag); + } else { + MatMul( + dev_ctx, + is_fold_init_dims_a ? FoldInitDims(a) + : FoldHeadAndLastDims(dev_ctx, a), + trans_a, + is_fold_init_dims_b ? FoldInitDims(b) + : FoldHeadAndLastDims(dev_ctx, b), + trans_b, + out, + flag); + } +} + +template +void MatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy) { + // get dims + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(out_grad.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's or y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + if (dx) dx->mutable_data(); + if (dy) dy->mutable_data(); + if (out_grad.numel() == 1) { + DotGradFunction()(dev_ctx, &x, &y, &out_grad, dx, dy); + return; + } + } + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + // for complex + DenseTensor x_conj; + DenseTensor y_conj; + + // Case2: no broadcast or no batch size, it aims to speed and it is same as + // matmul in old version. + if (!is_broadcast) { + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor out_grad_help = out_grad; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &out_grad_help, transpose_x, transpose_y); + + DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x_help.dims()) { + dx->Resize(x_help.dims()); + } + + y_conj = Conj(dev_ctx, y_help); + } + + DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y_help.dims()) { + dy->Resize(y_help.dims()); + } + + x_conj = Conj(dev_ctx, x_help); + } + + if (transpose_x && transpose_y) { + CalcInputGrad( + dev_ctx, y_conj, true, true, out_grad_help, true, false, dx); + CalcInputGrad( + dev_ctx, out_grad_help, true, true, x_conj, true, false, dy); + } else if (transpose_x) { + CalcInputGrad( + dev_ctx, y_conj, false, false, out_grad_help, true, false, dx); + CalcInputGrad( + dev_ctx, x_conj, false, false, out_grad_help, false, true, dy); + } else if (transpose_y) { + CalcInputGrad( + dev_ctx, out_grad_help, false, false, y_conj, false, true, dx); + CalcInputGrad( + dev_ctx, out_grad_help, true, true, x_conj, false, true, dy); + } else { + CalcInputGrad( + dev_ctx, out_grad_help, false, false, y_conj, true, false, dx); + CalcInputGrad( + dev_ctx, x_conj, true, true, out_grad_help, false, true, dy); + } + + if (dx) { + if (dx_dims != x_help.dims()) { + dx->Resize(dx_dims); + } + } + if (dy) { + if (dy_dims != y_help.dims()) { + dy->Resize(dy_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + + DenseTensor dx_help = Empty(dev_ctx); + DenseTensor dy_help = Empty(dev_ctx); + + if (transpose_x) { + if (transpose_y) { + // X'Y': dA = Y'G', dB = G'X' + if (dx) + MatMulFunction(dev_ctx, + y_conj, + out_grad, + y_dims, + dout_dims, + &dx_help, + true, + true); + if (dy) + MatMulFunction(dev_ctx, + out_grad, + x_conj, + dout_dims, + x_dims, + &dy_help, + true, + true); + } else { + // X'Y: dX = YG', dY = XG + if (dx) + MatMulFunction(dev_ctx, + y_conj, + out_grad, + y_dims, + dout_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + x_conj, + out_grad, + x_dims, + dout_dims, + &dy_help, + false, + false); + } + } else { + if (transpose_y) { + // XY': dX = GY, dY = G'X + if (dx) + MatMulFunction(dev_ctx, + out_grad, + y_conj, + dout_dims, + y_dims, + &dx_help, + false, + false); + if (dy) + MatMulFunction(dev_ctx, + out_grad, + x_conj, + dout_dims, + x_dims, + &dy_help, + true, + false); + } else { + // XY: dX = GY', dY = X'G + if (dx) + MatMulFunction(dev_ctx, + out_grad, + y_conj, + dout_dims, + y_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + x_conj, + out_grad, + x_dims, + dout_dims, + &dy_help, + true, + false); + } + } + + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // reduce sum to get grad by ReduceSum + if (dx) { + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dx_help, dx, dx_reduce_dims); + } + dx->Resize(x.dims()); + } + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dy_help, dy, dy_reduce_dims); + } + dy->Resize(y.dims()); + } + // Get the OutputGrad(out) + } +} + +template +void MatmulDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout) { + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's or y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + DotDoubleGradFunction()( + dev_ctx, &x, &y, &dout, ddx.get_ptr(), ddy.get_ptr(), dx, dy, ddout); + return; + } + + DenseTensor x_conj; + DenseTensor y_conj; + DenseTensor dout_conj; + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor dout_help = dout; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &dout_help, transpose_x, transpose_y); + DDim dx_dims; + + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x_help.dims()) { + dx->Resize(x_help.dims()); + } + } + + DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y_help.dims()) { + dy->Resize(y_help.dims()); + } + } + + DDim ddout_dims; + if (ddout) { + ddout_dims = ddout->dims(); + if (ddout_dims != dout_help.dims()) { + ddout->Resize(dout_help.dims()); + } + + x_conj = Conj(dev_ctx, x_help); + y_conj = Conj(dev_ctx, y_help); + } + + if (dx || dy) { + dout_conj = Conj(dev_ctx, dout_help); + } + + bool ddout_flag = false; + if (ddx) { + auto ddx_mat = ddx.get(); + if (ddx_mat.dims() != x_help.dims()) { + ddx_mat.Resize(x_help.dims()); + } + if (dy) { + if (transpose_x && transpose_y) { + // dy = dout' * ddx' + CalcInputGrad( + dev_ctx, dout_conj, true, true, ddx_mat, true, false, dy, false); + } else if (transpose_x) { + // dy = ddx * dout + CalcInputGrad(dev_ctx, + ddx_mat, + false, + false, + dout_conj, + false, + true, + dy, + false); + } else if (transpose_y) { + // dy = dout' * ddx + CalcInputGrad( + dev_ctx, dout_conj, true, true, ddx_mat, false, true, dy, false); + } else { + // dy = ddx' * dout + CalcInputGrad( + dev_ctx, ddx_mat, true, true, dout_conj, false, true, dy, false); + } + } + + if (ddout) { + CalcInputGrad(dev_ctx, + ddx_mat, + transpose_x, + true, + y_conj, + transpose_y, + false, + ddout, + ddout_flag); + ddout_flag = true; + } + } + + if (ddy) { + auto ddy_mat = ddy.get(); + if (ddy_mat.dims() != y_help.dims()) { + ddy_mat.Resize(y_help.dims()); + } + if (dx) { + if (transpose_x && transpose_y) { + // dx = ddy' * dout' + CalcInputGrad( + dev_ctx, ddy_mat, true, true, dout_conj, true, false, dx, false); + } else if (transpose_x) { + // dx = ddy * dout' + CalcInputGrad(dev_ctx, + ddy_mat, + false, + false, + dout_conj, + true, + false, + dx, + false); + } else if (transpose_y) { + // dx = dout * ddy + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + ddy_mat, + false, + true, + dx, + false); + } else { + // dx = dout * ddy' + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + ddy_mat, + true, + false, + dx, + false); + } + } + + if (ddout) { + CalcInputGrad(dev_ctx, + x_conj, + transpose_x, + true, + ddy_mat, + transpose_y, + false, + ddout, + ddout_flag); + } + } + + if (dx) { + if (dx_dims != x_help.dims()) { + dx->Resize(dx_dims); + } + } + + if (dy) { + if (dy_dims != y_help.dims()) { + dy->Resize(dy_dims); + } + } + + if (ddout) { + if (ddout_dims != dout_help.dims()) { + ddout->Resize(ddout_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + if (dx || dy) { + dout_conj = Conj(dev_ctx, dout); + } + if (ddout) { + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + } + + DenseTensor dx_help = Empty(dev_ctx); + DenseTensor dy_help = Empty(dev_ctx); + + if (transpose_x) { + if (transpose_y) { + if (dx) { + MatMulFunction(dev_ctx, + ddy.get(), + dout_conj, + y_dims, + dout_dims, + &dx_help, + true, + true); + } + if (dy) { + MatMulFunction(dev_ctx, + dout_conj, + ddx.get(), + dout_dims, + x_dims, + &dy_help, + true, + true); + } + } else { + if (dx) + MatMulFunction(dev_ctx, + ddy.get(), + dout_conj, + y_dims, + dout_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + ddx.get(), + dout_conj, + x_dims, + dout_dims, + &dy_help, + false, + false); + } + } else { + if (transpose_y) { + if (dx) { + MatMulFunction(dev_ctx, + dout_conj, + ddy.get(), + dout_dims, + y_dims, + &dx_help, + false, + false); + } + if (dy) { + MatMulFunction(dev_ctx, + dout_conj, + ddx.get(), + dout_dims, + x_dims, + &dy_help, + true, + false); + } + } else { + if (dx) { + MatMulFunction(dev_ctx, + dout_conj, + ddy.get(), + dout_dims, + y_dims, + &dx_help, + false, + true); + } + if (dy) { + MatMulFunction(dev_ctx, + ddx.get(), + dout_conj, + x_dims, + dout_dims, + &dy_help, + true, + false); + } + } + } + + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // Reduce sum to get grad by ReduceSum + if (dx) { + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dx_help, dx, dx_reduce_dims); + } + dx->Resize(x.dims()); + } + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dy_help, dy, dy_reduce_dims); + } + dy->Resize(y.dims()); + } + + if (ddout) { + // Calculate the gradient of OutputGrad(Out) + MatMulFunction(dev_ctx, + ddx.get(), + y_conj, + x_dims, + y_dims, + ddout, + transpose_x, + transpose_y); + MatMulFunction(dev_ctx, + x_conj, + ddy.get(), + x_dims, + y_dims, + ddout, + transpose_x, + transpose_y, + true); + } + } +} + +template +void MatmulTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& ddy, + paddle::optional d_dx, + paddle::optional d_dy, + paddle::optional d_ddout, + bool transpose_x, + bool transpose_y, + DenseTensor* out_d_x, + DenseTensor* out_d_y, + DenseTensor* out_d_dout, + DenseTensor* out_d_ddx, + DenseTensor* out_d_ddy) { + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's and y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1"; + DotTripleGradFunction()(dev_ctx, + &x, + &y, + &ddx, + &ddy, + d_dx.get_ptr(), + d_dy.get_ptr(), + &dout, + d_ddout.get_ptr(), + out_d_x, + out_d_y, + out_d_dout, + out_d_ddx, + out_d_ddy); + return; + } + + DenseTensor x_conj; + DenseTensor y_conj; + DenseTensor dout_conj; + DenseTensor ddx_conj; + DenseTensor ddy_conj; + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2"; + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor dout_help = dout; + DenseTensor ddx_help = ddx; + DenseTensor ddy_help = ddy; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &dout_help, transpose_x, transpose_y); + + if (ddx_help.dims() != x_help.dims()) { + ddx_help.Resize(x_help.dims()); + } + + if (ddy_help.dims() != y_help.dims()) { + ddy_help.Resize(y_help.dims()); + } + + DDim out_dx_dims; + if (out_d_x) { + out_dx_dims = out_d_x->dims(); + if (out_dx_dims != x_help.dims()) { + out_d_x->Resize(x_help.dims()); + } + } + + DDim out_dy_dims; + if (out_d_y) { + out_dy_dims = out_d_y->dims(); + if (out_dy_dims != y_help.dims()) { + out_d_y->Resize(y_help.dims()); + } + } + + DDim out_d_dout_dims; + if (out_d_dout) { + out_d_dout_dims = out_d_dout->dims(); + if (out_d_dout_dims != dout_help.dims()) { + out_d_dout->Resize(dout_help.dims()); + } + + ddx_conj = Conj(dev_ctx, ddx_help); + ddy_conj = Conj(dev_ctx, ddy_help); + } + + DDim out_d_ddx_dims; + if (out_d_ddx) { + out_d_ddx_dims = out_d_ddx->dims(); + if (out_d_ddx_dims != x_help.dims()) { + out_d_ddx->Resize(x_help.dims()); + } + } + + DDim out_d_ddy_dims; + if (out_d_ddy) { + out_d_ddy_dims = out_d_ddy->dims(); + if (out_d_ddy_dims != y_help.dims()) { + out_d_ddy->Resize(y_help.dims()); + } + } + + if (out_d_ddx || out_d_ddy) { + x_conj = Conj(dev_ctx, x_help); + y_conj = Conj(dev_ctx, y_help); + dout_conj = Conj(dev_ctx, dout_help); + } + + bool d_dout_flag = false; + bool d_ddx_flag = false; + bool d_ddy_flag = false; + + if (d_ddout) { + auto d_ddout_mat = d_ddout.get(); + if (d_ddout_mat.dims() != dout_help.dims()) { + d_ddout_mat.Resize(dout_help.dims()); + } + + if (out_d_y) { + if (transpose_x && transpose_y) { + // out_d_y = d_ddout' * ddx' + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + ddx_conj, + true, + false, + out_d_y, + false); + } else if (transpose_x) { + // out_d_y = ddx * d_ddout + CalcInputGrad(dev_ctx, + ddx_conj, + false, + false, + d_ddout_mat, + false, + true, + out_d_y, + false); + } else if (transpose_y) { + // out_d_y = d_ddout' * ddx + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + ddx_conj, + false, + true, + out_d_y, + false); + } else { + // out_d_y = ddx' * d_ddout + CalcInputGrad(dev_ctx, + ddx_conj, + true, + true, + d_ddout_mat, + false, + true, + out_d_y, + false); + } + } + if (out_d_x) { + if (transpose_x && transpose_y) { + // out_d_x = ddy' * d_ddout' + CalcInputGrad(dev_ctx, + ddy_conj, + true, + true, + d_ddout_mat, + true, + false, + out_d_x, + false); + } else if (transpose_x) { + // out_d_x = ddy * d_ddout' + CalcInputGrad(dev_ctx, + ddy_conj, + false, + false, + d_ddout_mat, + true, + false, + out_d_x, + false); + } else if (transpose_y) { + // out_d_x = d_ddout * ddy + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + ddy_conj, + false, + true, + out_d_x, + false); + } else { + // out_d_x = d_ddout * ddy' + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + ddy_conj, + true, + false, + out_d_x, + false); + } + } + + // equations: + // d_ddx = DOut * D_DY + Y * D_DDOut + // Let: d_ddx1 = Y * D_DDOut + // Let: d_ddx2 = DOut * D_DY + + // d_ddy = DOut * D_DX + X * D_DDOut + // Let: d_ddy1 = X * D_DDOut + // Let: d_ddy2 = DOut * D_DX + + // d_dout = DDY * D_DX + DDX * D_DY + // Let: d_dout1 = DDX * D_DY + // Let: d_dout2 = DDY * D_DX + + // compute d_ddx1 + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + CalcInputGrad(dev_ctx, + y_conj, + true, + true, + d_ddout_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_x) { + // out_d_ddx1 = y * d_ddout' + CalcInputGrad(dev_ctx, + y_conj, + false, + false, + d_ddout_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_y) { + // out_d_ddx1 = d_ddout * y + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + y_conj, + false, + true, + out_d_ddx, + d_ddx_flag); + } else { + // out_d_ddx1 = d_ddout * y' + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + y_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } + d_ddx_flag = true; + } + + // compute d_ddy1 + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy1 = d_ddout' * x' + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + x_conj, + true, + false, + out_d_ddy, + false); + } else if (transpose_x) { + // out_d_ddy1 = x * d_ddout + CalcInputGrad(dev_ctx, + x_conj, + false, + false, + d_ddout_mat, + false, + true, + out_d_ddy, + false); + } else if (transpose_y) { + // out_d_ddy1 = d_ddout' * x + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + x_conj, + false, + true, + out_d_ddy, + false); + } else { + // out_d_ddy1 = x' * d_ddout + CalcInputGrad(dev_ctx, + x_conj, + true, + true, + d_ddout_mat, + false, + true, + out_d_ddy, + false); + } + d_ddy_flag = true; + } + } + + if (d_dy) { + auto d_dy_mat = d_dy.get(); + if (d_dy_mat.dims() != y_help.dims()) { + d_dy_mat.Resize(y_help.dims()); + } + + // compute d_dout1 + if (out_d_dout) { + CalcInputGrad(dev_ctx, + ddx_conj, + transpose_x, + true, + d_dy_mat, + transpose_y, + false, + out_d_dout, + d_dout_flag); + d_dout_flag = true; + } + + // compute d_ddx2 + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx2 = D_DY' * DOut' + CalcInputGrad(dev_ctx, + d_dy_mat, + true, + true, + dout_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_x) { + // out_d_ddx2 = D_DY * Dout' + CalcInputGrad(dev_ctx, + d_dy_mat, + false, + false, + dout_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_y) { + // out_d_ddx2 = Dout * D_DY + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + d_dy_mat, + false, + true, + out_d_ddx, + d_ddx_flag); + } else { + // out_d_ddx2 = Dout * D_DY' + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + d_dy_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } + } + } + + if (d_dx) { + auto d_dx_mat = d_dx.get(); + if (d_dx_mat.dims() != x_help.dims()) { + d_dx_mat.Resize(x_help.dims()); + } + + // compute d_dout2 + if (out_d_dout) { + CalcInputGrad(dev_ctx, + d_dx_mat, + transpose_x, + true, + ddy_conj, + transpose_y, + false, + out_d_dout, + d_dout_flag); + } + + // compute d_ddy2 + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy2 = dout' * d_dx' + CalcInputGrad(dev_ctx, + dout_conj, + true, + true, + d_dx_mat, + true, + false, + out_d_ddy, + d_ddy_flag); + } else if (transpose_x) { + // out_d_ddy2 = d_dx * dout + CalcInputGrad(dev_ctx, + d_dx_mat, + false, + false, + dout_conj, + false, + true, + out_d_ddy, + d_ddy_flag); + } else if (transpose_y) { + // out_d_ddy2 = dout' * d_dx + CalcInputGrad(dev_ctx, + dout_conj, + true, + true, + d_dx_mat, + false, + true, + out_d_ddy, + d_ddy_flag); + } else { + // out_d_ddy2 = d_dx' * dout + CalcInputGrad(dev_ctx, + d_dx_mat, + true, + true, + dout_conj, + false, + true, + out_d_ddy, + d_ddy_flag); + } + } + } + + if (out_d_x) { + if (out_dx_dims != x_help.dims()) { + out_d_x->Resize(out_dx_dims); + } + } + + if (out_d_y) { + if (out_dy_dims != y_help.dims()) { + out_d_y->Resize(out_dy_dims); + } + } + + if (out_d_dout) { + if (out_d_dout_dims != dout_help.dims()) { + out_d_dout->Resize(out_d_dout_dims); + } + } + + if (out_d_ddx) { + if (out_d_ddx_dims != x_help.dims()) { + out_d_ddx->Resize(out_d_ddx_dims); + } + } + + if (out_d_ddy) { + if (out_d_ddy_dims != y_help.dims()) { + out_d_ddy->Resize(out_d_ddy_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3"; + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + + DenseTensor out_dx_help = Empty(dev_ctx); + DenseTensor out_dy_help = Empty(dev_ctx); + DenseTensor out_d_ddx_help = Empty(dev_ctx); + DenseTensor out_d_ddy_help = Empty(dev_ctx); + + if (out_d_dout) { + ddx_conj = Conj(dev_ctx, ddx); + ddy_conj = Conj(dev_ctx, ddy); + } + if (out_d_ddx || out_d_ddy) { + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + dout_conj = Conj(dev_ctx, dout); + } + + if (transpose_x) { + if (transpose_y) { + // dX = ddY' d_ddout’, dY = d_ddout’ ddX' + if (out_d_x) + MatMulFunction(dev_ctx, + ddy_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_dx_help, + true, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddx_conj, + dout_dims, + x_dims, + &out_dy_help, + true, + true); + } else { + // dX = ddY d_ddout', dY = ddX d_ddout + if (out_d_x) + MatMulFunction(dev_ctx, + ddy_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_dx_help, + false, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + ddx_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_dy_help, + false, + false); + } + } else { + if (transpose_y) { + // dX = d_ddout ddY, dY = d_ddout’ ddX + if (out_d_x) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddy_conj, + dout_dims, + y_dims, + &out_dx_help, + false, + false); + if (out_d_y) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddx_conj, + dout_dims, + x_dims, + &out_dy_help, + true, + false); + } else { + // dX = d_ddout ddY', dY = ddX' d_ddout + if (out_d_x) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddy_conj, + dout_dims, + y_dims, + &out_dx_help, + false, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + ddx_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_dy_help, + true, + false); + } + } + + // get help dims + const std::vector dx_help_dims = + vectorize(out_dx_help.dims()); + const std::vector dy_help_dims = + vectorize(out_dx_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // Reduce sum to get grad by ReduceSum + if (out_d_x) { + if (dx_reduce_dims.empty()) { + *out_d_x = std::move(out_dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_dx_help, out_d_x, dx_reduce_dims); + } + out_d_x->Resize(x.dims()); + } + + if (out_d_y) { + if (dy_reduce_dims.empty()) { + *out_d_y = std::move(out_dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_dy_help, out_d_y, dy_reduce_dims); + } + out_d_y->Resize(y.dims()); + } + + // compute d_dout + if (out_d_dout) { + MatMulFunction(dev_ctx, + d_dx.get(), + ddy_conj, + x_dims, + y_dims, + out_d_dout, + transpose_x, + transpose_y); + MatMulFunction(dev_ctx, + ddx_conj, + d_dy.get(), + x_dims, + y_dims, + out_d_dout, + transpose_x, + transpose_y, + true); + } + // compute d_ddx + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + MatMulFunction(dev_ctx, + y_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_d_ddx_help, + true, + true); + // out_d_ddx2 = D_DY' * DOut' + MatMulFunction(dev_ctx, + d_dy.get(), + dout_conj, + y_dims, + dout_dims, + &out_d_ddx_help, + true, + true, + true); + } else if (transpose_x) { + // out_d_ddx1 = y * d_ddout' + MatMulFunction(dev_ctx, + y_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_d_ddx_help, + false, + true); + // out_d_ddx2 = D_DY * Dout' + MatMulFunction(dev_ctx, + d_dy.get(), + dout_conj, + y_dims, + dout_dims, + &out_d_ddx_help, + false, + true, + true); + } else if (transpose_y) { + // out_d_ddx1 = d_ddout * y + MatMulFunction(dev_ctx, + d_ddout.get(), + y_conj, + dout_dims, + y_dims, + &out_d_ddx_help, + false, + false); + // out_d_ddx2 = Dout * D_DY + MatMulFunction(dev_ctx, + dout_conj, + d_dy.get(), + dout_dims, + y_dims, + &out_d_ddx_help, + false, + false, + true); + } else { + // out_d_ddx1 = d_ddout * y' + MatMulFunction(dev_ctx, + d_ddout.get(), + y_conj, + dout_dims, + y_dims, + &out_d_ddx_help, + false, + true); + // out_d_ddx2 = Dout * D_DY' + MatMulFunction(dev_ctx, + dout_conj, + d_dy.get(), + dout_dims, + y_dims, + &out_d_ddx_help, + false, + true, + true); + } + if (dx_reduce_dims.empty()) { + *out_d_ddx = std::move(out_d_ddx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_d_ddx_help, out_d_ddx, dx_reduce_dims); + } + out_d_ddx->Resize(x.dims()); + } + + // compute d_ddy + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy1 = d_ddout' * x' + MatMulFunction(dev_ctx, + d_ddout.get(), + x_conj, + dout_dims, + x_dims, + &out_d_ddy_help, + true, + true); + // out_d_ddy2 = dout' * d_dx' + MatMulFunction(dev_ctx, + dout_conj, + d_dx.get(), + dout_dims, + x_dims, + &out_d_ddy_help, + true, + true, + true); + } else if (transpose_x) { + // out_d_ddy1 = x * d_ddout + MatMulFunction(dev_ctx, + x_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_d_ddy_help, + false, + false); + // out_d_ddy2 = d_dx * dout + MatMulFunction(dev_ctx, + d_dx.get(), + dout_conj, + x_dims, + dout_dims, + &out_d_ddy_help, + false, + false, + true); + } else if (transpose_y) { + // out_d_ddy1 = d_ddout' * x + MatMulFunction(dev_ctx, + d_ddout.get(), + x_conj, + dout_dims, + x_dims, + &out_d_ddy_help, + true, + false); + // out_d_ddy2 = dout' * d_dx + MatMulFunction(dev_ctx, + dout_conj, + d_dx.get(), + dout_dims, + x_dims, + &out_d_ddy_help, + true, + false, + true); + } else { + // out_d_ddy1 = x' * d_ddout + MatMulFunction(dev_ctx, + x_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_d_ddy_help, + true, + false); + // out_d_ddy2 = d_dx' * dout + MatMulFunction(dev_ctx, + d_dx.get(), + dout_conj, + x_dims, + dout_dims, + &out_d_ddy_help, + true, + false, + true); + } + + if (dy_reduce_dims.empty()) { + *out_d_ddy = std::move(out_d_ddy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_d_ddy_help, out_d_ddy, dy_reduce_dims); + } + out_d_ddy->Resize(y.dims()); + } + } +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/matmul_kernel_impl.h b/paddle/pten/kernels/impl/matmul_kernel_impl.h index e50b2f0641a46..f5f69f327a69f 100644 --- a/paddle/pten/kernels/impl/matmul_kernel_impl.h +++ b/paddle/pten/kernels/impl/matmul_kernel_impl.h @@ -86,7 +86,7 @@ static void IndexIncreaseFromDims(const int ndim, } template -void MatMulFunction(const Context& context, +void MatMulFunction(const Context& dev_ctx, const DenseTensor& X, const DenseTensor& Y, const std::vector& x_dims, @@ -102,7 +102,7 @@ void MatMulFunction(const Context& context, const T* x_data = X.data(); const T* y_data = Y.data(); - auto blas = paddle::operators::math::GetBlas(context); + auto blas = paddle::operators::math::GetBlas(dev_ctx); if (x_ndim == 1 && y_ndim == 1) { const int M = X.numel(); @@ -117,6 +117,8 @@ void MatMulFunction(const Context& context, M, N)); VLOG(3) << "MatMul's case 1"; + Out->Resize({1}); + Out->mutable_data(); blas.GEMM(CblasNoTrans, CblasTrans, 1, @@ -471,7 +473,7 @@ void MatMulFunction(const Context& context, } template -void MatMulFunction(const Context& context, +void MatMulFunction(const Context& dev_ctx, const DenseTensor& X, const DenseTensor& Y, DenseTensor* Out, @@ -481,11 +483,11 @@ void MatMulFunction(const Context& context, const std::vector x_dims = vectorize(X.dims()); const std::vector y_dims = vectorize(Y.dims()); MatMulFunction( - context, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); + dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); } template -void MatmulKernel(const Context& context, +void MatmulKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, @@ -501,7 +503,7 @@ void MatmulKernel(const Context& context, paddle::platform::errors::InvalidArgument( "The Input(Y) dims size must not be equal 0," " but reviced dims size is 0. ")); - MatMulFunction(context, x, y, out, transpose_x, transpose_y); + MatMulFunction(dev_ctx, x, y, out, transpose_x, transpose_y); } } // namespace pten diff --git a/paddle/pten/kernels/matmul_grad_kernel.h b/paddle/pten/kernels/matmul_grad_kernel.h new file mode 100644 index 0000000000000..db485b79d2736 --- /dev/null +++ b/paddle/pten/kernels/matmul_grad_kernel.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace pten { + +template +void MatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy); + +template +void MatmulDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout); + +template +void MatmulTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& ddy, + paddle::optional d_dx, + paddle::optional d_dy, + paddle::optional d_ddout, + bool transpose_x, + bool transpose_y, + DenseTensor* out_d_x, + DenseTensor* out_d_y, + DenseTensor* out_d_dout, + DenseTensor* out_d_ddx, + DenseTensor* out_d_ddy); + +} // namespace pten diff --git a/paddle/pten/kernels/matmul_kernel.h b/paddle/pten/kernels/matmul_kernel.h index fb54a5301e61c..f9cb2c3801caa 100644 --- a/paddle/pten/kernels/matmul_kernel.h +++ b/paddle/pten/kernels/matmul_kernel.h @@ -14,14 +14,15 @@ #pragma once -#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/infermeta/binary.h" +#include "paddle/pten/kernels/empty_kernel.h" + namespace pten { template -void MatmulKernel(const Context& context, +void MatmulKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, @@ -29,17 +30,14 @@ void MatmulKernel(const Context& context, DenseTensor* out); template -DenseTensor Matmul(const Context& context, +DenseTensor Matmul(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, bool transpose_y) { auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y); - DenseTensor dense_out( - pten::make_intrusive( - context.GetPlace()), - std::move(out_meta)); - MatmulKernel(context, x, y, transpose_x, transpose_y, &dense_out); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + MatmulKernel(dev_ctx, x, y, transpose_x, transpose_y, &dense_out); return dense_out; } From 8cc09552473b842c651ead3b9848d41827a3dbab Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 11 Jan 2022 20:58:24 +0800 Subject: [PATCH 15/15] refactor reshape grad kernel (#38833) --- paddle/fluid/operators/reshape_op.cc | 64 ++++++++++++++---- paddle/pten/core/kernel_alias_name.h | 3 + paddle/pten/kernels/reshape_grad_kernel.cc | 75 ++++++++++++++++++++++ paddle/pten/kernels/reshape_grad_kernel.h | 31 +++++++++ 4 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 paddle/pten/kernels/reshape_grad_kernel.cc create mode 100644 paddle/pten/kernels/reshape_grad_kernel.h diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index f2162f55636e5..a25e53aac5d73 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/reshape_grad_kernel.h" #include "paddle/pten/kernels/reshape_kernel.h" namespace paddle { namespace framework { @@ -467,13 +468,27 @@ class ReshapeGradKernel { void operator()(const framework::ExecutionContext &ctx) const { auto *d_out = ctx.Input(framework::GradVarName("Out")); auto *d_x = ctx.Output(framework::GradVarName("X")); - auto in_dims = d_x->dims(); - d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopy( - *d_out, ctx.GetPlace(), - ctx.template device_context(), d_x); - d_x->Resize(in_dims); + + auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x); + auto pt_d_out = paddle::experimental::MakePtenDenseTensor(*d_out); + + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get()); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get()); + } +#endif +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get()); + } +#endif } }; @@ -482,14 +497,27 @@ class ReshapeDoubleGradKernel { void operator()(const framework::ExecutionContext &ctx) const { auto *dd_x = ctx.Input("DDX"); auto *dd_out = ctx.Output("DDOut"); + dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); - auto out_dims = dd_out->dims(); + auto pt_dd_x = paddle::experimental::MakePtenDenseTensor(*dd_x); + auto pt_dd_out = paddle::experimental::MakePtenDenseTensor(*dd_out); - dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); - framework::TensorCopy( - *dd_x, ctx.GetPlace(), - ctx.template device_context(), dd_out); - dd_out->Resize(out_dims); + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get()); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get()); + } +#endif +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get()); + } +#endif } }; @@ -624,6 +652,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + return framework::KernelSignature("reshape_grad", + {framework::GradVarName("Out")}, {}, + {framework::GradVarName("X")}); + } }; class Reshape2DoubleGradOp : public framework::OperatorWithKernel { @@ -660,6 +695,11 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + return framework::KernelSignature("reshape_double_grad", {"DDX"}, {}, + {"DDOut"}); + } }; DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"}); diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 46fa6dd376ee3..5c86787966368 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -35,6 +35,8 @@ const std::unordered_map kernel_alias_name_map = { {"reduce_mean", "mean"}, {"reduce_sum", "sum"}, {"reshape2", "reshape"}, + {"reshape2_grad", "reshape_grad"}, + {"reshape2_grad_grad", "reshape_double_grad"}, // fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated {"flatten", "deprecated"}, {"flatten_grad", "deprecated"}, @@ -43,6 +45,7 @@ const std::unordered_map kernel_alias_name_map = { {"matmul_grad_grad", "deprecated"}, {"mean", "deprecated"}, {"reshape", "deprecated"}, + {"reshape_grad", "deprecated"}, {"sum", "deprecated"}}; } // namespace pten diff --git a/paddle/pten/kernels/reshape_grad_kernel.cc b/paddle/pten/kernels/reshape_grad_kernel.cc new file mode 100644 index 0000000000000..99f0556765ef6 --- /dev/null +++ b/paddle/pten/kernels/reshape_grad_kernel.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/reshape_grad_kernel.h" +#include "paddle/pten/backends/all_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/copy_kernel.h" + +namespace pten { + +template +void ReshapeGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + auto x_dims = x_grad->dims(); + pten::Copy(dev_ctx, out_grad, false, x_grad); + x_grad->Resize(x_dims); +} + +template +void ReshapeDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x_grad_grad, + DenseTensor* out_grad_grad) { + ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad); +} + +} // namespace pten + +PT_REGISTER_GENERAL_KERNEL(reshape_grad, + CPU, + ALL_LAYOUT, + pten::ReshapeGradKernel, + ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape_double_grad, + CPU, + ALL_LAYOUT, + pten::ReshapeDoubleGradKernel, + ALL_DTYPE) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_GENERAL_KERNEL(reshape_grad, + GPU, + ALL_LAYOUT, + pten::ReshapeGradKernel, + ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape_double_grad, + GPU, + ALL_LAYOUT, + pten::ReshapeDoubleGradKernel, + ALL_DTYPE) {} +#endif + +#ifdef PADDLE_WITH_XPU +PT_REGISTER_GENERAL_KERNEL(reshape_grad, + XPU, + ALL_LAYOUT, + pten::ReshapeGradKernel, + ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape_double_grad, + XPU, + ALL_LAYOUT, + pten::ReshapeDoubleGradKernel, + ALL_DTYPE) {} +#endif diff --git a/paddle/pten/kernels/reshape_grad_kernel.h b/paddle/pten/kernels/reshape_grad_kernel.h new file mode 100644 index 0000000000000..1492d753704fd --- /dev/null +++ b/paddle/pten/kernels/reshape_grad_kernel.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void ReshapeGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +template +void ReshapeDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x_grad_grad, + DenseTensor* out_grad_grad); + +} // namespace pten