From 11eab76e1e5e82fdcd361c7ff67da746aa22f418 Mon Sep 17 00:00:00 2001 From: meiyang-intel Date: Tue, 7 Jun 2022 23:59:02 -0400 Subject: [PATCH] fix variable support issue for paddle top_k_v2 --- src/core/tests/frontend/paddle/op_fuzzy.cpp | 1 + .../test_models/gen_scripts/generate_top_k_v2.py | 16 +++++++++++----- src/frontends/paddle/src/op/top_k_v2.cpp | 3 ++- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/core/tests/frontend/paddle/op_fuzzy.cpp b/src/core/tests/frontend/paddle/op_fuzzy.cpp index bfdbea5dd825d6..37b2fd7c802b40 100644 --- a/src/core/tests/frontend/paddle/op_fuzzy.cpp +++ b/src/core/tests/frontend/paddle/op_fuzzy.cpp @@ -378,6 +378,7 @@ static const std::vector models{ std::string("top_k_v2_test_3"), std::string("top_k_v2_test_4"), std::string("top_k_v2_test_5"), + std::string("top_k_v2_test_6"), std::string("trilinear_downsample_false_0"), std::string("trilinear_downsample_false_1"), std::string("trilinear_downsample_true_0"), diff --git a/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_top_k_v2.py b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_top_k_v2.py index 28e9320f6aa052..86a28263dd4594 100644 --- a/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_top_k_v2.py +++ b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_top_k_v2.py @@ -9,14 +9,17 @@ data_type = 'float32' -def top_k_v2(name: str, x, k: int, axis=None, largest=True, sorted=True): +def top_k_v2(name: str, x, k: int, axis=None, largest=True, sorted=True, k_is_var=True): paddle.enable_static() + k = np.array([k], dtype='int32') if k_is_var else k + with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): node_x = paddle.static.data(name='x', shape=x.shape, dtype='float32') + input_k = paddle.static.data(name='k', shape=[1], dtype='int32') if k_is_var else k value, indices = paddle.topk( - node_x, k=k, axis=axis, largest=largest, sorted=sorted, name="top_k") + node_x, k=input_k, axis=axis, largest=largest, sorted=sorted, name="top_k") indices = paddle.cast(indices, np.float32) cpu = paddle.static.cpu_places(1) @@ -24,12 +27,14 @@ def top_k_v2(name: str, x, k: int, axis=None, largest=True, sorted=True): # startup program will call initializer to initialize the parameters. exe.run(paddle.static.default_startup_program()) + feed_list = {'x': x, 'k': k} if k_is_var else {'x': x} outs = exe.run( - feed={'x': x}, + feed=feed_list, fetch_list=[value, indices]) - saveModel(name, exe, feedkeys=['x'], fetchlist=[value, indices], inputs=[ - x], outputs=outs, target_dir=sys.argv[1]) + feedkey_list = ['x', 'k'] if k_is_var else ['x'] + input_list = [x, k] if k_is_var else [x] + saveModel(name, exe, feedkeys=feedkey_list, fetchlist=[value, indices], inputs=input_list, outputs=outs, target_dir=sys.argv[1]) return outs[0] @@ -43,6 +48,7 @@ def main(): top_k_v2("top_k_v2_test_4", data, k=7, axis=None, largest=True, sorted=True) top_k_v2("top_k_v2_test_5", data, k=6, axis=2, largest=False, sorted=True) + top_k_v2("top_k_v2_test_6", data, k=6, axis=2, largest=False, sorted=True, k_is_var=False) if __name__ == "__main__": diff --git a/src/frontends/paddle/src/op/top_k_v2.cpp b/src/frontends/paddle/src/op/top_k_v2.cpp index b4724026859baf..e48a4dd286d207 100644 --- a/src/frontends/paddle/src/op/top_k_v2.cpp +++ b/src/frontends/paddle/src/op/top_k_v2.cpp @@ -13,7 +13,8 @@ NamedOutputs top_k_v2(const NodeContext& node) { Output k_expected_node; if (node.has_input("K")) { auto k_variable = node.get_input("K"); - k_expected_node = std::make_shared(k_variable, element::i32); + auto k_var_node = std::make_shared(k_variable, element::i32); + k_expected_node = std::make_shared(k_var_node); } else { const auto k_expected = node.get_attribute("k", 1); k_expected_node = default_opset::Constant::create(element::i32, {}, {k_expected});