From f027b2ad964ea84051c39703e60738b4e10c811e Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 25 Mar 2022 16:27:20 +0800 Subject: [PATCH] [Refactor] refactored eager_gen.py PR #2 (#40907) --- .../final_state_generator/codegen_utils.py | 10 +- .../final_state_generator/eager_gen.py | 109 ++++++++++++------ 2 files changed, 78 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index 56cbc05b1a9834..4699fe73552c7e 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -50,6 +50,10 @@ ############################# ### File Reader Helpers ### ############################# +def AssertMessage(lhs_str, rhs_str): + return f"lhs: {lhs_str}, rhs: {rhs_str}" + + def ReadFwdFile(filepath): f = open(filepath, 'r') contents = yaml.load(f, Loader=yaml.FullLoader) @@ -62,10 +66,10 @@ def ReadBwdFile(filepath): contents = yaml.load(f, Loader=yaml.FullLoader) ret = {} for content in contents: + assert 'backward_api' in content.keys(), AssertMessage('backward_api', + content.keys()) if 'backward_api' in content.keys(): api_name = content['backward_api'] - else: - assert False ret[api_name] = content f.close() @@ -225,7 +229,7 @@ def ParseYamlReturns(string): ), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping." ret_type = yaml_types_mapping[ret_type] - assert "Tensor" in ret_type + assert "Tensor" in ret_type, AssertMessage("Tensor", ret_type) ret_name = RemoveSpecialSymbolsInName(ret_name) returns_list.append([ret_name, ret_type, i]) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index fd750c0d07369e..cd59211f02f3bb 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -16,6 +16,7 @@ import re import argparse import os +import logging from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info from codegen_utils import yaml_types_mapping from codegen_utils import ReadFwdFile, ReadBwdFile @@ -30,6 +31,7 @@ from codegen_utils import ParseYamlForward, ParseYamlBackward from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import ops_to_fill_zero_for_empty_grads +from codegen_utils import AssertMessage ########### @@ -398,14 +400,21 @@ def DygraphYamlValidationCheck(self): forward_api_contents = self.forward_api_contents grad_api_contents = self.grad_api_contents - assert 'api' in forward_api_contents.keys() - assert 'args' in forward_api_contents.keys() - assert 'output' in forward_api_contents.keys() - assert 'backward' in forward_api_contents.keys() - - assert 'args' in grad_api_contents.keys() - assert 'output' in grad_api_contents.keys() - assert 'forward' in grad_api_contents.keys() + assert 'api' in forward_api_contents.keys( + ), "Unable to find \"api\" in api.yaml" + assert 'args' in forward_api_contents.keys( + ), "Unable to find \"args\" in api.yaml" + assert 'output' in forward_api_contents.keys( + ), "Unable to find \"output\" in api.yaml" + assert 'backward' in forward_api_contents.keys( + ), "Unable to find \"backward\" in api.yaml" + + assert 'args' in grad_api_contents.keys( + ), "Unable to find \"args\" in backward.yaml" + assert 'output' in grad_api_contents.keys( + ), "Unable to find \"output\" in backward.yaml" + assert 'forward' in grad_api_contents.keys( + ), "Unable to find \"forward\" in backward.yaml" def ForwardsValidationCheck(self): forward_inputs_list = self.forward_inputs_list @@ -424,8 +433,10 @@ def ForwardsValidationCheck(self): orig_input_type = orig_forward_inputs_list[i][1] orig_input_pos = orig_forward_inputs_list[i][2] - assert forward_input_type == orig_input_type - assert forward_input_pos == orig_input_pos + assert forward_input_type == orig_input_type, AssertMessage( + forward_input_type, orig_input_type) + assert forward_input_pos == orig_input_pos, AssertMessage( + forward_input_pos, orig_input_pos) for i in range(len(forward_attrs_list)): orig_attr_name = orig_forward_attrs_list[i][0] @@ -436,9 +447,12 @@ def ForwardsValidationCheck(self): forward_attr_type = forward_attrs_list[i][1] forward_attr_default = forward_attrs_list[i][2] forward_attr_pos = forward_attrs_list[i][3] - assert orig_attr_type == forward_attr_type - assert orig_attr_default == forward_attr_default - assert orig_attr_pos == forward_attr_pos + assert orig_attr_type == forward_attr_type, AssertMessage( + orig_attr_type, forward_attr_type) + assert orig_attr_default == forward_attr_default, AssertMessage( + orig_attr_default, forward_attr_default) + assert orig_attr_pos == forward_attr_pos, AssertMessage( + orig_attr_pos, forward_attr_pos) for i in range(len(forward_returns_list)): orig_return_type = orig_forward_returns_list[i][1] @@ -446,8 +460,10 @@ def ForwardsValidationCheck(self): forward_return_type = forward_returns_list[i][1] forward_return_pos = forward_returns_list[i][2] - assert orig_return_type == forward_return_type - assert orig_return_pos == forward_return_pos + assert orig_return_type == forward_return_type, AssertMessage( + orig_return_type, forward_return_type) + assert orig_return_pos == forward_return_pos, AssertMessage( + orig_return_pos, forward_return_pos) # Check Order: Inputs, Attributes max_input_position = -1 @@ -456,7 +472,8 @@ def ForwardsValidationCheck(self): max_attr_position = -1 for _, _, _, pos in forward_attrs_list: - assert pos > max_input_position + assert pos > max_input_position, AssertMessage(pos, + max_input_position) max_attr_position = max(max_attr_position, pos) def BackwardValidationCheck(self): @@ -471,12 +488,14 @@ def BackwardValidationCheck(self): max_grad_tensor_position = -1 for _, (_, _, pos) in backward_grad_inputs_map.items(): - assert pos > max_fwd_input_position + assert pos > max_fwd_input_position, AssertMessage( + pos, max_grad_tensor_position) max_grad_tensor_position = max(max_grad_tensor_position, pos) max_attr_position = -1 for _, _, _, pos in backward_attrs_list: - assert pos > max_grad_tensor_position + assert pos > max_grad_tensor_position, AssertMessage( + pos, max_grad_tensor_position) max_attr_position = max(max_attr_position, pos) def IntermediateValidationCheck(self): @@ -491,7 +510,8 @@ def IntermediateValidationCheck(self): len(forward_returns_list)) for ret_name, _, pos in forward_returns_list: if ret_name in intermediate_outputs: - assert pos in intermediate_positions + assert pos in intermediate_positions, AssertMessage( + pos, intermediate_positions) def CollectBackwardInfo(self): forward_api_contents = self.forward_api_contents @@ -505,9 +525,12 @@ def CollectBackwardInfo(self): self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward( backward_args_str, backward_returns_str) - print("Parsed Backward Inputs List: ", self.backward_inputs_list) - print("Prased Backward Attrs List: ", self.backward_attrs_list) - print("Parsed Backward Returns List: ", self.backward_returns_list) + + logging.info( + f"Parsed Backward Inputs List: {self.backward_inputs_list}") + logging.info(f"Prased Backward Attrs List: {self.backward_attrs_list}") + logging.info( + f"Parsed Backward Returns List: {self.backward_returns_list}") def CollectForwardInfoFromBackwardContents(self): @@ -530,7 +553,9 @@ def SlotNameMatching(self): backward_fwd_name = FindForwardName(backward_input_name) if backward_fwd_name: # Grad Input - assert backward_fwd_name in forward_outputs_position_map.keys() + assert backward_fwd_name in forward_outputs_position_map.keys( + ), AssertMessage(backward_fwd_name, + forward_outputs_position_map.keys()) matched_forward_output_type = forward_outputs_position_map[ backward_fwd_name][0] matched_forward_output_pos = forward_outputs_position_map[ @@ -556,7 +581,7 @@ def SlotNameMatching(self): backward_input_type, False, backward_input_pos ] else: - assert False, backward_input_name + assert False, f"Cannot find {backward_input_name} in forward position map" for backward_output in backward_returns_list: backward_output_name = backward_output[0] @@ -564,9 +589,10 @@ def SlotNameMatching(self): backward_output_pos = backward_output[2] backward_fwd_name = FindForwardName(backward_output_name) - assert backward_fwd_name is not None + assert backward_fwd_name is not None, f"Detected {backward_fwd_name} = None" assert backward_fwd_name in forward_inputs_position_map.keys( - ), f"Unable to find {backward_fwd_name} in forward inputs" + ), AssertMessage(backward_fwd_name, + forward_inputs_position_map.keys()) matched_forward_input_type = forward_inputs_position_map[ backward_fwd_name][0] @@ -577,12 +603,15 @@ def SlotNameMatching(self): backward_output_type, matched_forward_input_pos, backward_output_pos ] - print("Generated Backward Fwd Input Map: ", - self.backward_forward_inputs_map) - print("Generated Backward Grad Input Map: ", - self.backward_grad_inputs_map) - print("Generated Backward Grad Output Map: ", - self.backward_grad_outputs_map) + logging.info( + f"Generated Backward Fwd Input Map: {self.backward_forward_inputs_map}" + ) + logging.info( + f"Generated Backward Grad Input Map: {self.backward_grad_inputs_map}" + ) + logging.info( + f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}" + ) def GenerateNodeDeclaration(self): forward_op_name = self.forward_api_name @@ -642,7 +671,7 @@ def GenerateNodeDeclaration(self): set_tensor_wrapper_methods_str, set_attribute_methods_str, tensor_wrapper_members_str, attribute_members_str) - print("Generated Node Declaration: ", self.node_declaration_str) + logging.info(f"Generated Node Declaration: {self.node_declaration_str}") def GenerateNodeDefinition(self): namespace = self.namespace @@ -710,7 +739,7 @@ def GenerateNodeDefinition(self): grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, backward_api_name, grad_api_args_str, returns_str) - print("Generated Node Definition: ", self.node_definition_str) + logging.info(f"Generated Node Definition: {self.node_definition_str}") def GenerateForwardDefinition(self, is_inplaced): namespace = self.namespace @@ -813,8 +842,10 @@ def GenerateForwardDefinition(self, is_inplaced): dygraph_event_str, node_creation_str, returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" - print("Generated Forward Definition: ", self.forward_definition_str) - print("Generated Forward Declaration: ", self.forward_declaration_str) + logging.info( + f"Generated Forward Definition: {self.forward_definition_str}") + logging.info( + f"Generated Forward Declaration: {self.forward_declaration_str}") def GenerateNodeCreationCodes(self, forward_call_str): forward_api_name = self.forward_api_name @@ -921,7 +952,8 @@ def GenerateNodeCreationCodes(self, forward_call_str): else: if num_fwd_outputs > 1: # Aligned with forward output position - assert name in forward_outputs_position_map.keys() + assert name in forward_outputs_position_map.keys( + ), AssertMessage(name, forward_outputs_position_map.keys()) fwd_output_pos = forward_outputs_position_map[name][1] tw_name = f"std::get<{fwd_output_pos}>(api_result)" else: @@ -1114,7 +1146,8 @@ def GetBackwardAPIContents(self, forward_api_contents): if 'backward' not in forward_api_contents.keys(): return None backward_api_name = forward_api_contents['backward'] - assert backward_api_name in grad_api_dict.keys() + assert backward_api_name in grad_api_dict.keys(), AssertMessage( + backward_api_name, grad_api_dict.keys()) backward_api_contents = grad_api_dict[backward_api_name] return backward_api_contents