diff --git a/adonis-typings/database.ts b/adonis-typings/database.ts index bb52612b..270beeac 100644 --- a/adonis-typings/database.ts +++ b/adonis-typings/database.ts @@ -14,6 +14,7 @@ declare module '@ioc:Adonis/Lucid/Database' { import { Pool } from 'tarn' import { EventEmitter } from 'events' import { Dictionary } from 'ts-essentials' + import { MacroableConstructorContract } from 'macroable' import { ProfilerRowContract, ProfilerContract } from '@ioc:Adonis/Core/Profiler' import { @@ -587,6 +588,10 @@ declare module '@ioc:Adonis/Lucid/Database' { * database connections */ export interface DatabaseContract { + DatabaseQueryBuilder: MacroableConstructorContract, + InsertQueryBuilder: MacroableConstructorContract, + ModelQueryBuilder: MacroableConstructorContract>, + /** * Name of the primary connection defined inside `config/database.ts` * file diff --git a/package.json b/package.json index 60bc19a3..851606a0 100644 --- a/package.json +++ b/package.json @@ -49,6 +49,7 @@ "knex": "^0.20.4", "knex-dynamic-connection": "^1.0.3", "log-update": "^3.3.0", + "macroable": "^3.0.0", "pluralize": "^8.0.0", "snake-case": "^3.0.2", "ts-essentials": "^4.0.0" diff --git a/src/Database/QueryBuilder/Chainable.ts b/src/Database/QueryBuilder/Chainable.ts index 36a09919..2e34298e 100644 --- a/src/Database/QueryBuilder/Chainable.ts +++ b/src/Database/QueryBuilder/Chainable.ts @@ -10,6 +10,7 @@ /// import knex from 'knex' +import { Macroable } from 'macroable' import { ChainableContract, DBQueryCallback } from '@ioc:Adonis/Lucid/DatabaseQueryBuilder' import { RawQueryBuilder } from './Raw' @@ -20,11 +21,13 @@ import { RawQueryBuilder } from './Raw' * The API internally uses the knex query builder. However, many of methods may have * different API. */ -export abstract class Chainable implements ChainableContract { +export abstract class Chainable extends Macroable implements ChainableContract { constructor ( public $knexBuilder: knex.QueryBuilder, // Needs to be public for Executable trait private _queryCallback: DBQueryCallback, - ) {} + ) { + super() + } /** * Returns the value pair for the `whereBetween` clause diff --git a/src/Database/QueryBuilder/Database.ts b/src/Database/QueryBuilder/Database.ts index 336e332a..52456884 100644 --- a/src/Database/QueryBuilder/Database.ts +++ b/src/Database/QueryBuilder/Database.ts @@ -48,6 +48,12 @@ export class DatabaseQueryBuilder extends Chainable implements DatabaseQueryBuil super(builder, queryCallback) } + /** + * Required by macroable + */ + protected static _macros = {} + protected static _getters = {} + /** * Ensures that we are not executing `update` or `del` when using read only * client diff --git a/src/Database/QueryBuilder/Insert.ts b/src/Database/QueryBuilder/Insert.ts index b151b2bf..828954ba 100644 --- a/src/Database/QueryBuilder/Insert.ts +++ b/src/Database/QueryBuilder/Insert.ts @@ -10,6 +10,7 @@ /// import knex from 'knex' +import { Macroable } from 'macroable' import { trait } from '@poppinss/traits' import { QueryClientContract } from '@ioc:Adonis/Lucid/Database' @@ -21,10 +22,17 @@ import { Executable, ExecutableConstructor } from '../../Traits/Executable' * Exposes the API for performing SQL inserts */ @trait(Executable) -export class InsertQueryBuilder implements InsertQueryBuilderContract { +export class InsertQueryBuilder extends Macroable implements InsertQueryBuilderContract { constructor (public $knexBuilder: knex.QueryBuilder, public client: QueryClientContract) { + super() } + /** + * Required by macroable + */ + protected static _macros = {} + protected static _getters = {} + /** * Returns the client to be used for the query. Even though the insert query * is always using the `write` client, we still go through the process of diff --git a/src/Database/index.ts b/src/Database/index.ts index d44d0e8e..88e0581f 100644 --- a/src/Database/index.ts +++ b/src/Database/index.ts @@ -23,6 +23,10 @@ import { import { QueryClient } from '../QueryClient' import { ConnectionManager } from '../Connection/Manager' +import { InsertQueryBuilder } from './QueryBuilder/Insert' +import { DatabaseQueryBuilder } from './QueryBuilder/Database' +import { ModelQueryBuilder } from '../Orm/QueryBuilder' + /** * Database class exposes the API to manage multiple connections and obtain an instance * of query/transaction clients. @@ -38,6 +42,13 @@ export class Database implements DatabaseContract { */ public primaryConnectionName = this._config.connection + /** + * Reference to query builders + */ + public DatabaseQueryBuilder = DatabaseQueryBuilder + public InsertQueryBuilder = InsertQueryBuilder + public ModelQueryBuilder = ModelQueryBuilder + constructor ( private _config: DatabaseConfigContract, private _logger: LoggerContract, diff --git a/src/Orm/QueryBuilder/index.ts b/src/Orm/QueryBuilder/index.ts index b3f78100..2297e13c 100644 --- a/src/Orm/QueryBuilder/index.ts +++ b/src/Orm/QueryBuilder/index.ts @@ -52,6 +52,12 @@ export class ModelQueryBuilder extends Chainable implements ModelQueryBuilderCon */ private _preloader = new Preloader(this.model) + /** + * Required by macroable + */ + protected static _macros = {} + protected static _getters = {} + /** * Options that must be passed to all new model instances */ diff --git a/test/database/database.spec.ts b/test/database/database.spec.ts index 97d67e82..99c8e853 100644 --- a/test/database/database.spec.ts +++ b/test/database/database.spec.ts @@ -12,7 +12,7 @@ import test from 'japa' import { Database } from '../../src/Database' -import { getConfig, setup, cleanup, getLogger, getProfiler } from '../../test-helpers' +import { getConfig, setup, cleanup, getLogger, getProfiler, getDb } from '../../test-helpers' test.group('Database', (group) => { group.before(async () => { @@ -220,3 +220,63 @@ test.group('Database', (group) => { await db.manager.closeAll() }) }) + +test.group('Database | extend', (group) => { + group.before(async () => { + await setup() + }) + + group.after(async () => { + await cleanup() + }) + + test('extend database query builder by adding macros', async (assert) => { + const db = getDb() + + db.DatabaseQueryBuilder.macro('whereActive', function whereActive () { + this.where('is_active', true) + return this + }) + + const knexClient = db.connection().getReadClient() + + const { sql, bindings } = db.query().from('users')['whereActive']().toSQL() + const { sql: knexSql, bindings: knexBindings } = knexClient + .from('users') + .where('is_active', true) + .toSQL() + + assert.equal(sql, knexSql) + assert.deepEqual(bindings, knexBindings) + + await db.manager.closeAll() + }) + + test('extend insert query builder by adding macros', async (assert) => { + const db = getDb() + + db.InsertQueryBuilder.macro('returnId', function whereActive () { + this.returning('id') + return this + }) + + const knexClient = db.connection().getReadClient() + + const { sql, bindings } = db + .insertQuery() + .table('users')['returnId']() + .insert({ id: 1 }) + .toSQL() + + const { sql: knexSql, bindings: knexBindings } = knexClient + .from('users') + .returning('id') + .insert({ id: 1 }) + .toSQL() + + assert.equal(sql, knexSql) + assert.deepEqual(bindings, knexBindings) + + await db.manager.closeAll() + }) +}) diff --git a/test/orm/base-model.spec.ts b/test/orm/base-model.spec.ts index cdcff88c..20cf3e42 100644 --- a/test/orm/base-model.spec.ts +++ b/test/orm/base-model.spec.ts @@ -2170,3 +2170,77 @@ test.group('Base Model | hooks', (group) => { assert.equal(usersCount[0].total, 1) }) }) + +test.group('Base model | extend', (group) => { + group.before(async () => { + db = getDb() + BaseModel = getBaseModel(ormAdapter(db)) + }) + + group.after(async () => { + await db.manager.closeAll() + }) + + test('extend model query builder', async (assert) => { + class User extends BaseModel { + @column({ primary: true }) + public id: number + + @column() + public username: string + } + User.$boot() + + db.ModelQueryBuilder.macro('whereActive', function () { + this.where('is_active', true) + return this + }) + + const knexClient = db.connection().getReadClient() + const { sql, bindings } = User.query()['whereActive']().toSQL() + const { sql: knexSql, bindings: knexBindings } = knexClient + .from('users') + .where('is_active', true) + .toSQL() + + assert.equal(sql, knexSql) + assert.deepEqual(bindings, knexBindings) + }) + + test('extend model insert query builder', async (assert) => { + class User extends BaseModel { + @column({ primary: true }) + public id: number + + @column() + public username: string + + public $getQueryFor (_, client) { + return client.insertQuery().table('users').withId() + } + } + User.$boot() + + db.InsertQueryBuilder.macro('withId', function () { + this.returning('id') + return this + }) + + const knexClient = db.connection().getReadClient() + const user = new User() + + const { sql, bindings } = user + .$getQueryFor('insert', db.connection()) + .insert({ id: 1 }) + .toSQL() + + const { sql: knexSql, bindings: knexBindings } = knexClient + .from('users') + .returning('id') + .insert({ id: 1 }) + .toSQL() + + assert.equal(sql, knexSql) + assert.deepEqual(bindings, knexBindings) + }) +})