Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto parallel support pipeline scheduler with standalone executor #54727

Merged
merged 11 commits into from
Jun 25, 2023
2 changes: 2 additions & 0 deletions paddle/fluid/operators/controlflow/fetch_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class FetchV2Kernel {
"operator 'Fetch') of current fetching variable to be "
"no less than 0. But received column index = %d.",
col));
VLOG(3) << "Fetch variable " << fetch_var_name << "'s " << col
<< " column.";

auto *fetch_list = out_var->GetMutable<framework::FetchList>();

Expand Down
9 changes: 5 additions & 4 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,11 @@ def _prepare_reader(self, feed_list=[]):
dist_main_block._sync_with_cpp()
self._has_prepared_reader[self._mode] = True

# Insert read op to forward TaskNode if 1F1B pass is setted
if self.main_program._pipeline_opt:
# Insert read op to forward TaskNode for fleet executor if 1F1B pass is setted
if (
self.main_program._pipeline_opt
and not auto_utils.use_new_executor()
):
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = None
Expand Down Expand Up @@ -471,8 +474,6 @@ def _process_fetch_group(group_name, var_list):
if var_name not in fetch_names:
fetch_names.append(var_name)
group_indices.append(fetch_names.index(var_name))
if not group_indices:
fetch_names.append([])
fetch_indices.append(group_indices)

dist_context = self._dist_contexts[mode]
Expand Down
36 changes: 26 additions & 10 deletions python/paddle/distributed/auto_parallel/static/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .partitioner import Partitioner
from .process_group import get_world_process_group
from .reshard import Resharder
from .utils import set_grad_var_shape
from .utils import set_grad_var_shape, use_new_executor


class Parallelizer:
Expand All @@ -38,6 +38,14 @@ def __init__(self, mode, completer, dist_context):
self._strategy = self._dist_context.strategy
self._logger = get_logger(logging.INFO)

@property
def is_train(self):
return self._mode == "train"

@property
def is_test(self):
return self._mode in ["eval", "predict"]

def parallel_all(self):
world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks
Expand All @@ -50,7 +58,7 @@ def parallel(self, rank):
serial_main_program = self._dist_context.serial_main_program
serial_startup_program = self._dist_context.serial_startup_program
serial_optimizer = self._dist_context.serial_optimizer
if self._mode == "train" and serial_optimizer:
if self.is_train and serial_optimizer:
# Generate backward
serial_loss = self._dist_context.serial_loss
params_grads = self._generate_backward(
Expand Down Expand Up @@ -191,8 +199,9 @@ def parallel(self, rank):
time.time() - time0, self._mode
)
)

