Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] Add embedding plugin #56488

Merged
merged 26 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ac8c61f
delete repeat ops: gather,squeeze,unsqueeze
csy0225 Jul 12, 2023
e82f825
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
csy0225 Jul 13, 2023
e2fea8c
add ut
csy0225 Jul 18, 2023
a965eed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
csy0225 Jul 18, 2023
fefe3bc
add transpose + matmul fuse
csy0225 Jul 28, 2023
365af8b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
csy0225 Jul 28, 2023
00f264e
add ut
csy0225 Jul 28, 2023
fda5c8a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
csy0225 Jul 31, 2023
c28e0e6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
csy0225 Aug 2, 2023
73412c8
add conv2d trans filter case: dilation > 1
csy0225 Aug 4, 2023
ed0545a
merge develop into branch
csy0225 Aug 4, 2023
f6aefb0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
csy0225 Aug 8, 2023
cef35c0
merge develop into branch
csy0225 Aug 8, 2023
ec44226
fix transfilter in fp16 mode
csy0225 Aug 8, 2023
204bf1f
add unit test
csy0225 Aug 9, 2023
521d2f2
support fp16 trans
csy0225 Aug 9, 2023
d807dca
fix comment
csy0225 Aug 9, 2023
b49efb4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
csy0225 Aug 17, 2023
d557e8f
add xpu embedding tiny dict plugin
csy0225 Aug 21, 2023
659d682
merge develop into branch
heavengate Aug 21, 2023
a6af0d6
add cast pass
heavengate Aug 22, 2023
597127e
fix codestyle
heavengate Aug 22, 2023
6a08122
add unit test
heavengate Aug 22, 2023
67b90c2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
heavengate Aug 22, 2023
3ef5492
fix comment
heavengate Aug 23, 2023
aab6c3b
fix comment
heavengate Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ if(WITH_XPU)
pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(cast_embedding_trans_ids_to_int32_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>

#include "glog/logging.h"

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

struct CastEmbeddingTransIdsToInt32Pattern : public PatternBase {
CastEmbeddingTransIdsToInt32Pattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast);
PATTERN_DECL_NODE(embedding);
// declare variable node's name
PATTERN_DECL_NODE(cast_x);
PATTERN_DECL_NODE(embedding_ids);
PATTERN_DECL_NODE(embedding_w);
PATTERN_DECL_NODE(embedding_out);
};

CastEmbeddingTransIdsToInt32Pattern::CastEmbeddingTransIdsToInt32Pattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto cast = pattern->NewNode(cast_repr())->assert_is_op("cast");
auto cast_x = pattern->NewNode(cast_x_repr())
->assert_is_op_input("cast", "X")
->assert_var_not_persistable()
->AsInput();
auto embedding_ids = pattern->NewNode(embedding_ids_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("lookup_table_v2", "Ids")
->assert_has_n_outputs(1);
cast->LinksFrom({cast_x}).LinksTo({embedding_ids});
auto embedding_w = pattern->NewNode(embedding_w_repr())
->assert_is_op_input("lookup_table_v2", "W");
auto embedding =
pattern->NewNode(embedding_repr())->assert_is_op("lookup_table_v2");
auto embedding_out = pattern->NewNode(embedding_out_repr())
->assert_is_op_output("lookup_table_v2", "Out")
->AsOutput();
embedding->LinksFrom({embedding_ids, embedding_w}).LinksTo({embedding_out});
}

} // namespace patterns

class CastEmbeddingTransIdsToInt32Pass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
const std::string name_scope_{"cast_embedding_trans_ids_to_int32_pass"};
};
void CastEmbeddingTransIdsToInt32Pass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);

GraphPatternDetector gpd;
patterns::CastEmbeddingTransIdsToInt32Pattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle CastEmbeddingTransIdsToInt32Pass";
GET_IR_NODE(cast);
GET_IR_NODE(embedding);
GET_IR_NODE(embedding_ids);
auto* block = cast->Op()->Block();
auto cast_node_attr_in_dtype = cast->Op()->GetAttrIfExists<int>("in_dtype");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 有考虑 in_dtype 为 int32,out_dtype 为 int64 的情况吗?理论上这个情况下,可以把 cast 删除。
  2. 理论上不应该限制 in_dtype 吧?是不是只看 out_dtype==int64就行?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在 cast + embedding 组合里面感觉应该不会出现,embedding 支持 int32,理论上不会出现你这个情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个我改下,防止模型出现这种情况吧,目前模型里面是没有这种情况的。

