Skip to content

Commit

Permalink
[Auto-Parallel] Reshard API & Hybrid Parallel Unitest for dy2static m…
Browse files Browse the repository at this point in the history
…ode (#59856)

* stsatic_decorate v0.1

* update static_decorate as comments

* add unit tests and adapt placement api

* add docs for the api

* remove useless print and comments

* first commit

* stsatic_decorate v0.1

* update static_decorate as comments

* add unit tests and adapt placement api

* add docs for the api

* remove useless print and comments

* add unit execution code

* fix sample code for static_decorate

* add get_program interface in DistModel

* modify as suggested

* move the init parameters part to helper.py

* fix unittest name in CMakeList

* add api

* add unitest

* typoes

* add dy2static test case for llama

* static dist model pp

* bug fixed

* enable all test

* fix typoes

---------

Co-authored-by: Yichen Zhang <zhangyichen03@baidu.com>
  • Loading branch information
JZ-LIANG and pkuzyc authored Dec 9, 2023
1 parent 5be87ba commit 70c4d21
Show file tree
Hide file tree
Showing 10 changed files with 593 additions and 28 deletions.
82 changes: 74 additions & 8 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,26 @@
import paddle
import paddle.distributed as dist
from paddle import nn
from paddle.base.framework import EagerParamBase
from paddle.base import unique_name
from paddle.base.framework import (
EagerParamBase,
Variable,
default_main_program,
)
from paddle.distributed.auto_parallel import Engine
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
from paddle.distributed.auto_parallel.static.completion import (
mark_as_sharding_propagation_skip_op,
)
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.static.utils import (
convert_to_dims_mapping,
)
from paddle.framework import core

from .placement_type import get_shard_spec
Expand Down Expand Up @@ -131,6 +146,7 @@ def __init__(
)
self._mode = None
self._feed_name_list = {}

# convert dygraph model to static model
batch_size = loader.batch_sampler.batch_size
inputs_spec, labels_spec = self._engine._prepare_data_spec(
Expand Down Expand Up @@ -268,15 +284,20 @@ def __call__(self, *args):
raise ValueError("Please set loss function before evaluation.")
feeds = self._make_feeds(list(args))
outs = self._engine.run(feeds)

if self._mode == "predict":
return outs["outputs"]
if "outputs" in outs:
return outs["outputs"]
else:
return None
else:
return outs["loss"]
if "loss" in outs:
return outs["loss"]
else:
return None


# Part2: DistTensor construction related APIs


def to_static(
layer: paddle.nn.Layer,
loader=None,
Expand Down Expand Up @@ -566,10 +587,55 @@ def reshard(dist_tensor, mesh, placements):

return paddle.base.core.reshard(dist_tensor, dist_attr)
else:
# TODO(GhostScreaming): Support static DistTensor later.
raise RuntimeError(
"paddle.dist.reshard only support dynamic graph now. It will be supported for static graph later."
assert isinstance(
dist_tensor, Variable
), "in dy2static mode, reshard's input should be Variable, but got [{}]".format(
dist_tensor
)
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
main_program = default_main_program()
default_dist_ctx = get_default_distributed_context()

# output variable
out_var = main_program.current_block().create_var(
name=unique_name.generate_with_ignorable_key(
".".join(['reshard_api', 'tmp'])
),
dtype=dist_tensor.dtype,
shape=dist_tensor.shape,
type=dist_tensor.type,
persistable=dist_tensor.persistable,
stop_gradient=dist_tensor.stop_gradient,
)

# transition op
# optimization in future to remove redundant D2D memory copy
target_dims_mapping = convert_to_dims_mapping(sharding_specs, mesh)
trans_op = main_program.current_block().append_op(
type='assign',
inputs={'X': [dist_tensor]},
outputs={'Out': [out_var]},
)
dist_op = DistributedOperator(trans_op)
dist_op.dist_attr.process_mesh = mesh
dist_op.dist_attr.mark_annotated("process_mesh")
dist_op.dist_attr.chunk_id = 0

input_dist_attr = dist_op.dist_attr.get_input_dist_attr(
dist_tensor.name
)
input_dist_attr.dims_mapping = target_dims_mapping
input_dist_attr.mark_annotated("dims_mapping")
output_dist_attr = dist_op.dist_attr.get_output_dist_attr(out_var.name)
output_dist_attr.dims_mapping = target_dims_mapping
output_dist_attr.mark_annotated("dims_mapping")

default_dist_ctx.add_dist_op_for_program(dist_op)
mark_as_sharding_propagation_skip_op(trans_op)
# trans_op = shard_op_static(paddle.assign, mesh, [sharding_specs])
# out_var = trans_op(dist_tensor)

return out_var


def shard_layer(
Expand Down
27 changes: 26 additions & 1 deletion python/paddle/distributed/auto_parallel/static/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import logging
import os

import paddle
from paddle.base.core import ( # noqa: F401
contains_spmd_rule,
get_phi_spmd_rule,
get_spmd_rule,
)
from paddle.base.framework import Operator
from paddle.base.log_helper import get_logger
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.framework import core
Expand Down Expand Up @@ -53,6 +55,24 @@
"read",
]

_skip_propagation_prefix = "Auto_Parallel_Completion_Skipped"


def mark_as_sharding_propagation_skip_op(op):
op._set_attr('op_namescope', '/' + _skip_propagation_prefix)


def is_sharding_propagation_skip_op(op):
if isinstance(op, paddle.base.libpaddle.OpDesc):
op_desc = op
elif isinstance(op, Operator):
op_desc = op.desc
else:
raise RuntimeError(f"static mode operator is expected but got [{op}]")
return op_desc.has_attr(
"op_namescope"
) and _skip_propagation_prefix in op_desc.attr("op_namescope")


def compute_compatible_dim_mapping(dim_mapping_list):
"""Compute the compatible dim mapping given a list of dim mapping."""
Expand Down Expand Up @@ -218,6 +238,7 @@ def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
or pred_op_node.op().type()
== "create_double_buffer_reader"
or pred_op_node.op().type() == "read"
# or is_sharding_propagation_skip_op(pred_op_node.op()) # reshard should only fwd tensor propagation
):
continue
op_dist_attr = (
Expand Down Expand Up @@ -255,6 +276,7 @@ def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
or succ_op_node.op().type()
== "create_double_buffer_reader"
or succ_op_node.op().type() == "read"
or is_sharding_propagation_skip_op(succ_op_node.op())
):
continue
op_dist_attr = (
Expand Down Expand Up @@ -293,7 +315,10 @@ def _update_op_node_dims_mapping(self, op_node, fwd=True):
if (not op_node.is_op()) or (op_node.op() is None):
return False
# Skip reader op
if op_desc.type() in __skip_dims_mapping_op__:
if (
op_desc.type() in __skip_dims_mapping_op__
or is_sharding_propagation_skip_op(op_node.op())
):
return False

dist_op = self._dist_context.get_dist_op_for_graph(op_node)
Expand Down
4 changes: 4 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_semi_auto_parallel_dist_to_static)
set_tests_properties(test_semi_auto_parallel_dist_to_static
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300)
py_test_modules(test_static_reshard_api MODULES test_static_reshard_api)
set_tests_properties(test_static_reshard_api
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300)

# End of unittests WITH multi cards and timeout

# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
Expand Down
57 changes: 54 additions & 3 deletions test/auto_parallel/hybrid_strategy/semi_auto_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from functools import reduce

import numpy as np
from semi_auto_parallel_llama_model import LlamaForCausalLMAuto, set_global_mesh
from semi_auto_parallel_llama_model import (
LlamaForCausalLMAuto,
LlamaPretrainingCriterionAuto,
set_global_mesh,
)

import paddle
import paddle.distributed as dist
Expand Down Expand Up @@ -104,8 +108,9 @@ def init_dist_env(self):
global_mesh = dist.ProcessMesh(mesh_arr, dim_names)
set_global_mesh(global_mesh)

def run_test_cases(self):
def run_dynamic(self):
model = LlamaForCausalLMAuto(self.config)
criterion = LlamaPretrainingCriterionAuto(self.config)

lr_scheduler = paddle.optimizer.lr.LinearWarmup(
learning_rate=0.0001, warmup_steps=2, start_lr=0, end_lr=0.0001
Expand Down Expand Up @@ -133,7 +138,8 @@ def run_test_cases(self):
for epoch_idx in range(1):
for step, inputs in enumerate(train_dataloader):
input_ids, labels = inputs
tr_loss_step, _ = model(input_ids, labels=labels)
logits = model(input_ids)
tr_loss_step = criterion(logits, labels)

if self.gradient_accumulation_steps > 1:
tr_loss_step /= self.gradient_accumulation_steps
Expand All @@ -154,6 +160,51 @@ def run_test_cases(self):
if global_step // self.gradient_accumulation_steps >= 10:
break

def run_dy2static(self):
model = LlamaForCausalLMAuto(self.config)
criterion = LlamaPretrainingCriterionAuto(self.config)

lr_scheduler = paddle.optimizer.lr.LinearWarmup(
learning_rate=0.0001, warmup_steps=2, start_lr=0, end_lr=0.0001
)
optimizer = create_optimizer(model, lr_scheduler)
optimizer = dist.shard_optimizer(optimizer)

train_dataset = RandomDataset(self.config.seq_length)
train_sampler = BatchSampler(
train_dataset,
batch_size=2,
shuffle=True,
drop_last=True,
)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=0,
)

if isinstance(optimizer, dist.auto_parallel.api._ShardOptimizer):
opt = optimizer._inner_opt
else:
opt = optimizer

dist_model, dist_loader = dist.to_static(
model, train_dataloader, criterion, opt
)

dist_model.train()
for step, inputs in enumerate(dist_loader()):
input_ids, labels = inputs
loss = dist_model(input_ids, labels)
print(step, loss)

if step >= 10:
break

def run_test_cases(self):
self.run_dynamic()
self.run_dy2static()


if __name__ == '__main__':
TestLlamaAuto().run_test_cases()
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def forward(
attn_output = outputs

attn_output = self.o_proj(attn_output)

# TODO add should be in SP region
if self.config.sequence_parallel:
attn_output = paddle.transpose(attn_output, [1, 0, 2])
attn_output = dist.reshard(
Expand Down Expand Up @@ -501,7 +501,7 @@ def _prepare_decoder_attention_mask(
combined_attention_mask = dist.shard_tensor(
combined_attention_mask,
get_mesh(),
[dist.Shard(0), dist.Replicate()],
[dist.Replicate(), dist.Replicate()],
)
expanded_attn_mask = (
expanded_attn_mask & combined_attention_mask
Expand Down Expand Up @@ -582,7 +582,7 @@ def forward(
(batch_size, seq_length)
)
position_ids = dist.shard_tensor(
position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()]
position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()]
)

if self.config.sequence_parallel:
Expand All @@ -607,7 +607,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

pre_ipp = 0
pre_ipp = None
for idx, (decoder_layer) in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand Down Expand Up @@ -708,6 +708,9 @@ def __init__(self, config):
)

def forward(self, prediction_scores, masked_lm_labels):
masked_lm_labels = dist.shard_tensor(
masked_lm_labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]
)
masked_lm_loss = self.loss_func(
prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)
)
Expand All @@ -732,7 +735,7 @@ def __init__(self, config):

self.llama = LlamaModelAuto(config)
self.lm_head = LlamaLMHeadAuto(config)
self.criterion = LlamaPretrainingCriterionAuto(config)
# self.criterion = LlamaPretrainingCriterionAuto(config)

def forward(
self,
Expand Down Expand Up @@ -770,25 +773,27 @@ def forward(

hidden_states = outputs[0] # [bs, seq_len, dim]

# if labels is None,means we need full output, instead of tensor_parallel_output
if self.config.sequence_parallel:
hidden_states = dist.reshard(
hidden_states, get_mesh(-1), [dist.Shard(1), dist.Replicate()]
)
# [S, B, H] -> [B, S, H]
hidden_states = paddle.transpose(hidden_states, [1, 0, 2])
# if labels is None,means we need full output, instead of tensor_parallel_output

logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
labels.stop_gradient = True
labels = dist.shard_tensor(
labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]
)
loss = self.criterion(logits, labels)
# loss = None
# if labels is not None:
# labels.stop_gradient = True
# labels = dist.shard_tensor(
# labels, get_mesh(-1), [dist.Shard(0), dist.Replicate()]
# )
# loss = self.criterion(logits, labels)

output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
# output = (logits,) + outputs[1:]
# return (loss,) + output if loss is not None else output
return logits


def _expand_2d_mask(mask, dtype, tgt_length):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_simple_net_hybrid_strategy(self):
class TestSemiAutoParallelLlama3D(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=8, timeout=200, nnode=1)
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"}
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "1"}
self._changeable_envs = {
"backend": ["gpu"],
"use_sp": ["true", "false"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def run_test(self):
# not prepared
# NOTE: This use is not recommended, only for the test. In normal
# use, DistModel is generated by dist.to_static.

dist_model._engine._has_prepared["train"] = False
dist_model._engine._has_prepared["eval"] = False
dist_model._engine._has_prepared["predict"] = False
Expand Down
Loading

0 comments on commit 70c4d21

Please sign in to comment.