Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework multiple query result #3191

Merged
merged 3 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/c_api/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ kuzu_query_result* kuzu_connection_query(kuzu_connection* connection, const char
if (query_result == nullptr) {
return nullptr;
}
auto* c_query_result = new kuzu_query_result;
auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result));
c_query_result->_query_result = query_result;
c_query_result->_is_owned_by_cpp = false;
return c_query_result;
} catch (Exception& e) { return nullptr; }
}
Expand All @@ -64,7 +65,7 @@ kuzu_prepared_statement* kuzu_connection_prepare(kuzu_connection* connection, co
if (prepared_statement == nullptr) {
return nullptr;
}
auto* c_prepared_statement = new kuzu_prepared_statement;
auto* c_prepared_statement = (kuzu_prepared_statement*)malloc(sizeof(kuzu_prepared_statement));
c_prepared_statement->_prepared_statement = prepared_statement;
c_prepared_statement->_bound_values =
new std::unordered_map<std::string, std::unique_ptr<Value>>;
Expand Down Expand Up @@ -92,8 +93,9 @@ kuzu_query_result* kuzu_connection_execute(
if (query_result == nullptr) {
return nullptr;
}
auto* c_query_result = new kuzu_query_result;
auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result));
c_query_result->_query_result = query_result;
c_query_result->_is_owned_by_cpp = false;
return c_query_result;
}

Expand Down
2 changes: 1 addition & 1 deletion src/c_api/prepared_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void kuzu_prepared_statement_destroy(kuzu_prepared_statement* prepared_statement
delete static_cast<std::unordered_map<std::string, std::unique_ptr<Value>>*>(
prepared_statement->_bound_values);
}
delete prepared_statement;
free(prepared_statement);
}

bool kuzu_prepared_statement_allow_active_transaction(kuzu_prepared_statement* prepared_statement) {
Expand Down
22 changes: 20 additions & 2 deletions src/c_api/query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ void kuzu_query_result_destroy(kuzu_query_result* query_result) {
return;
}
if (query_result->_query_result != nullptr) {
delete static_cast<QueryResult*>(query_result->_query_result);
if (!query_result->_is_owned_by_cpp) {
delete static_cast<QueryResult*>(query_result->_query_result);
}
}
delete query_result;
free(query_result);
}

bool kuzu_query_result_is_success(kuzu_query_result* query_result) {
Expand Down Expand Up @@ -69,6 +71,22 @@ bool kuzu_query_result_has_next(kuzu_query_result* query_result) {
return static_cast<QueryResult*>(query_result->_query_result)->hasNext();
}

bool kuzu_query_result_has_next_query_result(kuzu_query_result* query_result) {
return static_cast<QueryResult*>(query_result->_query_result)->hasNextQueryResult();
}

kuzu_query_result* kuzu_query_result_get_next_query_result(kuzu_query_result* query_result) {
auto next_query_result =
static_cast<QueryResult*>(query_result->_query_result)->getNextQueryResult();
if (next_query_result == nullptr) {
return nullptr;
}
auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result));
c_query_result->_query_result = next_query_result;
c_query_result->_is_owned_by_cpp = true;
return c_query_result;
}

