diff --git a/adonis-typings/database.ts b/adonis-typings/database.ts index ba84ea73..ec55ee07 100644 --- a/adonis-typings/database.ts +++ b/adonis-typings/database.ts @@ -25,7 +25,11 @@ declare module '@ioc:Adonis/Lucid/Database' { DatabaseQueryBuilderContract, } from '@ioc:Adonis/Lucid/DatabaseQueryBuilder' - import { ModelConstructorContract, ModelQueryBuilderContract } from '@ioc:Adonis/Lucid/Model' + import { + HooksContract, + ModelConstructorContract, + ModelQueryBuilderContract, + } from '@ioc:Adonis/Lucid/Model' /** * A executable query builder will always have these methods on it. @@ -152,6 +156,8 @@ declare module '@ioc:Adonis/Lucid/Database' { export interface TransactionClientContract extends QueryClientContract { knexClient: knex.Transaction, + hooks: HooksContract<'commit' | 'rollback', (client: TransactionClientContract) => void | Promise> + /** * Is transaction completed or not */ diff --git a/adonis-typings/model.ts b/adonis-typings/model.ts index 90c2090f..598f915a 100644 --- a/adonis-typings/model.ts +++ b/adonis-typings/model.ts @@ -434,4 +434,16 @@ declare module '@ioc:Adonis/Lucid/Model' { options?: ModelAdapterOptions, ): ModelQueryBuilderContract & ExcutableQueryBuilderContract } + + /** + * Shape of the hooks contract used by transaction client and models + */ + export interface HooksContract { + add (lifecycle: 'before' | 'after', event: Events, handler: Handler): this + before (event: Events, handler: Handler): this + after (event: Events, handler: Handler): this + execute (lifecycle: 'before' | 'after', event: Events, payload: any): Promise + clear (event: Events): void + clearAll (): void + } } diff --git a/src/Hooks/index.ts b/src/Hooks/index.ts new file mode 100644 index 00000000..c854eaa2 --- /dev/null +++ b/src/Hooks/index.ts @@ -0,0 +1,78 @@ +/* + * @adonisjs/lucid + * + * (c) Harminder Virk + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. +*/ + +/// + +import { HooksContract } from '@ioc:Adonis/Lucid/Model' + +/** + * A generic class to implement before and after lifecycle hooks + */ +export class Hooks implements HooksContract { + private _hooks: { + [event: string]: { + before: Set, + after: Set, + }, + } = {} + + /** + * Add hook for a given event and lifecycle + */ + public add (lifecycle: 'before' | 'after', event: Events, handler: Handler) { + this._hooks[event] = this._hooks[event] || { before: new Set(), after: new Set() } + this._hooks[event][lifecycle].add(handler) + return this + } + + /** + * Execute hooks for a given event and lifecycle + */ + public async execute (lifecycle: 'before' | 'after', event: Events, payload: any): Promise { + if (!this._hooks[event]) { + return + } + + for (let hook of this._hooks[event][lifecycle]) { + await hook(payload) + } + } + + /** + * Register before hook + */ + public before (event: Events, handler: Handler): this { + return this.add('before', event, handler) + } + + /** + * Register after hook + */ + public after (event: Events, handler: Handler): this { + return this.add('after', event, handler) + } + + /** + * Remove hooks for a given event + */ + public clear (event: Events): void { + if (!this._hooks[event]) { + return + } + + delete this._hooks[event] + } + + /** + * Remove all hooks + */ + public clearAll (): void { + this._hooks = {} + } +} diff --git a/src/TransactionClient/index.ts b/src/TransactionClient/index.ts index a2722749..725dc0bd 100644 --- a/src/TransactionClient/index.ts +++ b/src/TransactionClient/index.ts @@ -13,6 +13,7 @@ import knex from 'knex' import { TransactionClientContract } from '@ioc:Adonis/Lucid/Database' import { ProfilerRowContract, ProfilerContract } from '@ioc:Adonis/Core/Profiler' +import { Hooks } from '../Hooks' import { ModelQueryBuilder } from '../Orm/QueryBuilder' import { RawQueryBuilder } from '../Database/QueryBuilder/Raw' import { InsertQueryBuilder } from '../Database/QueryBuilder/Insert' @@ -39,6 +40,11 @@ export class TransactionClient implements TransactionClientContract { */ public profiler?: ProfilerRowContract | ProfilerContract + /** + * Reference to client hooks + */ + public hooks = new Hooks<'rollback' | 'commit', any>() + constructor ( public knexClient: knex.Transaction, public dialect: string, @@ -155,13 +161,64 @@ export class TransactionClient implements TransactionClientContract { * Commit the transaction */ public async commit () { - await this.knexClient.commit() + /** + * Execute before hooks + */ + await this.hooks.execute('before', 'commit', this) + + /** + * Commit and hold the error (if any) + */ + let commitError: any = null + try { + await this.knexClient.commit() + } catch (error) { + commitError = error + } + + /** + * Raise exception when commit fails + */ + if (commitError) { + this.hooks.clearAll() + throw commitError + } + + /** + * Execute after hooks + */ + await this.hooks.execute('after', 'commit', this) + this.hooks.clearAll() } /** * Rollback the transaction */ public async rollback () { - await this.knexClient.rollback() + /** + * Execute before hooks + */ + await this.hooks.execute('before', 'rollback', this) + + let rollbackError: any = null + try { + await this.knexClient.rollback() + } catch (error) { + rollbackError = error + } + + /** + * Raise exception when commit fails + */ + if (rollbackError) { + this.hooks.clearAll() + throw rollbackError + } + + /** + * Execute after hooks + */ + await this.hooks.execute('after', 'rollback', this) + this.hooks.clearAll() } } diff --git a/test/database/transactions.spec.ts b/test/database/transactions.spec.ts index 09f4bbb4..589413fd 100644 --- a/test/database/transactions.spec.ts +++ b/test/database/transactions.spec.ts @@ -13,6 +13,7 @@ import test from 'japa' import { Connection } from '../../src/Connection' import { QueryClient } from '../../src/QueryClient' +import { TransactionClient } from '../../src/TransactionClient' import { getConfig, setup, cleanup, resetTables, getLogger } from '../../test-helpers' test.group('Transaction | query', (group) => { @@ -98,4 +99,50 @@ test.group('Transaction | query', (group) => { assert.lengthOf(results, 1) assert.equal(results[0].username, 'virk') }) + + test('execute before and after commit hooks', async (assert) => { + const stack: string[] = [] + const connection = new Connection('primary', getConfig(), getLogger()) + connection.connect() + + const db = await new QueryClient('dual', connection).transaction() + + db.hooks.before('commit', (trx) => { + stack.push('before') + assert.instanceOf(trx, TransactionClient) + }) + + db.hooks.after('commit', (trx) => { + stack.push('after') + assert.instanceOf(trx, TransactionClient) + }) + + await db.insertQuery().table('users').insert({ username: 'virk' }) + await db.commit() + assert.deepEqual(db.hooks['_hooks'], {}) + assert.deepEqual(stack, ['before', 'after']) + }) + + test('execute before and after rollback hooks', async (assert) => { + const stack: string[] = [] + const connection = new Connection('primary', getConfig(), getLogger()) + connection.connect() + + const db = await new QueryClient('dual', connection).transaction() + + db.hooks.before('rollback', (trx) => { + stack.push('before') + assert.instanceOf(trx, TransactionClient) + }) + + db.hooks.after('rollback', (trx) => { + stack.push('after') + assert.instanceOf(trx, TransactionClient) + }) + + await db.insertQuery().table('users').insert({ username: 'virk' }) + await db.rollback() + assert.deepEqual(db.hooks['_hooks'], {}) + assert.deepEqual(stack, ['before', 'after']) + }) })