Skip to content

Commit

Permalink
feat(postgres): ensure consistency of prepare order, fix #69 (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hieuzest authored Dec 30, 2023
1 parent 5c4857b commit b15354a
Showing 1 changed file with 58 additions and 37 deletions.
95 changes: 58 additions & 37 deletions packages/postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ interface TableInfo {
commit_action: null
}

interface QueryTask {
sql: string
resolve: (value: any) => void
reject: (reason: unknown) => void
}

function escapeId(value: string) {
return '"' + value.replace(/"/g, '""') + '"'
}
Expand Down Expand Up @@ -436,6 +442,7 @@ export class PostgresDriver extends Driver {

private session?: postgres.TransactionSql
private _counter = 0
private _queryTasks: QueryTask[] = []

constructor(database: Database, config: PostgresDriver.Config) {
super(database)
Expand Down Expand Up @@ -474,18 +481,39 @@ export class PostgresDriver extends Driver {
})
}

queue<T extends any[] = any[]>(sql: string, values?: any): Promise<T> {
if (this.session) {
return this.query(sql)
}

return new Promise<any>((resolve, reject) => {
this._queryTasks.push({ sql, resolve, reject })
process.nextTick(() => this._flushTasks())
})
}

private async _flushTasks() {
const tasks = this._queryTasks
if (!tasks.length) return
this._queryTasks = []

try {
let results = await this.query(tasks.map(task => task.sql).join(';\n')) as any
if (tasks.length === 1) results = [results]
tasks.forEach((task, index) => {
task.resolve(results[index])
})
} catch (error) {
tasks.forEach(task => task.reject(error))
}
}

async prepare(name: string) {
const [columns, constraints] = await Promise.all([
this.query<ColumnInfo[]>(`
SELECT *
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = ${this.sql.escape(name)}`),
this.query<ConstraintInfo[]>(`
SELECT *
FROM information_schema.table_constraints
WHERE table_schema = 'public'
AND table_name = ${this.sql.escape(name)}`),
this.queue<ColumnInfo[]>(`SELECT * FROM information_schema.columns WHERE table_schema = 'public' AND table_name = ${this.sql.escape(name)}`),
this.queue<ConstraintInfo[]>(
`SELECT * FROM information_schema.table_constraints WHERE table_schema = 'public' AND table_name = ${this.sql.escape(name)}`,
),
])

const table = this.model(name)
Expand Down Expand Up @@ -587,27 +615,19 @@ export class PostgresDriver extends Driver {
await this.query(`DROP TABLE IF EXISTS ${escapeId(table)} CASCADE`)
return
}
const tables: TableInfo[] = await this.query(`
SELECT *
FROM information_schema.tables
WHERE table_schema = 'public'`)
const tables: TableInfo[] = await this.queue(`SELECT * FROM information_schema.tables WHERE table_schema = 'public'`)
if (!tables.length) return
await this.query(`DROP TABLE IF EXISTS ${tables.map(t => escapeId(t.table_name)).join(',')} CASCADE`)
}

async stats(): Promise<Partial<Driver.Stats>> {
const names = Object.keys(this.database.tables)
const tables = (await this.query<TableInfo[]>(`
SELECT *
FROM information_schema.tables
WHERE table_schema = 'public'`))
const tables = (await this.queue<TableInfo[]>(`SELECT * FROM information_schema.tables WHERE table_schema = 'public'`))
.map(t => t.table_name).filter(name => names.includes(name))
const tableStats = await this.query(
tables.map(name => {
return `SELECT '${name}' AS name,
pg_total_relation_size('${escapeId(name)}') AS size,
COUNT(*) AS count FROM ${escapeId(name)}`
}).join(' UNION '),
const tableStats = await this.queue(
tables.map(
(name) => `SELECT '${name}' AS name, pg_total_relation_size('${escapeId(name)}') AS size, COUNT(*) AS count FROM ${escapeId(name)}`,
).join(' UNION '),
).then(s => s.map(t => [t.name, { size: +t.size, count: +t.count }]))

return {
Expand All @@ -620,7 +640,7 @@ export class PostgresDriver extends Driver {
const builder = new PostgresBuilder(sel.tables)
const query = builder.get(sel)
if (!query) return []
return this.query(query).then(data => {
return this.queue(query).then(data => {
return data.map(row => builder.load(sel.model, row))
})
}
Expand All @@ -630,7 +650,7 @@ export class PostgresDriver extends Driver {
const inner = builder.get(sel.table as Selection, true, true)
const output = builder.parseEval(expr, false)
const ref = isBracketed(inner) ? sel.ref : ''
const [data] = await this.query(`SELECT ${output} AS value FROM ${inner} ${ref}`)
const [data] = await this.queue(`SELECT ${output} AS value FROM ${inner} ${ref}`)
return builder.load(data?.value)
}

Expand Down Expand Up @@ -665,10 +685,11 @@ export class PostgresDriver extends Driver {
const builder = new PostgresBuilder(sel.tables)
const formatted = builder.dump(model, data)
const keys = Object.keys(formatted)
const [row] = await this.query(`
INSERT INTO ${builder.escapeId(table)} (${keys.map(builder.escapeId).join(', ')})
VALUES (${keys.map(key => builder.escape(formatted[key])).join(', ')})
RETURNING *`)
const [row] = await this.query([
`INSERT INTO ${builder.escapeId(table)} (${keys.map(builder.escapeId).join(', ')})`,
`VALUES (${keys.map(key => builder.escape(formatted[key])).join(', ')})`,
`RETURNING *`,
].join(' '))
return builder.load(model, row)
}

Expand Down Expand Up @@ -731,13 +752,13 @@ export class PostgresDriver extends Driver {
return `${escaped} = ${value}`
}).join(', ')

const result = await this.query(`
INSERT INTO ${builder.escapeId(table)} (${initFields.map(builder.escapeId).join(', ')})
VALUES (${insertion.map(item => formatValues(table, item, initFields)).join('), (')})
ON CONFLICT (${keys.map(builder.escapeId).join(', ')})
DO UPDATE SET ${update}, _pg_mtime = ${mtime}
RETURNING _pg_mtime as rtime
`)
const result = await this.query([
`INSERT INTO ${builder.escapeId(table)} (${initFields.map(builder.escapeId).join(', ')})`,
`VALUES (${insertion.map(item => formatValues(table, item, initFields)).join('), (')})`,
`ON CONFLICT (${keys.map(builder.escapeId).join(', ')})`,
`DO UPDATE SET ${update}, _pg_mtime = ${mtime}`,
`RETURNING _pg_mtime as rtime`,
].join(' '))
return { inserted: result.filter(({ rtime }) => +rtime !== mtime).length, matched: result.filter(({ rtime }) => +rtime === mtime).length }
}

Expand Down

0 comments on commit b15354a

Please sign in to comment.