Skip to content

Commit

Permalink
Support multiple query statements in e2e test framework (#3437)
Browse files Browse the repository at this point in the history
* preliminary implementation to support multiple query statements

* Run clang-format

* update one test case to use multiple query statements

* fix compile error

* support multiple queries to multiple results

* Run clang-format

* remove comment

* fix msvc compile error

* address comments

* Run clang-format

---------

Co-authored-by: CI Bot <yiyun-sj@users.noreply.github.com>
  • Loading branch information
yiyun-sj and yiyun-sj authored May 6, 2024
1 parent 93ce790 commit cc43956
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 84 deletions.
26 changes: 17 additions & 9 deletions test/include/test_runner/test_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,36 @@
namespace kuzu {
namespace testing {

enum class ResultType {
OK,
HASH,
TUPLES,
CSV_FILE,
ERROR_MSG,
ERROR_REGEX,
};

struct TestQueryResult {
ResultType type;
uint64_t numTuples = 0;
// errorMsg, CSVFile, hashValue uses first element
std::vector<std::string> expectedResult;
};

struct TestStatement {
std::string logMessage;
std::string query;
uint64_t numThreads = 4;
std::string encodedJoin;
bool expectedError = false;
bool expectedErrorRegex = false;
std::string errorMessage;
bool expectedOk = false;
uint64_t expectedNumTuples = 0;
std::vector<std::string> expectedTuples;
bool enumerate = false;
bool checkOutputOrder = false;
bool checkColumnNames = false;
std::string expectedTuplesCSVFile;
std::vector<TestQueryResult> result;
// for multiple conns
std::string batchStatmentsCSVFile;
std::optional<std::string> connName;
bool reloadDBFlag = false;
bool expectHash = false;
bool importDBFlag = false;
std::string expectedHashValue;
// for export and import db
std::string importFilePath;
bool removeFileFlag = false;
Expand Down
3 changes: 2 additions & 1 deletion test/include/test_runner/test_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ class TestParser {
void genGroupName();
void parseHeader();
void parseBody();
void extractExpectedResult(TestStatement* statement);
void extractExpectedResults(TestStatement* statement);
TestQueryResult extractExpectedResultFromToken(bool checkOutputOrder);
void extractStatementBlock();
void extractDataset();
void addStatementBlock(const std::string& blockName, const std::string& testGroupName);
Expand Down
6 changes: 3 additions & 3 deletions test/include/test_runner/test_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ class TestRunner {
static bool testStatement(TestStatement* statement, main::Connection& conn,
std::string& databasePath);
static bool checkLogicalPlans(std::unique_ptr<main::PreparedStatement>& preparedStatement,
TestStatement* statement, main::Connection& conn);
TestStatement* statement, size_t resultIdx, main::Connection& conn);
static bool checkLogicalPlan(std::unique_ptr<main::PreparedStatement>& preparedStatement,
TestStatement* statement, main::Connection& conn, uint32_t planIdx);
TestStatement* statement, size_t resultIdx, main::Connection& conn, uint32_t planIdx);
static std::vector<std::string> convertResultToString(main::QueryResult& queryResult,
bool checkOutputOrder = false, bool checkColumnNames = false);
static std::string convertResultToMD5Hash(main::QueryResult& queryResult, bool checkOutputOrder,
bool checkColumnNames); // returns hash and number of values hashed
static std::string convertResultColumnsToString(main::QueryResult& queryResult);
static bool checkPlanResult(std::unique_ptr<main::QueryResult>& result,
TestStatement* statement, const std::string& planStr, uint32_t planIndex);
TestStatement* statement, size_t resultIdx, const std::string& planStr, uint32_t planIdx);
};

} // namespace testing
Expand Down
4 changes: 2 additions & 2 deletions test/test_files/extension/extension.test
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
-SKIP_MUSL
-LOG InstallExtension
-STATEMENT INSTALL httpfs;
LOAD EXTENSION httpfs;
LOAD FROM 'http://extension.kuzudb.com/dataset/test/city.csv' return *;
---- ok
-STATEMENT LOAD EXTENSION httpfs;
---- ok
-STATEMENT LOAD FROM 'http://extension.kuzudb.com/dataset/test/city.csv' return *;
---- 3
Guelph|75000
Kitchener|200000
Expand Down
2 changes: 1 addition & 1 deletion test/test_files/transaction/create_rel/violate_error.test
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
---- 1
1
-STATEMENT COMMIT

---- ok
68 changes: 43 additions & 25 deletions test/test_runner/test_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,45 +122,63 @@ void TestParser::replaceVariables(std::string& str) {
}
}

