Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] rewrite convert_to_mixed_precision #48853

Merged
merged 9 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions paddle/fluid/framework/ir/float_to_half_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ bool GpuKernelSupportPrecision(
return support;
}

inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}

inline bool IsFloatType(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32);
}

inline bool IsHalfType(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16);
}

}; // namespace

void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
Expand Down Expand Up @@ -118,23 +135,19 @@ void DoInsertCastOp(Graph* graph,
IR_NODE_UNLINK(var_node, op_node);
}

inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}

inline bool IsFloatType(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32);
}

inline bool IsHalfType(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16);
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list) {
bool support = false;
if (black_list.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
support = GpuKernelSupportPrecision(op_type, precision);
}
}
return support;
}

}; // namespace

// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
Expand Down Expand Up @@ -172,10 +185,17 @@ void FloatToHalfPass::SetDefaultBlacklist() const {

void FloatToHalfPass::Init(Graph* graph) const {
keep_io_types_ = true;
if (Has("keep_io_types")) {
keep_io_types_ = Get<bool>("keep_io_types");
}
half_precision_ =
static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist();
VLOG(4) << "black_list has ";
for (const auto& name : black_list_) {
VLOG(4) << " - " << name;
}

auto graph_size = graph->SubGraphsSize();
VLOG(4) << "graph size: " << graph_size;
Expand Down Expand Up @@ -235,18 +255,6 @@ void FloatToHalfPass::ApplyImpl(Graph* graph) const {
VLOG(4) << "RestoreOpOriginType done";
}

bool FloatToHalfPass::OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend) const {
bool support = false;
if (black_list_.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
support = GpuKernelSupportPrecision(op_type, precision);
}
}
return support;
}

void FloatToHalfPass::SetOpUniqueType() const {
int suffix = 0;
for (const auto& nodes : all_op_nodes_) {
Expand Down Expand Up @@ -328,8 +336,10 @@ void FloatToHalfPass::GetOpPrecision() const {
GetOpOriginalType(op_type) == "fetch") {
support_half = !keep_io_types_;
} else {
support_half =
OpSupportPrecision(GetOpOriginalType(op_type), half_precision_);
support_half = OpSupportPrecision(GetOpOriginalType(op_type),
phi::Backend::GPU,
half_precision_,
black_list_);
}

if (op_node->Op()->HasAttr("dtype")) {
Expand Down Expand Up @@ -555,7 +565,11 @@ bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node,
void FloatToHalfPass::SetVarPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_half_.count(op_node->Op()->Type())) {
if (op_run_half_.count(op_node->Op()->Type()) == 0) {
continue;
}

if (GetOpOriginalType(op_node->Op()->Type()) != "feed") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);

Expand All @@ -572,7 +586,9 @@ void FloatToHalfPass::SetVarPrecision() const {
vars_convert_to_half_.insert(in_var_name);
}
}
}

if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);

Expand Down
18 changes: 14 additions & 4 deletions paddle/fluid/framework/ir/float_to_half_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ class FloatToHalfPass : public FusePassBase {

void SetDefaultBlacklist() const;

bool OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend = phi::Backend::GPU) const;

void SetOpUniqueType() const;

void RestoreOpOriginType() const;
Expand Down Expand Up @@ -93,6 +89,20 @@ class FloatToHalfPass : public FusePassBase {
mutable std::unordered_set<std::string> vars_convert_to_half_;
};

bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);

void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
proto::VarType::Type from_type,
proto::VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache);

} // namespace ir
} // namespace framework
} // namespace paddle
31 changes: 18 additions & 13 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ IRPassManager::IRPassManager(Argument *argument) {

void IRPassManager::CreatePasses(Argument *argument,
const std::vector<std::string> &passes) {
// For graph_viz_pass
std::string pre_pass;
int pass_num = 0;

for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen()));
Expand Down Expand Up @@ -86,15 +88,6 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));

// mixed precision related
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));

if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
std::string dot_file_path;
Expand Down Expand Up @@ -209,10 +202,17 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::vector<std::string>(argument->tensorrt_disabled_ops()));
pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla()));
pass->Set("trt_dla_core", new int(argument->tensorrt_dla_core()));

// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not run fp16.
pass->Set("disable_trt_plugin_fp16",
new bool(argument->disable_trt_plugin_fp16()));

// Mixed precision related.
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
} else if (pass_name == "dlnne_subgraph_pass") {
auto precision_mode = argument->dlnne_precision_mode();
pass->Set("min_subgraph_size",
Expand All @@ -237,8 +237,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("root_predictor_id", new int(argument->root_predictor_id()));
} else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
}
if (pass_name == "lite_subgraph_pass") {
} else if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
pass->Set("program",
Expand Down Expand Up @@ -286,8 +285,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("nnadapter_model_cache_token",
new std::vector<std::string>(
argument->nnadapter_model_cache_token()));
}
if (pass_name == "fc_fuse_pass") {
} else if (pass_name == "fc_fuse_pass") {
pass->Set("use_gpu", new bool(argument->use_gpu()));
bool fc_mkldnn_pass = 0;
for (const std::string &pass_n : passes) {
Expand All @@ -297,6 +295,13 @@ void IRPassManager::CreatePasses(Argument *argument,
}
bool use_fc_padding = !fc_mkldnn_pass && argument->use_fc_padding();
pass->Set("use_fc_padding", new bool(use_fc_padding));
} else if (pass_name == "float_to_half_pass") {
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
}
pre_pass = pass_name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ void OutputProcess(framework::ir::Graph *graph,
backend,
precision,
blacklist)) {
AddCastOp(graph,
var_node,
next_op,
framework::proto::VarType::FP32,
to_type,
&suffix,
block_desc,
&var_to_cast_op_map);
InsertCastOp(graph,
var_node,
next_op,
framework::proto::VarType::FP32,
to_type,
block_desc,
&suffix,
&var_to_cast_op_map);
var_node->Var()->SetDataType(framework::proto::VarType::FP32);
}
}
Expand Down
Loading