From 9755782bc35647d719ab06958635dd699e4922dd Mon Sep 17 00:00:00 2001 From: liangjianzhong Date: Tue, 14 Nov 2023 17:14:58 +0800 Subject: [PATCH 1/3] add reducescatter for pir yaml --- paddle/fluid/pir/dialect/operator/ir/ops.yaml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 036e4818da2bd..e9053d5526344 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -38,6 +38,16 @@ backend: place> data_type : dtype +- op : c_reducescatter + args : (Tensor x, int ring_id = 0, int nranks = 1, bool use_calc_stream = false) + output : Tensor(out) + infer_meta : + func : ReduceScatterInferMeta + param: [x, nranks] + kernel : + func : reduce_scatter + param: [x, nranks] + - op : embedding_grad_sparse args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx = -1, bool sparse = false) output : SelectedRows(weight_grad) From 8e3c84c3eff37967a67bb04dca746a6d56226be5 Mon Sep 17 00:00:00 2001 From: liangjianzhong Date: Tue, 14 Nov 2023 20:25:47 +0800 Subject: [PATCH 2/3] adopt for static op --- paddle/fluid/pir/dialect/op_generator/ops_api_gen.py | 1 + paddle/fluid/pir/dialect/operator/utils/utils.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index bf02ad8626c71..096e89b3b7929 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -103,6 +103,7 @@ 'c_embedding', 'c_identity', 'c_reduce_sum', + 'c_reducescatter', 'dpsgd', 'embedding_grad_sparse', 'fused_batch_norm_act_', diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 8961c70569c8b..21d5be31090a9 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -36,6 +36,7 @@ const std::unordered_set LegacyOpList = { "pd_op.c_reduce_sum_", "pd_op.c_allreduce_max_", "pd_op.c_allgather", + "pd_op.c_reducescatter", "pd_op.seed", "pd_op.share_data", "pd_op.sparse_momentum"}; From 35e6487e44c8b85b977535b1c1be43a7a506f45a Mon Sep 17 00:00:00 2001 From: liangjianzhong Date: Wed, 15 Nov 2023 14:32:52 +0800 Subject: [PATCH 3/3] remove list --- paddle/fluid/pir/dialect/operator/utils/utils.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 21d5be31090a9..8961c70569c8b 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -36,7 +36,6 @@ const std::unordered_set LegacyOpList = { "pd_op.c_reduce_sum_", "pd_op.c_allreduce_max_", "pd_op.c_allgather", - "pd_op.c_reducescatter", "pd_op.seed", "pd_op.share_data", "pd_op.sparse_momentum"};