diff --git a/src/graph/TraverseExecutor.cpp b/src/graph/TraverseExecutor.cpp index b0227b98d9e..d68e4a893b2 100644 --- a/src/graph/TraverseExecutor.cpp +++ b/src/graph/TraverseExecutor.cpp @@ -353,33 +353,12 @@ bool WhereWrapper::rewrite(Expression *expr) const { return false; } } - case Expression::kUnary: { - auto *unaExpr = static_cast(expr); - return rewrite(const_cast(unaExpr->operand())); - } - case Expression::kTypeCasting: { - auto *typExpr = static_cast(expr); - return rewrite(const_cast(typExpr->operand())); - } - case Expression::kArithmetic: { - auto *ariExp = static_cast(expr); - return rewrite(const_cast(ariExp->left())) - && rewrite(const_cast(ariExp->right())); - } - case Expression::kRelational: { - auto *relExp = static_cast(expr); - return rewrite(const_cast(relExp->left())) - && rewrite(const_cast(relExp->right())); - } + case Expression::kUnary: + case Expression::kTypeCasting: + case Expression::kArithmetic: + case Expression::kRelational: case Expression::kFunctionCall: { - auto *funcExp = static_cast(expr); - auto &args = funcExp->args(); - for (auto &arg : args) { - if (!rewrite(arg)) { - return false; - } - } - return true; + return canPushdown(expr); } case Expression::kPrimary: case Expression::kSourceProp: @@ -405,9 +384,6 @@ bool WhereWrapper::rewrite(Expression *expr) const { } bool WhereWrapper::canPushdown(Expression *expr) const { - if (expr->isFunCallExpression()) { - return false; - } auto ectx = std::make_unique(); expr->setContext(ectx.get()); auto status = expr->prepare(); diff --git a/src/graph/test/GoTest.cpp b/src/graph/test/GoTest.cpp index 8a7317440b0..65053c8504a 100644 --- a/src/graph/test/GoTest.cpp +++ b/src/graph/test/GoTest.cpp @@ -1808,6 +1808,96 @@ TEST_P(GoTest, FilterPushdown) { std::vector> expected; ASSERT_TRUE(verifyResult(resp, expected)); } + { + auto *fmt = "GO FROM %ld OVER serve " + "WHERE udf_is_in(serve._dst, %ld, 2, 3)"; + auto query = folly::stringPrintf(fmt, + players_["Rajon Rondo"].vid(), teams_["Celtics"].vid()); + + TEST_FILTER_PUSHDOWN_REWRITE( + true, + folly::stringPrintf("udf_is_in(serve._dst,%ld,2,3)", teams_["Celtics"].vid())); + + cpp2::ExecutionResponse resp; + auto code = client_->execute(query, resp); + ASSERT_EQ(cpp2::ErrorCode::SUCCEEDED, code) << *(resp.get_error_msg()); + + std::vector expectedColNames{ + {"serve._dst"} + }; + ASSERT_TRUE(verifyColNames(resp, expectedColNames)); + + std::vector> expected = { + {teams_["Celtics"].vid()} + }; + ASSERT_TRUE(verifyResult(resp, expected)); + } + { + auto *fmt = "GO FROM %ld OVER serve " + "WHERE udf_is_in(\"test\", $$.team.name)"; + auto query = folly::stringPrintf(fmt, players_["Rajon Rondo"].vid()); + + TEST_FILTER_PUSHDOWN_REWRITE( + false, + ""); + + cpp2::ExecutionResponse resp; + auto code = client_->execute(query, resp); + ASSERT_EQ(cpp2::ErrorCode::SUCCEEDED, code) << *(resp.get_error_msg()); + + std::vector expectedColNames{ + {"serve._dst"} + }; + ASSERT_TRUE(verifyColNames(resp, expectedColNames)); + + std::vector> expected; + ASSERT_TRUE(verifyResult(resp, expected)); + } + { + auto *fmt = "GO FROM %ld OVER serve " + "WHERE udf_is_in($^.player.name, \"Tim Duncan\")"; + auto query = folly::stringPrintf(fmt, players_["Tim Duncan"].vid()); + + TEST_FILTER_PUSHDOWN_REWRITE( + true, + "udf_is_in($^.player.name,Tim Duncan)"); + + cpp2::ExecutionResponse resp; + auto code = client_->execute(query, resp); + ASSERT_EQ(cpp2::ErrorCode::SUCCEEDED, code) << *(resp.get_error_msg()); + + std::vector expectedColNames{ + {"serve._dst"} + }; + ASSERT_TRUE(verifyColNames(resp, expectedColNames)); + + std::vector> expected = { + {teams_["Spurs"].vid()} + }; + ASSERT_TRUE(verifyResult(resp, expected)); + } + { + auto *fmt = "GO FROM %ld OVER serve " + "WHERE !udf_is_in($^.player.name, \"Tim Duncan\")"; + auto query = folly::stringPrintf(fmt, players_["Tim Duncan"].vid()); + + TEST_FILTER_PUSHDOWN_REWRITE( + true, + "!(udf_is_in($^.player.name,Tim Duncan))"); + + cpp2::ExecutionResponse resp; + auto code = client_->execute(query, resp); + ASSERT_EQ(cpp2::ErrorCode::SUCCEEDED, code) << *(resp.get_error_msg()); + + std::vector expectedColNames{ + {"serve._dst"} + }; + ASSERT_TRUE(verifyColNames(resp, expectedColNames)); + + std::vector> expected; + ASSERT_TRUE(verifyResult(resp, expected)); + } + #undef TEST_FILTER_PUSHDWON_REWRITE }