Skip to content

Commit

Permalink
Fixes of several broken functions
Browse files Browse the repository at this point in the history
  • Loading branch information
FredyH committed Nov 8, 2021
1 parent ad93e28 commit 53e5f71
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 96 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,12 @@ Query:getData()
Query:abort()
-- Returns [Boolean]
-- Attempts to abort the query if it is still in the state QUERY_WAITING
-- Returns true if aborting was successful, false otherwise
-- Returns true if at least one running instance of the query was aborted successfully, false otherwise

Query:lastInsert()
-- Returns [Number]
-- Gets the autoincrement index of the last inserted row of the current result set

Query:status()
-- Returns [Number] (mysqloo.QUERY_* enums)
-- Gets the status of the query.

Query:affectedRows()
-- Returns [Number]
-- Gets the number of rows the query has affected (of the current result set)
Expand All @@ -198,11 +194,15 @@ Query:wait(shouldSwap)

Query:error()
-- Returns [String]
-- Gets the error caused by the query (if any).
-- Gets the error caused by the query, or "" if there was no error.

Query:hasMoreResults()
-- Returns [Boolean]
-- Returns true if the query still has more data associated with it (which means getNextResults() can be called)
-- Note: This function works unfortunately different that one would expect.
-- hasMoreResults() returns true if there is currently a result that can be popped, rather than if there is an
-- additional result that has data. However, this does make for a nicer code that handles multiple results.
-- See Examples/multi_results.lua for an example how to use it.

Query:getNextResults()
-- Returns [Table]
Expand Down
103 changes: 52 additions & 51 deletions src/BlockingQueue.h
Original file line number Diff line number Diff line change
@@ -1,70 +1,71 @@
#ifndef BLOCKING_QUEUE_
#define BLOCKING_QUEUE_

#include <deque>
#include <mutex>
#include <condition_variable>
#include <algorithm>

template <typename T>
template<typename T>
class BlockingQueue {
public:
void put(T elem) {
std::lock_guard<std::recursive_mutex> lock(mutex);
backingQueue.push_back(elem);
waitObj.notify_all();
}
void put(T elem) {
std::lock_guard<std::recursive_mutex> lock(mutex);
backingQueue.push_back(elem);
waitObj.notify_all();
}

bool empty() {
return size() == 0;
}

bool empty() {
return size() == 0;
}
bool swapToFrontIf(std::function<bool(T)> func) {
std::lock_guard<std::recursive_mutex> lock(mutex);
auto pos = std::find_if(backingQueue.begin(), backingQueue.end(), func);
if (pos != backingQueue.begin() && pos != backingQueue.end()) {
std::iter_swap(pos, backingQueue.begin());
return true;
}
return false;
}

bool swapToFrontIf(std::function<bool(T)> func) {
std::lock_guard<std::recursive_mutex> lock(mutex);
auto pos = std::find_if(backingQueue.begin(), backingQueue.end(), func);
if (pos != backingQueue.begin() && pos != backingQueue.end()) {
std::iter_swap(pos, backingQueue.begin());
return true;
}
return false;
}
bool removeIf(std::function<bool(T)> func) {
std::lock_guard<std::recursive_mutex> lock(mutex);
auto it = std::remove_if(backingQueue.begin(), backingQueue.end(), func);
bool removed = it != backingQueue.end();
backingQueue.erase(it, backingQueue.end());
return removed;
}

bool removeIf(std::function<bool(T)> func) {
std::lock_guard<std::recursive_mutex> lock(mutex);
auto pos = std::find_if(backingQueue.begin(), backingQueue.end(), func);
if (pos != backingQueue.begin() && pos != backingQueue.end()) {
backingQueue.erase(pos);
return true;
}
return false;
}
void remove(T elem) {
std::lock_guard<std::recursive_mutex> lock(mutex);
backingQueue.erase(std::remove(backingQueue.begin(), backingQueue.end(), elem), backingQueue.end());
}

void remove(T elem) {
std::lock_guard<std::recursive_mutex> lock(mutex);
backingQueue.erase(std::remove(backingQueue.begin(), backingQueue.end(), elem), backingQueue.end());
}
size_t size() {
std::lock_guard<std::recursive_mutex> lock(mutex);
return backingQueue.size();
}

size_t size() {
std::lock_guard<std::recursive_mutex> lock(mutex);
return backingQueue.size();
}
T take() {
std::unique_lock<std::recursive_mutex> lock(mutex);
while (size() == 0) waitObj.wait(lock);
auto front = backingQueue.front();
backingQueue.pop_front();
return front;
}

T take() {
std::unique_lock<std::recursive_mutex> lock(mutex);
while (size() == 0) waitObj.wait(lock);
auto front = backingQueue.front();
backingQueue.pop_front();
return front;
}
std::deque<T> clear() {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::deque<T> returnQueue = backingQueue;
backingQueue.clear();
return returnQueue;
}

std::deque<T> clear() {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::deque<T> returnQueue = backingQueue;
backingQueue.clear();
return returnQueue;
}
private:
std::deque<T> backingQueue;
std::recursive_mutex mutex;
std::condition_variable_any waitObj;
std::deque<T> backingQueue{};
std::recursive_mutex mutex{};
std::condition_variable_any waitObj{};
};

