Skip to content

Commit

Permalink
Enabled Eager Dygraph AutoCodeGen for 500+ existing ops
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Dec 1, 2021
1 parent c1ac58b commit 361742c
Show file tree
Hide file tree
Showing 6 changed files with 538 additions and 15 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/eager/auto_code_generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ if(WIN32)
endif()

add_custom_target(eager_codegen
COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt"
DEPENDS ${EAGER_CODEGEN_DEPS}
VERBATIM)
else()
add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt"
DEPENDS eager_generator
VERBATIM)
endif()
40 changes: 31 additions & 9 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gflags/gflags.h>
#include <algorithm>
#include <fstream>
#include <iostream>
Expand All @@ -26,6 +27,9 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"

DEFINE_bool(generate_all, false,
"Generate all operators currently registered in Paddle");

static std::unordered_set<std::string> operators_to_skip = {
"fused_elemwise_add_activation", // No Default Attr
"fused_elemwise_activation", // No Default Attr
Expand All @@ -40,12 +44,10 @@ static std::unordered_set<std::string> operators_to_skip = {
"pull_box_sparse",
"fused_attention",
"diag_v2",
};

static std::unordered_set<std::string> operators_to_codegen = {
"sigmoid", "matmul_v2", "reduce_sum", "elementwise_add",
"share_buffer", "var_conv_2d", "split"};
"transfer_dtype",
"c_split"};

static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {};

namespace paddle {
Expand Down Expand Up @@ -353,7 +355,10 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now
VLOG(1) << "------ Analyzing Op ------: " << op_type;

if (!operators_to_codegen.count(op_type)) return false;
if (!FLAGS_generate_all) {
if (!operators_to_codegen.count(op_type)) return false;
}

if (operators_to_skip.count(op_type)) return false;

return true;
Expand Down Expand Up @@ -976,7 +981,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::ConstructDuplicableOutput(%s) },";
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },";
outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE,
output_name, outnum);
} else {
Expand Down Expand Up @@ -1253,7 +1258,7 @@ static std::string GenerateGradNodeCCContents(

if (duplicable_input_name_set.count(fwd_input_name)) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::ConstructDuplicableOutput( "
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
Expand Down Expand Up @@ -1639,13 +1644,30 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
} // namespace framework
} // namespace paddle

static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}

int main(int argc, char* argv[]) {
if (argc != 2) {
if (argc != 3) {
std::cerr << "argc must be 2" << std::endl;
return -1;
}

std::string eager_root = argv[1];
std::string op_list_path = argv[2];

CollectOperatorsToCodeGen(op_list_path);
paddle::framework::DygraphCodeGeneration(eager_root);

return 0;
Expand Down
Loading

1 comment on commit 361742c

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.