From 53e5f7172bdc30b4dc0f75c1db18745012b5b051 Mon Sep 17 00:00:00 2001 From: Fredy Date: Mon, 8 Nov 2021 03:50:15 +0100 Subject: [PATCH] Fixes of several broken functions --- README.md | 12 ++--- src/BlockingQueue.h | 103 +++++++++++++++++++------------------ src/lua/LuaDatabase.cpp | 12 ++--- src/lua/LuaDatabase.h | 2 +- src/lua/LuaIQuery.cpp | 26 ++++++++-- src/lua/LuaIQuery.h | 4 +- src/lua/LuaObject.cpp | 4 ++ src/lua/LuaQuery.cpp | 11 ++-- src/lua/LuaQuery.h | 2 +- src/lua/LuaTransaction.cpp | 3 +- src/lua/LuaTransaction.h | 3 +- src/mysql/Database.cpp | 6 +-- src/mysql/IQuery.cpp | 11 ++-- src/mysql/Query.cpp | 8 +-- 14 files changed, 111 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index 378696a..4776926 100644 --- a/README.md +++ b/README.md @@ -170,16 +170,12 @@ Query:getData() Query:abort() -- Returns [Boolean] -- Attempts to abort the query if it is still in the state QUERY_WAITING --- Returns true if aborting was successful, false otherwise +-- Returns true if at least one running instance of the query was aborted successfully, false otherwise Query:lastInsert() -- Returns [Number] -- Gets the autoincrement index of the last inserted row of the current result set -Query:status() --- Returns [Number] (mysqloo.QUERY_* enums) --- Gets the status of the query. - Query:affectedRows() -- Returns [Number] -- Gets the number of rows the query has affected (of the current result set) @@ -198,11 +194,15 @@ Query:wait(shouldSwap) Query:error() -- Returns [String] --- Gets the error caused by the query (if any). +-- Gets the error caused by the query, or "" if there was no error. Query:hasMoreResults() -- Returns [Boolean] -- Returns true if the query still has more data associated with it (which means getNextResults() can be called) +-- Note: This function works unfortunately different that one would expect. +-- hasMoreResults() returns true if there is currently a result that can be popped, rather than if there is an +-- additional result that has data. However, this does make for a nicer code that handles multiple results. +-- See Examples/multi_results.lua for an example how to use it. Query:getNextResults() -- Returns [Table] diff --git a/src/BlockingQueue.h b/src/BlockingQueue.h index 32ef2c1..d3e13ac 100644 --- a/src/BlockingQueue.h +++ b/src/BlockingQueue.h @@ -1,70 +1,71 @@ #ifndef BLOCKING_QUEUE_ #define BLOCKING_QUEUE_ + #include #include #include #include -template +template class BlockingQueue { public: - void put(T elem) { - std::lock_guard lock(mutex); - backingQueue.push_back(elem); - waitObj.notify_all(); - } + void put(T elem) { + std::lock_guard lock(mutex); + backingQueue.push_back(elem); + waitObj.notify_all(); + } + + bool empty() { + return size() == 0; + } - bool empty() { - return size() == 0; - } + bool swapToFrontIf(std::function func) { + std::lock_guard lock(mutex); + auto pos = std::find_if(backingQueue.begin(), backingQueue.end(), func); + if (pos != backingQueue.begin() && pos != backingQueue.end()) { + std::iter_swap(pos, backingQueue.begin()); + return true; + } + return false; + } - bool swapToFrontIf(std::function func) { - std::lock_guard lock(mutex); - auto pos = std::find_if(backingQueue.begin(), backingQueue.end(), func); - if (pos != backingQueue.begin() && pos != backingQueue.end()) { - std::iter_swap(pos, backingQueue.begin()); - return true; - } - return false; - } + bool removeIf(std::function func) { + std::lock_guard lock(mutex); + auto it = std::remove_if(backingQueue.begin(), backingQueue.end(), func); + bool removed = it != backingQueue.end(); + backingQueue.erase(it, backingQueue.end()); + return removed; + } - bool removeIf(std::function func) { - std::lock_guard lock(mutex); - auto pos = std::find_if(backingQueue.begin(), backingQueue.end(), func); - if (pos != backingQueue.begin() && pos != backingQueue.end()) { - backingQueue.erase(pos); - return true; - } - return false; - } + void remove(T elem) { + std::lock_guard lock(mutex); + backingQueue.erase(std::remove(backingQueue.begin(), backingQueue.end(), elem), backingQueue.end()); + } - void remove(T elem) { - std::lock_guard lock(mutex); - backingQueue.erase(std::remove(backingQueue.begin(), backingQueue.end(), elem), backingQueue.end()); - } + size_t size() { + std::lock_guard lock(mutex); + return backingQueue.size(); + } - size_t size() { - std::lock_guard lock(mutex); - return backingQueue.size(); - } + T take() { + std::unique_lock lock(mutex); + while (size() == 0) waitObj.wait(lock); + auto front = backingQueue.front(); + backingQueue.pop_front(); + return front; + } - T take() { - std::unique_lock lock(mutex); - while (size() == 0) waitObj.wait(lock); - auto front = backingQueue.front(); - backingQueue.pop_front(); - return front; - } + std::deque clear() { + std::lock_guard lock(mutex); + std::deque returnQueue = backingQueue; + backingQueue.clear(); + return returnQueue; + } - std::deque clear() { - std::lock_guard lock(mutex); - std::deque returnQueue = backingQueue; - backingQueue.clear(); - return returnQueue; - } private: - std::deque backingQueue; - std::recursive_mutex mutex; - std::condition_variable_any waitObj; + std::deque backingQueue{}; + std::recursive_mutex mutex{}; + std::condition_variable_any waitObj{}; }; + #endif diff --git a/src/lua/LuaDatabase.cpp b/src/lua/LuaDatabase.cpp index 75acd7d..2795424 100644 --- a/src/lua/LuaDatabase.cpp +++ b/src/lua/LuaDatabase.cpp @@ -194,6 +194,7 @@ MYSQLOO_LUA_FUNCTION(abortAllQueries) { auto database = LuaObject::getLuaObject(LUA); auto abortedQueries = database->m_database->abortAllQueries(); for (const auto& pair: abortedQueries) { + LuaIQuery::runAbortedCallback(LUA, pair.second); LuaIQuery::finishQueryData(LUA, pair.first, pair.second); } LUA->PushNumber((double) abortedQueries.size()); @@ -215,6 +216,7 @@ MYSQLOO_LUA_FUNCTION(ping) { MYSQLOO_LUA_FUNCTION(wait) { auto database = LuaObject::getLuaObject(LUA); database->m_database->wait(); + database->think(LUA); //To set callback data, run callbacks return 0; } @@ -284,7 +286,7 @@ void LuaDatabase::createMetaTable(ILuaBase *LUA) { void LuaDatabase::think(ILuaBase *LUA) { //Connection callbacks - auto database = this->m_database.get(); + auto database = this->m_database; if (database->isConnectionDone() && !this->m_dbCallbackRan && this->m_tableReference != 0) { LUA->ReferencePush(this->m_tableReference); if (database->connectionSuccessful()) { @@ -313,13 +315,7 @@ void LuaDatabase::think(ILuaBase *LUA) { //Run callbacks of finished queries auto finishedQueries = database->takeFinishedQueries(); for (auto &pair: finishedQueries) { - auto data = pair.second; - if (data->m_tableReference != 0) { - LUA->ReferencePush(data->m_tableReference); - auto luaQuery = LuaIQuery::getLuaObject(LUA, -1); - LUA->Pop(); - luaQuery->runCallback(LUA, data); - } + LuaQuery::runCallback(LUA, pair.first, pair.second); } } diff --git a/src/lua/LuaDatabase.h b/src/lua/LuaDatabase.h index 24de9dd..8b693d5 100644 --- a/src/lua/LuaDatabase.h +++ b/src/lua/LuaDatabase.h @@ -14,7 +14,7 @@ class LuaDatabase : public LuaObject { static int create(lua_State *L); - void think(ILuaBase *lua); + void think(ILuaBase *LUA); int m_tableReference = 0; std::shared_ptr m_database; diff --git a/src/lua/LuaIQuery.cpp b/src/lua/LuaIQuery.cpp index 21e4a1c..dfde04d 100644 --- a/src/lua/LuaIQuery.cpp +++ b/src/lua/LuaIQuery.cpp @@ -1,5 +1,8 @@ #include "LuaIQuery.h" +#include "LuaQuery.h" +#include "LuaTransaction.h" +#include "LuaDatabase.h" MYSQLOO_LUA_FUNCTION(start) { @@ -22,6 +25,13 @@ MYSQLOO_LUA_FUNCTION(wait) { } auto query = LuaIQuery::getLuaObject(LUA); query->m_query->wait(shouldSwap); + if (query->m_databaseReference != 0) { + LUA->ReferencePush(query->m_databaseReference); + auto database = LuaObject::getLuaObject(LUA, -1); + database->think(LUA); + LUA->Pop(); + } + return 0; } @@ -50,8 +60,10 @@ MYSQLOO_LUA_FUNCTION(abort) { auto abortedData = query->m_query->abort(); for (auto &data: abortedData) { LuaIQuery::runAbortedCallback(LUA, data); + LuaIQuery::finishQueryData(LUA, query->m_query, data); } - return 0; + LUA->PushBool(!abortedData.empty()); + return 1; } void LuaIQuery::runAbortedCallback(ILuaBase *LUA, const std::shared_ptr &data) { @@ -140,8 +152,8 @@ void LuaIQuery::finishQueryData(GarrysMod::Lua::ILuaBase *LUA, const std::shared data->m_tableReference = 0; } -void LuaIQuery::runCallback(ILuaBase *LUA, const std::shared_ptr &data) { - m_query->setCallbackData(data); +void LuaIQuery::runCallback(ILuaBase *LUA, const std::shared_ptr &iQuery, const std::shared_ptr &data) { + iQuery->setCallbackData(data); auto status = data->getResultStatus(); switch (status) { @@ -151,11 +163,15 @@ void LuaIQuery::runCallback(ILuaBase *LUA, const std::shared_ptr &da runErrorCallback(LUA, data); break; case QUERY_SUCCESS: - runSuccessCallback(LUA, data); + if (auto query = std::dynamic_pointer_cast(iQuery)) { + LuaQuery::runSuccessCallback(LUA, query, std::dynamic_pointer_cast(data)); + } else if (auto transaction = std::dynamic_pointer_cast(query)) { + LuaTransaction::runSuccessCallback(LUA, transaction, std::dynamic_pointer_cast(data)); + } break; } - LuaIQuery::finishQueryData(LUA, m_query, data); + LuaIQuery::finishQueryData(LUA, iQuery, data); } void LuaIQuery::onDestroyedByLua(ILuaBase *LUA) { diff --git a/src/lua/LuaIQuery.h b/src/lua/LuaIQuery.h index f0ae93f..032a4da 100644 --- a/src/lua/LuaIQuery.h +++ b/src/lua/LuaIQuery.h @@ -22,15 +22,13 @@ class LuaIQuery : public LuaObject { //The table is at the top virtual std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition) = 0; - virtual void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr &data) = 0; - static void referenceCallbacks(ILuaBase *LUA, int stackPosition, IQueryData &data); static void runAbortedCallback(ILuaBase *LUA, const std::shared_ptr &data); static void runErrorCallback(ILuaBase *LUA, const std::shared_ptr &data); - void runCallback(ILuaBase *LUA, const std::shared_ptr &data); + static void runCallback(ILuaBase *LUA, const std::shared_ptr &query, const std::shared_ptr &data); static void finishQueryData(ILuaBase *LUA, const std::shared_ptr &query, const std::shared_ptr &data); diff --git a/src/lua/LuaObject.cpp b/src/lua/LuaObject.cpp index e925133..a674e90 100644 --- a/src/lua/LuaObject.cpp +++ b/src/lua/LuaObject.cpp @@ -28,6 +28,10 @@ LUA_FUNCTION(luaObjectGc) { LUA_CLASS_FUNCTION(LuaObject, luaObjectThink) { std::unordered_set databasesCopy = *LuaDatabase::luaDatabases; for (auto &database: databasesCopy) { + if (LuaDatabase::luaDatabases->find(database) == LuaDatabase::luaDatabases->end()) { + //This means the database instance was collected during the think hook and is thus invalid. + continue; + } database->think(LUA); } return 0; diff --git a/src/lua/LuaQuery.cpp b/src/lua/LuaQuery.cpp index ef74782..661db06 100644 --- a/src/lua/LuaQuery.cpp +++ b/src/lua/LuaQuery.cpp @@ -102,12 +102,10 @@ static void runOnDataCallbacks( } -void LuaQuery::runSuccessCallback(ILuaBase *LUA, const std::shared_ptr &data) { - auto query = std::dynamic_pointer_cast(m_query); - auto queryData = std::dynamic_pointer_cast(data); +void LuaQuery::runSuccessCallback(ILuaBase *LUA, const std::shared_ptr& query, const std::shared_ptr &data) { //Need to clear old data, if it exists freeDataReference(LUA, *query); - int dataReference = LuaQuery::createDataReference(LUA, *query, *queryData); + int dataReference = LuaQuery::createDataReference(LUA, *query, *data); runOnDataCallbacks(LUA, query, data, dataReference); if (!LuaIQuery::pushCallbackReference(LUA, data->m_successReference, data->m_tableReference, @@ -136,7 +134,7 @@ MYSQLOO_LUA_FUNCTION(lastInsert) { MYSQLOO_LUA_FUNCTION(getData) { auto luaQuery = LuaQuery::getLuaObject(LUA); - auto query = (Query *) luaQuery->m_query.get(); + auto query = std::dynamic_pointer_cast(luaQuery->m_query); if (!query->hasCallbackData() || query->callbackQueryData->getResultStatus() == QUERY_ERROR) { LUA->PushNil(); } else { @@ -155,7 +153,8 @@ MYSQLOO_LUA_FUNCTION(hasMoreResults) { LUA_FUNCTION(getNextResults) { auto luaQuery = LuaQuery::getLuaObject(LUA); - auto query = (Query *) luaQuery->m_query.get(); + auto query = std::dynamic_pointer_cast(luaQuery->m_query); + LuaQuery::freeDataReference(LUA, *query); query->getNextResults(); return 0; } diff --git a/src/lua/LuaQuery.h b/src/lua/LuaQuery.h index 486e378..439557c 100644 --- a/src/lua/LuaQuery.h +++ b/src/lua/LuaQuery.h @@ -15,7 +15,7 @@ class LuaQuery : public LuaIQuery { static void createMetaTable(ILuaBase *LUA); - void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr &data) override; + static void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr& query, const std::shared_ptr &data); std::shared_ptr buildQueryData(ILuaBase *LUA, int stackPosition) override; diff --git a/src/lua/LuaTransaction.cpp b/src/lua/LuaTransaction.cpp index 6d5fda5..083901c 100644 --- a/src/lua/LuaTransaction.cpp +++ b/src/lua/LuaTransaction.cpp @@ -84,7 +84,8 @@ std::shared_ptr LuaTransaction::buildQueryData(ILuaBase *LUA, int st return data; } -void LuaTransaction::runSuccessCallback(ILuaBase *LUA, const std::shared_ptr &data) { +void LuaTransaction::runSuccessCallback(ILuaBase *LUA, const std::shared_ptr &transaction, + const std::shared_ptr &data) { auto transactionData = std::dynamic_pointer_cast(data); if (transactionData->m_tableReference == 0) return; transactionData->setStatus(QUERY_COMPLETE); diff --git a/src/lua/LuaTransaction.h b/src/lua/LuaTransaction.h index 09620ae..4d42f0d 100644 --- a/src/lua/LuaTransaction.h +++ b/src/lua/LuaTransaction.h @@ -13,7 +13,8 @@ class LuaTransaction : public LuaIQuery { static void createMetaTable(ILuaBase *LUA); - void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr &data) override; + static void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr &transaction, + const std::shared_ptr &data); explicit LuaTransaction(const std::shared_ptr &transaction, int databaseRef) : LuaIQuery( std::static_pointer_cast(transaction), "MySQLOO Transaction", databaseRef diff --git a/src/mysql/Database.cpp b/src/mysql/Database.cpp index 3362fd9..47932ce 100644 --- a/src/mysql/Database.cpp +++ b/src/mysql/Database.cpp @@ -94,13 +94,14 @@ size_t Database::queueSize() { std::deque, std::shared_ptr>> Database::abortAllQueries() { auto canceledQueries = queryQueue.clear(); for (auto &pair: canceledQueries) { + if (!pair.first || !pair.second) continue; auto data = pair.second; data->setStatus(QUERY_ABORTED); } return canceledQueries; } -/* Waits for the connection of the database to finish by blocking the current thread until the connect thread finished. +/* Waits for the connection of the database to finish by blocking the current thread until the connection thread finished. */ void Database::wait() { if (!startedConnecting) { @@ -200,9 +201,6 @@ void Database::shutdown() { * database thread to end. */ void Database::disconnect(bool wait) { - if (m_status != DATABASE_CONNECTED) { - throw MySQLOOException("Database not connected."); - } shutdown(); if (wait && m_thread.joinable()) { m_thread.join(); diff --git a/src/mysql/IQuery.cpp b/src/mysql/IQuery.cpp index bd7f9d7..8f9bac7 100644 --- a/src/mysql/IQuery.cpp +++ b/src/mysql/IQuery.cpp @@ -65,12 +65,13 @@ void IQuery::wait(bool shouldSwap) { } } -//Returns the error message produced by the mysql query or 0 if there is none +//Returns the error message produced by the mysql query or "" if there is none std::string IQuery::error() const { - if (!hasCallbackData()) { - throw MySQLOOException("Query not started"); + auto currentQueryData = callbackQueryData; + if (!currentQueryData) { + return ""; } - return callbackQueryData->getError(); + return currentQueryData->getError(); } //Attempts to abort the query, returns true if it was able to stop at least one query in time, false otherwise @@ -85,7 +86,7 @@ std::vector> IQuery::abort() { //aren't in the query queue bool wasRemoved = database->queryQueue.removeIf( [&](std::pair, std::shared_ptr> const &p) { - return p.second.get() == data.get(); + return p.second == data; }); if (wasRemoved) { data->setStatus(QUERY_ABORTED); diff --git a/src/mysql/Query.cpp b/src/mysql/Query.cpp index 8afa086..fb69064 100644 --- a/src/mysql/Query.cpp +++ b/src/mysql/Query.cpp @@ -50,7 +50,7 @@ bool Query::hasMoreResults() { if (!hasCallbackData()) { throw MySQLOOException("Query not completed yet"); } - auto* data = (QueryData*) callbackQueryData.get(); + auto data = std::dynamic_pointer_cast(this->callbackQueryData); return data->hasMoreResults(); } @@ -59,7 +59,7 @@ void Query::getNextResults() { if (!hasCallbackData()) { throw MySQLOOException("Query not completed yet"); } - auto* data = (QueryData*) callbackQueryData.get(); + auto data = std::dynamic_pointer_cast(this->callbackQueryData); if (!data->getNextResults()) { throw MySQLOOException("Query doesn't have any more results"); } @@ -70,7 +70,7 @@ my_ulonglong Query::lastInsert() { if (!hasCallbackData()) { return 0; } - auto* data = (QueryData*) this->callbackQueryData.get(); + auto data = std::dynamic_pointer_cast(this->callbackQueryData); //Calling lastInsert() after query was executed but before the callback is run can cause race conditions return data->getLastInsertID(); } @@ -81,7 +81,7 @@ my_ulonglong Query::affectedRows() { if (!hasCallbackData()) { return 0; } - auto* data = (QueryData*) this->callbackQueryData.get(); + auto data = std::dynamic_pointer_cast(this->callbackQueryData); //Calling affectedRows() after query was executed but before the callback is run can cause race conditions return data->getAffectedRows(); }