#endif
12 changes: 4 additions & 8 deletions src/lua/LuaDatabase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ MYSQLOO_LUA_FUNCTION(abortAllQueries) {
auto database = LuaObject::getLuaObject<LuaDatabase>(LUA);
auto abortedQueries = database->m_database->abortAllQueries();
for (const auto& pair: abortedQueries) {
LuaIQuery::runAbortedCallback(LUA, pair.second);
LuaIQuery::finishQueryData(LUA, pair.first, pair.second);
}
LUA->PushNumber((double) abortedQueries.size());
Expand All @@ -215,6 +216,7 @@ MYSQLOO_LUA_FUNCTION(ping) {
MYSQLOO_LUA_FUNCTION(wait) {
auto database = LuaObject::getLuaObject<LuaDatabase>(LUA);
database->m_database->wait();
database->think(LUA); //To set callback data, run callbacks
return 0;
}

Expand Down Expand Up @@ -284,7 +286,7 @@ void LuaDatabase::createMetaTable(ILuaBase *LUA) {

void LuaDatabase::think(ILuaBase *LUA) {
//Connection callbacks
auto database = this->m_database.get();
auto database = this->m_database;
if (database->isConnectionDone() && !this->m_dbCallbackRan && this->m_tableReference != 0) {
LUA->ReferencePush(this->m_tableReference);
if (database->connectionSuccessful()) {
Expand Down Expand Up @@ -313,13 +315,7 @@ void LuaDatabase::think(ILuaBase *LUA) {
//Run callbacks of finished queries
auto finishedQueries = database->takeFinishedQueries();
for (auto &pair: finishedQueries) {
auto data = pair.second;
if (data->m_tableReference != 0) {
LUA->ReferencePush(data->m_tableReference);
auto luaQuery = LuaIQuery::getLuaObject<LuaIQuery>(LUA, -1);
LUA->Pop();
luaQuery->runCallback(LUA, data);
}
LuaQuery::runCallback(LUA, pair.first, pair.second);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/lua/LuaDatabase.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LuaDatabase : public LuaObject {

static int create(lua_State *L);

void think(ILuaBase *lua);
void think(ILuaBase *LUA);

int m_tableReference = 0;
std::shared_ptr<Database> m_database;
Expand Down
26 changes: 21 additions & 5 deletions src/lua/LuaIQuery.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@

#include "LuaIQuery.h"
#include "LuaQuery.h"
#include "LuaTransaction.h"
#include "LuaDatabase.h"


MYSQLOO_LUA_FUNCTION(start) {
Expand All @@ -22,6 +25,13 @@ MYSQLOO_LUA_FUNCTION(wait) {
}
auto query = LuaIQuery::getLuaObject<LuaIQuery>(LUA);
query->m_query->wait(shouldSwap);
if (query->m_databaseReference != 0) {
LUA->ReferencePush(query->m_databaseReference);
auto database = LuaObject::getLuaObject<LuaDatabase>(LUA, -1);
database->think(LUA);
LUA->Pop();
}

return 0;
}

Expand Down Expand Up @@ -50,8 +60,10 @@ MYSQLOO_LUA_FUNCTION(abort) {
auto abortedData = query->m_query->abort();
for (auto &data: abortedData) {
LuaIQuery::runAbortedCallback(LUA, data);
LuaIQuery::finishQueryData(LUA, query->m_query, data);
}
return 0;
LUA->PushBool(!abortedData.empty());
return 1;
}

void LuaIQuery::runAbortedCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data) {
Expand Down Expand Up @@ -140,8 +152,8 @@ void LuaIQuery::finishQueryData(GarrysMod::Lua::ILuaBase *LUA, const std::shared
data->m_tableReference = 0;
}

void LuaIQuery::runCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data) {
m_query->setCallbackData(data);
void LuaIQuery::runCallback(ILuaBase *LUA, const std::shared_ptr<IQuery> &iQuery, const std::shared_ptr<IQueryData> &data) {
iQuery->setCallbackData(data);

auto status = data->getResultStatus();
switch (status) {
Expand All @@ -151,11 +163,15 @@ void LuaIQuery::runCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &da
runErrorCallback(LUA, data);
break;
case QUERY_SUCCESS:
runSuccessCallback(LUA, data);
if (auto query = std::dynamic_pointer_cast<Query>(iQuery)) {
LuaQuery::runSuccessCallback(LUA, query, std::dynamic_pointer_cast<QueryData>(data));
} else if (auto transaction = std::dynamic_pointer_cast<Transaction>(query)) {
LuaTransaction::runSuccessCallback(LUA, transaction, std::dynamic_pointer_cast<TransactionData>(data));
}
break;
}

LuaIQuery::finishQueryData(LUA, m_query, data);
LuaIQuery::finishQueryData(LUA, iQuery, data);
}

void LuaIQuery::onDestroyedByLua(ILuaBase *LUA) {
Expand Down
4 changes: 1 addition & 3 deletions src/lua/LuaIQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ class LuaIQuery : public LuaObject {
//The table is at the top
virtual std::shared_ptr<IQueryData> buildQueryData(ILuaBase *LUA, int stackPosition) = 0;

virtual void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data) = 0;

static void referenceCallbacks(ILuaBase *LUA, int stackPosition, IQueryData &data);

static void runAbortedCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data);

static void runErrorCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data);

void runCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data);
static void runCallback(ILuaBase *LUA, const std::shared_ptr<IQuery> &query, const std::shared_ptr<IQueryData> &data);

static void finishQueryData(ILuaBase *LUA, const std::shared_ptr<IQuery> &query, const std::shared_ptr<IQueryData> &data);

Expand Down
4 changes: 4 additions & 0 deletions src/lua/LuaObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ LUA_FUNCTION(luaObjectGc) {
LUA_CLASS_FUNCTION(LuaObject, luaObjectThink) {
std::unordered_set<LuaDatabase*> databasesCopy = *LuaDatabase::luaDatabases;
for (auto &database: databasesCopy) {
if (LuaDatabase::luaDatabases->find(database) == LuaDatabase::luaDatabases->end()) {
//This means the database instance was collected during the think hook and is thus invalid.
continue;
}
database->think(LUA);
}
return 0;
Expand Down
11 changes: 5 additions & 6 deletions src/lua/LuaQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@ static void runOnDataCallbacks(
}


void LuaQuery::runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data) {
auto query = std::dynamic_pointer_cast<Query>(m_query);
auto queryData = std::dynamic_pointer_cast<QueryData>(data);
void LuaQuery::runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<Query>& query, const std::shared_ptr<QueryData> &data) {
//Need to clear old data, if it exists
freeDataReference(LUA, *query);
int dataReference = LuaQuery::createDataReference(LUA, *query, *queryData);
int dataReference = LuaQuery::createDataReference(LUA, *query, *data);
runOnDataCallbacks(LUA, query, data, dataReference);

if (!LuaIQuery::pushCallbackReference(LUA, data->m_successReference, data->m_tableReference,
Expand Down Expand Up @@ -136,7 +134,7 @@ MYSQLOO_LUA_FUNCTION(lastInsert) {

MYSQLOO_LUA_FUNCTION(getData) {
auto luaQuery = LuaQuery::getLuaObject<LuaQuery>(LUA);
auto query = (Query *) luaQuery->m_query.get();
auto query = std::dynamic_pointer_cast<Query>(luaQuery->m_query);
if (!query->hasCallbackData() || query->callbackQueryData->getResultStatus() == QUERY_ERROR) {
LUA->PushNil();
} else {
Expand All @@ -155,7 +153,8 @@ MYSQLOO_LUA_FUNCTION(hasMoreResults) {

LUA_FUNCTION(getNextResults) {
auto luaQuery = LuaQuery::getLuaObject<LuaQuery>(LUA);
auto query = (Query *) luaQuery->m_query.get();
auto query = std::dynamic_pointer_cast<Query>(luaQuery->m_query);
LuaQuery::freeDataReference(LUA, *query);
query->getNextResults();
return 0;
}
Expand Down
2 changes: 1 addition & 1 deletion src/lua/LuaQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LuaQuery : public LuaIQuery {

static void createMetaTable(ILuaBase *LUA);

void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data) override;
static void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<Query>& query, const std::shared_ptr<QueryData> &data);

std::shared_ptr<IQueryData> buildQueryData(ILuaBase *LUA, int stackPosition) override;

Expand Down
3 changes: 2 additions & 1 deletion src/lua/LuaTransaction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ std::shared_ptr<IQueryData> LuaTransaction::buildQueryData(ILuaBase *LUA, int st
return data;
}

void LuaTransaction::runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data) {
void LuaTransaction::runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<Transaction> &transaction,
const std::shared_ptr<TransactionData> &data) {
auto transactionData = std::dynamic_pointer_cast<TransactionData>(data);
if (transactionData->m_tableReference == 0) return;
transactionData->setStatus(QUERY_COMPLETE);
Expand Down
3 changes: 2 additions & 1 deletion src/lua/LuaTransaction.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class LuaTransaction : public LuaIQuery {

static void createMetaTable(ILuaBase *LUA);

void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<IQueryData> &data) override;
static void runSuccessCallback(ILuaBase *LUA, const std::shared_ptr<Transaction> &transaction,
const std::shared_ptr<TransactionData> &data);

explicit LuaTransaction(const std::shared_ptr<Transaction> &transaction, int databaseRef) : LuaIQuery(
std::static_pointer_cast<IQuery>(transaction), "MySQLOO Transaction", databaseRef
Expand Down
6 changes: 2 additions & 4 deletions src/mysql/Database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ size_t Database::queueSize() {
std::deque<std::pair<std::shared_ptr<IQuery>, std::shared_ptr<IQueryData>>> Database::abortAllQueries() {
auto canceledQueries = queryQueue.clear();
for (auto &pair: canceledQueries) {
if (!pair.first || !pair.second) continue;
auto data = pair.second;
data->setStatus(QUERY_ABORTED);
}
return canceledQueries;
}

/* Waits for the connection of the database to finish by blocking the current thread until the connect thread finished.
/* Waits for the connection of the database to finish by blocking the current thread until the connection thread finished.
*/
void Database::wait() {
if (!startedConnecting) {
Expand Down Expand Up @@ -200,9 +201,6 @@ void Database::shutdown() {
* database thread to end.
*/
void Database::disconnect(bool wait) {
if (m_status != DATABASE_CONNECTED) {
throw MySQLOOException("Database not connected.");
}
shutdown();
if (wait && m_thread.joinable()) {
m_thread.join();
Expand Down
Loading

0 comments on commit 53e5f71

Please sign in to comment.