Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR]support c_reduce_sum/c_allgather/c_allreduce_max/seed #57185

Merged
merged 8 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@
'c_allreduce_sum',
'c_embedding',
'c_identity',
'c_reduce_sum',
'c_allreduce_max',
'c_allgather',
'seed',
]


Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@
param: [x, file_path, overwrite, save_as_fp16, save_to_memory]
optional : out

- op : seed
args : (int seed, bool deterministic, str rng_name, bool force_cpu)
output : Tensor(out)
infer_meta:
func: SeedInferMeta
param: [seed]
kernel:
func: seed

- op : send_v2
args : (Tensor x, int ring_id = 0, int peer = 0, bool use_calc_stream = false, bool dynamic_shape = false)
output :
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ const std::unordered_set<std::string> LegacyOpList = {
"pd_op.send_v2",
"pd_op.recv_v2",
"pd_op.c_allreduce_sum",
"pd_op.c_allreduce_sum_"};
"pd_op.c_allreduce_sum_",
"pd_op.c_reduce_sum",
"pd_op.c_reduce_sum_",
"pd_op.c_allreduce_max_",
"pd_op.c_allgather",
"pd_op.seed"};

enum class AttrType {
UNDEFINED = 0,
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,14 @@ phi::KernelKey GetKernelKey(
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
}

if (op->name() == "pd_op.seed") {
auto backend = paddle::experimental::ParseBackend(place);
return {backend,
phi::DataLayout::ANY,
TransToPhiDataType(
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
}

phi::Backend kernel_backend = phi::Backend::UNDEFINED;
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
phi::DataType kernel_data_type = phi::DataType::UNDEFINED;
Expand Down
29 changes: 29 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@
backward : batch_norm_grad
optional : reserve_space

- op : c_allgather
args : (Tensor x, int ring_id, int nranks, bool use_calc_stream)
output : Tensor(out)
infer_meta :
func : AllGatherInferMeta
param: [x, nranks]
kernel :
func : c_allgather

- op : c_allreduce_max
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
output : Tensor(out)
infer_meta :
func : AllReduceInferMeta
param : [x]
kernel :
func : c_allreduce_max
inplace : (x -> out)

- op : c_allreduce_sum
args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel)
output : Tensor(out)
Expand Down Expand Up @@ -173,6 +192,16 @@
func : c_identity
inplace : (x -> out)

- op : c_reduce_sum
args : (Tensor x, int ring_id, int root_id, bool use_calc_stream)
output : Tensor(out)
infer_meta :
func : DistReduceInferMeta
param : [x]
kernel :
func : c_reduce_sum
inplace : (x -> out)

- op : c_sync_calc_stream
args : (Tensor x)
output : Tensor(out)
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@
backward : dropout_grad
inputs :
x : X
seed_tensor : Seed
outputs :
out : Out
mask : Mask
Expand Down Expand Up @@ -2428,6 +2429,8 @@
out : Out

- op : seed
outputs :
out : Out
extra :
attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false]

Expand Down Expand Up @@ -3047,6 +3050,18 @@
yolo_loss : GetYoloLossExpectedKernelType
yolo_loss_grad : GetYoloLossExpectedKernelType

- op: c_allgather
inputs :
x : X
outputs :
out: Out

- op: c_allreduce_max
inputs :
x : X
outputs :
out: Out

- op: c_allreduce_sum
inputs :
x : X
Expand All @@ -3065,6 +3080,12 @@
outputs :
out: Out

- op: c_reduce_sum
inputs :
x : X
outputs :
out: Out

- op: c_sync_calc_stream
inputs :
x : X
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ void RecvV2InferMeta(const int ring_id,
out->set_dtype(dtype);
}

void SeedInferMeta(int seed, MetaTensor* out) {
out->set_dims(phi::make_ddim({1}));
out->set_dtype(DataType::INT32);
}

void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ void RecvV2InferMeta(const int ring_id,
DataType dtype,
MetaTensor* out);

void SeedInferMeta(int seed, MetaTensor* out);

void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/distributed/auto_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_inited_rng_name_to_seed = {}
_enable_random_control = False
_basic_seed = 42
_basic_name = ""

# use Prime number as offset to avoid confict
_mesh_offset = 173
Expand All @@ -41,7 +42,7 @@ def enable_auto_rand_ctrl():
_enable_random_control = True


def parallel_manual_seed(seed):
def parallel_manual_seed(seed, name=""):
"""Enable auto parallel random control.
Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order).
* Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness.
Expand All @@ -66,6 +67,8 @@ def parallel_manual_seed(seed):
enable_auto_rand_ctrl()
global _basic_seed
_basic_seed = seed
global _basic_name
_basic_name = name


def determinate_rng(rank, dims_mapping, process_mesh):
Expand All @@ -80,13 +83,18 @@ def determinate_rng(rank, dims_mapping, process_mesh):
)
global _basic_seed
seed_ = _basic_seed
global _basic_name
name_ = _basic_name

if name_:
name_ += "_"

# FIXME
# unique_id = process_mesh.unique_id
unique_id = retrive_unique_id_for_process_mesh(
process_mesh.shape, process_mesh.process_ids
)
sharding_expr = f'mesh:{unique_id}'
sharding_expr = name_ + f'mesh:{unique_id}'
seed_ += _mesh_offset * (unique_id + 1)

for i in range(len(process_mesh.shape)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,10 @@ def amend_dist_attr_for_program(self):
):
dims_mapping[i] = -1
dist_attr.set_output_dims_mapping(arg_name, dims_mapping)
if len(process_mesh_processes) == 1:
if (
len(process_mesh_processes) == 1
and dist_op.serial_op.type != "dropout"
):
dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ def forward(ctx, *args, **kwargs):
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"

if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute:
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"

# check validation of inputs / outputs
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert (
Expand Down Expand Up @@ -164,8 +163,8 @@ def forward(ctx, *args, **kwargs):

# modify dropout op
src_op.desc.set_input("Seed", [seed_var.name])
src_op._remove_attr("fix_seed")
src_op._remove_attr("seed")
src_op.desc._set_attr("fix_seed", False)
src_op.desc._set_attr("seed", 0)
op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from .partitioner import Partitioner
from .process_group import get_world_process_group
from .reshard import Resharder
from .utils import get_pp_stage, set_grad_var_shape, use_new_executor
from .utils import (
get_pp_stage,
is_sequential_run,
set_grad_var_shape,
use_new_executor,
)


class Parallelizer:
Expand Down Expand Up @@ -367,6 +372,7 @@ def _apply_post_optimization(
[main_program], [startup_program], self._pass_context
)

if not is_sequential_run():
# deps for newexe
config = {}
config["dist_context"] = self._dist_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,5 +294,6 @@ def profiler(args):


if __name__ == "__main__":
paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1})
args = parse_args()
profiler(args)
14 changes: 14 additions & 0 deletions python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2250,6 +2250,9 @@ def insert_dependencies_for_two_ops(
dependency: prior_op should be run before posterior_op
"""

if is_sequential_run():
return

assert (
len(prior_op.output_arg_names) >= 1
), "first op of dependency should at least have one output. [{}]".format(
Expand Down Expand Up @@ -2320,6 +2323,9 @@ def insert_dependencies_for_vars(
dependency: op that generates prior_vars should be run before op that generates post_vars
"""

if is_sequential_run():
return

if isinstance(prior_vars, Variable):
prior_vars = [prior_vars]
if isinstance(post_vars, Variable):
Expand Down Expand Up @@ -2423,6 +2429,14 @@ def use_new_executor():
]


def is_sequential_run():
return bool(
paddle.get_flags("FLAGS_new_executor_sequential_run")[
"FLAGS_new_executor_sequential_run"
]
)


def get_pp_stage(dist_context, rank):
pp_idx = None
for idx, process_mesh in enumerate(dist_context.process_meshes):
Expand Down
8 changes: 2 additions & 6 deletions python/paddle/distributed/passes/auto_parallel_recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def modify_forward_desc_for_recompute(self, dist_context):
# modify dropout op's desc
self.ops.insert(op_idx, seed_op)
cur_op.desc.set_input(seed_tensor_name, [var_unique_name])
cur_op._remove_attr("fix_seed")
cur_op._remove_attr("seed")
cur_op.desc._set_attr("fix_seed", False)
cur_op.desc._set_attr("seed", 0)
cur_op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
Expand Down Expand Up @@ -416,10 +416,6 @@ def _apply_single_impl(self, main_program, startup_program, context):
# segments ops should be inserted.
for i in range(len(ops) - 1, loss_op_idx, -1):
grad_op = ops[i]
# remove some attrs of dropout_grad op's desc
if grad_op.type == "dropout_grad":
grad_op._remove_attr("fix_seed")
grad_op._remove_attr("seed")

input_and_output_names = []
input_and_output_names.extend(grad_op.input_arg_names)
Expand Down
Loading