# Clone program for test
if self._mode != 'train':
if self.is_test:
pipeline_opt = dist_main_prog._pipeline_opt
dist_main_prog = dist_main_prog.clone(for_test=True)
dist_startup_prog = dist_startup_prog.clone(for_test=True)
Expand Down Expand Up @@ -263,7 +272,7 @@ def _apply_pre_optimization(

# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat.enable:
if self.is_train and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
Expand All @@ -282,7 +291,7 @@ def _apply_pre_optimization(

# apply recompute pass
# recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute.enable:
if self.is_train and self._strategy.recompute.enable:
config = copy.deepcopy(self._strategy.recompute.to_dict())
config["dist_context"] = self._dist_context
config["no_grad_set"] = None
Expand Down Expand Up @@ -326,7 +335,7 @@ def _apply_post_optimization(
)
params_grads = self._pass_context.get_attr("params_grads")

if self._mode == "train":
if self.is_train:
# GradClip is train-only optimization
config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context
Expand All @@ -349,15 +358,15 @@ def _apply_post_optimization(
[main_program], [startup_program], self._pass_context
)

if self._strategy.pipeline.enable:
if self.is_train and self._strategy.pipeline.enable:
self._strategy.gradient_merge.enable = True
self._strategy.gradient_merge.k_steps = (
self._strategy.pipeline.accumulate_steps
)
self._strategy.gradient_merge.avg = True

# gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge.enable:
if self.is_train and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
Expand All @@ -368,7 +377,7 @@ def _apply_post_optimization(
[main_program], [startup_program], self._pass_context
)

if self._strategy.pipeline.enable:
if self._strategy.pipeline.enable and not use_new_executor():
config = copy.deepcopy(self._strategy.pipeline.to_dict())
config["dist_context"] = self._dist_context
auto_parallel_pipeline_pass = new_pass(
Expand All @@ -378,10 +387,17 @@ def _apply_post_optimization(
[main_program], [startup_program], self._pass_context
)

if self._mode == "train" and self._strategy.fused_passes.enable:
if self.is_train and self._strategy.fused_passes.enable:
if len(self._strategy.fused_passes.fused_passes_list) > 0:
new_pass_list = []
for op in self._strategy.fused_passes.fused_passes_list:
new_pass_list.append(new_pass(op))
pass_manager = PassManager(new_pass_list)
pass_manager.apply([main_program], [startup_program])

if self._strategy.pipeline.enable and use_new_executor():
main_program._pipeline_opt = {}
main_program._pipeline_opt["standalone_opt"] = {
"schedule_mode": self._strategy.pipeline.schedule_mode,
"num_micro_batches": self._strategy.pipeline.accumulate_steps,
}
13 changes: 13 additions & 0 deletions python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2367,3 +2367,16 @@ def __impl__(*args, **kwargs):


dygraph_guard = wrap_decorator(_dygraph_guard_)


def use_new_executor():
new_executor_micro_batching = os.environ.get(
'FLAGS_new_executor_micro_batching', None
)
return new_executor_micro_batching in [
1,
'1',
True,
'True',
'true',
]
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,13 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
)
set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks)

cond_var = main_block.create_var(
name="gradient_merge_cond", shape=[1], dtype='bool'
cond_var = paddle.static.create_global_var(
name="gradient_merge_cond",
shape=[1],
value=bool(0),
dtype='bool',
persistable=True,
force_cpu=True,
)
set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)

Expand Down
23 changes: 18 additions & 5 deletions python/paddle/distributed/passes/pipeline_scheduler_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program

from .pass_base import PassBase, register_pass
from .pass_base import PassBase, PassContext, new_pass, register_pass

__not_shape_var_type__ = [
core.VarDesc.VarType.READER,
Expand Down Expand Up @@ -257,7 +257,7 @@ def _program_for_fthenb_and_1f1b(program):
}


@register_pass("pipeline_fthenb_scheduler")
@register_pass("pipeline_scheduler_FThenB")
class PipelineFThenBPass(PassBase):
def __init__(self):
super().__init__()
Expand All @@ -272,12 +272,12 @@ def _create_job_list(self):
job_list = []
lr_job = core.Job("lr")
job_list.append(lr_job)
for i in range(self._micro_batch_size):
for i in range(self._num_micro_batches):
forward_job = core.Job("forward")
forward_job.set_micro_batch_id(i)
job_list.append(forward_job)

for i in range(self._micro_batch_size):
for i in range(self._num_micro_batches):
backward_job = core.Job("backward")
backward_job.set_micro_batch_id(i)
job_list.append(backward_job)
Expand All @@ -287,7 +287,7 @@ def _create_job_list(self):
return job_list

def _apply_single_impl(self, main_program, startup_program, context):
self._micro_batch_size = self.get_attr("micro_batch_size")
self._num_micro_batches = self.get_attr("num_micro_batches")
self._program = main_program

_insert_sync_for_fthenb_1f1b(self._program)
Expand All @@ -296,3 +296,16 @@ def _apply_single_impl(self, main_program, startup_program, context):

plan = core.Plan(job_list, type_to_program)
context.set_attr("plan", plan)


def apply_pass(main_program, startup_program, pass_name, pass_attr={}):
assert pass_name in [
"FThenB"
], "pipeline scheduler only support FThenB, but recieve {}".format(
pass_name
)
pipeline_pass = new_pass("pipeline_scheduler_" + pass_name, pass_attr)
pass_context = PassContext()
pipeline_pass.apply([main_program], [startup_program], pass_context)
plan = pass_context.get_attr("plan")
return plan
75 changes: 68 additions & 7 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,36 @@ def _add_feed_fetch_ops(
return tmp_program


def _set_micro_batch_fetch(plan):
if plan.micro_batch_num() <= 1:
return

valid_fetch_types = ["fetch", "fetch_v2"]
for job in plan.job_list():
idx_to_col_attr = {}
prog = plan.program(job.type())
for i in range(prog.block(0).op_size()):
op = prog.block(0).op(i)
if op.type() in valid_fetch_types:
idx_to_col_attr[i] = op.attr('col')

for idx, col in idx_to_col_attr.items():
job.set_col_attr_for_fetch_op(
idx, col * plan.micro_batch_num() + job.micro_batch_id()
)


def _merge_tensors(tensor, micro_batch_num):
if micro_batch_num <= 1:
return tensor
assert len(tensor) % micro_batch_num == 0
chunk_tensor = [
tensor[i : i + micro_batch_num]
for i in range(0, len(tensor), micro_batch_num)
]
return [np.array(chunk) for chunk in chunk_tensor]


def _apply_inplace_addto_pass(
program, enable_inplace, enable_addto, skip_var_names
):
Expand Down Expand Up @@ -653,8 +683,13 @@ def run(self, feed_names, return_numpy=True):
"""
tensors = self._new_exe.run(feed_names)._move_to_list()
if return_numpy:
return as_numpy(tensors, copy=True)
tensors = as_numpy(tensors, copy=True)
return _merge_tensors(tensors, self._plan.micro_batch_num())
else:
if self._plan.micro_batch_num() > 1:
raise RuntimeError(
"`merge_tensor` does not support when return_numpy is False."
)
return tensors

def _create_new_executor(self):
Expand Down Expand Up @@ -831,12 +866,30 @@ def _get_program_and_executor(self, cached_data):
_apply_inplace_addto_pass(
program, enable_inplace, enable_addto, skip_var_names
)

new_program = program.clone()
new_exe = _StandaloneExecutor(
place,
core.Plan([core.Job("default")], {"default": new_program.desc}),
scope,
)
if (
new_program._pipeline_opt
and "standalone_opt" in new_program._pipeline_opt
):
from paddle.distributed.passes.pipeline_scheduler_pass import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不建议在函数内部做导入。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不在此处做导入会发生循环引用的问题。

apply_pass,
)

standalone_opt = new_program._pipeline_opt["standalone_opt"]
pass_name = standalone_opt["schedule_mode"]
pass_attr = {
"num_micro_batches": standalone_opt["num_micro_batches"]
}
plan = apply_pass(new_program, new_program, pass_name, pass_attr)
else:
default_job = core.Job("default")
type_to_program = {"default": new_program.desc}
plan = core.Plan([default_job], type_to_program)

_set_micro_batch_fetch(plan)

new_exe = _StandaloneExecutor(place, plan, scope)
return new_program, new_exe


Expand Down Expand Up @@ -1408,7 +1461,15 @@ def _run_impl(

fetch_list = self._check_fetch_list(fetch_list)

if isinstance(program, Program) and program._pipeline_opt:
from paddle.distributed.auto_parallel.static.utils import (
use_new_executor,
)

if (
isinstance(program, Program)
and program._pipeline_opt
and not use_new_executor()
):
if "fleet_opt" in program._pipeline_opt:
# Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None:
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -5921,6 +5921,8 @@ def network():
p._appending_grad_times = self._appending_grad_times
if hasattr(self, 'lr_scheduler'):
p.lr_scheduler = self.lr_scheduler
if hasattr(self, '_pipeline_opt'):
p._pipeline_opt = self._pipeline_opt

# NOTE(zhiqiu): we sync the cloned program, to update its program by
# its desc.
Expand Down
4 changes: 4 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_tuner MODULES test_auto_tuner)
set_tests_properties(test_auto_tuner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 100)
py_test_modules(test_pipeline_scheduler_FThenB MODULES
test_pipeline_scheduler_FThenB)
set_tests_properties(test_pipeline_scheduler_FThenB
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_auto_tuner_compare MODULES test_auto_tuner_compare)
set_tests_properties(test_auto_tuner_compare
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
Expand Down
Loading