Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward backward contiguous ones #57189

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,30 @@
"matmul_grad": {"x": "grad_y", "y": "grad_x"},
}

strided_op_list = {
"as_complex",
"as_real",
"as_strided",
"real",
"imag",
"diagonal",
"flatten",
"flatten_infer",
"reshape",
"slice",
"squeeze_infer",
"squeeze",
"strided_slice",
"strided_slice_raw",
"tensor_unfold",
"transpose",
"unbind",
"unsqueeze_infer",
"unsqueeze",
"view_shape",
"view_dtype",
}


#########
# Utils #
Expand Down Expand Up @@ -234,6 +258,9 @@ class {} : public egr::GradNodeBase {{
// Node Declaration
std::shared_ptr<{}> grad_node;

// Pre contiguous tensor in not strided op, if 1)require_any_grad=true; 2) need wrapper to backward; 3) not contiguous
{}

// Set grad_node before API Call
{}

Expand Down Expand Up @@ -409,6 +436,7 @@ class {} : public egr::GradNodeBase {{
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/api/lib/data_transform.h"

PHI_DECLARE_bool(check_nan_inf);
PHI_DECLARE_string(tensor_operants_mode);
Expand Down Expand Up @@ -984,6 +1012,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)

need_pre_contiguous_set = set()
# SetTensorWrappers
set_input_tensor_wrappers_list = []
set_output_tensor_wrappers_list = []
Expand All @@ -1007,12 +1036,30 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
{"indent": indent, "name": name}
)
else:
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
if (
(forward_api_name in strided_op_list)
or for_backward
or IsVectorTensorType(atype)
or (name in self.optional_inputs)
):
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
else:
need_pre_contiguous_set.add(name)
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name}_tmp);"
else:
if is_inplace_input:
set_tensor_wrappers = f"{indent}auto {name}_clone = paddle::experimental::assign({name});\n{indent}grad_node->SetTensorWrapper{name}({name}_clone);"
else:
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
if (
(forward_api_name in strided_op_list)
or for_backward
or IsVectorTensorType(atype)
or (name in self.optional_inputs)
):
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});"
else:
need_pre_contiguous_set.add(name)
set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}_tmp);"
set_input_tensor_wrappers_list.append(set_tensor_wrappers)
else: # Forwad's output as backward's input
if num_fwd_outputs > 1:
Expand All @@ -1032,6 +1079,24 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False):
set_output_tensor_wrappers_list
)

if (forward_api_name in strided_op_list) or for_backward:
self.inputs_call_list_tmp = None
self.node_creation_pre_contiguous_str = ""
else:
self.inputs_call_list_tmp = self.inputs_call_list
pre_contiguous_list = []
for name, (ttype, pos) in forward_inputs_position_map.items():
if name in need_pre_contiguous_set:
pre_contiguous_list.append(
f"{indent}const auto& {name}_tmp = (require_any_grad && {name}.is_dense_tensor() && !std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl())->meta().is_contiguous()) ? paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast<phi::DenseTensor>({name}.impl())))))) : {name};"
)
self.inputs_call_list_tmp[pos] = (
self.inputs_call_list_tmp[pos] + '_tmp'
)
self.node_creation_pre_contiguous_str = "\n".join(
pre_contiguous_list
)

# SetGradOutMeta & SetEdges
grad_node_out_list = []
set_grad_out_meta_list = []
Expand Down Expand Up @@ -1470,6 +1535,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
inputs_args_definition_str = ", ".join(inputs_args_definition_list)
inputs_call_args_str = ", ".join(inputs_call_list)
self.inputs_call_list = inputs_call_list

# Forward Full Logic
function_name = forward_api_name
Expand Down Expand Up @@ -1656,6 +1722,12 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
node_creation_str = self.node_creation_str
node_creation_before_call_str = self.node_creation_before_call_str
node_creation_after_call_str = self.node_creation_after_call_str
node_creation_pre_contiguous_str = (
self.node_creation_pre_contiguous_str
)
if self.inputs_call_list_tmp is not None:
inputs_call_args_str_tmp = ", ".join(self.inputs_call_list_tmp)
forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str_tmp});"

dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_ad_function_name = GetDygraphForwardFunctionName(
Expand Down Expand Up @@ -1767,6 +1839,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
before_log_str,
compute_require_grad_args_str,
self.grad_node_name,
node_creation_pre_contiguous_str,
node_creation_before_call_str,
forward_call_str,
check_nan_inf_str,
Expand Down