Skip to content

Commit

Permalink
Rework multiple query result (#3191)
Browse files Browse the repository at this point in the history
* multiple query results

* Multiple query results C API fix (#3195)

---------

Co-authored-by: 囧囧 <liuc223@gmail.com>
  • Loading branch information
hououou and mewim authored Apr 4, 2024
1 parent 37de692 commit ec6e309
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 24 deletions.
8 changes: 5 additions & 3 deletions src/c_api/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ kuzu_query_result* kuzu_connection_query(kuzu_connection* connection, const char
if (query_result == nullptr) {
return nullptr;
}
auto* c_query_result = new kuzu_query_result;
auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result));
c_query_result->_query_result = query_result;
c_query_result->_is_owned_by_cpp = false;
return c_query_result;
} catch (Exception& e) { return nullptr; }
}
Expand All @@ -64,7 +65,7 @@ kuzu_prepared_statement* kuzu_connection_prepare(kuzu_connection* connection, co
if (prepared_statement == nullptr) {
return nullptr;
}
auto* c_prepared_statement = new kuzu_prepared_statement;
auto* c_prepared_statement = (kuzu_prepared_statement*)malloc(sizeof(kuzu_prepared_statement));
c_prepared_statement->_prepared_statement = prepared_statement;
c_prepared_statement->_bound_values =
new std::unordered_map<std::string, std::unique_ptr<Value>>;
Expand Down Expand Up @@ -92,8 +93,9 @@ kuzu_query_result* kuzu_connection_execute(
if (query_result == nullptr) {
return nullptr;
}
auto* c_query_result = new kuzu_query_result;
auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result));
c_query_result->_query_result = query_result;
c_query_result->_is_owned_by_cpp = false;
return c_query_result;
}

Expand Down
2 changes: 1 addition & 1 deletion src/c_api/prepared_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void kuzu_prepared_statement_destroy(kuzu_prepared_statement* prepared_statement
delete static_cast<std::unordered_map<std::string, std::unique_ptr<Value>>*>(
prepared_statement->_bound_values);
}
delete prepared_statement;
free(prepared_statement);
}

bool kuzu_prepared_statement_allow_active_transaction(kuzu_prepared_statement* prepared_statement) {
Expand Down
22 changes: 20 additions & 2 deletions src/c_api/query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ void kuzu_query_result_destroy(kuzu_query_result* query_result) {
return;
}
if (query_result->_query_result != nullptr) {
delete static_cast<QueryResult*>(query_result->_query_result);
if (!query_result->_is_owned_by_cpp) {
delete static_cast<QueryResult*>(query_result->_query_result);
}
}
delete query_result;
free(query_result);
}

bool kuzu_query_result_is_success(kuzu_query_result* query_result) {
Expand Down Expand Up @@ -69,6 +71,22 @@ bool kuzu_query_result_has_next(kuzu_query_result* query_result) {
return static_cast<QueryResult*>(query_result->_query_result)->hasNext();
}

bool kuzu_query_result_has_next_query_result(kuzu_query_result* query_result) {
return static_cast<QueryResult*>(query_result->_query_result)->hasNextQueryResult();
}

kuzu_query_result* kuzu_query_result_get_next_query_result(kuzu_query_result* query_result) {
auto next_query_result =
static_cast<QueryResult*>(query_result->_query_result)->getNextQueryResult();
if (next_query_result == nullptr) {
return nullptr;
}
auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result));
c_query_result->_query_result = next_query_result;
c_query_result->_is_owned_by_cpp = true;
return c_query_result;
}