auto cast_node_attr_out_dtype =
cast->Op()->GetAttrIfExists<int>("out_dtype");
if (cast_node_attr_in_dtype !=
static_cast<int>(paddle::framework::proto::VarType::FP32) &&
cast_node_attr_out_dtype !=
static_cast<int>(paddle::framework::proto::VarType::INT64)) {
return;
}
cast->Op()->SetAttr(
"out_dtype",
static_cast<int>(paddle::framework::proto::VarType::INT32));
embedding_ids->Var()->SetDataType(paddle::framework::proto::VarType::INT32);
embedding->Op()->Flush();
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(cast_embedding_trans_ids_to_int32_pass,
paddle::framework::ir::CastEmbeddingTransIdsToInt32Pass);

REGISTER_PASS_CAPABILITY(cast_embedding_trans_ids_to_int32_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"lookup_table_v2", 1));
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"reshape_unstack_concat_fuse_pass",
"delete_op_device_pass",
"constant_folding_pass",
"cast_embedding_trans_ids_to_int32_pass",
"delete_elementwise_mul_op_pass",
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
Expand Down
72 changes: 51 additions & 21 deletions paddle/phi/kernels/xpu/embedding_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,6 @@ void EmbeddingKernel(const Context &ctx,
auto *table = table_t->data<T>();
auto *output = dev_ctx.template Alloc<T>(output_t);

xpu::ctx_guard RAII_GUARD(ctx.x_context());
const int64_t *ids;
if (ids_t->dtype() == phi::DataType::INT64) {
ids = ids_t->data<int64_t>();
} else {
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
int r = xpu::cast<int32_t, int64_t>(
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
ids = reinterpret_cast<const int64_t *>(ids_tt);
}

PADDLE_ENFORCE_EQ(
ids_numel <= std::numeric_limits<int32_t>::max(),
true,
Expand All @@ -68,15 +56,57 @@ void EmbeddingKernel(const Context &ctx,
size_t xm = table_t->dims()[0];
size_t n = table_t->dims()[1];

int r = xpu::embedding<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids,
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);

int r;
xpu::ctx_guard RAII_GUARD(ctx.x_context());
if (ids_t->dtype() == phi::DataType::INT64) {
#ifndef PADDLE_WITH_XPU_PLUGIN
r = xpu::embedding<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int64_t>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#else
r = xpu::plugin::embedding_tiny_dict<XPUType, int64_t>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一下名称吧 fast_embedding 以后各种 case 的加速版在 fast_embedding 的wrapper 内部判断。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉 fast 名称不好啊,没办法显示出 kernel 的具体场景,我这个是针对小词表下面的 kernel 优化,感觉这个名称更合适。一个 plugin 倒可以改成 fast,如果接下来有需要针对 embedding 其他场景进行优化,就不太好取名字了。

dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int64_t>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#endif
} else {
#ifndef PADDLE_WITH_XPU_PLUGIN
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xpu::embedding 不支持 int32 的 index 吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是之前 phi kernel 的逻辑, kernel 里面会插入 cast 算子将 int32 转成 int64,昆仑2我看是支持int32的,昆仑一不清楚,所以为了兼容性,我没有改原来的逻辑

r = xpu::cast<int32_t, int64_t>(
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
const int64_t *ids = reinterpret_cast<const int64_t *>(ids_tt);
r = xpu::embedding<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids,
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#else
r = xpu::plugin::embedding_tiny_dict<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#endif
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
}

Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ DLL_EXPORT int fast_reduce_min(Context* ctx,
const std::vector<int>& xshape,
const std::vector<int>& rdims);

template <typename T, typename TID>
DLL_EXPORT int embedding_tiny_dict(Context* ctx,
const T* x,
const TID* indices,
T* y,
int64_t xm,
int64_t n,
int64_t ym,
int64_t padding_idx,
TID start_index = 0);

} // namespace plugin
} // namespace api
} // namespace xpu
Expand Down
Loading