void TestParser::extractExpectedResult(TestStatement* statement) {
void TestParser::extractExpectedResults(TestStatement* statement) {
do {
tokenize();
if (currentToken.type == TokenType::EMPTY) {
continue;
}
if (currentToken.type != TokenType::RESULT) {
setCursorToPreviousLine();
return;
}
statement->result.push_back(extractExpectedResultFromToken(statement->checkOutputOrder));
} while (nextLine());
}

TestQueryResult TestParser::extractExpectedResultFromToken(bool checkOutputOrder) {
checkMinimumParams(1);
std::string result = currentToken.params[1];
TestQueryResult queryResult;
if (result == "ok") {
statement->expectedOk = true;
queryResult.type = ResultType::OK;
} else if (result == "error") {
statement->expectedError = true;
statement->errorMessage = extractTextBeforeNextStatement();
replaceVariables(statement->errorMessage);
queryResult.type = ResultType::ERROR_MSG;
queryResult.expectedResult.push_back(extractTextBeforeNextStatement());
replaceVariables(queryResult.expectedResult[0]);
} else if (result == "error(regex)") {
statement->expectedErrorRegex = true;
statement->errorMessage = extractTextBeforeNextStatement();
replaceVariables(statement->errorMessage);
queryResult.type = ResultType::ERROR_REGEX;
queryResult.expectedResult.push_back(extractTextBeforeNextStatement());
replaceVariables(queryResult.expectedResult[0]);
} else if (result.substr(0, 4) == "hash") {
statement->expectHash = true;
queryResult.type = ResultType::HASH;
checkMinimumParams(1);
nextLine();
tokenize();
statement->expectedNumTuples = stoi(currentToken.params[0]);
statement->expectedHashValue = currentToken.params.back();
queryResult.numTuples = stoi(currentToken.params[0]);
queryResult.expectedResult.push_back(currentToken.params.back());
} else {
checkMinimumParams(1);
statement->expectedNumTuples = stoi(result);
queryResult.numTuples = stoi(result);
nextLine();
if (line.starts_with("<FILE>:")) {
statement->expectedTuplesCSVFile = TestHelper::appendKuzuRootPath(
(std::filesystem::path(TestHelper::TEST_ANSWERS_PATH) / line.substr(7)).string());
return;
}
setCursorToPreviousLine();
for (auto i = 0u; i < statement->expectedNumTuples; i++) {
nextLine();
replaceVariables(line);
statement->expectedTuples.push_back(line);
}
if (!statement->checkOutputOrder) { // order is not important for result
sort(statement->expectedTuples.begin(), statement->expectedTuples.end());
queryResult.type = ResultType::CSV_FILE;
queryResult.expectedResult.push_back(TestHelper::appendKuzuRootPath(
(std::filesystem::path(TestHelper::TEST_ANSWERS_PATH) / line.substr(7)).string()));
} else {
queryResult.type = ResultType::TUPLES;
setCursorToPreviousLine();
for (auto i = 0u; i < queryResult.numTuples; i++) {
nextLine();
replaceVariables(line);
queryResult.expectedResult.push_back(line);
}
if (!checkOutputOrder) { // order is not important for result
sort(queryResult.expectedResult.begin(), queryResult.expectedResult.end());
}
}
}
return queryResult;
}

std::string TestParser::extractTextBeforeNextStatement(bool ignoreLineBreak) {
Expand Down Expand Up @@ -247,7 +265,7 @@ TestStatement* TestParser::extractStatement(TestStatement* statement,
break;
}
case TokenType::RESULT: {
extractExpectedResult(statement);
extractExpectedResults(statement);
return statement;
}
case TokenType::CHECK_ORDER: {
Expand Down
101 changes: 58 additions & 43 deletions test/test_runner/test_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,92 +51,107 @@ bool TestRunner::testStatement(TestStatement* statement, Connection& conn,
parsedStatements = conn.getClientContext()->parseQuery(statement->query);
} catch (std::exception& exception) {
auto errorPreparedStatement = conn.preparedStatementWithError(exception.what());
return checkLogicalPlan(errorPreparedStatement, statement, conn, 0);
return checkLogicalPlan(errorPreparedStatement, statement, 0, conn, 0);
}
if (parsedStatements.empty()) {
auto errorPreparedStatement =
conn.preparedStatementWithError("Connection Exception: Query is empty.");
return checkLogicalPlan(errorPreparedStatement, statement, conn, 0);
return checkLogicalPlan(errorPreparedStatement, statement, 0, conn, 0);
}
if (parsedStatements.size() > 1) {
throw TestException("Current test framework does not support multiple query statements!");
}
auto parsedStatement = std::move(parsedStatements[0]);
if (statement->encodedJoin.empty()) {
preparedStatement = conn.prepareNoLock(parsedStatement, statement->enumerate);
} else {
preparedStatement = conn.prepareNoLock(parsedStatement, true, statement->encodedJoin);
}
// Check for wrong statements
if (!statement->expectedError && !statement->expectedErrorRegex &&
!preparedStatement->isSuccess()) {
spdlog::error(preparedStatement->getErrorMessage());
return false;

size_t numParsed = parsedStatements.size();
for (size_t i = 0; i < numParsed; i++) {
auto parsedStatement = std::move(parsedStatements[i]);
if (statement->encodedJoin.empty()) {
preparedStatement = conn.prepareNoLock(parsedStatement, statement->enumerate);
} else {
preparedStatement = conn.prepareNoLock(parsedStatement, true, statement->encodedJoin);
}
// Check for wrong statements
ResultType resultType = statement->result[i].type;
if (resultType != ResultType::ERROR_MSG && resultType != ResultType::ERROR_REGEX &&
!preparedStatement->isSuccess()) {
spdlog::error(preparedStatement->getErrorMessage());
return false;
}
if (!checkLogicalPlans(preparedStatement, statement, i, conn)) {
return false;
}
}
return checkLogicalPlans(preparedStatement, statement, conn);
return true;
}

bool TestRunner::checkLogicalPlans(std::unique_ptr<PreparedStatement>& preparedStatement,
TestStatement* statement, Connection& conn) {
TestStatement* statement, size_t resultIdx, Connection& conn) {
auto numPlans = preparedStatement->logicalPlans.size();
auto numPassedPlans = 0u;
if (numPlans == 0) {
return checkLogicalPlan(preparedStatement, statement, conn, 0);
return checkLogicalPlan(preparedStatement, statement, resultIdx, conn, 0);
}
for (auto i = 0u; i < numPlans; ++i) {
if (checkLogicalPlan(preparedStatement, statement, conn, i)) {
if (checkLogicalPlan(preparedStatement, statement, resultIdx, conn, i)) {
numPassedPlans++;
}
}
return numPassedPlans == numPlans;
}

bool TestRunner::checkLogicalPlan(std::unique_ptr<PreparedStatement>& preparedStatement,
TestStatement* statement, Connection& conn, uint32_t planIdx) {
TestStatement* statement, size_t resultIdx, Connection& conn, uint32_t planIdx) {
auto result = conn.executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get(), planIdx);
if (statement->expectedError) {
std::string expectedError = StringUtils::rtrim(result->getErrorMessage());
if (statement->errorMessage == expectedError) {
TestQueryResult& testAnswer = statement->result[resultIdx];
std::string expectedError;
switch (testAnswer.type) {
case ResultType::OK: {
return result->isSuccess();
}
case ResultType::ERROR_MSG: {
expectedError = StringUtils::rtrim(result->getErrorMessage());
if (testAnswer.expectedResult[0] == expectedError) {
return true;
}
spdlog::info("EXPECTED ERROR: {}", expectedError);
} else if (statement->expectedErrorRegex) {
std::string expectedError = StringUtils::rtrim(result->getErrorMessage());
std::regex pattern(statement->errorMessage);
if (std::regex_match(expectedError, pattern)) {
break;
}
case ResultType::ERROR_REGEX: {
expectedError = StringUtils::rtrim(result->getErrorMessage());
if (std::regex_match(expectedError, std::regex(testAnswer.expectedResult[0]))) {
return true;
}
spdlog::info("EXPECTED ERROR: {}", expectedError);
} else if (statement->expectedOk && result->isSuccess()) {
return true;
} else {
break;
}
default: {
if (!preparedStatement->success) {
spdlog::info("Query compilation failed with error: {}",
preparedStatement->getErrorMessage());
return false;
}
auto planStr = preparedStatement->logicalPlans[planIdx]->toString();
if (checkPlanResult(result, statement, planStr, planIdx)) {
if (checkPlanResult(result, statement, resultIdx, planStr, planIdx)) {
return true;
}
break;
}
}
return false;
}

bool TestRunner::checkPlanResult(std::unique_ptr<QueryResult>& result, TestStatement* statement,
const std::string& planStr, uint32_t planIdx) {

if (!statement->expectedTuplesCSVFile.empty()) {
std::ifstream expectedTuplesFile(statement->expectedTuplesCSVFile);
size_t resultIdx, const std::string& planStr, uint32_t planIdx) {
TestQueryResult& testAnswer = statement->result[resultIdx];
if (testAnswer.type == ResultType::CSV_FILE) {
std::ifstream expectedTuplesFile(testAnswer.expectedResult[0]);
if (!expectedTuplesFile.is_open()) {
throw TestException("Cannot open file: " + statement->expectedTuplesCSVFile);
throw TestException("Cannot open file: " + testAnswer.expectedResult[0]);
}
std::string line;
testAnswer.expectedResult.clear();
while (std::getline(expectedTuplesFile, line)) {
statement->expectedTuples.push_back(line);
testAnswer.expectedResult.push_back(line);
}
if (!statement->checkOutputOrder) {
sort(statement->expectedTuples.begin(), statement->expectedTuples.end());
sort(testAnswer.expectedResult.begin(), testAnswer.expectedResult.end());
}
}
std::vector<std::string> resultTuples = TestRunner::convertResultToString(*result,
Expand All @@ -145,11 +160,11 @@ bool TestRunner::checkPlanResult(std::unique_ptr<QueryResult>& result, TestState
if (statement->checkColumnNames) {
actualNumTuples++;
}
if (statement->expectHash) {
if (testAnswer.type == ResultType::HASH) {
std::string resultHash = TestRunner::convertResultToMD5Hash(*result,
statement->checkOutputOrder, statement->checkColumnNames);
if (resultTuples.size() == actualNumTuples && resultHash == statement->expectedHashValue &&
resultTuples.size() == statement->expectedNumTuples) {
if (resultTuples.size() == actualNumTuples && resultHash == testAnswer.expectedResult[0] &&
resultTuples.size() == testAnswer.numTuples) {
spdlog::info("PLAN{} PASSED in {}ms.", planIdx,
result->getQuerySummary()->getExecutionTime());
return true;
Expand All @@ -164,7 +179,7 @@ bool TestRunner::checkPlanResult(std::unique_ptr<QueryResult>& result, TestState
return false;
}
}
if (resultTuples.size() == actualNumTuples && resultTuples == statement->expectedTuples) {
if (resultTuples.size() == actualNumTuples && resultTuples == testAnswer.expectedResult) {
spdlog::info("PLAN{} PASSED in {}ms.", planIdx,
result->getQuerySummary()->getExecutionTime());
return true;
Expand Down

0 comments on commit cc43956

Please sign in to comment.