diff --git a/src/c_api/connection.cpp b/src/c_api/connection.cpp index acd9b91e3e8..754db57a58e 100644 --- a/src/c_api/connection.cpp +++ b/src/c_api/connection.cpp @@ -52,8 +52,9 @@ kuzu_query_result* kuzu_connection_query(kuzu_connection* connection, const char if (query_result == nullptr) { return nullptr; } - auto* c_query_result = new kuzu_query_result; + auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result)); c_query_result->_query_result = query_result; + c_query_result->_is_owned_by_cpp = false; return c_query_result; } catch (Exception& e) { return nullptr; } } @@ -64,7 +65,7 @@ kuzu_prepared_statement* kuzu_connection_prepare(kuzu_connection* connection, co if (prepared_statement == nullptr) { return nullptr; } - auto* c_prepared_statement = new kuzu_prepared_statement; + auto* c_prepared_statement = (kuzu_prepared_statement*)malloc(sizeof(kuzu_prepared_statement)); c_prepared_statement->_prepared_statement = prepared_statement; c_prepared_statement->_bound_values = new std::unordered_map>; @@ -92,8 +93,9 @@ kuzu_query_result* kuzu_connection_execute( if (query_result == nullptr) { return nullptr; } - auto* c_query_result = new kuzu_query_result; + auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result)); c_query_result->_query_result = query_result; + c_query_result->_is_owned_by_cpp = false; return c_query_result; } diff --git a/src/c_api/prepared_statement.cpp b/src/c_api/prepared_statement.cpp index c6d5e485eb9..5c1464459f9 100644 --- a/src/c_api/prepared_statement.cpp +++ b/src/c_api/prepared_statement.cpp @@ -25,7 +25,7 @@ void kuzu_prepared_statement_destroy(kuzu_prepared_statement* prepared_statement delete static_cast>*>( prepared_statement->_bound_values); } - delete prepared_statement; + free(prepared_statement); } bool kuzu_prepared_statement_allow_active_transaction(kuzu_prepared_statement* prepared_statement) { diff --git a/src/c_api/query_result.cpp b/src/c_api/query_result.cpp index 8bdc49d2b38..ec839c95aef 100644 --- a/src/c_api/query_result.cpp +++ b/src/c_api/query_result.cpp @@ -12,9 +12,11 @@ void kuzu_query_result_destroy(kuzu_query_result* query_result) { return; } if (query_result->_query_result != nullptr) { - delete static_cast(query_result->_query_result); + if (!query_result->_is_owned_by_cpp) { + delete static_cast(query_result->_query_result); + } } - delete query_result; + free(query_result); } bool kuzu_query_result_is_success(kuzu_query_result* query_result) { @@ -69,6 +71,22 @@ bool kuzu_query_result_has_next(kuzu_query_result* query_result) { return static_cast(query_result->_query_result)->hasNext(); } +bool kuzu_query_result_has_next_query_result(kuzu_query_result* query_result) { + return static_cast(query_result->_query_result)->hasNextQueryResult(); +} + +kuzu_query_result* kuzu_query_result_get_next_query_result(kuzu_query_result* query_result) { + auto next_query_result = + static_cast(query_result->_query_result)->getNextQueryResult(); + if (next_query_result == nullptr) { + return nullptr; + } + auto* c_query_result = (kuzu_query_result*)malloc(sizeof(kuzu_query_result)); + c_query_result->_query_result = next_query_result; + c_query_result->_is_owned_by_cpp = true; + return c_query_result; +} + kuzu_flat_tuple* kuzu_query_result_get_next(kuzu_query_result* query_result) { auto flat_tuple = static_cast(query_result->_query_result)->getNext(); auto* flat_tuple_c = (kuzu_flat_tuple*)malloc(sizeof(kuzu_flat_tuple)); diff --git a/src/include/c_api/kuzu.h b/src/include/c_api/kuzu.h index f27677d303c..ca1ef568aaf 100644 --- a/src/include/c_api/kuzu.h +++ b/src/include/c_api/kuzu.h @@ -154,6 +154,7 @@ typedef struct { */ typedef struct { void* _query_result; + bool _is_owned_by_cpp; } kuzu_query_result; /** @@ -630,6 +631,20 @@ KUZU_C_API bool kuzu_query_result_has_next(kuzu_query_result* query_result); * @param query_result The query result instance to return. */ KUZU_C_API kuzu_flat_tuple* kuzu_query_result_get_next(kuzu_query_result* query_result); +/** + * @brief Returns true if we have not consumed all query results, false otherwise. Use this function + * for loop results of multiple query statements + * @param query_result The query result instance to check. + */ +KUZU_C_API bool kuzu_query_result_has_next_query_result(kuzu_query_result* query_result); +/** + * @brief Returns the next query result. Use this function to loop multiple query statements' + * results. + * @param query_result The query result instance to return. + */ +KUZU_C_API kuzu_query_result* kuzu_query_result_get_next_query_result( + kuzu_query_result* query_result); + /** * @brief Returns the query result as a string. * @param query_result The query result instance to return. diff --git a/src/include/main/query_result.h b/src/include/main/query_result.h index de517b4c844..7cac6b82d14 100644 --- a/src/include/main/query_result.h +++ b/src/include/main/query_result.h @@ -38,6 +38,8 @@ class QueryResult { QueryResult* currentResult; public: + QueryResultIterator() = default; + explicit QueryResultIterator(QueryResult* startResult) : currentResult(startResult) {} void operator++() { @@ -46,9 +48,9 @@ class QueryResult { } } - bool isEnd() { return currentResult == nullptr; } + bool isEnd() const { return currentResult == nullptr; } - QueryResult* getCurrentResult() { return currentResult; } + QueryResult* getCurrentResult() const { return currentResult; } }; public: @@ -97,15 +99,22 @@ class QueryResult { * @return whether there are more tuples to read. */ KUZU_API bool hasNext() const; - std::unique_ptr nextQueryResult; + /** + * @return whether there are more query results to read. + */ + KUZU_API bool hasNextQueryResult() const; + /** + * @return get next query result to read (for multiple query statements). + */ + KUZU_API QueryResult* getNextQueryResult(); - std::string toSingleQueryString(); + std::unique_ptr nextQueryResult; /** * @return next flat tuple in the query result. */ KUZU_API std::shared_ptr getNext(); /** - * @return string of query result. + * @return string of first query result. */ KUZU_API std::string toString(); @@ -158,6 +167,9 @@ class QueryResult { // execution statistics std::unique_ptr querySummary; + + // query iterator + QueryResultIterator queryResultIterator; }; } // namespace main diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 1315025e15f..41368387aa1 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -262,6 +262,7 @@ std::unique_ptr ClientContext::queryResultWithError(std::string_vie queryResult->success = false; queryResult->errMsg = errMsg; queryResult->nextQueryResult = nullptr; + queryResult->queryResultIterator = QueryResult::QueryResultIterator{queryResult.get()}; return queryResult; } diff --git a/src/main/query_result.cpp b/src/main/query_result.cpp index af96d6038cd..bfffc8628b0 100644 --- a/src/main/query_result.cpp +++ b/src/main/query_result.cpp @@ -55,6 +55,7 @@ QueryResult::QueryResult(const PreparedSummary& preparedSummary) { querySummary = std::make_unique(); querySummary->setPreparedSummary(preparedSummary); nextQueryResult = nullptr; + queryResultIterator = QueryResultIterator{this}; } QueryResult::~QueryResult() = default; @@ -152,6 +153,21 @@ bool QueryResult::hasNext() const { return iterator->hasNextFlatTuple(); } +bool QueryResult::hasNextQueryResult() const { + if (!queryResultIterator.isEnd()) { + return true; + } + return false; +} + +QueryResult* QueryResult::getNextQueryResult() { + ++queryResultIterator; + if (!queryResultIterator.isEnd()) { + return queryResultIterator.getCurrentResult(); + } + return nullptr; +} + std::shared_ptr QueryResult::getNext() { if (!hasNext()) { throw RuntimeException( @@ -163,16 +179,6 @@ std::shared_ptr QueryResult::getNext() { } std::string QueryResult::toString() { - std::string result; - QueryResultIterator it(this); - while (!it.isEnd()) { - result += it.getCurrentResult()->toSingleQueryString() + "\n"; - ++it; - } - return result; -} - -std::string QueryResult::toSingleQueryString() { std::string result; if (isSuccess()) { // print header diff --git a/test/c_api/query_result_test.cpp b/test/c_api/query_result_test.cpp index 48dd7be9f24..128d786ade3 100644 --- a/test/c_api/query_result_test.cpp +++ b/test/c_api/query_result_test.cpp @@ -176,3 +176,30 @@ TEST_F(CApiQueryResultTest, ResetIterator) { kuzu_query_result_destroy(result); } + +TEST_F(CApiQueryResultTest, MultipleQuery) { + auto connection = getConnection(); + auto result = kuzu_connection_query(connection, "return 1; return 2; return 3;"); + ASSERT_TRUE(kuzu_query_result_is_success(result)); + + auto str = kuzu_query_result_to_string(result); + ASSERT_EQ(std::string(str), "1\n1\n"); + kuzu_destroy_string(str); + + ASSERT_TRUE(kuzu_query_result_has_next_query_result(result)); + auto next_query_result = kuzu_query_result_get_next_query_result(result); + ASSERT_TRUE(kuzu_query_result_is_success(next_query_result)); + str = kuzu_query_result_to_string(next_query_result); + ASSERT_EQ(std::string(str), "2\n2\n"); + kuzu_destroy_string(str); + kuzu_query_result_destroy(next_query_result); + + next_query_result = kuzu_query_result_get_next_query_result(result); + ASSERT_TRUE(kuzu_query_result_is_success(next_query_result)); + str = kuzu_query_result_to_string(next_query_result); + ASSERT_EQ(std::string(str), "3\n3\n"); + kuzu_destroy_string(str); + kuzu_query_result_destroy(next_query_result); + + kuzu_query_result_destroy(result); +} diff --git a/test/main/connection_test.cpp b/test/main/connection_test.cpp index d7bda29f578..8a6538b19b1 100644 --- a/test/main/connection_test.cpp +++ b/test/main/connection_test.cpp @@ -114,10 +114,18 @@ TEST_F(ApiTest, MultipleQuery) { "MATCH (a:person) RETURN a.fName; MATCH (a:person)-[:knows]->(b:person) RETURN count(*);"); ASSERT_TRUE(result->isSuccess()); - result = - conn->query("CREATE NODE TABLE N(ID INT64, PRIMARY KEY(ID));CREATE REL TABLE E(FROM N TO " - "N, MANY_MANY);MATCH (a:N)-[:E]->(b:N) WHERE a.ID = 0 return b.ID;"); + result = conn->query("CREATE NODE TABLE Test(name STRING, age INT64, PRIMARY KEY(name));CREATE " + "(:Test {name: 'Alice', age: 25});" + "MATCH (a:Test) where a.name='Alice' return a.age;"); + ASSERT_TRUE(result->isSuccess()); + + result = conn->query("return 1; return 2; return 3;"); ASSERT_TRUE(result->isSuccess()); + ASSERT_EQ(result->toString(), "1\n1\n"); + ASSERT_TRUE(result->hasNextQueryResult()); + ASSERT_EQ(result->getNextQueryResult()->toString(), "2\n2\n"); + ASSERT_TRUE(result->hasNextQueryResult()); + ASSERT_EQ(result->getNextQueryResult()->toString(), "3\n3\n"); } TEST_F(ApiTest, Prepare) {