-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel] Add paddle.distributed.to_static api #59682
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
0a9798f
to
daee140
Compare
daee140
to
228e8fa
Compare
|
||
def __call__(self, *args): | ||
if self._mode is None: | ||
raise ValueError("Please call train()/eval()/predict() first.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default is train
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, default mode is set according to the init args.
# convert dygraph model to static model | ||
batch_size = loader.batch_sampler.batch_size | ||
inputs_spec, labels_spec = dist_model._engine._prepare_data_spec( | ||
loader.dataset, None, batch_size | ||
) | ||
|
||
if optimizer is not None and loss is not None: | ||
# get the static graph in train mode | ||
dist_model._engine.prepare( | ||
inputs_spec, labels_spec, mode="train", init_parameters=False | ||
) | ||
if loss is not None: | ||
# get the static graph in eval mode | ||
dist_model._engine.prepare( | ||
inputs_spec, labels_spec, mode="eval", init_parameters=False | ||
) | ||
# get the static graph in predict mode | ||
dist_model._engine.prepare( | ||
inputs_spec, None, mode="predict", init_parameters=False | ||
) | ||
|
||
# get DistributedDataLoader for static mode auto-parallelism | ||
batch_size = dist_model._engine._validate_batch_size(batch_size) | ||
dist_loader = dist_model._engine._prepare_dataloader( | ||
loader.dataset, return_list=True, batch_size=batch_size | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better move these lines into __init__ of DistModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
inputs_var = dist_context.serial_feed_vars["inputs"] | ||
labels_var = dist_context.serial_feed_vars["labels"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the feed is not called inputs
and labels
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputs
and labels
are the key names of the dict dist_context.serial_feed_vars
, not the names of model input and label.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
for name, param in named_params.items(): | ||
var = global_scope().var(name) | ||
dense_tensor = var.get_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
miss a filter that filter out not to share param not in this rank in PP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now the parameter initialization is moved to LayerHelper.init(), the filter is included in LayerHelper.init()
... ) | ||
>>> loss_fn = nn.MSELoss() | ||
|
||
>>> dist_model, dist_loader = dist.static_decorate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dist.static_decorate
should be dist.to_static
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
): | ||
""" | ||
Converts the model and data loader used in dygraph auto-parallelism to | ||
that in static mode auto-parallelism. static_decorate returns a DistModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_decorate
should be to_static
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
dist_model._engine._has_prepared["eval"] = True | ||
dist_model._engine._has_prepared["predict"] = True | ||
|
||
# python -m paddle.distributed.launch --devices=0,1 semi_auto_parallel_static_decorate_api.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
semi_auto_parallel_static_decorate_api.py
should be semi_auto_parallel_dist_to_static_api.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
np.testing.assert_allclose(dy_losses, dy2static_losses, rtol=1e-6) | ||
|
||
# python -m paddle.distributed.launch --devices=0,1 semi_auto_parallel_static_decorate_mlp.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
semi_auto_parallel_static_decorate_mlp.py
should be semi_auto_parallel_dist_to_static_mlp.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def to_static( | ||
layer: paddle.nn.Layer, | ||
loader=None, | ||
loss=None, | ||
optimizer=None, | ||
strategy=None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw in the design document that there is parameter of metrics
. Shall we need to implement metrics
which is not implemented here? If not, please explain the reason and modify the design document.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metrics
is not consisted in the original design, it is not used now. So removed it here. I will modify the design document.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for set_tests_properties(test_semi_auto_parallel_dist_to_static PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300)
""" | ||
DistModel is a wrapper of the network model for the use of static mode | ||
auto parallel. DistModel contains the distributed Graph of the model and | ||
offers the APIs for training, evaluation and prediction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to make this docstring more understandable ?
e.g.: very challenging for me to understand the term : "static mode auto parallel.", even after googling this phrase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will rephrase the docstring, update in next PR.
# Part2: DistTensor construction related APIs | ||
|
||
|
||
def to_static( | ||
layer: paddle.nn.Layer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to suggest not using type annotation for any arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update in next PR
""" | ||
dist_model = DistModel(layer, loader, loss, optimizer, strategy) | ||
dist_loader = dist_model.dist_loader | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the API name to_static
made me very confused, especially trying very hard to understand the relation to paddle.jit.to_static
.
from this 2 LOC implementation, I think this API basically is a creator API for DistModel
. I'd like to suggest a more intuitive API name, e.g.: dist_model_creator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function (or purpose) of this api is to convert a model whose parameters are Distributed Tensor (generated by shard_tensor
). I think to_static
is more suitable for this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR types
New features
PR changes
APIs
Description
Pcard-76459
Add paddle.distributed.to_static api and its returned class DistModel for converting dygraph auto parallel model to static mode.
paddle.distributed.to_static converts the model and data loader used in dygraph auto-parallelism to that in static mode auto-parallelism. It returns a DistModel instance that provides APIs and a DistributedDataLoader to generate data for static mode auto-parallel training, evaluation and prediction.
Doc: PaddlePaddle/docs#6357