-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XPU] Add embedding plugin #56488
[XPU] Add embedding plugin #56488
Changes from 23 commits
ac8c61f
e82f825
e2fea8c
a965eed
fefe3bc
365af8b
00f264e
fda5c8a
c28e0e6
73412c8
ed0545a
f6aefb0
cef35c0
ec44226
204bf1f
521d2f2
d807dca
b49efb4
d557e8f
659d682
a6af0d6
597127e
6a08122
67b90c2
3ef5492
aab6c3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <string> | ||
|
||
#include "glog/logging.h" | ||
|
||
#include "paddle/fluid/framework/ir/fuse_pass_base.h" | ||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" | ||
#include "paddle/fluid/framework/ir/pass.h" | ||
#include "paddle/fluid/framework/ir/xpu/pass_utils.h" | ||
#include "paddle/fluid/framework/ir/xpu/quant_utils.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace phi { | ||
class DenseTensor; | ||
} // namespace phi | ||
|
||
namespace paddle { | ||
namespace framework { | ||
class Scope; | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
namespace paddle { | ||
namespace framework { | ||
namespace ir { | ||
namespace patterns { | ||
|
||
struct CastEmbeddingTransIdsToInt32Pattern : public PatternBase { | ||
CastEmbeddingTransIdsToInt32Pattern(PDPattern* pattern, | ||
const std::string& name_scope); | ||
// declare operator node's name | ||
PATTERN_DECL_NODE(cast); | ||
PATTERN_DECL_NODE(embedding); | ||
// declare variable node's name | ||
PATTERN_DECL_NODE(cast_x); | ||
PATTERN_DECL_NODE(embedding_ids); | ||
PATTERN_DECL_NODE(embedding_w); | ||
PATTERN_DECL_NODE(embedding_out); | ||
}; | ||
|
||
CastEmbeddingTransIdsToInt32Pattern::CastEmbeddingTransIdsToInt32Pattern( | ||
PDPattern* pattern, const std::string& name_scope) | ||
: PatternBase(pattern, name_scope, name_scope) { | ||
auto cast = pattern->NewNode(cast_repr())->assert_is_op("cast"); | ||
auto cast_x = pattern->NewNode(cast_x_repr()) | ||
->assert_is_op_input("cast", "X") | ||
->assert_var_not_persistable() | ||
->AsInput(); | ||
auto embedding_ids = pattern->NewNode(embedding_ids_repr()) | ||
->assert_is_op_output("cast", "Out") | ||
->assert_is_op_input("lookup_table_v2", "Ids") | ||
->assert_has_n_outputs(1); | ||
cast->LinksFrom({cast_x}).LinksTo({embedding_ids}); | ||
auto embedding_w = pattern->NewNode(embedding_w_repr()) | ||
->assert_is_op_input("lookup_table_v2", "W"); | ||
auto embedding = | ||
pattern->NewNode(embedding_repr())->assert_is_op("lookup_table_v2"); | ||
auto embedding_out = pattern->NewNode(embedding_out_repr()) | ||
->assert_is_op_output("lookup_table_v2", "Out") | ||
->AsOutput(); | ||
embedding->LinksFrom({embedding_ids, embedding_w}).LinksTo({embedding_out}); | ||
} | ||
|
||
} // namespace patterns | ||
|
||
class CastEmbeddingTransIdsToInt32Pass : public FusePassBase { | ||
protected: | ||
void ApplyImpl(ir::Graph* graph) const override; | ||
|
||
private: | ||
const std::string name_scope_{"cast_embedding_trans_ids_to_int32_pass"}; | ||
}; | ||
void CastEmbeddingTransIdsToInt32Pass::ApplyImpl(ir::Graph* graph) const { | ||
PADDLE_ENFORCE_NOT_NULL( | ||
graph, platform::errors::PreconditionNotMet("graph should not be null.")); | ||
Init(name_scope_, graph); | ||
|
||
GraphPatternDetector gpd; | ||
patterns::CastEmbeddingTransIdsToInt32Pattern pattern(gpd.mutable_pattern(), | ||
name_scope_); | ||
int found_subgraph_count = 0; | ||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, | ||
Graph* graph) { | ||
VLOG(4) << "handle CastEmbeddingTransIdsToInt32Pass"; | ||
GET_IR_NODE(cast); | ||
GET_IR_NODE(embedding); | ||
GET_IR_NODE(embedding_ids); | ||
auto* block = cast->Op()->Block(); | ||
auto cast_node_attr_in_dtype = cast->Op()->GetAttrIfExists<int>("in_dtype"); | ||
auto cast_node_attr_out_dtype = | ||
cast->Op()->GetAttrIfExists<int>("out_dtype"); | ||
if (cast_node_attr_in_dtype != | ||
static_cast<int>(paddle::framework::proto::VarType::FP32) && | ||
cast_node_attr_out_dtype != | ||
static_cast<int>(paddle::framework::proto::VarType::INT64)) { | ||
return; | ||
} | ||
cast->Op()->SetAttr( | ||
"out_dtype", | ||
static_cast<int>(paddle::framework::proto::VarType::INT32)); | ||
embedding_ids->Var()->SetDataType(paddle::framework::proto::VarType::INT32); | ||
embedding->Op()->Flush(); | ||
found_subgraph_count++; | ||
}; | ||
gpd(graph, handler); | ||
AddStatis(found_subgraph_count); | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
REGISTER_PASS(cast_embedding_trans_ids_to_int32_pass, | ||
paddle::framework::ir::CastEmbeddingTransIdsToInt32Pass); | ||
|
||
REGISTER_PASS_CAPABILITY(cast_embedding_trans_ids_to_int32_pass) | ||
.AddCombination( | ||
paddle::framework::compatible::OpVersionComparatorCombination().LE( | ||
"lookup_table_v2", 1)); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,18 +44,6 @@ void EmbeddingKernel(const Context &ctx, | |
auto *table = table_t->data<T>(); | ||
auto *output = dev_ctx.template Alloc<T>(output_t); | ||
|
||
xpu::ctx_guard RAII_GUARD(ctx.x_context()); | ||
const int64_t *ids; | ||
if (ids_t->dtype() == phi::DataType::INT64) { | ||
ids = ids_t->data<int64_t>(); | ||
} else { | ||
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel()); | ||
int r = xpu::cast<int32_t, int64_t>( | ||
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel()); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); | ||
ids = reinterpret_cast<const int64_t *>(ids_tt); | ||
} | ||
|
||
PADDLE_ENFORCE_EQ( | ||
ids_numel <= std::numeric_limits<int32_t>::max(), | ||
true, | ||
|
@@ -68,15 +56,57 @@ void EmbeddingKernel(const Context &ctx, | |
size_t xm = table_t->dims()[0]; | ||
size_t n = table_t->dims()[1]; | ||
|
||
int r = xpu::embedding<XPUType>(dev_ctx.x_context(), | ||
reinterpret_cast<const XPUType *>(table), | ||
ids, | ||
reinterpret_cast<XPUType *>(output), | ||
xm, | ||
n, | ||
ym, | ||
padding_idx); | ||
|
||
int r; | ||
xpu::ctx_guard RAII_GUARD(ctx.x_context()); | ||
if (ids_t->dtype() == phi::DataType::INT64) { | ||
#ifndef PADDLE_WITH_XPU_PLUGIN | ||
r = xpu::embedding<XPUType, int64_t>( | ||
dev_ctx.x_context(), | ||
reinterpret_cast<const XPUType *>(table), | ||
ids_t->data<int64_t>(), | ||
reinterpret_cast<XPUType *>(output), | ||
xm, | ||
n, | ||
ym, | ||
padding_idx); | ||
#else | ||
r = xpu::plugin::embedding_tiny_dict<XPUType, int64_t>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 统一下名称吧 fast_embedding 以后各种 case 的加速版在 fast_embedding 的wrapper 内部判断。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉 fast 名称不好啊,没办法显示出 kernel 的具体场景,我这个是针对小词表下面的 kernel 优化,感觉这个名称更合适。一个 plugin 倒可以改成 fast,如果接下来有需要针对 embedding 其他场景进行优化,就不太好取名字了。 |
||
dev_ctx.x_context(), | ||
reinterpret_cast<const XPUType *>(table), | ||
ids_t->data<int64_t>(), | ||
reinterpret_cast<XPUType *>(output), | ||
xm, | ||
n, | ||
ym, | ||
padding_idx); | ||
#endif | ||
} else { | ||
#ifndef PADDLE_WITH_XPU_PLUGIN | ||
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xpu::embedding 不支持 int32 的 index 吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是之前 phi kernel 的逻辑, kernel 里面会插入 cast 算子将 int32 转成 int64,昆仑2我看是支持int32的,昆仑一不清楚,所以为了兼容性,我没有改原来的逻辑 |
||
r = xpu::cast<int32_t, int64_t>( | ||
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel()); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); | ||
const int64_t *ids = reinterpret_cast<const int64_t *>(ids_tt); | ||
r = xpu::embedding<XPUType>(dev_ctx.x_context(), | ||
reinterpret_cast<const XPUType *>(table), | ||
ids, | ||
reinterpret_cast<XPUType *>(output), | ||
xm, | ||
n, | ||
ym, | ||
padding_idx); | ||
#else | ||
r = xpu::plugin::embedding_tiny_dict<XPUType, int>( | ||
dev_ctx.x_context(), | ||
reinterpret_cast<const XPUType *>(table), | ||
ids_t->data<int>(), | ||
reinterpret_cast<XPUType *>(output), | ||
xm, | ||
n, | ||
ym, | ||
padding_idx); | ||
#endif | ||
} | ||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在 cast + embedding 组合里面感觉应该不会出现,embedding 支持 int32,理论上不会出现你这个情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个我改下,防止模型出现这种情况吧,目前模型里面是没有这种情况的。