Skip to content

Commit

Permalink
fix embedding multihead (#49085)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wangzheee authored Dec 15, 2022
1 parent e577040 commit 439b2b9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer,
"ManyEmbLayerNormPluginDynamic_V1",
"ManyEmbLayerNormVarlenPluginDynamicV1",
{output_name,
std::string("qkv_plugin_mask"),
std::string("max_seqlen_tensor")},
Expand Down
15 changes: 10 additions & 5 deletions paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
RreplenishLayerAndOutput(
plugin_layer, "multihead_matmul", {output_name}, test_mode);
} else {
int head_size = hidden_out / head_number;
// [3, head_number, head_size, hidden_in] -> [head_number, 3,
Expand Down Expand Up @@ -381,7 +382,8 @@ class MultiheadMatMulOpConverter : public OpConverter {

auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);

plugin_layer->setName(
("CustomQKVToContextPluginDynamic: " + output_name).c_str());
// recover no_varlen output
if (!flag_varseqlen) {
std::vector<nvinfer1::ITensor*> output_transformer;
Expand All @@ -394,7 +396,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_->AddDynamicPlugin(output_transformer.data(),
output_transformer.size(),
plugin);
layer = transformer_output_layer;
engine_->SetITensor(output_name,
transformer_output_layer->getOutput(0));
} else {
engine_->SetITensor(output_name, plugin_layer->getOutput(0));
}
}
} else {
Expand Down Expand Up @@ -776,6 +781,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
new plugin::QkvToContextPluginDynamic(
hidden_in, head_number, head_size, scale, with_fp16);
layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin);
RreplenishLayerAndOutput(
layer, "multihead_matmul", {output_name}, test_mode);
}
}
} else {
Expand All @@ -785,8 +792,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."));
}
RreplenishLayerAndOutput(
layer, "multihead_matmul", {output_name}, test_mode);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
desc.dims.d[0] == prev.dims.d[0];
}
if (pos == nbInputs - 1) { // mask id
return desc.type == prev.type;
return desc.type == mType;
}
// embedded sequence
if (pos == nbInputs) {
Expand All @@ -265,11 +265,11 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
}
// mask(HFace) or pre_layernorm_bias(MTron)
if (pos == nbInputs + 1) {
return desc.type == prev.type;
return desc.type == mType;
}
// max seqlen
if (pos == nbInputs + 2) {
return desc.type == prev.type;
return desc.type == mType;
}
}

Expand Down

0 comments on commit 439b2b9

Please sign in to comment.