-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel] Add paddle.distributed.to_static api #59682
Changes from 7 commits
45aef18
a570331
7526459
76d2c0a
2d5ba39
e350af0
228e8fa
f7d0d66
4a4e8d6
1dc7187
049ba6a
95ea711
45b87b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import paddle.distributed as dist | ||
from paddle import nn | ||
from paddle.base.framework import EagerParamBase | ||
from paddle.distributed.auto_parallel import Engine | ||
from paddle.distributed.auto_parallel.interface import ( | ||
shard_tensor as shard_tensor_static, | ||
) | ||
|
@@ -90,9 +91,263 @@ def sharding_specs(self): | |
return self._sharding_specs | ||
|
||
|
||
class DistModel: | ||
""" | ||
DistModel is a wrapper of the network model for the use of static mode | ||
auto parallel. DistModel contains the distributed Graph of the model and | ||
offers the APIs for training, evaluation and prediction. | ||
|
||
Please first set the DistModel to "train", "eval" or "predict" mode and | ||
then use the __call__ method for training, evaluation and prediction | ||
respectively. | ||
|
||
In "train" mode, executing ``__call__`` will update the parameters | ||
of the model and return the loss. In "eval" mode, executing ``__call__`` | ||
will return the loss. In "predict" mode, executing ``__call__`` returns a | ||
dict that contains the outputs of the model, where the value of "out0" is | ||
the first output. | ||
|
||
DistModel is generated by ``static_decorate``, for more details of the | ||
usage, please refer to the sample code in ``static_decorate``. | ||
""" | ||
|
||
def __init__( | ||
self, layer, loss=None, optimizer=None, strategy=None, metrics=None | ||
): | ||
self._feed_name_list = [] | ||
self._engine = Engine( | ||
layer, loss, optimizer, metrics, strategy=strategy | ||
) | ||
self._mode = None | ||
self._feed_name_list = {} | ||
|
||
def train(self): | ||
""" | ||
Set the mode of DistModel to "train". | ||
""" | ||
if not self._engine._has_prepared["train"]: | ||
raise RuntimeError("The model for training has not been prepared.") | ||
self._mode = "train" | ||
self._engine.to_mode("train") | ||
|
||
def eval(self): | ||
""" | ||
Set the mode of DistModel to "eval". | ||
""" | ||
if not self._engine._has_prepared["eval"]: | ||
raise RuntimeError( | ||
"The model for evaluation has not been prepared." | ||
) | ||
self._mode = "eval" | ||
self._engine.to_mode("eval") | ||
|
||
def predict(self): | ||
""" | ||
Set the mode of DistModel to "predict". | ||
""" | ||
if not self._engine._has_prepared["predict"]: | ||
raise RuntimeError( | ||
"The model for prediction has not been prepared." | ||
) | ||
self._mode = "predict" | ||
self._engine.to_mode("predict") | ||
|
||
def _make_feeds(self, data_list): | ||
if ( | ||
self._mode not in self._feed_name_list | ||
or self._feed_name_list[self._mode] == [] | ||
): | ||
feed_list = self._engine.get_feed_list() | ||
self._feed_name_list[self._mode] = [var.name for var in feed_list] | ||
feed_name_list = self._feed_name_list[self._mode] | ||
if len(feed_name_list) != len(data_list): | ||
raise ValueError( | ||
"The input data and feed_list are not consistent." | ||
"The model takes %s as input" % (str(feed_name_list)) | ||
) | ||
return dict(zip(feed_name_list, data_list)) | ||
|
||
def __call__(self, *args): | ||
if self._mode is None: | ||
raise ValueError("Please call train()/eval()/predict() first.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default is train There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, default mode is set according to the init args. |
||
if self._mode == "train": | ||
if self._engine._optimizer is None or self._engine._loss is None: | ||
raise ValueError( | ||
"Please set optimizer and loss function before training." | ||
) | ||
if self._mode == "eval": | ||
if self._engine._loss is None: | ||
raise ValueError("Please set loss function before evaluation.") | ||
feeds = self._make_feeds(list(args)) | ||
outs = self._engine.run(feeds) | ||
if self._mode == "predict": | ||
return outs["outputs"] | ||
else: | ||
return outs["loss"] | ||
|
||
|
||
# Part2: DistTensor construction related APIs | ||
|
||
|
||
def static_decorate( | ||
layer: paddle.nn.Layer, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to suggest not using type annotation for any arguments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update in next PR |
||
loader=None, | ||
loss=None, | ||
optimizer=None, | ||
strategy=None, | ||
): | ||
""" | ||
Converts the model and data loader used in dygraph auto-parallelism to | ||
that in static mode auto-parallelism. static_decorate returns a DistModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
instance that provides APIs and a DistributedDataLoader to generate data | ||
for static mode auto-parallel training, evaluation and prediction. | ||
|
||
Args: | ||
layer(paddle.nn.Layer): The layer in dygraph model, the parameters | ||
or its inputs can be sharded. | ||
loader(paddle.io.DataLoader): The data loader used in dygraph model, | ||
used to generate Distributed Dataloader for static auto parallel. | ||
loss(Loss|Callable|None, optional): The loss function for training | ||
or evaluating the model. Can be a `paddle.nn.Layer` instance or | ||
any callable function. Default: None. | ||
optimizer(paddle.optimizer.Optimizer|None, optional): The optimizer | ||
for training. Default: None. | ||
strategy(Strategy|None, optional): Configs for parallel strategies | ||
(e.g. data parallel, hybrid parallel etc.) and optimization | ||
settings (e.g. mixed-precision). Default: None. | ||
|
||
Returns: | ||
DistModel: A DistModel tha contains corresponding computational graph | ||
for the input layer and provides APIs for training, evaluation and | ||
prediction. | ||
DistributedDataLoader: An optimized data loader that can be used | ||
to generate data. | ||
|
||
Examples: | ||
.. code-block:: python | ||
>>> import numpy as np | ||
>>> import paddle | ||
>>> import paddle.distributed as dist | ||
>>> from paddle import nn | ||
>>> from paddle.distributed import Replicate, Shard | ||
|
||
>>> BATCH_SIZE = 4 | ||
>>> BATCH_NUM = 4 | ||
>>> IMAGE_SIZE = 16 | ||
>>> CLASS_NUM = 8 | ||
>>> class RandomDataset(paddle.io.Dataset): | ||
... def __init__(self, images, labels, num_samples): | ||
... self.images = images | ||
... self.labels = labels | ||
... self.num_samples = num_samples | ||
... def __getitem__(self, idx): | ||
... return self.images[idx], self.labels[idx] | ||
... def __len__(self): | ||
... return self.num_samples | ||
|
||
>>> class DemoNet(nn.Layer): | ||
... def __init__(self, mesh): | ||
... super().__init__() | ||
... self._mesh = mesh | ||
... self.linear_0 = nn.Linear(IMAGE_SIZE, IMAGE_SIZE) | ||
... self.linear_1 = nn.Linear(IMAGE_SIZE, CLASS_NUM) | ||
... self.relu = nn.ReLU() | ||
... # shard the weights of this layer | ||
... self.linear_0.weight = dist.shard_tensor( | ||
... self.linear_0.weight, | ||
... self._mesh, | ||
... [Shard(1)], | ||
... stop_gradient=False, | ||
... ) | ||
... self.linear_1.weight = dist.shard_tensor( | ||
... self.linear_1.weight, | ||
... self._mesh, | ||
... [Shard(0)], | ||
... stop_gradient=False, | ||
... ) | ||
... def forward(self, x): | ||
... out = self.linear_0(x) | ||
... out = self.relu(out) | ||
... out = self.linear_1(out) | ||
... return out | ||
|
||
>>> images = np.random.rand(BATCH_SIZE, IMAGE_SIZE).astype('float32') | ||
>>> labels = np.random.rand(BATCH_SIZE, CLASS_NUM).astype('float32') | ||
>>> dataset = RandomDataset(images, labels, BATCH_SIZE) | ||
>>> loader = paddle.io.DataLoader(dataset, batch_size=BATCH_SIZE) | ||
|
||
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||
>>> layer = DemoNet(mesh) | ||
>>> opt = paddle.optimizer.SGD( | ||
... learning_rate=0.1, parameters=layer.parameters() | ||
... ) | ||
>>> loss_fn = nn.MSELoss() | ||
|
||
>>> dist_model, dist_loader = dist.static_decorate( | ||
... layer, loader, loss_fn, opt | ||
... ) | ||
|
||
>>> # training | ||
>>> dist_model.train() | ||
>>> for batch_id, (image, label) in enumerate(dist_loader()): | ||
... # in train mode, executing the __call__ method will | ||
... # update the parameters of the model and return the | ||
... # loss | ||
... loss = dist_model(image, label) | ||
|
||
>>> # evaluation | ||
>>> dist_model.eval() | ||
>>> for batch_id, (image, label) in enumerate(dist_loader()): | ||
... # in eval mode, executing the __call__ method will | ||
... # return the loss | ||
... loss = dist_model(image, label) | ||
... print(loss) | ||
|
||
>>> # prediction | ||
>>> dist_model.predict() | ||
>>> for batch_id, (image, label) in enumerate(dist_loader()): | ||
... # in predict mode, executing the __call__ method will | ||
... # return a dict that contains the outputs of the model, | ||
... # where the value of "out0" is the first output. | ||
... outs = dist_model(image) | ||
... print(outs['out0']) | ||
|
||
>>> # This case need to be excuted in multi-card environment | ||
>>> # export CUDA_VISIBLE_DEVICES=0,1 | ||
>>> # python -m paddle.distributed.launch {test_case}.py | ||
""" | ||
dist_model = DistModel(layer, loss, optimizer, strategy) | ||
|
||
# convert dygraph model to static model | ||
batch_size = loader.batch_sampler.batch_size | ||
inputs_spec, labels_spec = dist_model._engine._prepare_data_spec( | ||
loader.dataset, None, batch_size | ||
) | ||
|
||
if optimizer is not None and loss is not None: | ||
# get the static graph in train mode | ||
dist_model._engine.prepare( | ||
inputs_spec, labels_spec, mode="train", init_parameters=False | ||
) | ||
if loss is not None: | ||
# get the static graph in eval mode | ||
dist_model._engine.prepare( | ||
inputs_spec, labels_spec, mode="eval", init_parameters=False | ||
) | ||
# get the static graph in predict mode | ||
dist_model._engine.prepare( | ||
inputs_spec, None, mode="predict", init_parameters=False | ||
) | ||
|
||
# get DistributedDataLoader for static mode auto-parallelism | ||
batch_size = dist_model._engine._validate_batch_size(batch_size) | ||
dist_loader = dist_model._engine._prepare_dataloader( | ||
loader.dataset, return_list=True, batch_size=batch_size | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better move these lines into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the API name from this 2 LOC implementation, I think this API basically is a creator API for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function (or purpose) of this api is to convert a model whose parameters are Distributed Tensor (generated by |
||
return dist_model, dist_loader | ||
|
||
|
||
def shard_tensor( | ||
data, mesh, placements, dtype=None, place=None, stop_gradient=True | ||
): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -812,13 +812,27 @@ def _init_comm(self): | |
def _share_parameters(self): | ||
# mapping from {variable -> parameter} | ||
named_params = self.program_helper.named_parameters() | ||
dist_context = self._dist_contexts[self._mode] | ||
dist_main_program = dist_context.dist_main_programs[self._cur_rank] | ||
|
||
for name, param in named_params.items(): | ||
var = global_scope().var(name) | ||
dense_tensor = var.get_tensor() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. miss a filter that filter out not to share param not in this rank in PP. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now the parameter initialization is moved to LayerHelper.init(), the filter is included in LayerHelper.init() |
||
if param.is_dense(): | ||
dense_tensor = var.get_tensor() | ||
dense_tensor._share_data_with(param.get_tensor()) | ||
var_in_program = dist_main_program.global_block().vars[name] | ||
var_dist_attr = dist_context.get_tensor_dist_attr_for_program( | ||
var_in_program | ||
) | ||
dict_dist_attr = { | ||
"dims_mapping": var_dist_attr.dims_mapping, | ||
"process_shape": var_dist_attr.process_mesh.shape, | ||
"process_group": var_dist_attr.process_mesh.process_ids, | ||
} | ||
sliced_param = Converter.slice_with_dist_attr( | ||
param.numpy(), dict_dist_attr | ||
) | ||
dense_tensor.set(sliced_param, self._place) | ||
elif param.is_dist(): | ||
dense_tensor = var.get_tensor() | ||
dense_tensor._share_data_with(param.get_tensor().get_tensor()) | ||
|
||
def _initialize(self, mode, init_parameters=True): | ||
|
@@ -1545,6 +1559,29 @@ def run(self, data=None, feed=None, fetch_list=None, mode=None): | |
) | ||
return logs | ||
|
||
def get_feed_list(self): | ||
dist_context = self._dist_contexts[self._mode] | ||
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] | ||
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] | ||
dist_main_block = dist_main_prog.global_block() | ||
|
||
# NOTE: Get feed_list, then insert dataloader op with sharded var shape. | ||
# Cause predict_program does not contain labels var, | ||
# then we will add labels var from serial_program to dist_program, | ||
# that maintains the length of feed_list equal to the length of dataset's values. | ||
inputs_var = dist_context.serial_feed_vars["inputs"] | ||
labels_var = dist_context.serial_feed_vars["labels"] | ||
Comment on lines
+1543
to
+1544
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if the feed is not called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
feed_list = [] | ||
for var in inputs_var + labels_var: | ||
if var.name in dist_main_block.vars: | ||
feed_list.append(dist_main_block.vars[var.name]) | ||
else: | ||
copy_var = dist_main_block._clone_variable(var, var.persistable) | ||
copy_var.desc.set_original_id(var.desc.original_id()) | ||
feed_list.append(copy_var) | ||
|
||
return feed_list | ||
|
||
def _prepare_dataloader( | ||
self, | ||
dataset, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to make this docstring more understandable ?
e.g.: very challenging for me to understand the term : "static mode auto parallel.", even after googling this phrase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will rephrase the docstring, update in next PR.