Skip to content

Commit

Permalink
[NewIR]support c_reduce_sum/c_allgather/c_allreduce_max/seed (#57185)
Browse files Browse the repository at this point in the history
* support c_reduce_sum/c_allgather/c_allreduce_max/seed

* fix dropout_grad attr

* rename pd to pd_op

* fix _set_attr

* update unittest and control random

* hack for dropout in dist_context

* set env flag for profiler
  • Loading branch information
zhaoyinglia authored Sep 14, 2023
1 parent 5934505 commit 7eea08b
Show file tree
Hide file tree
Showing 17 changed files with 168 additions and 32 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@
'c_allreduce_sum',
'c_embedding',
'c_identity',
'c_reduce_sum',
'c_allreduce_max',
'c_allgather',
'seed',
]


Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@
param: [x, file_path, overwrite, save_as_fp16, save_to_memory]
optional : out

- op : seed
args : (int seed, bool deterministic, str rng_name, bool force_cpu)
output : Tensor(out)
infer_meta:
func: SeedInferMeta
param: [seed]
kernel:
func: seed

- op : send_v2
args : (Tensor x, int ring_id = 0, int peer = 0, bool use_calc_stream = false, bool dynamic_shape = false)
output :
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ const std::unordered_set<std::string> LegacyOpList = {
"pd_op.send_v2",
"pd_op.recv_v2",
"pd_op.c_allreduce_sum",
"pd_op.c_allreduce_sum_"};
"pd_op.c_allreduce_sum_",
"pd_op.c_reduce_sum",
"pd_op.c_reduce_sum_",
"pd_op.c_allreduce_max_",
"pd_op.c_allgather",
"pd_op.seed"};

