Skip to content

Commit

Permalink
To add RunOp method in ConstantFoldingPass
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 committed Dec 28, 2023
1 parent c705025 commit 274ffca
Showing 1 changed file with 23 additions and 27 deletions.
50 changes: 23 additions & 27 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name()
<< "] op";
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
auto output_var_names = RunOp(op, rewriter, place_);

// ParameterOp and ConstantTensorOp should be created in the top-level block
rewriter.SetInsertionPointToStart(
Expand Down Expand Up @@ -237,6 +224,27 @@ class ConstantFoldingPattern : public pir::RewritePattern {
}

protected:
std::vector<std::string> RunOp(
pir::Operation* op,
pir::PatternRewriter& rewriter,
phi::Place place) const { // NOLINT
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place);
paddle::framework::InterpreterCore core(
place, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
return output_var_names;
}

std::vector<std::string> BuildProgramFromOperation(
pir::Operation* op,
pir::Program* new_program,
Expand Down Expand Up @@ -340,20 +348,8 @@ class ConstantFoldingTrainingPattern : public ConstantFoldingPattern {
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name()
<< "] op";
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, phi::CPUPlace{});
paddle::framework::InterpreterCore core(
phi::CPUPlace{}, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
auto output_var_names = RunOp(op, rewriter, phi::CPUPlace{});

// ConstantTensorOp should be created in the top-level block
rewriter.SetInsertionPointToStart(
Expand Down

0 comments on commit 274ffca

Please sign in to comment.