diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index bde5c7c23a7bc..479ae1bef5787 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -86,6 +86,10 @@ 'c_allreduce_sum', 'c_embedding', 'c_identity', + 'c_reduce_sum', + 'c_allreduce_max', + 'c_allgather', + 'seed', ] diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index da4c252af7217..bf80652d03134 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -106,6 +106,15 @@ param: [x, file_path, overwrite, save_as_fp16, save_to_memory] optional : out +- op : seed + args : (int seed, bool deterministic, str rng_name, bool force_cpu) + output : Tensor(out) + infer_meta: + func: SeedInferMeta + param: [seed] + kernel: + func: seed + - op : send_v2 args : (Tensor x, int ring_id = 0, int peer = 0, bool use_calc_stream = false, bool dynamic_shape = false) output : diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 72e5f63d28673..164e4c41f7065 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -29,7 +29,12 @@ const std::unordered_set LegacyOpList = { "pd_op.send_v2", "pd_op.recv_v2", "pd_op.c_allreduce_sum", - "pd_op.c_allreduce_sum_"}; + "pd_op.c_allreduce_sum_", + "pd_op.c_reduce_sum", + "pd_op.c_reduce_sum_", + "pd_op.c_allreduce_max_", + "pd_op.c_allgather", + "pd_op.seed"}; enum class AttrType { UNDEFINED = 0, diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 29b5c07f1dab9..90b9f5b8749f1 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -646,6 +646,14 @@ phi::KernelKey GetKernelKey( op->result(0).type().dyn_cast().dtype())}; } + if (op->name() == "pd_op.seed") { + auto backend = paddle::experimental::ParseBackend(place); + return {backend, + phi::DataLayout::ANY, + TransToPhiDataType( + op->result(0).type().dyn_cast().dtype())}; + } + phi::Backend kernel_backend = phi::Backend::UNDEFINED; phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED; phi::DataType kernel_data_type = phi::DataType::UNDEFINED; diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 4c151374c6893..55577dd3b006b 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -123,6 +123,25 @@ backward : batch_norm_grad optional : reserve_space +- op : c_allgather + args : (Tensor x, int ring_id, int nranks, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : AllGatherInferMeta + param: [x, nranks] + kernel : + func : c_allgather + +- op : c_allreduce_max + args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) + output : Tensor(out) + infer_meta : + func : AllReduceInferMeta + param : [x] + kernel : + func : c_allreduce_max + inplace : (x -> out) + - op : c_allreduce_sum args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) output : Tensor(out) @@ -173,6 +192,16 @@ func : c_identity inplace : (x -> out) +- op : c_reduce_sum + args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param : [x] + kernel : + func : c_reduce_sum + inplace : (x -> out) + - op : c_sync_calc_stream args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 495ba53cd7613..0ae5ec57a17b7 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -810,6 +810,7 @@ backward : dropout_grad inputs : x : X + seed_tensor : Seed outputs : out : Out mask : Mask @@ -2428,6 +2429,8 @@ out : Out - op : seed + outputs : + out : Out extra : attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false] @@ -3047,6 +3050,18 @@ yolo_loss : GetYoloLossExpectedKernelType yolo_loss_grad : GetYoloLossExpectedKernelType +- op: c_allgather + inputs : + x : X + outputs : + out: Out + +- op: c_allreduce_max + inputs : + x : X + outputs : + out: Out + - op: c_allreduce_sum inputs : x : X @@ -3065,6 +3080,12 @@ outputs : out: Out +- op: c_reduce_sum + inputs : + x : X + outputs : + out: Out + - op: c_sync_calc_stream inputs : x : X diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index d5da3a2f8bc87..1c57e2fae92ac 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -223,6 +223,11 @@ void RecvV2InferMeta(const int ring_id, out->set_dtype(dtype); } +void SeedInferMeta(int seed, MetaTensor* out) { + out->set_dims(phi::make_ddim({1})); + out->set_dtype(DataType::INT32); +} + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index bc73942c8ec1c..2f9c9a69a13f1 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -83,6 +83,8 @@ void RecvV2InferMeta(const int ring_id, DataType dtype, MetaTensor* out); +void SeedInferMeta(int seed, MetaTensor* out); + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/python/paddle/distributed/auto_parallel/random.py b/python/paddle/distributed/auto_parallel/random.py index dbd0dbc691d02..d79f94e166524 100644 --- a/python/paddle/distributed/auto_parallel/random.py +++ b/python/paddle/distributed/auto_parallel/random.py @@ -25,6 +25,7 @@ _inited_rng_name_to_seed = {} _enable_random_control = False _basic_seed = 42 +_basic_name = "" # use Prime number as offset to avoid confict _mesh_offset = 173 @@ -41,7 +42,7 @@ def enable_auto_rand_ctrl(): _enable_random_control = True -def parallel_manual_seed(seed): +def parallel_manual_seed(seed, name=""): """Enable auto parallel random control. Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order). * Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness. @@ -66,6 +67,8 @@ def parallel_manual_seed(seed): enable_auto_rand_ctrl() global _basic_seed _basic_seed = seed + global _basic_name + _basic_name = name def determinate_rng(rank, dims_mapping, process_mesh): @@ -80,13 +83,18 @@ def determinate_rng(rank, dims_mapping, process_mesh): ) global _basic_seed seed_ = _basic_seed + global _basic_name + name_ = _basic_name + + if name_: + name_ += "_" # FIXME # unique_id = process_mesh.unique_id unique_id = retrive_unique_id_for_process_mesh( process_mesh.shape, process_mesh.process_ids ) - sharding_expr = f'mesh:{unique_id}' + sharding_expr = name_ + f'mesh:{unique_id}' seed_ += _mesh_offset * (unique_id + 1) for i in range(len(process_mesh.shape)): diff --git a/python/paddle/distributed/auto_parallel/static/dist_context.py b/python/paddle/distributed/auto_parallel/static/dist_context.py index 09dfcee60b8fc..5eabdd312bbb7 100644 --- a/python/paddle/distributed/auto_parallel/static/dist_context.py +++ b/python/paddle/distributed/auto_parallel/static/dist_context.py @@ -982,7 +982,10 @@ def amend_dist_attr_for_program(self): ): dims_mapping[i] = -1 dist_attr.set_output_dims_mapping(arg_name, dims_mapping) - if len(process_mesh_processes) == 1: + if ( + len(process_mesh_processes) == 1 + and dist_op.serial_op.type != "dropout" + ): dist_op.dist_attr.impl_type = "default" dist_op.dist_attr.impl_idx = 0 diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py index 08976d7eba74f..ec3692f0385b5 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_dropout.py @@ -58,12 +58,11 @@ def forward(ctx, *args, **kwargs): src_op = dist_op_context.cur_src_op rank_id = dist_op_context.rank_id op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + assert ( + op_dist_attr is not None + ), f"forward op [{str(src_op)}] don't have dist attribute !" if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute: - assert ( - op_dist_attr is not None - ), f"forward op [{str(src_op)}] don't have dist attribute !" - # check validation of inputs / outputs assert 'X' in kwargs, "input [{}] is not given".format('X') assert ( @@ -164,8 +163,8 @@ def forward(ctx, *args, **kwargs): # modify dropout op src_op.desc.set_input("Seed", [seed_var.name]) - src_op._remove_attr("fix_seed") - src_op._remove_attr("seed") + src_op.desc._set_attr("fix_seed", False) + src_op.desc._set_attr("seed", 0) op_dist_attr.set_input_dist_attr( seed_var.name, seed_var_dist_attr ) diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 8b5136a61b9f6..086221d64f092 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -25,7 +25,12 @@ from .partitioner import Partitioner from .process_group import get_world_process_group from .reshard import Resharder -from .utils import get_pp_stage, set_grad_var_shape, use_new_executor +from .utils import ( + get_pp_stage, + is_sequential_run, + set_grad_var_shape, + use_new_executor, +) class Parallelizer: @@ -367,6 +372,7 @@ def _apply_post_optimization( [main_program], [startup_program], self._pass_context ) + if not is_sequential_run(): # deps for newexe config = {} config["dist_context"] = self._dist_context diff --git a/python/paddle/distributed/auto_parallel/static/tuner/profiler.py b/python/paddle/distributed/auto_parallel/static/tuner/profiler.py index 992ac034e16b7..bbc3171cbebe9 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/profiler.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/profiler.py @@ -294,5 +294,6 @@ def profiler(args): if __name__ == "__main__": + paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1}) args = parse_args() profiler(args) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index fa12cfd68e3b2..062880a4e70d2 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2250,6 +2250,9 @@ def insert_dependencies_for_two_ops( dependency: prior_op should be run before posterior_op """ + if is_sequential_run(): + return + assert ( len(prior_op.output_arg_names) >= 1 ), "first op of dependency should at least have one output. [{}]".format( @@ -2320,6 +2323,9 @@ def insert_dependencies_for_vars( dependency: op that generates prior_vars should be run before op that generates post_vars """ + if is_sequential_run(): + return + if isinstance(prior_vars, Variable): prior_vars = [prior_vars] if isinstance(post_vars, Variable): @@ -2423,6 +2429,14 @@ def use_new_executor(): ] +def is_sequential_run(): + return bool( + paddle.get_flags("FLAGS_new_executor_sequential_run")[ + "FLAGS_new_executor_sequential_run" + ] + ) + + def get_pp_stage(dist_context, rank): pp_idx = None for idx, process_mesh in enumerate(dist_context.process_meshes): diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 0fd008ff5a701..5c63d93cb14b1 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -185,8 +185,8 @@ def modify_forward_desc_for_recompute(self, dist_context): # modify dropout op's desc self.ops.insert(op_idx, seed_op) cur_op.desc.set_input(seed_tensor_name, [var_unique_name]) - cur_op._remove_attr("fix_seed") - cur_op._remove_attr("seed") + cur_op.desc._set_attr("fix_seed", False) + cur_op.desc._set_attr("seed", 0) cur_op_dist_attr.set_input_dist_attr( seed_var.name, seed_var_dist_attr ) @@ -416,10 +416,6 @@ def _apply_single_impl(self, main_program, startup_program, context): # segments ops should be inserted. for i in range(len(ops) - 1, loss_op_idx, -1): grad_op = ops[i] - # remove some attrs of dropout_grad op's desc - if grad_op.type == "dropout_grad": - grad_op._remove_attr("fix_seed") - grad_op._remove_attr("seed") input_and_output_names = [] input_and_output_names.extend(grad_op.input_arg_names) diff --git a/test/auto_parallel/gpt_with_newir.py b/test/auto_parallel/gpt_with_newir.py index 4ddfd5a76ffe0..1be3202a23777 100644 --- a/test/auto_parallel/gpt_with_newir.py +++ b/test/auto_parallel/gpt_with_newir.py @@ -26,10 +26,31 @@ paddle.enable_static() -def apply_pass(): +def apply_pass(use_sharding=False): strategy = auto.Strategy() strategy.auto_mode = "semi" strategy.reinit = True + + amp = strategy.amp + amp.enable = True + amp.dtype = "float16" + amp.level = "o2" + amp.custom_white_list = ['softmax', 'layer_norm', 'gelu'] + amp.custom_black_list = [ + 'c_softmax_with_cross_entropy', + 'elementwise_div', + 'reduce_sum', + ] + + recompute = strategy.recompute + recompute.enable = True + + if use_sharding: + sharding = strategy.sharding + sharding.enable = True + sharding.degree = 2 + sharding.stage = 2 + return strategy @@ -49,24 +70,27 @@ def setUp(self): paddle.set_flags({'FLAGS_embedding_deterministic': 1}) paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) - def init(self, engine): + def init(self, engine, name): paddle.seed(2021) np.random.seed(2021) random.seed(2021) paddle.distributed.fleet.init(is_collective=True) + paddle.distributed.auto_parallel.random._rng_name_to_seed.clear() + paddle.distributed.auto_parallel.random._inited_rng_name_to_seed.clear() + paddle.distributed.auto_parallel.parallel_manual_seed(2021, name) place = paddle.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) - def get_engine(self, mode): + def get_engine(self, mode, name, use_sharding=False): reset_prog() - strategy = apply_pass() + strategy = apply_pass(use_sharding) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) - opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=None) - model, loss = generate_model(mode) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model(mode, dropout_prob=0.1) engine = auto.Engine(model, loss, opt, strategy=strategy) - self.init(engine) + self.init(engine, name) return engine def check_results(self, ref_losses, check_losses): @@ -84,13 +108,15 @@ def enable_new_ir(self, flag): def test_dp(self): self.enable_new_ir(False) - engine_dp_prog = self.get_engine("dp") + engine_dp_prog = self.get_engine( + "dp", name="dp_prog", use_sharding=True + ) out_dp_prog = engine_dp_prog.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) self.enable_new_ir(True) - engine_dp_ir = self.get_engine("dp") + engine_dp_ir = self.get_engine("dp", name="dp_newir", use_sharding=True) out_dp_ir = engine_dp_ir.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) @@ -101,13 +127,13 @@ def test_dp(self): def test_mp(self): self.enable_new_ir(False) - engine_mp_prog = self.get_engine("mp") + engine_mp_prog = self.get_engine("mp", name="mp_prog") out_mp_prog = engine_mp_prog.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) self.enable_new_ir(True) - engine_mp_ir = self.get_engine("mp") + engine_mp_ir = self.get_engine("mp", name="mp_newir") out_mp_ir = engine_mp_ir.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) @@ -119,14 +145,14 @@ def test_mp(self): def test_pp(self): # navie pipeline parallel without schedule self.enable_new_ir(False) - engine_pp_prog = self.get_engine("pp") + engine_pp_prog = self.get_engine("pp", name="pp_prog0") out_pp_prog = engine_pp_prog.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) self.enable_new_ir(True) # send_v2/recv_v2 dynamic_shape is True - engine_pp_ir = self.get_engine("pp") + engine_pp_ir = self.get_engine("pp", name="pp_newir") out_pp_ir = engine_pp_ir.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) @@ -137,7 +163,7 @@ def test_pp(self): ) # send_v2/recv_v2 dynamic_shape is False - engine_pp_prog1 = self.get_engine("pp") + engine_pp_prog1 = self.get_engine("pp", name="pp_prog1") dataloader_pp_prog = engine_pp_prog1.dataloader( self.dataset, batch_size=self.batch_size, diff --git a/test/auto_parallel/test_tuning_recompute.py b/test/auto_parallel/test_tuning_recompute.py index ef9a16a2cae72..ca266475b5aff 100644 --- a/test/auto_parallel/test_tuning_recompute.py +++ b/test/auto_parallel/test_tuning_recompute.py @@ -79,7 +79,7 @@ def apply_pass(): class TestRecomputePassTuning(unittest.TestCase): def setUp(self): self.batch_size = 8 - self.batch_num = 200 + self.batch_num = 5 self.dataset = FakeDataset( self.batch_size * self.batch_num, vocab_size=50304,