Skip to content

Commit

Permalink
Fix problem of Tensor double free and polish code.
Browse files Browse the repository at this point in the history
  • Loading branch information
GhostScreaming committed Oct 30, 2023
1 parent a024e6a commit 624648a
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 8 deletions.
34 changes: 31 additions & 3 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,24 @@ class {} : public egr::GradNodeBase {{
}}
"""

FILL_ZERO_GRAD_TEMPLATE_BACKWARD = """
if (!IsRunAutoParallel()) {{
egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);
}}
"""

FILL_ZERO_PLAIN_GRAD_TEMPLATE_BACKWARD = """
if (!IsRunAutoParallel()) {{
egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);
}}
"""

FILL_ZERO_OPTIONAL_PLAIN_GRAD_TEMPLATE_BACKWARD = """
if (!IsRunAutoParallel()) {{
egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);
}}
"""

inplace_optional_out_type_map = {
"Tensor": "paddle::optional<paddle::Tensor>&",
"std::vector<Tensor>": "paddle::optional<std::vector<paddle::Tensor>>&",
Expand Down Expand Up @@ -2224,12 +2242,22 @@ def GenerateNodeDefinition(
) in backward_grad_inputs_map.items():
if name in self.optional_inputs:
if IsPlainTensorType(ttype):
fill_zero_str += f"{indent}if (!IsRunAutoParallel()) {{\n{indent}{indent}egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);\n{indent}}}"
fill_zero_str += FILL_ZERO_OPTIONAL_PLAIN_GRAD_TEMPLATE_BACKWARD.format(
fwd_position=fwd_position
)
else:
if IsPlainTensorType(ttype):
fill_zero_str += f"{indent}if (!IsRunAutoParallel()) {{\n{indent}{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);\n{indent}}}"
fill_zero_str += (
FILL_ZERO_PLAIN_GRAD_TEMPLATE_BACKWARD.format(
fwd_position=fwd_position
)
)
else:
fill_zero_str += f"{indent}if (!IsRunAutoParallel()) {{\n{indent}{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n{indent}}}"
fill_zero_str += (
FILL_ZERO_GRAD_TEMPLATE_BACKWARD.format(
fwd_position=fwd_position
)
)

inplace_grad_input_str = ""
inplace_check_str = ""
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ PADDLE_API std::shared_ptr<phi::distributed::DistTensor> reshard(
"However it's %s",
typeid(input.impl().get()).name()));
auto dev_ctx = phi::distributed::GetDistTensorDeviceContext(
std::static_pointer_cast<phi::distributed::DistTensor>(input.impl()));
static_cast<phi::distributed::DistTensor*>(input.impl().get()));
auto input_tensor_impl = input.impl();
std::shared_ptr<phi::distributed::DistTensor> dist_out_ptr = nullptr;
if (input_tensor_impl) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@

# 9. Reshard Partial Output to Replicated
RESHARD_P2R_SINGLE_OUTPUT_TEMPLATE = """
dev_ctx = phi::distributed::GetDistTensorDeviceContext(std::shared_ptr<phi::distributed::DistTensor>(dist_out));
dev_ctx = phi::distributed::GetDistTensorDeviceContext(dist_out);
ReshardOutputPartialAxisToReplicated(dev_ctx, dist_out);"""
RESHARD_P2R_MULTI_SINGLE_OUTPUT_TEMPLATE = """
dev_ctx = phi::distributed::GetDistTensorDeviceContext(std::shared_ptr<phi::distributed::DistTensor>(dist_out_{idx}));
dev_ctx = phi::distributed::GetDistTensorDeviceContext(dist_out_{idx});
ReshardOutputPartialAxisToReplicated(dev_ctx, dist_out_{idx});"""
UNSUPPORTED_RESHARD_OUTPUT_COMMENT_TEMPLATE = """
// API `{}` does not need to support ReshardOutput now."""
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ Place GetDefaultPlace() {
}

phi::DeviceContext* GetDistTensorDeviceContext(
const std::shared_ptr<phi::distributed::DistTensor>& input) {
phi::distributed::DistTensor* input) {
// TODO(GhostScreaming): pipeline parallel may create an undefined middle grad
// tensor. In such case, we need to get default place.
auto place = input && input->defined() ? input->place() : GetDefaultPlace();
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/distributed/auto_parallel/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ bool NeedComputationClipForPP(
Place GetDefaultPlace();

phi::DeviceContext* GetDistTensorDeviceContext(
const std::shared_ptr<phi::distributed::DistTensor>& input);
phi::distributed::DistTensor* input);

int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids,
int64_t global_rank = -1);
Expand Down

0 comments on commit 624648a

Please sign in to comment.