Skip to content

Commit

Permalink
feat(minato): support withTransaction (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hieuzest authored Nov 23, 2023
1 parent cb94b78 commit 709fb07
Show file tree
Hide file tree
Showing 10 changed files with 405 additions and 21 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,41 @@ jobs:
files: ./coverage/coverage-final.json
name: codecov

mongo-replica:
name: mongo-replica:${{ matrix.mongo-version }} (${{ matrix.node-version }})
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
mongo-version:
- latest
node-version: [18, 20]

steps:
- name: Check out
uses: actions/checkout@v4
- name: Set up Node
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version }}
- name: Set up MongoDB
uses: supercharge/mongodb-github-action@1.10.0
with:
mongodb-version: ${{ matrix.mongo-version }}
mongodb-replica-set: test-rs
mongodb-port: 27017
- name: Install
run: yarn
- name: Unit Test
run: yarn test:json mongo --+transaction.abort
- name: Report Coverage
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage/coverage-final.json
name: codecov

test:
name: ${{ matrix.driver-name }} (${{ matrix.node-version }})
runs-on: ubuntu-latest
Expand Down
21 changes: 21 additions & 0 deletions packages/core/src/driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ type JoinCallback2<S, U extends Dict<TableLike<S>>> = (args: {
[K in keyof U]: Row<TableType<S, U[K]>>
}) => Eval.Expr<boolean>

const kTransaction = Symbol('transaction')

export class Database<S = any> {
public tables: { [K in Keys<S>]: Model<S[K]> } = Object.create(null)
public drivers: Record<keyof any, Driver> = Object.create(null)
Expand Down Expand Up @@ -184,6 +186,24 @@ export class Database<S = any> {
return await sel._action('upsert', upsert, keys).execute()
}

async withTransaction(callback: (database: Database<S>) => Promise<void>): Promise<void>
async withTransaction<T extends Keys<S>>(table: T, callback: (database: Database<S>) => Promise<void>): Promise<void>
async withTransaction(arg: any, ...args: any[]) {
if (this[kTransaction]) throw new Error('nested transactions are not supported')
const [table, callback] = typeof arg === 'string' ? [arg, ...args] : [null, arg, ...args]
const driver = this.getDriver(table)
return await driver.withTransaction(async (session) => {
const database = new Proxy(this, {
get(target, p, receiver) {
if (p === kTransaction) return true
else if (p === 'getDriver') return () => session
else return Reflect.get(target, p, receiver)
},
})
await callback(database)
})
}

async stopAll() {
const drivers = Object.values(this.drivers)
this.drivers = Object.create(null)
Expand Down Expand Up @@ -225,6 +245,7 @@ export abstract class Driver {
abstract remove(sel: Selection.Mutable): Promise<Driver.WriteResult>
abstract create(sel: Selection.Mutable, data: any): Promise<any>
abstract upsert(sel: Selection.Mutable, data: any[], keys: string[]): Promise<Driver.WriteResult>
abstract withTransaction(callback: (driver: Driver) => Promise<void>): Promise<void>

constructor(public database: Database) {}

Expand Down
8 changes: 8 additions & 0 deletions packages/memory/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ export class MemoryDriver extends Driver {
this.$save(table)
return result
}

async withTransaction(callback: (session: Driver) => Promise<void>) {
const data = clone(this.#store)
await callback(this).then(undefined, (e) => {
this.#store = data
throw e
})
}
}

export default MemoryDriver
48 changes: 35 additions & 13 deletions packages/mongo/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BSONType, Collection, Db, IndexDescription, MongoClient, MongoClientOptions, MongoError } from 'mongodb'
import { BSONType, ClientSession, Collection, Db, IndexDescription, MongoClient, MongoClientOptions, MongoError } from 'mongodb'
import { Dict, isNullable, makeArray, noop, omit, pick } from 'cosmokit'
import { Database, Driver, Eval, executeEval, executeUpdate, Query, RuntimeError, Selection } from '@minatojs/core'
import { URLSearchParams } from 'url'
Expand Down Expand Up @@ -46,6 +46,7 @@ export class MongoDriver extends Driver {
public db!: Db
public mongo = this

private session?: ClientSession
private _evalTasks: EvalTask[] = []
private _createTasks: Dict<Promise<void>> = {}

Expand Down Expand Up @@ -254,13 +255,13 @@ export class MongoDriver extends Driver {

async drop(table?: string) {
if (table) {
await this.db.dropCollection(table)
await this.db.dropCollection(table, { session: this.session })
return
}
await Promise.all([
'_fields',
...Object.keys(this.database.tables),
].map(name => this.db.dropCollection(name)))
].map(name => this.db.dropCollection(name, { session: this.session })))
}

private async _collStats() {
Expand All @@ -269,7 +270,7 @@ export class MongoDriver extends Driver {
const coll = this.db.collection(name)
const [{ storageStats: { count, size } }] = await coll.aggregate([{
$collStats: { storageStats: {} },
}]).toArray()
}], { session: this.session }).toArray()
return [coll.collectionName, { count, size }] as const
}))
return Object.fromEntries(entries)
Expand Down Expand Up @@ -383,7 +384,7 @@ export class MongoDriver extends Driver {
logger.debug('%s %s', result.table, JSON.stringify(result.pipeline))
return this.db
.collection(result.table)
.aggregate(result.pipeline, { allowDiskUse: true })
.aggregate(result.pipeline, { allowDiskUse: true, session: this.session })
.toArray()
}