enum class AttrType {
UNDEFINED = 0,
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,14 @@ phi::KernelKey GetKernelKey(
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
}

if (op->name() == "pd_op.seed") {
auto backend = paddle::experimental::ParseBackend(place);
return {backend,
phi::DataLayout::ANY,
TransToPhiDataType(
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
}

phi::Backend kernel_backend = phi::Backend::UNDEFINED;
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
phi::DataType kernel_data_type = phi::DataType::UNDEFINED;
Expand Down
29 changes: 29 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@
backward : batch_norm_grad
optional : reserve_space

- op : c_allgather
args : (Tensor x, int ring_id, int nranks, bool use_calc_stream)
output : Tensor(out)
infer_meta :
func : AllGatherInferMeta
param: [x, nranks]
kernel :
func : c_allgather

- op : c_allreduce_max
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
output : Tensor(out)
infer_meta :
func : AllReduceInferMeta
param : [x]
kernel :
func : c_allreduce_max
inplace : (x -> out)

- op : c_allreduce_sum
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
output : Tensor(out)
Expand Down Expand Up @@ -173,6 +192,16 @@
func : c_identity
inplace : (x -> out)

- op : c_reduce_sum
args : (Tensor x, int ring_id, int root_id, bool use_calc_stream)
output : Tensor(out)
infer_meta :
func : DistReduceInferMeta
param : [x]
kernel :
func : c_reduce_sum
inplace : (x -> out)

- op : c_sync_calc_stream
args : (Tensor x)
output : Tensor(out)
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@
backward : dropout_grad
inputs :
x : X
seed_tensor : Seed
outputs :
out : Out
mask : Mask
Expand Down Expand Up @@ -2428,6 +2429,8 @@
out : Out

- op : seed
outputs :
out : Out
extra :
attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false]

Expand Down Expand Up @@ -3048,6 +3051,18 @@
yolo_loss : GetYoloLossExpectedKernelType
yolo_loss_grad : GetYoloLossExpectedKernelType

- op: c_allgather
inputs :
x : X
outputs :
out: Out

- op: c_allreduce_max
inputs :
x : X
outputs :
out: Out

- op: c_allreduce_sum
inputs :
x : X
Expand All @@ -3066,6 +3081,12 @@
outputs :
out: Out

- op: c_reduce_sum
inputs :
x : X
outputs :
out: Out

- op: c_sync_calc_stream
inputs :
x : X
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ void RecvV2InferMeta(const int ring_id,
out->set_dtype(dtype);
}

void SeedInferMeta(int seed, MetaTensor* out) {
out->set_dims(phi::make_ddim({1}));
out->set_dtype(DataType::INT32);
}

void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ void RecvV2InferMeta(const int ring_id,
DataType dtype,
MetaTensor* out);

void SeedInferMeta(int seed, MetaTensor* out);

void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/distributed/auto_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_inited_rng_name_to_seed = {}
_enable_random_control = False
_basic_seed = 42
_basic_name = ""

# use Prime number as offset to avoid confict
_mesh_offset = 173
Expand All @@ -41,7 +42,7 @@ def enable_auto_rand_ctrl():
_enable_random_control = True


def parallel_manual_seed(seed):
def parallel_manual_seed(seed, name=""):
"""Enable auto parallel random control.
Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order).
* Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness.
Expand All @@ -66,6 +67,8 @@ def parallel_manual_seed(seed):
enable_auto_rand_ctrl()
global _basic_seed
_basic_seed = seed
global _basic_name
_basic_name = name


def determinate_rng(rank, dims_mapping, process_mesh):
Expand All @@ -80,13 +83,18 @@ def determinate_rng(rank, dims_mapping, process_mesh):
)
global _basic_seed
seed_ = _basic_seed
global _basic_name
name_ = _basic_name

if name_:
name_ += "_"

# FIXME
# unique_id = process_mesh.unique_id
unique_id = retrive_unique_id_for_process_mesh(
process_mesh.shape, process_mesh.process_ids
)
sharding_expr = f'mesh:{unique_id}'
sharding_expr = name_ + f'mesh:{unique_id}'
seed_ += _mesh_offset * (unique_id + 1)

for i in range(len(process_mesh.shape)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,10 @@ def amend_dist_attr_for_program(self):
):
dims_mapping[i] = -1
dist_attr.set_output_dims_mapping(arg_name, dims_mapping)
if len(process_mesh_processes) == 1:
if (
len(process_mesh_processes) == 1
and dist_op.serial_op.type != "dropout"
):
dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ def forward(ctx, *args, **kwargs):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"

if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute:
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"

# check validation of inputs / outputs
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert (
Expand Down Expand Up @@ -164,8 +163,8 @@ def forward(ctx, *args, **kwargs):

# modify dropout op
src_op.desc.set_input("Seed", [seed_var.name])
src_op._remove_attr("fix_seed")
src_op._remove_attr("seed")
src_op.desc._set_attr("fix_seed", False)
src_op.desc._set_attr("seed", 0)
op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from .partitioner import Partitioner
from .process_group import get_world_process_group
from .reshard import Resharder
from .utils import get_pp_stage, set_grad_var_shape, use_new_executor
from .utils import (
get_pp_stage,
is_sequential_run,
set_grad_var_shape,
use_new_executor,
)


class Parallelizer:
Expand Down Expand Up @@ -367,6 +372,7 @@ def _apply_post_optimization(
[main_program], [startup_program], self._pass_context
)

if not is_sequential_run():
# deps for newexe
config = {}
config["dist_context"] = self._dist_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,5 +294,6 @@ def profiler(args):


if __name__ == "__main__":
paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1})
args = parse_args()
profiler(args)
14 changes: 14 additions & 0 deletions python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2250,6 +2250,9 @@ def insert_dependencies_for_two_ops(
dependency: prior_op should be run before posterior_op
"""

if is_sequential_run():
return

assert (
len(prior_op.output_arg_names) >= 1
), "first op of dependency should at least have one output. [{}]".format(
Expand Down Expand Up @@ -2320,6 +2323,9 @@ def insert_dependencies_for_vars(
dependency: op that generates prior_vars should be run before op that generates post_vars
"""

if is_sequential_run():
return

if isinstance(prior_vars, Variable):
prior_vars = [prior_vars]
if isinstance(post_vars, Variable):
Expand Down Expand Up @@ -2423,6 +2429,14 @@ def use_new_executor():
]


def is_sequential_run():
return bool(
paddle.get_flags("FLAGS_new_executor_sequential_run")[
"FLAGS_new_executor_sequential_run"
]
)


def get_pp_stage(dist_context, rank):
pp_idx = None
for idx, process_mesh in enumerate(dist_context.process_meshes):
Expand Down
8 changes: 2 additions & 6 deletions python/paddle/distributed/passes/auto_parallel_recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def modify_forward_desc_for_recompute(self, dist_context):
# modify dropout op's desc
self.ops.insert(op_idx, seed_op)
cur_op.desc.set_input(seed_tensor_name, [var_unique_name])
cur_op._remove_attr("fix_seed")
cur_op._remove_attr("seed")
cur_op.desc._set_attr("fix_seed", False)
cur_op.desc._set_attr("seed", 0)
cur_op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
Expand Down Expand Up @@ -416,10 +416,6 @@ def _apply_single_impl(self, main_program, startup_program, context):
# segments ops should be inserted.
for i in range(len(ops) - 1, loss_op_idx, -1):
grad_op = ops[i]
# remove some attrs of dropout_grad op's desc
if grad_op.type == "dropout_grad":
grad_op._remove_attr("fix_seed")
grad_op._remove_attr("seed")

input_and_output_names = []
input_and_output_names.extend(grad_op.input_arg_names)
Expand Down
Loading

0 comments on commit 7eea08b

Please sign in to comment.