Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Apr 5, 2022
1 parent 1ffaf63 commit 518d60d
Showing 1 changed file with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1257,11 +1257,18 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
name)

is_optional = (name in self.optional_inputs)
optional_suffix = '_optional' if is_optional else ''
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name}{optional_suffix} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());\n"
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
if is_optional:
tensor_wrapper_recover_str = tensor_wrapper_recover_str + f"{indent}auto {transformed_tensor_name} = {transformed_tensor_name}{optional_suffix}.initialized() ? paddle::make_optional<const paddle::experimental::Tensor&>({transformed_tensor_name}{optional_suffix}) : paddle::none;"
grad_api_args[grad_api_position] = transformed_tensor_name
tensor_wrapper_recover_str += "\n" + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format(
transformed_tensor_name, transformed_tensor_name,
transformed_tensor_name, transformed_tensor_name)

grad_api_args[
grad_api_position] = transformed_tensor_name + "_optional"

else:
grad_api_args[grad_api_position] = transformed_tensor_name

get_grad_in_args_list.append(tensor_wrapper_recover_str)

# Grad Ins from grads
Expand Down

1 comment on commit 518d60d

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.