Expand Down Expand Up @@ -420,7 +421,7 @@ export class MongoDriver extends Driver {
try {
const results = await this.db
.collection('_fields')
.aggregate(stages, { allowDiskUse: true })
.aggregate(stages, { allowDiskUse: true, session: this.session })
.toArray()
data = Object.assign({}, ...results)
} catch (error) {
Expand Down Expand Up @@ -455,15 +456,15 @@ export class MongoDriver extends Driver {
...$unset.length ? [{ $unset }] : [],
{ $set },
...transformer.walkedKeys.length ? [{ $unset: [tempKey] }] : [],
])
], { session: this.session })
return { matched: result.matchedCount, modified: result.modifiedCount }
}

async remove(sel: Selection.Mutable) {
const { query, table } = sel
const filter = this.transformQuery(query, table)
if (!filter) return {}
const result = await this.db.collection(table).deleteMany(filter)
const result = await this.db.collection(table).deleteMany(filter, { session: this.session })
return { removed: result.deletedCount }
}

Expand All @@ -488,7 +489,7 @@ export class MongoDriver extends Driver {
const doc = await this.db.collection('_fields').findOneAndUpdate(
{ table, field: primary },
{ $inc: { autoInc: missing.length } },
{ upsert: true },
{ session: this.session, upsert: true },
)
for (let i = 1; i <= missing.length; i++) {
missing[i - 1][primary] = (doc!.autoInc ?? 0) + i
Expand All @@ -507,7 +508,7 @@ export class MongoDriver extends Driver {
try {
data = model.create(data)
const copy = this.unpatchVirtual(table, { ...data })
const insertedId = (await coll.insertOne(copy)).insertedId
const insertedId = (await coll.insertOne(copy, { session: this.session })).insertedId
if (this.shouldFillPrimary(table)) {
return { ...data, [model.primary as string]: insertedId }
} else return data
Expand All @@ -531,7 +532,7 @@ export class MongoDriver extends Driver {
$or: data.map((item) => {
return this.transformQuery(pick(item, keys), table)!
}),
}).toArray()).map(row => this.patchVirtual(table, row))
}, { session: this.session }).toArray()).map(row => this.patchVirtual(table, row))

const bulk = coll.initializeUnorderedBulkOp()
const insertion: any[] = []
Expand All @@ -552,7 +553,7 @@ export class MongoDriver extends Driver {
const copy = executeUpdate(model.create(), update, ref)
bulk.insert(this.unpatchVirtual(table, copy))
}
const result = await bulk.execute()
const result = await bulk.execute({ session: this.session })
return { inserted: result.insertedCount + result.upsertedCount, matched: result.matchedCount, modified: result.modifiedCount }
} else {
const bulk = coll.initializeUnorderedBulkOp()
Expand All @@ -578,10 +579,31 @@ export class MongoDriver extends Driver {
...transformer.walkedKeys.length ? [{ $unset: [tempKey] }] : [],
])
}
const result = await bulk.execute()
const result = await bulk.execute({ session: this.session })
return { inserted: result.insertedCount + result.upsertedCount, matched: result.matchedCount, modified: result.modifiedCount }
}
}

async withTransaction(callback: (session: Driver) => Promise<void>) {
await this.client.withSession(async (session) => {
const driver = new Proxy(this, {
get(target, p, receiver) {
if (p === 'session') return session
else return Reflect.get(target, p, receiver)
},
})
await session.withTransaction(async () => callback(driver)).catch(async e => {
if (e instanceof MongoError && e.code === 20 && e.message.includes('Transaction numbers')) {
logger.warn(`MongoDB is currently running as standalone server, transaction is disabled.
Convert to replicaSet to enable the feature.
See https://www.mongodb.com/docs/manual/tutorial/convert-standalone-to-replica-set/`)
await callback(this)
return
}
throw e
})
})
}
}

export default MongoDriver
6 changes: 5 additions & 1 deletion packages/mongo/tests/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ describe('@minatojs/driver-mongo', () => {
logger.level = 2
})

