Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPUW] Add Slice before last MatMul #27229

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ DEFINE_OPT(NPUW_FOLD, bool, false, npuw::partitioning::fold, CompileTime);
DEFINE_OPT(NPUW_CWAI, bool, false, npuw::partitioning::cwai, CompileTime);
DEFINE_OPT(NPUW_DQ, bool, false, npuw::partitioning::dyn_quant, CompileTime);
DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, CompileTime);
DEFINE_OPT(NPUW_SLICE_OUT, bool, false, npuw::partitioning::slice_out, CompileTime);
DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, CompileTime);
DEFINE_OPT(NPUW_SPATIAL, bool, false, npuw::partitioning::spatial, CompileTime);
DEFINE_OPT(NPUW_SPATIAL_NWAY, std::size_t, 128, npuw::partitioning::spatial_nway, CompileTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ static constexpr ov::Property<bool> dyn_quant{"NPUW_DQ"};
*/
static constexpr ov::Property<std::string> par_matmul_merge_dims{"NPUW_PMM"};

/**
* @brief
* Type: bool.
* Add Slice before the last MatMul reducing output's dimention.
* Default value: false.
*/
static constexpr ov::Property<bool> slice_out{"NPUW_SLICE_OUT"};

/**
* @brief
* Type: boolean.
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/al/src/config/npuw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) {
desc.add<NPUW_CWAI>();
desc.add<NPUW_DQ>();
desc.add<NPUW_PMM>();
desc.add<NPUW_SLICE_OUT>();
desc.add<NPUW_SPATIAL>();
desc.add<NPUW_SPATIAL_NWAY>();
desc.add<NPUW_SPATIAL_DYN>();
Expand Down
11 changes: 11 additions & 0 deletions src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ ov::npuw::CompiledModel::CompiledModel(const std::shared_ptr<ov::Model>& model,
rewr.run_on_model(model);
}

if (m_cfg.get<::intel_npu::NPUW_SLICE_OUT>()) {
// Add Slice before last MatMul for the prefill model
ov::pass::GraphRewrite rewr;
rewr.add_matcher<ov::npuw::patterns::opt::SliceLastMatmul>();
rewr.add_matcher<ov::npuw::patterns::opt::SliceLastMatmulAdd>();
rewr.add_matcher<ov::npuw::patterns::opt::SliceLastMatmulTranspose>();
rewr.add_matcher<ov::npuw::patterns::opt::SliceLastMatmulMultiply>();
rewr.run_on_model(model);
}

auto partitioning = getPartitioning(model, m_cfg);
m_total_stat.gflops = partitioning.total_gflops;
m_total_stat.ops = partitioning.total_ops;
Expand Down Expand Up @@ -907,6 +917,7 @@ void ov::npuw::CompiledModel::implement_properties() {
BIND(npuw::partitioning::cwai, NPUW_CWAI),
BIND(npuw::partitioning::dyn_quant, NPUW_DQ),
BIND(npuw::partitioning::par_matmul_merge_dims, NPUW_PMM),
BIND(npuw::partitioning::slice_out, NPUW_SLICE_OUT),
BIND(npuw::partitioning::spatial, NPUW_SPATIAL),
BIND(npuw::partitioning::spatial_nway, NPUW_SPATIAL_NWAY),
BIND(npuw::partitioning::spatial_dyn, NPUW_SPATIAL_DYN),
Expand Down
158 changes: 146 additions & 12 deletions src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,7 @@

#include "../../logging.hpp"
#include "../../util.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/ops.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/label.hpp" // any_input
#include "openvino/pass/pattern/op/optional.hpp"
Expand Down Expand Up @@ -1296,6 +1285,151 @@ CompressDictMatMulf32::CompressDictMatMulf32(Context::Ref ctx) {
register_matcher(std::make_shared<opp::Matcher>(res, "OptCompressDictMatMulf32"), std::move(callback));
}

SliceLastMatmul::SliceLastMatmul() {
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
auto res = opp::wrap_type<ov::op::v0::Result>({matmul});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_out_matmul = node_to_output.at(matmul);

auto shape = matched_out_matmul.get_node()->input(0).get_shape();

if (shape.size() == 3 && shape[1] > 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{0, int32_t(shape[1] - 1), 0});
auto stop =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{1, int32_t(shape[1]), int32_t(shape[2])});
auto step =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 1, 1});

auto slice =
std::make_shared<ov::op::v8::Slice>(matched_out_matmul.get_node()->input_value(0), start, stop, step);

matched_out_matmul.get_node()->input(0).replace_source_output(slice);

return true; // root was changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmul"), std::move(callback));
}

SliceLastMatmulAdd::SliceLastMatmulAdd() {
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
auto add = opp::wrap_type<ov::op::v1::Add>({matmul, opp::any_input()});
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if all those additional patterns could be simplified with optional nodes

auto res = opp::wrap_type<ov::op::v0::Result>({add});
Comment on lines +1325 to +1326
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I've seen cases like this. @TolyaTalamanov did you?

@smirnov-alexey what pattern worked for you most of the time?

I think this pattern we can drop

Copy link
Contributor

@slyalin slyalin Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summoning @olpipi who showed that this and later mentioned patterns were required. @olpipi could you give a list of topologies?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only saw the first one, but as discussed let's keep all 3 for now (enabled via property)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the next patterns and applied them in genai repo:

chatglm2-6b
chatglm3-6b
Matmul -> Transpose -> Result

codegen2-1b
codegen2-3_7b
codegen2-7b
gpt-j-6b
phi-2
Matmul -> Add -> Result

gemma-2-2b
gemma-2-9b
MatMul -> Divide -> Tanh -> Multiply -> Result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @olpipi, appreciate your help


// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_out_matmul = node_to_output.at(matmul);

auto shape = matched_out_matmul.get_node()->input(0).get_shape();

if (shape.size() == 3 && shape[1] > 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{0, int32_t(shape[1] - 1), 0});
auto stop =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{1, int32_t(shape[1]), int32_t(shape[2])});
auto step =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 1, 1});

auto slice =
std::make_shared<ov::op::v8::Slice>(matched_out_matmul.get_node()->input_value(0), start, stop, step);

matched_out_matmul.get_node()->input(0).replace_source_output(slice);

return true; // root was changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmulAdd"), std::move(callback));
}

SliceLastMatmulTranspose::SliceLastMatmulTranspose() {
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
auto add = opp::wrap_type<ov::op::v1::Transpose>({matmul, opp::any_input()});
auto res = opp::wrap_type<ov::op::v0::Result>({matmul});
Comment on lines +1359 to +1362
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not sure about this one. My main concern is that we can alter more matmuls than we actually need..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, keeping for now


// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_out_matmul = node_to_output.at(matmul);

auto shape = matched_out_matmul.get_node()->input(0).get_shape();

if (shape.size() == 3 && shape[1] > 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{0, int32_t(shape[1] - 1), 0});
auto stop =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{1, int32_t(shape[1]), int32_t(shape[2])});
auto step =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 1, 1});

auto slice =
std::make_shared<ov::op::v8::Slice>(matched_out_matmul.get_node()->input_value(0), start, stop, step);

matched_out_matmul.get_node()->input(0).replace_source_output(slice);

return true; // root was changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmulTranspose"), std::move(callback));
}

SliceLastMatmulMultiply::SliceLastMatmulMultiply() {
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
auto div = opp::wrap_type<ov::op::v1::Divide>({matmul, opp::any_input()});
auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input()});
auto res = opp::wrap_type<ov::op::v0::Result>({multiply});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_out_matmul = node_to_output.at(matmul);

auto shape = matched_out_matmul.get_node()->input(0).get_shape();

if (shape.size() == 3 && shape[1] > 1) {
auto start = std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{0, int32_t(shape[1] - 1), 0});
auto stop =
std::make_shared<ov::op::v0::Constant>(ov::element::i32,
ov::Shape{3},
std::vector<int32_t>{1, int32_t(shape[1]), int32_t(shape[2])});
auto step =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{1, 1, 1});

auto slice =
std::make_shared<ov::op::v8::Slice>(matched_out_matmul.get_node()->input_value(0), start, stop, step);

matched_out_matmul.get_node()->input(0).replace_source_output(slice);

return true; // root was changed
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(res, "SliceLastMatmulMultiply"), std::move(callback));
}
Comment on lines +1395 to +1431
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern here. Not sure which topologies does it serve.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, keeping for now


} // namespace opt
} // namespace patterns
} // namespace npuw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ class CompressDictMatMulf32 : public ov::pass::MatcherPass {
CompressDictMatMulf32(Context::Ref ctx);
};

// Slice last Matmul
class SliceLastMatmul : public ov::pass::MatcherPass {
public:
SliceLastMatmul();
};

class SliceLastMatmulAdd : public ov::pass::MatcherPass {
public:
SliceLastMatmulAdd();
};

class SliceLastMatmulTranspose : public ov::pass::MatcherPass {
public:
SliceLastMatmulTranspose();
};

class SliceLastMatmulMultiply : public ov::pass::MatcherPass {
public:
SliceLastMatmulMultiply();
};

} // namespace opt
} // namespace patterns
} // namespace npuw
Expand Down
Loading