From 6f194f84a0ff3647a25539e9de0f59943a705d86 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Tue, 5 Sep 2023 12:44:54 +0800 Subject: [PATCH] [CustomDevice] add mp_allreduce_sum op for all custom devices (#56927) --- .../custom_device_common_op_registry.cc | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index df4d65af1d03f0..a1106b1386757b 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -535,6 +535,23 @@ template class CAllReduceOpCustomDeviceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + if (ctx.HasInput("Cond")) { + auto cond = ctx.Input("Cond"); + auto place = cond->place(); + PADDLE_ENFORCE_EQ(platform::is_cpu_place(place), + true, + platform::errors::PreconditionNotMet( + "The input `cond` tensor should be on cpu place")); + PADDLE_ENFORCE_EQ(cond->numel(), + 1, + platform::errors::PreconditionNotMet( + "The input `cond` should be shape [1]")); + if (!cond->data()[0]) { + VLOG(4) << "Skip all reduce Op since cond is 0"; + return; + } + } + auto in = ctx.Input("X"); auto out = ctx.Output("Out"); int rid = ctx.Attr("ring_id"); @@ -1441,6 +1458,29 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { paddle::platform::CustomDeviceContext, int64_t, phi::ccl::CCLReduceOp::SUM>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + mp_allreduce_sum, + device_type, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float, + phi::ccl::CCLReduceOp::SUM>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + double, + phi::ccl::CCLReduceOp::SUM>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16, + phi::ccl::CCLReduceOp::SUM>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int32_t, + phi::ccl::CCLReduceOp::SUM>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int64_t, + phi::ccl::CCLReduceOp::SUM>) {} REGISTER_OP_CUSTOM_DEVICE_KERNEL( c_allreduce_min, device_type,