kuzu_flat_tuple* kuzu_query_result_get_next(kuzu_query_result* query_result) {
auto flat_tuple = static_cast<QueryResult*>(query_result->_query_result)->getNext();
auto* flat_tuple_c = (kuzu_flat_tuple*)malloc(sizeof(kuzu_flat_tuple));
Expand Down
15 changes: 15 additions & 0 deletions src/include/c_api/kuzu.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ typedef struct {
*/
typedef struct {
void* _query_result;
bool _is_owned_by_cpp;
} kuzu_query_result;

/**
Expand Down Expand Up @@ -630,6 +631,20 @@ KUZU_C_API bool kuzu_query_result_has_next(kuzu_query_result* query_result);
* @param query_result The query result instance to return.
*/
KUZU_C_API kuzu_flat_tuple* kuzu_query_result_get_next(kuzu_query_result* query_result);
/**
* @brief Returns true if we have not consumed all query results, false otherwise. Use this function
* for loop results of multiple query statements
* @param query_result The query result instance to check.
*/
KUZU_C_API bool kuzu_query_result_has_next_query_result(kuzu_query_result* query_result);
/**
* @brief Returns the next query result. Use this function to loop multiple query statements'
* results.
* @param query_result The query result instance to return.
*/
KUZU_C_API kuzu_query_result* kuzu_query_result_get_next_query_result(
kuzu_query_result* query_result);

/**
* @brief Returns the query result as a string.
* @param query_result The query result instance to return.
Expand Down
22 changes: 17 additions & 5 deletions src/include/main/query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class QueryResult {
QueryResult* currentResult;

public:
QueryResultIterator() = default;

explicit QueryResultIterator(QueryResult* startResult) : currentResult(startResult) {}

void operator++() {
Expand All @@ -46,9 +48,9 @@ class QueryResult {
}
}

bool isEnd() { return currentResult == nullptr; }
bool isEnd() const { return currentResult == nullptr; }

QueryResult* getCurrentResult() { return currentResult; }
QueryResult* getCurrentResult() const { return currentResult; }
};

public:
Expand Down Expand Up @@ -97,15 +99,22 @@ class QueryResult {
* @return whether there are more tuples to read.
*/
KUZU_API bool hasNext() const;
std::unique_ptr<QueryResult> nextQueryResult;
/**
* @return whether there are more query results to read.
*/
KUZU_API bool hasNextQueryResult() const;
/**
* @return get next query result to read (for multiple query statements).
*/
KUZU_API QueryResult* getNextQueryResult();

std::string toSingleQueryString();
std::unique_ptr<QueryResult> nextQueryResult;
/**
* @return next flat tuple in the query result.
*/
KUZU_API std::shared_ptr<processor::FlatTuple> getNext();
/**
* @return string of query result.
* @return string of first query result.
*/
KUZU_API std::string toString();

Expand Down Expand Up @@ -158,6 +167,9 @@ class QueryResult {

// execution statistics
std::unique_ptr<QuerySummary> querySummary;

// query iterator
QueryResultIterator queryResultIterator;
};

} // namespace main
Expand Down
1 change: 1 addition & 0 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ std::unique_ptr<QueryResult> ClientContext::queryResultWithError(std::string_vie
queryResult->success = false;
queryResult->errMsg = errMsg;
queryResult->nextQueryResult = nullptr;
queryResult->queryResultIterator = QueryResult::QueryResultIterator{queryResult.get()};
return queryResult;
}

Expand Down
26 changes: 16 additions & 10 deletions src/main/query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ QueryResult::QueryResult(const PreparedSummary& preparedSummary) {
querySummary = std::make_unique<QuerySummary>();
querySummary->setPreparedSummary(preparedSummary);
nextQueryResult = nullptr;
queryResultIterator = QueryResultIterator{this};
}

QueryResult::~QueryResult() = default;
Expand Down Expand Up @@ -152,6 +153,21 @@ bool QueryResult::hasNext() const {
return iterator->hasNextFlatTuple();
}

bool QueryResult::hasNextQueryResult() const {
if (!queryResultIterator.isEnd()) {
return true;
}
return false;
}

QueryResult* QueryResult::getNextQueryResult() {
++queryResultIterator;
if (!queryResultIterator.isEnd()) {
return queryResultIterator.getCurrentResult();
}
return nullptr;
}

std::shared_ptr<FlatTuple> QueryResult::getNext() {
if (!hasNext()) {
throw RuntimeException(
Expand All @@ -163,16 +179,6 @@ std::shared_ptr<FlatTuple> QueryResult::getNext() {
}

std::string QueryResult::toString() {
std::string result;
QueryResultIterator it(this);
while (!it.isEnd()) {
result += it.getCurrentResult()->toSingleQueryString() + "\n";
++it;
}
return result;
}

std::string QueryResult::toSingleQueryString() {
std::string result;
if (isSuccess()) {
// print header
Expand Down
27 changes: 27 additions & 0 deletions test/c_api/query_result_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,30 @@ TEST_F(CApiQueryResultTest, ResetIterator) {

kuzu_query_result_destroy(result);
}

TEST_F(CApiQueryResultTest, MultipleQuery) {
auto connection = getConnection();
auto result = kuzu_connection_query(connection, "return 1; return 2; return 3;");
ASSERT_TRUE(kuzu_query_result_is_success(result));

auto str = kuzu_query_result_to_string(result);
ASSERT_EQ(std::string(str), "1\n1\n");
kuzu_destroy_string(str);

ASSERT_TRUE(kuzu_query_result_has_next_query_result(result));
auto next_query_result = kuzu_query_result_get_next_query_result(result);
ASSERT_TRUE(kuzu_query_result_is_success(next_query_result));
str = kuzu_query_result_to_string(next_query_result);
ASSERT_EQ(std::string(str), "2\n2\n");
kuzu_destroy_string(str);
kuzu_query_result_destroy(next_query_result);

next_query_result = kuzu_query_result_get_next_query_result(result);
ASSERT_TRUE(kuzu_query_result_is_success(next_query_result));
str = kuzu_query_result_to_string(next_query_result);
ASSERT_EQ(std::string(str), "3\n3\n");
kuzu_destroy_string(str);
kuzu_query_result_destroy(next_query_result);

kuzu_query_result_destroy(result);
}
14 changes: 11 additions & 3 deletions test/main/connection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,18 @@ TEST_F(ApiTest, MultipleQuery) {
"MATCH (a:person) RETURN a.fName; MATCH (a:person)-[:knows]->(b:person) RETURN count(*);");
ASSERT_TRUE(result->isSuccess());

result =
conn->query("CREATE NODE TABLE N(ID INT64, PRIMARY KEY(ID));CREATE REL TABLE E(FROM N TO "
"N, MANY_MANY);MATCH (a:N)-[:E]->(b:N) WHERE a.ID = 0 return b.ID;");
result = conn->query("CREATE NODE TABLE Test(name STRING, age INT64, PRIMARY KEY(name));CREATE "
"(:Test {name: 'Alice', age: 25});"
"MATCH (a:Test) where a.name='Alice' return a.age;");
ASSERT_TRUE(result->isSuccess());

result = conn->query("return 1; return 2; return 3;");
ASSERT_TRUE(result->isSuccess());
ASSERT_EQ(result->toString(), "1\n1\n");
ASSERT_TRUE(result->hasNextQueryResult());
ASSERT_EQ(result->getNextQueryResult()->toString(), "2\n2\n");
ASSERT_TRUE(result->hasNextQueryResult());
ASSERT_EQ(result->getNextQueryResult()->toString(), "3\n3\n");
}

TEST_F(ApiTest, Prepare) {
Expand Down

0 comments on commit ec6e309

Please sign in to comment.