kuzu_flat_tuple* kuzu_query_result_get_next(kuzu_query_result* query_result) {
auto flat_tuple = static_cast<QueryResult*>(query_result->_query_result)->getNext();
auto* flat_tuple_c = (kuzu_flat_tuple*)malloc(sizeof(kuzu_flat_tuple));
Expand Down
15 changes: 15 additions & 0 deletions src/include/c_api/kuzu.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ typedef struct {
*/
typedef struct {
void* _query_result;
bool _is_owned_by_cpp;
} kuzu_query_result;

/**
Expand Down Expand Up @@ -630,6 +631,20 @@ KUZU_C_API bool kuzu_query_result_has_next(kuzu_query_result* query_result);
* @param query_result The query result instance to return.
*/
KUZU_C_API kuzu_flat_tuple* kuzu_query_result_get_next(kuzu_query_result* query_result);
/**
* @brief Returns true if we have not consumed all query results, false otherwise. Use this function
* for loop results of multiple query statements
* @param query_result The query result instance to check.
*/
KUZU_C_API bool kuzu_query_result_has_next_query_result(kuzu_query_result* query_result);
/**
* @brief Returns the next query result. Use this function to loop multiple query statements'
* results.
* @param query_result The query result instance to return.
*/
KUZU_C_API kuzu_query_result* kuzu_query_result_get_next_query_result(
kuzu_query_result* query_result);

/**
* @brief Returns the query result as a string.
* @param query_result The query result instance to return.
Expand Down
22 changes: 17 additions & 5 deletions src/include/main/query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class QueryResult {
QueryResult* currentResult;

public:
QueryResultIterator() = default;

explicit QueryResultIterator(QueryResult* startResult) : currentResult(startResult) {}

void operator++() {
Expand All @@ -46,9 +48,9 @@ class QueryResult {
}
}

bool isEnd() { return currentResult == nullptr; }
bool isEnd() const { return currentResult == nullptr; }

QueryResult* getCurrentResult() { return currentResult; }
QueryResult* getCurrentResult() const { return currentResult; }
};

public:
Expand Down Expand Up @@ -97,15 +99,22 @@ class QueryResult {
* @return whether there are more tuples to read.
*/
KUZU_API bool hasNext() const;
std::unique_ptr<QueryResult> nextQueryResult;
/**
* @return whether there are more query results to read.
*/
KUZU_API bool hasNextQueryResult() const;
/**
* @return get next query result to read (for multiple query statements).
*/
KUZU_API QueryResult* getNextQueryResult();

std::string toSingleQueryString();
std::unique_ptr<QueryResult> nextQueryResult;
/**
* @return next flat tuple in the query result.
*/
KUZU_API std::shared_ptr<processor::FlatTuple> getNext();
/**
* @return string of query result.
* @return string of first query result.
*/
KUZU_API std::string toString();

Expand Down Expand Up @@ -158,6 +167,9 @@ class QueryResult {

// execution statistics
std::unique_ptr<QuerySummary> querySummary;

// query iterator
QueryResultIterator queryResultIterator;
};

} // namespace main
Expand Down
1 change: 1 addition & 0 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ std::unique_ptr<QueryResult> ClientContext::queryResultWithError(std::string_vie
queryResult->success = false;
queryResult->errMsg = errMsg;
queryResult->nextQueryResult = nullptr;
queryResult->queryResultIterator = QueryResult::QueryResultIterator{queryResult.get()};
return queryResult;
}

Expand Down
26 changes: 16 additions & 10 deletions src/main/query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ QueryResult::QueryResult(const PreparedSummary& preparedSummary) {
querySummary = std::make_unique<QuerySummary>();
querySummary->setPreparedSummary(preparedSummary);
nextQueryResult = nullptr;
queryResultIterator = QueryResultIterator{this};
}

QueryResult::~QueryResult() = default;
Expand Down Expand Up @@ -152,6 +153,21 @@ bool QueryResult::hasNext() const {
return iterator->hasNextFlatTuple();
}

bool QueryResult::hasNextQueryResult() const {
if (!queryResultIterator.isEnd()) {
return true;
}
return false;
}

QueryResult* QueryResult::getNextQueryResult() {
++queryResultIterator;
if (!queryResultIterator.isEnd()) {
return queryResultIterator.getCurrentResult();
}
return nullptr;
}

std::shared_ptr<FlatTuple> QueryResult::getNext() {
if (!hasNext()) {
throw RuntimeException(
Expand All @@ -163,16 +179,6 @@ std::shared_ptr<FlatTuple> QueryResult::getNext() {
}

std::string QueryResult::toString() {
std::string result;
QueryResultIterator it(this);
while (!it.isEnd()) {
result += it.getCurrentResult()->toSingleQueryString() + "\n";
++it;
}
return result;
}

std::string QueryResult::toSingleQueryString() {
std::string result;
if (isSuccess()) {
// print header
Expand Down
27 changes: 27 additions & 0 deletions test/c_api/query_result_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,30 @@ TEST_F(CApiQueryResultTest, ResetIterator) {

kuzu_query_result_destroy(result);
}

TEST_F(CApiQueryResultTest, MultipleQuery) {
auto connection = getConnection();
auto result = kuzu_connection_query(connection, "return 1; return 2; return 3;");
ASSERT_TRUE(kuzu_query_result_is_success(result));

auto str = kuzu_query_result_to_string(result);
ASSERT_EQ(std::string(str), "1\n1\n");
kuzu_destroy_string(str);

ASSERT_TRUE(kuzu_query_result_has_next_query_result(result));
auto next_query_result = kuzu_query_result_get_next_query_result(result);
ASSERT_TRUE(kuzu_query_result_is_success(next_query_result));
str = kuzu_query_result_to_string(next_query_result);
ASSERT_EQ(std::string(str), "2\n2\n");
kuzu_destroy_string(str);
kuzu_query_result_destroy(next_query_result);

next_query_result = kuzu_query_result_get_next_query_result(result);
ASSERT_TRUE(kuzu_query_result_is_success(next_query_result));
str = kuzu_query_result_to_string(next_query_result);
ASSERT_EQ(std::string(str), "3\n3\n");
kuzu_destroy_string(str);
kuzu_query_result_destroy(next_query_result);

kuzu_query_result_destroy(result);
}
14 changes: 11 additions & 3 deletions test/main/connection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,18 @@ TEST_F(ApiTest, MultipleQuery) {
"MATCH (a:person) RETURN a.fName; MATCH (a:person)-[:knows]->(b:person) RETURN count(*);");
ASSERT_TRUE(result->isSuccess());

result =
conn->query("CREATE NODE TABLE N(ID INT64, PRIMARY KEY(ID));CREATE REL TABLE E(FROM N TO "
"N, MANY_MANY);MATCH (a:N)-[:E]->(b:N) WHERE a.ID = 0 return b.ID;");
result = conn->query("CREATE NODE TABLE Test(name STRING, age INT64, PRIMARY KEY(name));CREATE "
"(:Test {name: 'Alice', age: 25});"
"MATCH (a:Test) where a.name='Alice' return a.age;");
ASSERT_TRUE(result->isSuccess());

result = conn->query("return 1; return 2; return 3;");
ASSERT_TRUE(result->isSuccess());
ASSERT_EQ(result->toString(), "1\n1\n");
ASSERT_TRUE(result->hasNextQueryResult());
ASSERT_EQ(result->getNextQueryResult()->toString(), "2\n2\n");
ASSERT_TRUE(result->hasNextQueryResult());
ASSERT_EQ(result->getNextQueryResult()->toString(), "3\n3\n");
}

TEST_F(ApiTest, Prepare) {
Expand Down
Loading