Skip to content

Commit

Permalink
Enabled Eager Dygraph AutoCodeGen for 500+ existing ops (#37753)
Browse files Browse the repository at this point in the history
* Handled dispensable tensors in AutoCodeGen for Eager Dygraph

* Enabled Eager Dygraph AutoCodeGen for 500+ existing ops
  • Loading branch information
jim19930609 authored Dec 2, 2021
1 parent 7094251 commit 9ecb746
Show file tree
Hide file tree
Showing 6 changed files with 538 additions and 15 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/eager/auto_code_generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ if(WIN32)
endif()

add_custom_target(eager_codegen
COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt"
DEPENDS ${EAGER_CODEGEN_DEPS}
VERBATIM)
else()
add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt"
DEPENDS eager_generator
VERBATIM)
endif()
40 changes: 31 additions & 9 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gflags/gflags.h>
#include <algorithm>
#include <fstream>
#include <iostream>
Expand All @@ -26,6 +27,9 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"

DEFINE_bool(generate_all, false,
"Generate all operators currently registered in Paddle");

static std::unordered_set<std::string> operators_to_skip = {
"fused_elemwise_add_activation", // No Default Attr
"fused_elemwise_activation", // No Default Attr
Expand All @@ -40,12 +44,10 @@ static std::unordered_set<std::string> operators_to_skip = {
"pull_box_sparse",
"fused_attention",
"diag_v2",
};

static std::unordered_set<std::string> operators_to_codegen = {
"sigmoid", "matmul_v2", "reduce_sum", "elementwise_add",
"share_buffer", "var_conv_2d", "split"};
"transfer_dtype",
"c_split"};

static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {};

namespace paddle {
Expand Down Expand Up @@ -353,7 +355,10 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now
VLOG(1) << "------ Analyzing Op ------: " << op_type;

if (!operators_to_codegen.count(op_type)) return false;
if (!FLAGS_generate_all) {
if (!operators_to_codegen.count(op_type)) return false;
}

if (operators_to_skip.count(op_type)) return false;

return true;
Expand Down Expand Up @@ -976,7 +981,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::ConstructDuplicableOutput(%s) },";
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },";
outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE,
output_name, outnum);
} else {
Expand Down Expand Up @@ -1253,7 +1258,7 @@ static std::string GenerateGradNodeCCContents(

if (duplicable_input_name_set.count(fwd_input_name)) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::ConstructDuplicableOutput( "
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
Expand Down Expand Up @@ -1639,13 +1644,30 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
} // namespace framework
} // namespace paddle

static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}

int main(int argc, char* argv[]) {
if (argc != 2) {
if (argc != 3) {
std::cerr << "argc must be 2" << std::endl;
return -1;
}

std::string eager_root = argv[1];
std::string op_list_path = argv[2];

CollectOperatorsToCodeGen(op_list_path);
paddle::framework::DygraphCodeGeneration(eager_root);

return 0;
Expand Down
Loading

0 comments on commit 9ecb746

Please sign in to comment.