From 28de83f93d6989780b8b21d5d8a8e71bfd3bee83 Mon Sep 17 00:00:00 2001 From: Whitney Young Date: Tue, 12 May 2015 14:14:09 -0700 Subject: [PATCH] Support for user functions. Fixes #140. --- src/database.cc | 150 ++++++++++++++++++++++++++++++++++++ src/database.h | 31 ++++++++ test/user_functions.test.js | 107 +++++++++++++++++++++++++ 3 files changed, 288 insertions(+) create mode 100644 test/user_functions.test.js diff --git a/src/database.cc b/src/database.cc index d34b865a8..45203c234 100644 --- a/src/database.cc +++ b/src/database.cc @@ -5,6 +5,10 @@ #include "database.h" #include "statement.h" +#ifndef SQLITE_DETERMINISTIC +#define SQLITE_DETERMINISTIC 0x800 +#endif + using namespace node_sqlite3; Persistent Database::constructor_template; @@ -24,6 +28,7 @@ void Database::Init(Handle target) { NODE_SET_PROTOTYPE_METHOD(t, "serialize", Serialize); NODE_SET_PROTOTYPE_METHOD(t, "parallelize", Parallelize); NODE_SET_PROTOTYPE_METHOD(t, "configure", Configure); + NODE_SET_PROTOTYPE_METHOD(t, "registerFunction", RegisterFunction); NODE_SET_GETTER(t, "open", OpenGetter); @@ -356,6 +361,151 @@ NAN_METHOD(Database::Configure) { NanReturnValue(args.This()); } +NAN_METHOD(Database::RegisterFunction) { + NanScope(); + Database* db = ObjectWrap::Unwrap(args.This()); + + REQUIRE_ARGUMENTS(2); + REQUIRE_ARGUMENT_STRING(0, functionName); + REQUIRE_ARGUMENT_FUNCTION(1, callback); + + FunctionBaton *baton = new FunctionBaton(db, *functionName, callback); + sqlite3_create_function( + db->_handle, + *functionName, + -1, // arbitrary number of args + SQLITE_UTF8 | SQLITE_DETERMINISTIC, + baton, + FunctionEnqueue, + NULL, + NULL); + + uv_mutex_init(&baton->mutex); + uv_cond_init(&baton->condition); + uv_async_init(uv_default_loop(), &baton->async, (uv_async_cb)Database::AsyncFunctionProcessQueue); + + NanReturnValue(args.This()); +} + +void Database::FunctionEnqueue(sqlite3_context *context, int argc, sqlite3_value **argv) { + // the JS function can only be safely executed on the main thread + // (uv_default_loop), so setup an invocation w/ the relevant information, + // enqueue it and signal the main thread to process the invocation queue. + // sqlite3 requires the result to be set before this function returns, so + // wait for the invocation to be completed. + FunctionBaton *baton = (FunctionBaton *)sqlite3_user_data(context); + FunctionInvocation invocation = { .context = context, .argc = argc, .argv = argv }; + + uv_async_send(&baton->async); + uv_mutex_lock(&baton->mutex); + baton->queue.push(&invocation); + while (!invocation.complete) { + uv_cond_wait(&baton->condition, &baton->mutex); + } + uv_mutex_unlock(&baton->mutex); +} + +void Database::AsyncFunctionProcessQueue(uv_async_t *async) { + FunctionBaton *baton = (FunctionBaton *)async->data; + + for (;;) { + FunctionInvocation *invocation = NULL; + + uv_mutex_lock(&baton->mutex); + if (!baton->queue.empty()) { + invocation = baton->queue.front(); + baton->queue.pop(); + } + uv_mutex_unlock(&baton->mutex); + + if (!invocation) { break; } + + Database::FunctionExecute(baton, invocation); + + uv_mutex_lock(&baton->mutex); + invocation->complete = true; + uv_cond_signal(&baton->condition); // allow paused thread to complete + uv_mutex_unlock(&baton->mutex); + } +} + +void Database::FunctionExecute(FunctionBaton *baton, FunctionInvocation *invocation) { + NanScope(); + + Database *db = baton->db; + Local cb = NanNew(baton->callback); + sqlite3_context *context = invocation->context; + sqlite3_value **values = invocation->argv; + int argc = invocation->argc; + + if (!cb.IsEmpty() && cb->IsFunction()) { + + // build the argument list for the function call + typedef Local LocalValue; + std::vector argv; + for (int i = 0; i < argc; i++) { + sqlite3_value *value = values[i]; + int type = sqlite3_value_type(value); + Local arg; + switch(type) { + case SQLITE_INTEGER: { + arg = NanNew(sqlite3_value_int64(value)); + } break; + case SQLITE_FLOAT: { + arg = NanNew(sqlite3_value_double(value)); + } break; + case SQLITE_TEXT: { + const char* text = (const char*)sqlite3_value_text(value); + int length = sqlite3_value_bytes(value); + arg = NanNew(text, length); + } break; + case SQLITE_BLOB: { + const void *blob = sqlite3_value_blob(value); + int length = sqlite3_value_bytes(value); + arg = NanNew(NanNewBufferHandle((char *)blob, length)); + } break; + case SQLITE_NULL: { + arg = NanNew(NanNull()); + } break; + } + + argv.push_back(arg); + } + + Local result = TRY_CATCH_CALL(NanObjectWrapHandle(db), cb, argc, argv.data()); + + // process the result + if (result->IsString() || result->IsRegExp()) { + String::Utf8Value value(result->ToString()); + sqlite3_result_text(context, *value, value.length(), SQLITE_TRANSIENT); + } + else if (result->IsInt32()) { + sqlite3_result_int(context, result->Int32Value()); + } + else if (result->IsNumber() || result->IsDate()) { + sqlite3_result_double(context, result->NumberValue()); + } + else if (result->IsBoolean()) { + sqlite3_result_int(context, result->BooleanValue()); + } + else if (result->IsNull() || result->IsUndefined()) { + sqlite3_result_null(context); + } + else if (Buffer::HasInstance(result)) { + Local buffer = result->ToObject(); + sqlite3_result_blob(context, + Buffer::Data(buffer), + Buffer::Length(buffer), + SQLITE_TRANSIENT); + } + else { + std::string message("invalid return type in user function"); + message = message + " " + baton->name; + sqlite3_result_error(context, message.c_str(), message.length()); + } + } +} + void Database::SetBusyTimeout(Baton* baton) { assert(baton->db->open); assert(baton->db->_handle); diff --git a/src/database.h b/src/database.h index af83ee715..75a940dc2 100644 --- a/src/database.h +++ b/src/database.h @@ -69,6 +69,32 @@ class Database : public ObjectWrap { Baton(db_, cb_), filename(filename_) {} }; + struct FunctionInvocation { + sqlite3_context *context; + sqlite3_value **argv; + int argc; + bool complete; + }; + + struct FunctionBaton { + Database* db; + std::string name; + Persistent callback; + uv_async_t async; + uv_mutex_t mutex; + uv_cond_t condition; + std::queue queue; + + FunctionBaton(Database* db_, const char* name_, Handle cb_) : + db(db_), name(name_) { + async.data = this; + NanAssignPersistent(callback, cb_); + } + virtual ~FunctionBaton() { + NanDisposePersistent(callback); + } + }; + typedef void (*Work_Callback)(Baton* baton); struct Call { @@ -152,6 +178,11 @@ class Database : public ObjectWrap { static NAN_METHOD(Configure); + static NAN_METHOD(RegisterFunction); + static void FunctionEnqueue(sqlite3_context *context, int argc, sqlite3_value **argv); + static void FunctionExecute(FunctionBaton *baton, FunctionInvocation *invocation); + static void AsyncFunctionProcessQueue(uv_async_t *async); + static void SetBusyTimeout(Baton* baton); static void RegisterTraceCallback(Baton* baton); diff --git a/test/user_functions.test.js b/test/user_functions.test.js new file mode 100644 index 000000000..9542e8b59 --- /dev/null +++ b/test/user_functions.test.js @@ -0,0 +1,107 @@ +var sqlite3 = require('..'); +var assert = require('assert'); + +describe('user functions', function() { + var db; + before(function(done) { db = new sqlite3.Database(':memory:', done); }); + + it('should allow registration of user functions', function() { + db.registerFunction('MY_UPPERCASE', function(value) { + return value.toUpperCase(); + }); + db.registerFunction('MY_STRING_JOIN', function(value1, value2) { + return [value1, value2].join(' '); + }); + db.registerFunction('MY_Add', function(value1, value2) { + return value1 + value2; + }); + db.registerFunction('MY_REGEX', function(regex, value) { + return !!value.match(new RegExp(regex)); + }); + db.registerFunction('MY_REGEX_VALUE', function(regex, value) { + return /match things/i; + }); + db.registerFunction('MY_ERROR', function(value) { + throw new Error('This function always throws'); + }); + db.registerFunction('MY_UNHANDLED_TYPE', function(value) { + return {}; + }); + db.registerFunction('MY_NOTHING', function(value) { + + }); + }); + + it('should process user functions with one arg', function(done) { + db.all('SELECT MY_UPPERCASE("hello") AS txt', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].txt, 'HELLO') + done(); + }); + }); + + it('should process user functions with two args', function(done) { + db.all('SELECT MY_STRING_JOIN("hello", "world") AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, 'hello world'); + done(); + }); + }); + + it('should process user functions with number args', function(done) { + db.all('SELECT MY_ADD(1, 2) AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, 3); + done(); + }); + }); + + it('allows writing of a regex function', function(done) { + db.all('SELECT MY_REGEX("colou?r", "color") AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(Boolean(rows[0].val), true); + done(); + }); + }); + + it('converts returned regex instances to strings', function(done) { + db.all('SELECT MY_REGEX_VALUE() AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, '/match things/i'); + done(); + }); + }); + + it.skip('reports errors thrown in functions', function(done) { + db.all('SELECT MY_ERROR() AS val', function(err, rows) { + assert.equal(err.message, 'This function always throws'); + assert.equal(rows, undefined); + done(); + }); + }); + + it('reports errors for unhandled types', function(done) { + db.all('SELECT MY_UNHANDLED_TYPE() AS val', function(err, rows) { + assert.equal(err.message, 'SQLITE_ERROR: invalid return type in ' + + 'user function MY_UNHANDLED_TYPE'); + assert.equal(rows, undefined); + done(); + }); + }); + + it('allows no return value from functions', function(done) { + db.all('SELECT MY_NOTHING() AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, undefined); + done(); + }); + }); + + after(function(done) { db.close(done); }); +});