diff --git a/adonis-typings/model.ts b/adonis-typings/model.ts index 598f915a..b4f06ce1 100644 --- a/adonis-typings/model.ts +++ b/adonis-typings/model.ts @@ -10,7 +10,11 @@ declare module '@ioc:Adonis/Lucid/Model' { import { ChainableContract } from '@ioc:Adonis/Lucid/DatabaseQueryBuilder' import { ProfilerContract, ProfilerRowContract } from '@ioc:Adonis/Core/Profiler' - import { QueryClientContract, ExcutableQueryBuilderContract } from '@ioc:Adonis/Lucid/Database' + import { + QueryClientContract, + TransactionClientContract, + ExcutableQueryBuilderContract, + } from '@ioc:Adonis/Lucid/Database' /** * Represents a single column on the model @@ -173,6 +177,7 @@ declare module '@ioc:Adonis/Lucid/Model' { $sideloaded: ModelObject $primaryKeyValue?: any $options?: ModelOptions + $trx?: TransactionClientContract, /** * Gives an option to the end user to define constraints for update, insert diff --git a/src/Orm/Adapter/index.ts b/src/Orm/Adapter/index.ts index a0943887..6c7365b2 100644 --- a/src/Orm/Adapter/index.ts +++ b/src/Orm/Adapter/index.ts @@ -51,7 +51,7 @@ export class Adapter implements AdapterContract { */ public async insert (instance: ModelContract, attributes: any) { const modelConstructor = instance.constructor as unknown as ModelConstructorContract - const client = this._getModelClient(modelConstructor, instance.$options) + const client = instance.$trx ? instance.$trx : this._getModelClient(modelConstructor, instance.$options) const query = instance.$getQueryFor('insert', client) const result = await query.insert(attributes) @@ -65,7 +65,7 @@ export class Adapter implements AdapterContract { */ public async update (instance: ModelContract, dirty: any) { const modelConstructor = instance.constructor as unknown as ModelConstructorContract - const client = this._getModelClient(modelConstructor, instance.$options) + const client = instance.$trx ? instance.$trx : this._getModelClient(modelConstructor, instance.$options) const query = instance.$getQueryFor('update', client) await query.update(dirty) @@ -76,7 +76,7 @@ export class Adapter implements AdapterContract { */ public async delete (instance: ModelContract) { const modelConstructor = instance.constructor as unknown as ModelConstructorContract - const client = this._getModelClient(modelConstructor, instance.$options) + const client = instance.$trx ? instance.$trx : this._getModelClient(modelConstructor, instance.$options) const query = instance.$getQueryFor('delete', client) await query.del() } diff --git a/src/Orm/BaseModel/index.ts b/src/Orm/BaseModel/index.ts index a9ea1817..ce877f39 100644 --- a/src/Orm/BaseModel/index.ts +++ b/src/Orm/BaseModel/index.ts @@ -13,7 +13,7 @@ import pluralize from 'pluralize' import { isObject, snakeCase } from 'lodash' import { Exception } from '@poppinss/utils' -import { QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { QueryClientContract, TransactionClientContract } from '@ioc:Adonis/Lucid/Database' import { CacheNode, ColumnNode, @@ -383,6 +383,20 @@ export class BaseModel implements ModelContract { return new Proxy(this, proxyHandler) } + /** + * Reference to transaction that will be used for performing queries on a given + * model instance. + */ + private _trx?: TransactionClientContract + + /** + * The transaction listener listens for the `commit` and `rollback` events and + * cleansup the `$trx` reference + */ + private _transactionListener = function listener () { + this.$trx = undefined + }.bind(this) + /** * When `fill` method is called, then we may have a situation where it * removed the values which exists in `original` and hence the dirty @@ -555,6 +569,38 @@ export class BaseModel implements ModelContract { return Object.keys(this.$dirty).length > 0 } + /** + * Returns the transaction + */ + public get $trx (): TransactionClientContract | undefined { + return this._trx + } + + /** + * Set the trx to be used by the model to executing queries + */ + public set $trx (trx: TransactionClientContract | undefined) { + if (!trx) { + this._trx = undefined + return + } + + /** + * Remove old listeners + */ + if (this.$trx) { + this.$trx.removeListener('commit', this._transactionListener) + this.$trx.removeListener('rollback', this._transactionListener) + } + + /** + * Store reference to the transaction + */ + this._trx = trx + this._trx.once('commit', this._transactionListener) + this._trx.once('rollback', this._transactionListener) + } + /** * Sets the options on the model instance */ diff --git a/test/orm/adapter.spec.ts b/test/orm/adapter.spec.ts index 3df04003..9bf77238 100644 --- a/test/orm/adapter.spec.ts +++ b/test/orm/adapter.spec.ts @@ -143,4 +143,167 @@ test.group('Adapter', (group) => { assert.deepEqual(users[0].$attributes, { id: 2, username: 'nikk' }) assert.deepEqual(users[1].$attributes, { id: 1, username: 'virk' }) }) + + test('use transaction client set on the model for the insert', async (assert) => { + const BaseModel = getBaseModel(ormAdapter()) + + class User extends BaseModel { + public static $table = 'users' + + @column({ primary: true }) + public id: number + + @column() + public username: string + } + + User.$boot() + const db = getDb() + const trx = await db.transaction() + + const user = new User() + user.$trx = trx + user.username = 'virk' + await user.save() + await trx.commit() + + const totalUsers = await db.from('users').count('*', 'total') + + assert.equal(totalUsers[0].total, 1) + assert.exists(user.id) + assert.isUndefined(user.$trx) + assert.deepEqual(user.$attributes, { username: 'virk', id: user.id }) + assert.isFalse(user.$isDirty) + assert.isTrue(user.$persisted) + }) + + test('do not insert when transaction rollbacks', async (assert) => { + const BaseModel = getBaseModel(ormAdapter()) + + class User extends BaseModel { + public static $table = 'users' + + @column({ primary: true }) + public id: number + + @column() + public username: string + } + + User.$boot() + const db = getDb() + const trx = await db.transaction() + + const user = new User() + user.$trx = trx + user.username = 'virk' + await user.save() + await trx.rollback() + + const totalUsers = await db.from('users').count('*', 'total') + + assert.equal(totalUsers[0].total, 0) + assert.exists(user.id) + assert.isUndefined(user.$trx) + assert.deepEqual(user.$attributes, { username: 'virk', id: user.id }) + assert.isFalse(user.$isDirty) + assert.isTrue(user.$persisted) + }) + + test('cleanup old trx event listeners when transaction is updated', async (assert) => { + const BaseModel = getBaseModel(ormAdapter()) + + class User extends BaseModel { + public static $table = 'users' + + @column({ primary: true }) + public id: number + + @column() + public username: string + } + + User.$boot() + const db = getDb() + const trx = await db.transaction() + const trx1 = await trx.transaction() + + const user = new User() + user.$trx = trx1 + user.$trx = trx + user.username = 'virk' + + await trx1.rollback() + assert.deepEqual(user.$trx, trx) + await trx.rollback() + }) + + test('use transaction client set on the model for the update', async (assert) => { + const BaseModel = getBaseModel(ormAdapter()) + + class User extends BaseModel { + public static $table = 'users' + + @column({ primary: true }) + public id: number + + @column() + public username: string + } + User.$boot() + + const user = new User() + user.username = 'virk' + await user.save() + + assert.exists(user.id) + assert.deepEqual(user.$attributes, { username: 'virk', id: user.id }) + assert.isFalse(user.$isDirty) + assert.isTrue(user.$persisted) + + const db = getDb() + const trx = await db.transaction() + user.$trx = trx + user.username = 'nikk' + await user.save() + await trx.rollback() + + const users = await db.from('users') + assert.lengthOf(users, 1) + assert.equal(users[0].username, 'virk') + }) + + test('use transaction client set on the model for the delete', async (assert) => { + const db = getDb() + const BaseModel = getBaseModel(ormAdapter()) + + class User extends BaseModel { + public static $table = 'users' + + @column({ primary: true }) + public id: number + + @column() + public username: string + } + User.$boot() + + const user = new User() + user.username = 'virk' + await user.save() + + assert.exists(user.id) + assert.deepEqual(user.$attributes, { username: 'virk', id: user.id }) + assert.isFalse(user.$isDirty) + assert.isTrue(user.$persisted) + + const trx = await db.transaction() + user.$trx = trx + + await user.delete() + await trx.rollback() + + const users = await db.from('users').select('*') + assert.lengthOf(users, 1) + }) })