test(database)
test(database, process.argv.includes('--enable-transaction-abort') ? {} : {
transaction: {
abort: false
}
})
})
27 changes: 24 additions & 3 deletions packages/mysql/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { createPool, format } from '@vlasky/mysql'
import type { OkPacket, Pool, PoolConfig } from 'mysql'
import type { OkPacket, Pool, PoolConfig, PoolConnection } from 'mysql'
import { Dict, difference, makeArray, pick, Time } from 'cosmokit'
import { Database, Driver, Eval, executeUpdate, Field, isEvalExpr, Model, RuntimeError, Selection } from '@minatojs/core'
import { Builder, escapeId } from '@minatojs/sql-utils'
Expand Down Expand Up @@ -225,6 +225,7 @@ export class MySQLDriver extends Driver {
public config: MySQLDriver.Config
public sql: MySQLBuilder

private session?: PoolConnection
private _compat: Compat = {}
private _queryTasks: QueryTask[] = []

Expand Down Expand Up @@ -443,7 +444,7 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH
const error = new Error()
return new Promise((resolve, reject) => {
if (debug) logger.debug('> %s', sql)
this.pool.query(sql, (err: Error, results) => {
;(this.session ?? this.pool).query(sql, (err: Error, results) => {
if (!err) return resolve(results)
logger.warn('> %s', sql)
if (err['code'] === 'ER_DUP_ENTRY') {
Expand All @@ -457,7 +458,7 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH

queue<T = any>(sql: string, values?: any): Promise<T> {
sql = format(sql, values)
if (!this.config.multipleStatements) {
if (this.session || !this.config.multipleStatements) {
return this.query(sql)
}

Expand Down Expand Up @@ -623,6 +624,26 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH
const records = +(/^&Records:\s*(\d+)/.exec(result.message)?.[1] ?? result.affectedRows)
return { inserted: records - result.changedRows, matched: result.changedRows, modified: result.affectedRows - records }
}

async withTransaction(callback: (session: Driver) => Promise<void>) {
return new Promise<void>((resolve, reject) => {
this.pool.getConnection((err, conn) => {
if (err) {
logger.warn('getConnection failed: ', err)
return
}
const driver = new Proxy(this, {
get(target, p, receiver) {
if (p === 'session') return conn
else return Reflect.get(target, p, receiver)
},
})
conn.beginTransaction(() => callback(driver)
.then(() => conn.commit(() => resolve()), (e) => conn.rollback(() => reject(e)))
.finally(() => conn.release()))
})
})
}
}

export default MySQLDriver
23 changes: 19 additions & 4 deletions packages/postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ export class PostgresDriver extends Driver {
public sql!: postgres.Sql
public config: PostgresDriver.Config

#counter = 0
private session?: postgres.TransactionSql
private _counter = 0

constructor(database: Database, config: PostgresDriver.Config) {
super(database)
Expand All @@ -391,7 +392,7 @@ export class PostgresDriver extends Driver {
}

async query(sql: string) {
return await this.sql.unsafe(sql).catch(e => {
return await (this.session ?? this.sql).unsafe(sql).catch(e => {
logger.warn('> %s', sql)
throw e
})
Expand Down Expand Up @@ -557,8 +558,8 @@ export class PostgresDriver extends Driver {
const builder = new PostgresBuilder(tables)
builder.upsert(table)

this.#counter = (this.#counter + 1) % 256
const mtime = Date.now() * 256 + this.#counter
this._counter = (this._counter + 1) % 256
const mtime = Date.now() * 256 + this._counter
const merged = {}
const insertion = data.map((item) => {
Object.assign(merged, item)
Expand Down Expand Up @@ -618,6 +619,20 @@ export class PostgresDriver extends Driver {
`)
return { inserted: result.filter(({ rtime }) => +rtime !== mtime).length, matched: result.filter(({ rtime }) => +rtime === mtime).length }
}

async withTransaction(callback: (session: Driver) => Promise<void>) {
return await this.sql.begin(async (conn) => {
const driver = new Proxy(this, {
get(target, p, receiver) {
if (p === 'session') return conn
else return Reflect.get(target, p, receiver)
},
})

await callback(driver)
await conn.unsafe(`COMMIT`)
})
}
}

export default PostgresDriver
10 changes: 10 additions & 0 deletions packages/sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ export class SQLiteDriver extends Driver {
sql: Builder
beforeUnload?: () => void

private _transactionTask?

constructor(database: Database, public config: SQLiteDriver.Config) {
super(database)

Expand Down Expand Up @@ -438,6 +440,14 @@ export class SQLiteDriver extends Driver {
}
return result
}

async withTransaction(callback: (session: Driver) => Promise<void>) {
if (this._transactionTask) await this._transactionTask
return this._transactionTask = new Promise<void>((resolve, reject) => {
this.#run('BEGIN TRANSACTION')
callback(this).then(() => resolve(this.#run('COMMIT')), (e) => (this.#run('ROLLBACK'), reject(e)))
})
}
}

export default SQLiteDriver
Loading

0 comments on commit 709fb07

Please sign in to comment.