From 2e852445a35ccd52f7299f3061f615de341cb6b4 Mon Sep 17 00:00:00 2001 From: Hieuzest Date: Wed, 3 Apr 2024 23:37:10 +0800 Subject: [PATCH] feat(minato): impl typed `Expr` and `Field` (#74) --- packages/core/src/database.ts | 100 +++- packages/core/src/driver.ts | 22 +- packages/core/src/eval.ts | 114 +++-- packages/core/src/index.ts | 1 + packages/core/src/model.ts | 143 ++++-- packages/core/src/selection.ts | 23 +- packages/core/src/type.ts | 83 +++ packages/core/src/utils.ts | 72 ++- packages/memory/src/index.ts | 4 +- packages/memory/tests/index.spec.ts | 14 + packages/mongo/src/{utils.ts => builder.ts} | 99 +++- packages/mongo/src/index.ts | 55 +- packages/mongo/tests/index.spec.ts | 5 + packages/mongo/tests/migration.spec.ts | 24 +- packages/mysql/src/builder.ts | 126 +++-- packages/mysql/src/index.ts | 87 ++-- packages/postgres/src/builder.ts | 191 +++---- packages/postgres/src/index.ts | 151 +++--- packages/sql-utils/src/index.ts | 299 ++++++----- packages/sqlite/src/builder.ts | 78 +-- packages/sqlite/src/index.ts | 54 +- packages/tests/package.json | 4 +- packages/tests/src/index.ts | 2 + packages/tests/src/model.ts | 529 ++++++++++++++++++++ packages/tests/src/object.ts | 5 + packages/tests/src/query.ts | 12 +- packages/tests/src/selection.ts | 8 + packages/tests/src/setup.ts | 60 ++- packages/tests/src/shape.ts | 3 + packages/tests/src/shims.d.ts | 46 ++ packages/tests/src/update.ts | 81 ++- 31 files changed, 1897 insertions(+), 598 deletions(-) create mode 100644 packages/core/src/type.ts rename packages/mongo/src/{utils.ts => builder.ts} (80%) create mode 100644 packages/tests/src/model.ts create mode 100644 packages/tests/src/shims.d.ts diff --git a/packages/core/src/database.ts b/packages/core/src/database.ts index 963882bc..0a12eefe 100644 --- a/packages/core/src/database.ts +++ b/packages/core/src/database.ts @@ -1,11 +1,12 @@ -import { Dict, Intersect, makeArray, MaybeArray, valueMap } from 'cosmokit' +import { Dict, Intersect, makeArray, mapValues, MaybeArray, valueMap } from 'cosmokit' import { Context, Service, Spread } from 'cordis' -import { Flatten, Indexable, Keys, Row } from './utils.ts' +import { Flatten, Indexable, Keys, randomId, Row, unravel } from './utils.ts' import { Selection } from './selection.ts' import { Field, Model } from './model.ts' import { Driver } from './driver.ts' import { Eval, Update } from './eval.ts' import { Query } from './query.ts' +import { Type } from './type.ts' type TableLike = Keys | Selection @@ -47,12 +48,13 @@ export namespace Join2 { const kTransaction = Symbol('transaction') -export class Database extends Service { +export class Database extends Service { static [Service.provide] = 'model' static [Service.immediate] = true public tables: { [K in Keys]: Model } = Object.create(null) public drivers: Record = Object.create(null) + public types: Dict = Object.create(null) public migrating = false private prepareTasks: Dict> = Object.create(null) private migrateTasks: Dict> = Object.create(null) @@ -90,21 +92,103 @@ export class Database extends Service field?.transformers?.forEach(x => driver.define(x))) + + await driver.prepare(name) } - extend>(name: K, fields: Field.Extension, config: Partial> = {}) { + extend>(name: K, fields: Field.Extension, config: Partial> = {}) { let model = this.tables[name] if (!model) { model = this.tables[name] = new Model(name) // model.driver = config.driver } + Object.entries(fields).forEach(([key, field]: [string, any]) => { + const transformer = [] + this.parseField(field, transformer, undefined, value => field = fields[key] = value) + if (typeof field === 'object') field.transformers = transformer + }) model.extend(fields, config) this.prepareTasks[name] = this.prepare(name) ;(this.ctx as Context).emit('model', name) } - migrate>(name: K, fields: Field.Extension, callback: Model.Migration) { + private parseField(field: any, transformers: Driver.Transformer[] = [], setInitial?: (value) => void, setField?: (value) => void): Type { + if (field === 'array') { + setInitial?.([]) + setField?.({ type: 'json', initial: [] }) + return Type.Array() + } else if (field === 'object') { + setInitial?.({}) + setField?.({ type: 'json', initial: {} }) + return Type.Object() + } else if (typeof field === 'string' && this.types[field]) { + transformers.push({ + types: [this.types[field].type], + load: this.types[field].load, + dump: this.types[field].dump, + }) + setInitial?.(this.types[field].initial) + setField?.(this.types[field]) + return Type.fromField(field as any) + } else if (typeof field === 'object' && field.load && field.dump) { + const name = this.define(field) + transformers.push({ + types: [name as any], + load: field.load, + dump: field.dump, + }) + // for transform type, intentionally assign a null initial on default + // setInitial?.(Field.getInitial(field.type, field.initial)) + setInitial?.(field.initial) + setField?.({ ...field, deftype: field.type, type: name }) + return Type.fromField(name as any) + } else if (typeof field === 'object' && field.type === 'object') { + const inner = unravel(field.inner, value => (value.type = 'object', value.inner ??= {})) + const initial = Object.create(null) + const res = Type.Object(mapValues(inner, (x, k) => this.parseField(x, transformers, value => initial[k] = value))) + setInitial?.(Field.getInitial('json', initial)) + setField?.({ initial: Field.getInitial('json', initial), ...field, deftype: 'json', type: res }) + return res + } else if (typeof field === 'object' && field.type === 'array') { + const res = field.inner ? Type.Array(this.parseField(field.inner, transformers)) : Type.Array() + setInitial?.([]) + setField?.({ initial: [], ...field, deftype: 'json', type: res }) + return res + } else if (typeof field === 'object') { + setInitial?.(Field.getInitial(field.type.split('(')[0], field.initial)) + setField?.(field) + return Type.fromField(field.type.split('(')[0]) + } else { + setInitial?.(Field.getInitial(field.split('(')[0])) + setField?.(field) + return Type.fromField(field.split('(')[0]) + } + } + + define, Field.Type>>(name: K, field: Field.Transform): K + define(field: Field.Transform): Field.NewType + define(name: any, field?: any) { + if (typeof name === 'object') { + field = name + name = undefined + } + + if (name && this.types[name]) throw new Error(`type "${name}" already defined`) + if (!name) while (this.types[name = '_define_' + randomId()]); + this[Context.current].effect(() => { + this.types[name] = { deftype: field.type, ...field, type: name } + return () => delete this.types[name] + }) + return name as any + } + + migrate>(name: K, fields: Field.Extension, callback: Model.Migration) { this.extend(name, fields, { callback }) } @@ -183,8 +267,8 @@ export class Database extends Service) => Promise): Promise - async withTransaction>(table: T, callback: (database: Database) => Promise): Promise + async withTransaction(callback: (database: this) => Promise): Promise + async withTransaction>(table: T, callback: (database: this) => Promise): Promise async withTransaction(arg: any, ...args: any[]) { if (this[kTransaction]) throw new Error('nested transactions are not supported') const [table, callback] = typeof arg === 'string' ? [arg, ...args] : [null, arg, ...args] diff --git a/packages/core/src/driver.ts b/packages/core/src/driver.ts index cf8ab1a1..b089cb1f 100644 --- a/packages/core/src/driver.ts +++ b/packages/core/src/driver.ts @@ -2,8 +2,9 @@ import { Awaitable, Dict, valueMap } from 'cosmokit' import { Context, Logger } from 'cordis' import { Eval, Update } from './eval.ts' import { Direction, Modifier, Selection } from './selection.ts' -import { Model } from './model.ts' +import { Field, Model } from './model.ts' import { Database } from './database.ts' +import { Type } from './type.ts' export namespace Driver { export interface Stats { @@ -31,6 +32,12 @@ export namespace Driver { modified?: number removed?: number } + + export interface Transformer { + types: Field.Type[] + dump: (value: S) => T | null + load: (value: T) => S | null + } } export namespace Driver { @@ -56,6 +63,7 @@ export abstract class Driver { public database: Database public logger: Logger + public types: Dict = Object.create(null) constructor(public ctx: Context, public config: T) { this.database = ctx.model @@ -87,8 +95,8 @@ export abstract class Driver { if (table instanceof Selection) { if (!table.args[0].fields) return table.model const model = new Model('temp') - model.fields = valueMap(table.args[0].fields, (_, key) => ({ - type: 'expr', + model.fields = valueMap(table.args[0].fields, (expr, key) => ({ + type: Type.fromTerm(expr), })) return model } @@ -99,8 +107,8 @@ export abstract class Driver { for (const field in submodel.fields) { if (submodel.fields[field]!.deprecated) continue model.fields[`${key}.${field}`] = { - type: 'expr', - expr: { $: [key, field] } as any, + expr: Eval('', [table[key].ref, field], Type.fromField(submodel.fields[field]!)), + type: Type.fromField(submodel.fields[field]!), } } } @@ -124,6 +132,10 @@ export abstract class Driver { })) }).then(hooks.finalize).catch(hooks.error) } + + define(converter: Driver.Transformer) { + converter.types.forEach(type => this.types[type] = converter) + } } export interface MigrationHooks { diff --git a/packages/core/src/eval.ts b/packages/core/src/eval.ts index 73fa7bca..037e1e6f 100644 --- a/packages/core/src/eval.ts +++ b/packages/core/src/eval.ts @@ -1,5 +1,7 @@ -import { defineProperty, Dict, isNullable, valueMap } from 'cosmokit' +import { defineProperty, isNullable, valueMap } from 'cosmokit' import { Comparable, Flatten, isComparable, makeRegExp, Row } from './utils.ts' +import { Type } from './type.ts' +import { Field } from './model.ts' export function isEvalExpr(value: any): value is Eval.Expr { return value && Object.keys(value).some(key => key.startsWith('$')) @@ -27,6 +29,7 @@ export namespace Eval { [kExpr]: true [kType]?: T [kAggr]?: A + [Type.kType]?: Type } export type Any = Comparable | Expr @@ -49,7 +52,7 @@ export namespace Eval { } export interface Static { - (key: string, value: any): Eval.Expr + (key: string, value: any, type: Type): Eval.Expr // univeral if(cond: Any, vThen: Term, vElse: Term): Expr @@ -102,6 +105,7 @@ export namespace Eval { not: Unary // typecast + literal(value: T, type?: Field.Type | Field.NewType | string): Expr number: Unary // aggregation / json @@ -114,28 +118,31 @@ export namespace Eval { size(value: (Any | Expr)[] | Expr): Expr length(value: any[] | Expr): Expr - object>(fields: T): Expr object(row: Row.Cell): Expr + object(row: Row): Expr array(value: Expr): Expr } } -export const Eval = ((key, value) => defineProperty({ ['$' + key]: value }, kExpr, true)) as Eval.Static +export const Eval = ((key, value, type) => defineProperty(defineProperty({ ['$' + key]: value }, kExpr, true), Type.kType, type)) as Eval.Static -const operators = {} as Record<`$${keyof Eval.Static}`, (args: any, data: any) => any> +const operators = Object.create(null) as Record<`$${keyof Eval.Static}`, (args: any, data: any) => any> operators['$'] = getRecursive type UnaryCallback = T extends (value: infer R) => Eval.Expr ? (value: R, data: any[]) => S : never -function unary(key: K, callback: UnaryCallback): Eval.Static[K] { +function unary(key: K, callback: UnaryCallback, type: Type | ((...args: any[]) => Type)): Eval.Static[K] { operators[`$${key}`] = callback - return ((value: any) => Eval(key, value)) as any + return ((value: any) => Eval(key, value, typeof type === 'function' ? type(value) : type)) as any } type MultivariateCallback = T extends (...args: infer R) => Eval.Expr ? (args: R, data: any) => S : never -function multary(key: K, callback: MultivariateCallback): Eval.Static[K] { +function multary( + key: K, callback: MultivariateCallback, + type: Type | ((...args: any[]) => Type), +): Eval.Static[K] { operators[`$${key}`] = callback - return (...args: any) => Eval(key, args) as any + return (...args: any) => Eval(key, args, typeof type === 'function' ? type(...args) : type) as any } type BinaryCallback = T extends (...args: any[]) => Eval.Expr ? (...args: any[]) => S : never @@ -146,10 +153,10 @@ function comparator(key: K, callback: BinaryCallbac if (isNullable(left) || isNullable(right)) return true return callback(left.valueOf(), right.valueOf()) } - return (...args: any) => Eval(key, args) as any + return (...args: any) => Eval(key, args, Type.Boolean) as any } -Eval.switch = (branches, vDefault) => Eval('switch', { branches, default: vDefault }) +Eval.switch = (branches, vDefault) => Eval('switch', { branches, default: vDefault }, Type.fromTerm(branches[0])) operators.$switch = (args, data) => { for (const branch of args.branches) { if (executeEval(data, branch.case)) return executeEval(data, branch.then) @@ -158,25 +165,26 @@ operators.$switch = (args, data) => { } // univeral -Eval.if = multary('if', ([cond, vThen, vElse], data) => executeEval(data, cond) ? executeEval(data, vThen) : executeEval(data, vElse)) -Eval.ifNull = multary('ifNull', ([value, fallback], data) => executeEval(data, value) ?? executeEval(data, fallback)) +Eval.if = multary('if', ([cond, vThen, vElse], data) => executeEval(data, cond) ? executeEval(data, vThen) + : executeEval(data, vElse), (cond, vThen, vElse) => Type.fromTerm(vThen)) +Eval.ifNull = multary('ifNull', ([value, fallback], data) => executeEval(data, value) ?? executeEval(data, fallback), (value) => Type.fromTerm(value)) // arithmetic -Eval.add = multary('add', (args, data) => args.reduce((prev, curr) => prev + executeEval(data, curr), 0)) -Eval.mul = Eval.multiply = multary('multiply', (args, data) => args.reduce((prev, curr) => prev * executeEval(data, curr), 1)) -Eval.sub = Eval.subtract = multary('subtract', ([left, right], data) => executeEval(data, left) - executeEval(data, right)) -Eval.div = Eval.divide = multary('divide', ([left, right], data) => executeEval(data, left) / executeEval(data, right)) -Eval.mod = Eval.modulo = multary('modulo', ([left, right], data) => executeEval(data, left) % executeEval(data, right)) +Eval.add = multary('add', (args, data) => args.reduce((prev, curr) => prev + executeEval(data, curr), 0), Type.Number) +Eval.mul = Eval.multiply = multary('multiply', (args, data) => args.reduce((prev, curr) => prev * executeEval(data, curr), 1), Type.Number) +Eval.sub = Eval.subtract = multary('subtract', ([left, right], data) => executeEval(data, left) - executeEval(data, right), Type.Number) +Eval.div = Eval.divide = multary('divide', ([left, right], data) => executeEval(data, left) / executeEval(data, right), Type.Number) +Eval.mod = Eval.modulo = multary('modulo', ([left, right], data) => executeEval(data, left) % executeEval(data, right), Type.Number) // mathematic -Eval.abs = unary('abs', (arg, data) => Math.abs(executeEval(data, arg))) -Eval.floor = unary('floor', (arg, data) => Math.floor(executeEval(data, arg))) -Eval.ceil = unary('ceil', (arg, data) => Math.ceil(executeEval(data, arg))) -Eval.round = unary('round', (arg, data) => Math.round(executeEval(data, arg))) -Eval.exp = unary('exp', (arg, data) => Math.exp(executeEval(data, arg))) -Eval.log = multary('log', ([left, right], data) => Math.log(executeEval(data, left)) / Math.log(executeEval(data, right ?? Math.E))) -Eval.pow = Eval.power = multary('power', ([left, right], data) => Math.pow(executeEval(data, left), executeEval(data, right))) -Eval.random = () => Eval('random', {}) +Eval.abs = unary('abs', (arg, data) => Math.abs(executeEval(data, arg)), Type.Number) +Eval.floor = unary('floor', (arg, data) => Math.floor(executeEval(data, arg)), Type.Number) +Eval.ceil = unary('ceil', (arg, data) => Math.ceil(executeEval(data, arg)), Type.Number) +Eval.round = unary('round', (arg, data) => Math.round(executeEval(data, arg)), Type.Number) +Eval.exp = unary('exp', (arg, data) => Math.exp(executeEval(data, arg)), Type.Number) +Eval.log = multary('log', ([left, right], data) => Math.log(executeEval(data, left)) / Math.log(executeEval(data, right ?? Math.E)), Type.Number) +Eval.pow = Eval.power = multary('power', ([left, right], data) => Math.pow(executeEval(data, left), executeEval(data, right)), Type.Number) +Eval.random = () => Eval('random', {}, Type.Number) operators.$random = () => Math.random() // comparison @@ -188,63 +196,73 @@ Eval.lt = comparator('lt', (left, right) => left < right) Eval.le = Eval.lte = comparator('lte', (left, right) => left <= right) // element -Eval.in = multary('in', ([value, array], data) => executeEval(data, array).includes(executeEval(data, value))) -Eval.nin = multary('nin', ([value, array], data) => !executeEval(data, array).includes(executeEval(data, value))) +Eval.in = multary('in', ([value, array], data) => executeEval(data, array).includes(executeEval(data, value)), Type.Boolean) +Eval.nin = multary('nin', ([value, array], data) => !executeEval(data, array).includes(executeEval(data, value)), Type.Boolean) // string -Eval.concat = multary('concat', (args, data) => args.map(arg => executeEval(data, arg)).join('')) -Eval.regex = multary('regex', ([value, regex], data) => makeRegExp(executeEval(data, regex)).test(executeEval(data, value))) +Eval.concat = multary('concat', (args, data) => args.map(arg => executeEval(data, arg)).join(''), Type.String) +Eval.regex = multary('regex', ([value, regex], data) => makeRegExp(executeEval(data, regex)).test(executeEval(data, value)), Type.Boolean) // logical -Eval.and = multary('and', (args, data) => args.every(arg => executeEval(data, arg))) -Eval.or = multary('or', (args, data) => args.some(arg => executeEval(data, arg))) -Eval.not = unary('not', (value, data) => !executeEval(data, value)) +Eval.and = multary('and', (args, data) => args.every(arg => executeEval(data, arg)), Type.Boolean) +Eval.or = multary('or', (args, data) => args.some(arg => executeEval(data, arg)), Type.Boolean) +Eval.not = unary('not', (value, data) => !executeEval(data, value), Type.Boolean) // typecast +Eval.literal = multary('literal', ([value, type]) => { + if (type) throw new TypeError('literal cast is not supported') + else return value +}, (value, type) => type ? Type.fromField(type) : Type.fromTerm(value)) Eval.number = unary('number', (arg, data) => { const value = executeEval(data, arg) return value instanceof Date ? Math.floor(value.valueOf() / 1000) : Number(value) -}) +}, Type.Number) + +const unwrapAggr = (expr: any) => { + const type = Type.fromTerm(expr) + return Type.getInner(type) ?? type +} // aggregation Eval.sum = unary('sum', (expr, table) => Array.isArray(table) ? table.reduce((prev, curr) => prev + executeAggr(expr, curr), 0) - : Array.from(executeEval(table, expr)).reduce((prev, curr) => prev + curr, 0)) + : Array.from(executeEval(table, expr)).reduce((prev, curr) => prev + curr, 0), Type.Number) Eval.avg = unary('avg', (expr, table) => { if (Array.isArray(table)) return table.reduce((prev, curr) => prev + executeAggr(expr, curr), 0) / table.length else { const array = Array.from(executeEval(table, expr)) return array.reduce((prev, curr) => prev + curr, 0) / array.length } -}) +}, Type.Number) Eval.max = unary('max', (expr, table) => Array.isArray(table) ? table.map(data => executeAggr(expr, data)).reduce((x, y) => x > y ? x : y, -Infinity) - : Array.from(executeEval(table, expr)).reduce((x, y) => x > y ? x : y, -Infinity)) + : Array.from(executeEval(table, expr)).reduce((x, y) => x > y ? x : y, -Infinity), (expr) => unwrapAggr(expr)) Eval.min = unary('min', (expr, table) => Array.isArray(table) ? table.map(data => executeAggr(expr, data)).reduce((x, y) => x < y ? x : y, Infinity) - : Array.from(executeEval(table, expr)).reduce((x, y) => x < y ? x : y, Infinity)) -Eval.count = unary('count', (expr, table) => new Set(table.map(data => executeAggr(expr, data))).size) + : Array.from(executeEval(table, expr)).reduce((x, y) => x < y ? x : y, Infinity), (expr) => unwrapAggr(expr)) +Eval.count = unary('count', (expr, table) => new Set(table.map(data => executeAggr(expr, data))).size, Type.Number) defineProperty(Eval, 'length', unary('length', (expr, table) => Array.isArray(table) ? table.map(data => executeAggr(expr, data)).length - : Array.from(executeEval(table, expr)).length)) + : Array.from(executeEval(table, expr)).length, Type.Number)) operators.$object = (field, table) => valueMap(field, value => executeAggr(value, table)) Eval.object = (fields) => { if (fields.$model) { - const modelFields = Object.keys(fields.$model.fields) + const modelFields: [string, Field][] = Object.entries(fields.$model.fields) const prefix: string = fields.$prefix - return Eval('object', Object.fromEntries(modelFields - .filter(path => path.startsWith(prefix)) - .map(k => [k.slice(prefix.length), fields[k.slice(prefix.length)]]), - )) + fields = Object.fromEntries(modelFields + .filter(([, field]) => !field.deprecated) + .filter(([path]) => path.startsWith(prefix)) + .map(([k]) => [k.slice(prefix.length), fields[k.slice(prefix.length)]])) + return Eval('object', fields, Type.Object(valueMap(fields, (value) => Type.fromTerm(value)))) } - return Eval('object', fields) as any + return Eval('object', fields, Type.Object(valueMap(fields, (value) => Type.fromTerm(value)))) as any } Eval.array = unary('array', (expr, table) => Array.isArray(table) ? table.map(data => executeAggr(expr, data)) - : Array.from(executeEval(table, expr))) + : Array.from(executeEval(table, expr)), (expr) => Type.Array(Type.fromTerm(expr))) -Eval.exec = unary('exec', (expr, data) => (expr.driver as any).executeSelection(expr, data)) +Eval.exec = unary('exec', (expr, data) => (expr.driver as any).executeSelection(expr, data), (expr) => Type.fromTerm(expr.args[0])) export { Eval as $ } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 323482df..35d87689 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -7,6 +7,7 @@ export * from './eval.ts' export * from './model.ts' export * from './query.ts' export * from './selection.ts' +export * from './type.ts' export * from './utils.ts' declare module 'cordis' { diff --git a/packages/core/src/model.ts b/packages/core/src/model.ts index ace8612d..03bd6f2f 100644 --- a/packages/core/src/model.ts +++ b/packages/core/src/model.ts @@ -1,14 +1,16 @@ -import { clone, isNullable, makeArray, MaybeArray } from 'cosmokit' +import { isNullable, makeArray, MaybeArray, valueMap } from 'cosmokit' import { Database } from './database.ts' import { Eval, isEvalExpr } from './eval.ts' -import { Selection } from './selection.ts' -import { Flatten, Keys } from './utils.ts' +import { clone, Flatten, isUint8Array, Keys } from './utils.ts' +import { Type } from './type.ts' +import { Driver } from './driver.ts' export const Primary = Symbol('Primary') export type Primary = (string | number) & { [Primary]: true } export interface Field { - type: Field.Type + type: Type + deftype?: Field.Type length?: number nullable?: boolean initial?: T @@ -17,6 +19,7 @@ export interface Field { expr?: Eval.Expr legacy?: string[] deprecated?: boolean + transformers?: Driver.Transformer[] } export namespace Field { @@ -32,17 +35,51 @@ export namespace Field { : T extends string ? 'char' | 'string' | 'text' : T extends boolean ? 'boolean' : T extends Date ? 'timestamp' | 'date' | 'time' - : T extends unknown[] ? 'list' | 'json' - : T extends object ? 'json' + : T extends Uint8Array ? 'binary' + : T extends unknown[] ? 'list' | 'json' | 'array' + : T extends object ? 'json' | 'object' : 'expr' type Shorthand = S | `${S}(${any})` - type MapField = { - [K in keyof O]?: Field | Shorthand> | Selection.Callback + export type Object = { + type: 'object' + inner: Extension + } & Omit, 'type'> + + export type Array = { + type: 'array' + inner?: Definition + } & Omit, 'type'> + + export type Transform = { + type: Type + dump: (value: S) => T | null + load: (value: T) => S | null + initial?: S + } & Omit, 'type' | 'initial'> + + type Parsed = { + type: Type | Field['type'] + } & Omit, 'type'> + + export type Definition = + | (Omit, 'type'> & { type: Type }) + | Object + | (T extends (infer I)[] ? Array : never) + | Shorthand> + | Transform + | Keys + | NewType + + type MapField = { + [K in keyof O]?: Definition } - export type Extension = MapField> + export type Extension = MapField, N> + + const NewType = Symbol('newtype') + export type NewType = { [NewType]?: T } export type Config = { [K in keyof O]?: Field @@ -50,24 +87,26 @@ export namespace Field { const regexp = /^(\w+)(?:\((.+)\))?$/ - export function parse(source: string | Field): Field { - if (typeof source === 'function') return { type: 'expr', expr: source } - if (typeof source !== 'string') return { initial: null, ...source } + export function parse(source: string | Parsed): Field { + if (typeof source === 'function') throw new TypeError('view field is not supported') + if (typeof source !== 'string') { + return { + initial: null, + deftype: source.type as any, + ...source, + type: Type.isType(source.type) ? source.type : Type.fromField(source.type), + } + } // parse string definition const capture = regexp.exec(source) if (!capture) throw new TypeError('invalid field definition') const type = capture[1] as Type const args = (capture[2] || '').split(',') - const field: Field = { type } + const field: Field = { deftype: type, type: Type.fromField(type) } // set default initial value - if (field.initial === undefined) { - if (number.includes(field.type)) field.initial = 0 - if (string.includes(field.type)) field.initial = '' - if (field.type === 'list') field.initial = [] - if (field.type === 'json') field.initial = {} - } + if (field.initial === undefined) field.initial = getInitial(type) // set length information if (type === 'decimal') { @@ -79,6 +118,16 @@ export namespace Field { return field } + + export function getInitial(type: Field.Type, initial?: any) { + if (initial === undefined) { + if (Field.number.includes(type)) return 0 + if (Field.string.includes(type)) return '' + if (type === 'list') return [] + if (type === 'json') return {} + } + return initial + } } export namespace Model { @@ -102,6 +151,8 @@ export class Model { fields: Field.Config = {} migrations = new Map() + private type: Type | undefined + constructor(public name: string) { this.autoInc = false this.primary = 'id' as never @@ -125,7 +176,7 @@ export class Model { this.fields[key].deprecated = !!callback } - if (typeof this.primary === 'string' && this.fields[this.primary]?.type === 'primary') { + if (typeof this.primary === 'string' && this.fields[this.primary]?.deftype === 'primary') { this.autoInc = true } @@ -142,13 +193,15 @@ export class Model { } } - resolveValue(key: string, value: any) { + resolveValue(field: string | Field | Type, value: any) { if (isNullable(value)) return value - if (this.fields[key]?.type === 'time') { + if (typeof field === 'string') field = this.fields[field] as Field + if (field && !Type.isType(field)) field = Type.fromField(field) + if (field?.type === 'time') { const date = new Date(0) date.setHours(value.getHours(), value.getMinutes(), value.getSeconds(), value.getMilliseconds()) return date - } else if (this.fields[key]?.type === 'date') { + } else if (field?.type === 'date') { const date = new Date(value) date.setHours(0, 0, 0, 0) return date @@ -156,13 +209,38 @@ export class Model { return value } + resolveModel(obj: any, model?: Type) { + if (!model) model = this.getType() + if (isNullable(obj) || !model.inner) return obj + if (Type.isArray(model) && Array.isArray(obj)) { + return obj.map(x => this.resolveModel(x, Type.getInner(model)!)) + } + + const result = {} + for (const key in obj) { + const type = Type.getInner(model, key) + if (!type || isNullable(obj[key])) { + result[key] = obj[key] + } else if (type.type !== 'json') { + result[key] = this.resolveValue(type, obj[key]) + } else if (type.inner && Type.isArray(type) && Array.isArray(obj[key])) { + result[key] = obj[key].map(x => this.resolveModel(x, Type.getInner(type))) + } else if (type.inner) { + result[key] = this.resolveModel(obj[key], type) + } else { + result[key] = obj[key] + } + } + return result + } + format(source: object, strict = true, prefix = '', result = {} as S) { const fields = Object.keys(this.fields) Object.entries(source).map(([key, value]) => { key = prefix + key if (value === undefined) return if (fields.includes(key)) { - result[key] = this.resolveValue(key, value) + result[key] = value return } const field = fields.find(field => key.startsWith(field + '.')) @@ -176,7 +254,7 @@ export class Model { this.format(value, strict, key + '.', result) } }) - return result + return prefix === '' ? this.resolveModel(result) : result } parse(source: object, strict = true, prefix = '', result = {} as S) { @@ -192,19 +270,19 @@ export class Model { const fullKey = prefix + key, value = source[key] const field = fields.find(field => fullKey === field || fullKey.startsWith(field + '.')) if (field) { - node[segments[0]] = this.resolveValue(key, value) - } else if (!value || typeof value !== 'object' || isEvalExpr(value) || Array.isArray(value) || Object.keys(value).length === 0) { + node[segments[0]] = value + } else if (!value || typeof value !== 'object' || isEvalExpr(value) || Array.isArray(value) || isUint8Array(value) || Object.keys(value).length === 0) { if (strict) { throw new TypeError(`unknown field "${fullKey}" in model ${this.name}`) } else { - node[segments[0]] = this.resolveValue(key, value) + node[segments[0]] = value } } else { this.parse(value, strict, fullKey + '.', node[segments[0]] ??= {}) } } } - return result + return prefix === '' ? this.resolveModel(result) : result } create(data?: {}) { @@ -219,4 +297,11 @@ export class Model { } return this.parse({ ...result, ...data }) } + + getType(): Type + getType(key: string): Type | undefined + getType(key?: string): Type | undefined { + this.type ??= Type.Object(valueMap(this.fields!, field => Type.fromField(field!))) as any + return key ? Type.getInner(this.type, key) : this.type + } } diff --git a/packages/core/src/selection.ts b/packages/core/src/selection.ts index 3adfd5fa..382b3941 100644 --- a/packages/core/src/selection.ts +++ b/packages/core/src/selection.ts @@ -4,6 +4,7 @@ import { Eval, executeEval } from './eval.ts' import { Model } from './model.ts' import { Query } from './query.ts' import { Keys, randomId, Row } from './utils.ts' +import { Type } from './type.ts' declare module './eval.ts' { export namespace Eval { @@ -42,7 +43,23 @@ const createRow = (ref: string, expr = {}, prefix = '', model?: Model) => new Pr if (key === '$prefix') return prefix if (key === '$model') return model if (typeof key === 'symbol' || key in target || key.startsWith('$')) return Reflect.get(target, key) - return createRow(ref, Eval('', [ref, `${prefix}${key}`]), `${prefix}${key}.`, model) + + let type: Type + const field = model?.fields[prefix + key as string] + if (Type.getInner(expr?.[Type.kType], key)) { + // type may conatins object layout + type = Type.getInner(expr?.[Type.kType], key)! + } else if (field) { + type = Type.fromField(field) + } else if (Object.keys(model?.fields!).some(k => k.startsWith(`${prefix}${key}.`))) { + type = Type.Object(Object.fromEntries(Object.entries(model?.fields!) + .filter(([k]) => k.startsWith(`${prefix}${key}`)) + .map(([k, field]) => [k.slice(prefix.length + key.length + 1), Type.fromField(field!)]))) + } else { + // unknown field inside json + type = Type.fromField('expr') + } + return createRow(ref, Eval('', [ref, `${prefix}${key}`], type), `${prefix}${key}.`, model) }, }) @@ -218,7 +235,9 @@ export class Selection extends Executable { evaluate(callback?: any): any { const selection = new Selection(this.driver, this) if (!callback) callback = (row) => Eval.array(Eval.object(row)) - return Eval('exec', selection._action('eval', this.resolveField(callback))) + const expr = this.resolveField(callback) + if (expr['$']) defineProperty(expr, Type.kType, Type.Array(Type.fromTerm(expr))) + return Eval.exec(selection._action('eval', expr)) } execute = Keys>(cursor?: Driver.Cursor): Promise[]> diff --git a/packages/core/src/type.ts b/packages/core/src/type.ts new file mode 100644 index 00000000..fabd9f1d --- /dev/null +++ b/packages/core/src/type.ts @@ -0,0 +1,83 @@ +import { defineProperty, isNullable, mapValues } from 'cosmokit' +import { Field } from './model.ts' +import { Eval, isEvalExpr } from './eval.ts' + +export interface Type { + [Type.kType]?: true + type: Field.Type + inner?: T extends (infer I)[] ? Type : Field.Type extends 'json' ? { [key in keyof T]: Type } : never + array?: boolean +} + +export namespace Type { + export const kType = Symbol.for('minato.type') + + export const Boolean: Type = defineProperty({ type: 'boolean' }, kType, true) as any + export const Number: Type = defineProperty({ type: 'double' }, kType, true) + export const String: Type = defineProperty({ type: 'string' }, kType, true) + + type Extract = + | T extends Type ? I + : T extends Field ? I + : T extends Field.Type ? I + : T extends Eval.Term ? I + : never + + export type Object = Type + export const Object = (obj?: T): Object<{ [K in keyof T]: Extract }> => defineProperty({ + type: 'json' as any, + inner: globalThis.Object.keys(obj ?? {}).length ? mapValues(obj!, (value) => isType(value) ? value : fromField(value)) as any : undefined, + }, kType, true) + + export type Array = Type + export const Array = (type?: Type): Type.Array => defineProperty({ + type: 'json', + inner: type, + array: true, + }, kType, true) + + export function fromPrimitive(value: T): Type { + if (isNullable(value)) return fromField('expr' as any) + else if (typeof value === 'number') return Number as any + else if (typeof value === 'string') return String as any + else if (typeof value === 'boolean') return Boolean as any + else if (value instanceof Date) return fromField('timestamp' as any) + else if (ArrayBuffer.isView(value)) return fromField('binary' as any) + else if (globalThis.Array.isArray(value)) return Array(value.length ? fromPrimitive(value[0]) : undefined) as any + else if (typeof value === 'object') return fromField('json' as any) + throw new TypeError(`invalid primitive: ${value}`) + } + + export function fromField(field: Field | Field.Type): Type { + if (isType(field)) throw new TypeError(`invalid field: ${JSON.stringify(field)}`) + if (typeof field === 'string') return defineProperty({ type: field }, kType, true) + else if (field.type) return field.type + else if (field.expr?.[kType]) return field.expr[kType] + throw new TypeError(`invalid field: ${field}`) + } + + export function fromTerm(value: Eval.Term): Type { + if (isEvalExpr(value)) return value[kType] ?? fromField('expr' as any) + else return fromPrimitive(value) + } + + export function isType(value: any): value is Type { + return value?.[kType] === true + } + + export function isArray(type: Type) { + return (type.type === 'json') && type.array + } + + export function getInner(type?: Type, key?: string): Type | undefined { + if (!type?.inner) return + if (isArray(type) && isNullable(key)) return type.inner + if (isNullable(key)) return + if (type.inner[key]) return type.inner[key] + if (key.includes('.')) return key.split('.').reduce((t, k) => getInner(t, k), type) + return Object(globalThis.Object.fromEntries(globalThis.Object.entries(type.inner) + .filter(([k]) => k.startsWith(`${key}.`)) + .map(([k, v]) => [k.slice(key.length + 1), v]), + )) + } +} diff --git a/packages/core/src/utils.ts b/packages/core/src/utils.ts index f90eca1c..8de53ff7 100644 --- a/packages/core/src/utils.ts +++ b/packages/core/src/utils.ts @@ -1,4 +1,4 @@ -import { Intersect } from 'cosmokit' +import { Intersect, is, mapValues } from 'cosmokit' import { Eval } from './eval.ts' export type Values = S[keyof S] @@ -51,3 +51,73 @@ export function randomId() { export function makeRegExp(source: string | RegExp) { return source instanceof RegExp ? source : new RegExp(source) } + +export function unravel(source: object, init?: (value) => any) { + const result = {} + for (const key in source) { + let node = result + const segments = key.split('.').reverse() + for (let index = segments.length - 1; index > 0; index--) { + const segment = segments[index] + node = node[segment] ??= {} + if (init) node = init(node) + } + node[segments[0]] = source[key] + } + return result +} + +export function clone(source: T): T +export function clone(source: any) { + if (!source || typeof source !== 'object') return source + if (isUint8Array(source)) return (hasGlobalBuffer && Buffer.isBuffer(source)) ? Buffer.copyBytesFrom(source) : source.slice() + if (Array.isArray(source)) return source.map(clone) + if (is('Date', source)) return new Date(source.valueOf()) + if (is('RegExp', source)) return new RegExp(source.source, source.flags) + return mapValues(source, clone) +} + +const hasGlobalBuffer = typeof Buffer === 'function' && Buffer.prototype?._isBuffer !== true + +export function isUint8Array(value: any): value is Uint8Array { + const stringTag = value?.[Symbol.toStringTag] ?? Object.prototype.toString.call(value) + return (hasGlobalBuffer && Buffer.isBuffer(value)) + || ArrayBuffer.isView(value) + || ['ArrayBuffer', 'SharedArrayBuffer', '[object ArrayBuffer]', '[object SharedArrayBuffer]'].includes(stringTag) +} + +export function Uint8ArrayFromHex(source: string) { + if (hasGlobalBuffer) return Buffer.from(source, 'hex') + const hex = source.length % 2 === 0 ? source : source.slice(0, source.length - 1) + const buffer: number[] = [] + for (let i = 0; i < hex.length; i += 2) { + buffer.push(Number.parseInt(`${hex[i]}${hex[i + 1]}`, 16)) + } + return Uint8Array.from(buffer) +} + +export function Uint8ArrayToHex(source: Uint8Array) { + return (hasGlobalBuffer) ? toLocalUint8Array(source).toString('hex') + : Array.from(toLocalUint8Array(source), byte => byte.toString(16).padStart(2, '0')).join('') +} + +export function Uint8ArrayFromBase64(source: string) { + return (hasGlobalBuffer) ? Buffer.from(source, 'base64') : Uint8Array.from(atob(source), c => c.charCodeAt(0)) +} + +export function Uint8ArrayToBase64(source: Uint8Array) { + return (hasGlobalBuffer) ? (source as Buffer).toString('base64') : btoa(Array.from(Uint16Array.from(source), b => String.fromCharCode(b)).join('')) +} + +export function toLocalUint8Array(source: Uint8Array) { + if (hasGlobalBuffer) { + return Buffer.isBuffer(source) ? Buffer.from(source) + : ArrayBuffer.isView(source) ? Buffer.from(source.buffer, source.byteOffset, source.byteLength) + : Buffer.from(source) + } else { + const stringTag = source?.[Symbol.toStringTag] ?? Object.prototype.toString.call(source) + return stringTag === 'Uint8Array' ? source + : ArrayBuffer.isView(source) ? new Uint8Array(source.buffer.slice(source.byteOffset, source.byteOffset + source.byteLength)) + : new Uint8Array(source) + } +} diff --git a/packages/memory/src/index.ts b/packages/memory/src/index.ts index ab165a32..c1cc28d6 100644 --- a/packages/memory/src/index.ts +++ b/packages/memory/src/index.ts @@ -1,5 +1,5 @@ -import { clone, Dict, makeArray, noop, omit, pick, valueMap } from 'cosmokit' -import { Driver, Eval, executeEval, executeQuery, executeSort, executeUpdate, RuntimeError, Selection, z } from 'minato' +import { Dict, makeArray, noop, omit, pick, valueMap } from 'cosmokit' +import { clone, Driver, Eval, executeEval, executeQuery, executeSort, executeUpdate, RuntimeError, Selection, z } from 'minato' export class MemoryDriver extends Driver { static name = 'memory' diff --git a/packages/memory/tests/index.spec.ts b/packages/memory/tests/index.spec.ts index f6049ec1..8934a0ff 100644 --- a/packages/memory/tests/index.spec.ts +++ b/packages/memory/tests/index.spec.ts @@ -16,5 +16,19 @@ describe('@minatojs/driver-memory', () => { test(database, { migration: false, + model: { + fields: { + cast: false, + typeModel: false, + }, + object: { + typeModel: false, + } + }, + query: { + comparison: { + nullableComparator: false, + } + } }) }) diff --git a/packages/mongo/src/utils.ts b/packages/mongo/src/builder.ts similarity index 80% rename from packages/mongo/src/utils.ts rename to packages/mongo/src/builder.ts index 9b9c6b81..9701a51f 100644 --- a/packages/mongo/src/utils.ts +++ b/packages/mongo/src/builder.ts @@ -1,5 +1,5 @@ -import { Dict, isNullable, valueMap } from 'cosmokit' -import { Eval, isComparable, Query, Selection } from 'minato' +import { Dict, isNullable, mapValues, valueMap } from 'cosmokit' +import { Driver, Eval, isComparable, isEvalExpr, Model, Query, Selection, Type, unravel } from 'minato' import { Filter, FilterOperators, ObjectId } from 'mongodb' import MongoDriver from '.' @@ -78,7 +78,7 @@ export type EvalOperators = { const aggrKeys = ['$sum', '$avg', '$min', '$max', '$count', '$length', '$array'] -export class Transformer { +export class Builder { private counter = 0 public table!: string public walkedKeys: string[] = [] @@ -91,7 +91,7 @@ export class Transformer { private evalOperators: EvalOperators - constructor(private tables: string[], public virtualKey?: string, public recursivePrefix: string = '$') { + constructor(private driver: Driver, private tables: string[], public virtualKey?: string, public recursivePrefix: string = '$') { this.walkedKeys = [] this.evalOperators = { @@ -109,8 +109,8 @@ export class Transformer { } }, $if: (arg, group) => ({ $cond: arg.map(val => this.eval(val, group)) }), - $array: (arg, group) => this.transformEvalExpr(arg), - $object: (arg, group) => this.transformEvalExpr(arg), + + $object: (arg, group) => valueMap(arg as any, x => this.transformEvalExpr(x)), $regex: (arg, group) => ({ $regexMatch: { input: this.eval(arg[0], group), regex: this.eval(arg[1], group) } }), @@ -124,6 +124,10 @@ export class Transformer { $power: (arg, group) => ({ $pow: arg.map(val => this.eval(val, group)) }), $random: (arg, group) => ({ $rand: {} }), + $literal: (arg, group) => { + const converter = this.driver.types[arg[1] as any] + return converter ? converter.dump(arg[0]) : arg[0] + }, $number: (arg, group) => { const value = this.eval(arg, group) return { @@ -180,6 +184,7 @@ export class Transformer { return `$${name}` }, } + this.evalOperators = Object.assign(Object.create(null), this.evalOperators) } public createKey() { @@ -200,6 +205,14 @@ export class Transformer { for (const key in expr) { if (this.evalOperators[key]) { return this.evalOperators[key](expr[key], group) + } else if (key?.startsWith('$') && Eval[key.slice(1)]) { + return valueMap(expr, (value) => { + if (Array.isArray(value)) { + return value.map(val => this.eval(val, group)) + } else { + return this.eval(value, group) + } + }) } } @@ -207,13 +220,7 @@ export class Transformer { return expr.map(val => this.eval(val, group)) } - return valueMap(expr as any, (value) => { - if (Array.isArray(value)) { - return value.map(val => this.eval(val, group)) - } else { - return this.eval(value, group) - } - }) + return expr } private transformAggr(expr: any) { @@ -226,7 +233,8 @@ export class Transformer { return this.recursivePrefix + expr } - return this.transformEvalExpr(expr) + expr = this.transformEvalExpr(expr) + return typeof expr === 'object' ? unravel(expr) : expr } public flushLookups() { @@ -364,7 +372,7 @@ export class Transformer { } protected createSubquery(sel: Selection.Immutable) { - const predecessor = new Transformer(Object.keys(sel.tables)) + const predecessor = new Builder(this.driver, Object.keys(sel.tables)) predecessor.refTables = [...this.refTables, ...this.tables] predecessor.refVirtualKeys = this.refVirtualKeys return predecessor.select(sel) @@ -430,4 +438,65 @@ export class Transformer { } return this } + + dump(value: any, type: Model | Type | Eval.Expr | undefined): any { + if (!type) return value + if (isEvalExpr(type)) type = Type.fromTerm(type) + if (!Type.isType(type)) type = type.getType() + + type = Type.isType(type) ? type : Type.fromTerm(type) + const converter = this.driver.types[type?.type] + let res = value + + if (!isNullable(res) && type.inner) { + if (Type.isArray(type)) { + res = res.map(x => this.dump(x, Type.getInner(type)!)) + } else { + res = mapValues(res, (x, k) => this.dump(x, Type.getInner(type, k))) + } + } + + res = converter ? converter.dump(res) : res + return res + } + + load(value: any, type: Model | Type | Eval.Expr | undefined): any { + if (!type) return value + + if (Type.isType(type) || isEvalExpr(type)) { + type = Type.isType(type) ? type : Type.fromTerm(type) + const converter = this.driver.types[type?.inner ? 'json' : type?.type!] + let res = converter ? converter.load(value) : value + + if (!isNullable(res) && type.inner) { + if (Type.isArray(type)) { + res = res.map(x => this.load(x, Type.getInner(type as Type))) + } else { + res = mapValues(res, (x, k) => this.load(x, Type.getInner(type as Type, k))) + } + } + return res + } + + value = type.format(value) + const result = {} + for (const key in value) { + if (!(key in type.fields)) continue + result[key] = this.load(value[key], type.fields[key]!.type) + } + return type.parse(result) + } + + formatUpdateAggr(model: Type, obj: any) { + const result = {} + for (const key in obj) { + const type = Type.getInner(model, key) + if (!type || type.type !== 'json' || isNullable(obj[key]) || obj[key].$literal) result[key] = obj[key] + else if (Type.isArray(type) && Array.isArray(obj[key])) result[key] = obj[key] + else if (Object.keys(obj[key]).length === 0) result[key] = { $literal: obj[key] } + else if (type.inner) result[key] = this.formatUpdateAggr(type, obj[key]) + else result[key] = obj[key] + } + return result + } } diff --git a/packages/mongo/src/index.ts b/packages/mongo/src/index.ts index 8c711642..10aef2d1 100644 --- a/packages/mongo/src/index.ts +++ b/packages/mongo/src/index.ts @@ -2,7 +2,7 @@ import { BSONType, ClientSession, Collection, Db, IndexDescription, MongoClient, import { Dict, isNullable, makeArray, mapValues, noop, omit, pick } from 'cosmokit' import { Driver, Eval, executeUpdate, Query, RuntimeError, Selection, z } from 'minato' import { URLSearchParams } from 'url' -import { Transformer } from './utils' +import { Builder } from './builder' const tempKey = '__temp_minato_mongo__' @@ -13,6 +13,7 @@ export class MongoDriver extends Driver { public db!: Db public mongo = this + private builder: Builder = new Builder(this, []) private session?: ClientSession private _createTasks: Dict> = {} @@ -44,6 +45,12 @@ export class MongoDriver extends Driver { 'writeConcern', ])) this.db = this.client.db(this.config.database) + + this.define({ + types: ['binary'], + dump: value => value, + load: (value: any) => value ? value.buffer : value, + }) } stop() { @@ -124,7 +131,7 @@ export class MongoDriver extends Driver { if (!found) { const doc = await this.db.collection(table).findOne() if (doc) { - virtual = typeof doc._id !== 'object' || (typeof primary === 'string' && modelFields[primary]?.type === 'primary') + virtual = typeof doc._id !== 'object' || (typeof primary === 'string' && modelFields[primary]?.deftype === 'primary') } else { // Empty collection, just set meta and return fields.updateOne(meta, { $set: { virtual: useVirtualKey } }, { upsert: true }) @@ -254,14 +261,14 @@ export class MongoDriver extends Driver { public getVirtualKey(table: string) { const { primary, fields } = this.model(table) - if (typeof primary === 'string' && (this.config.optimizeIndex || fields[primary]?.type === 'primary')) { + if (typeof primary === 'string' && (this.config.optimizeIndex || fields[primary]?.deftype === 'primary')) { return primary } } private patchVirtual(table: string, row: any) { const { primary, fields } = this.model(table) - if (typeof primary === 'string' && (this.config.optimizeIndex || fields[primary]?.type === 'primary')) { + if (typeof primary === 'string' && (this.config.optimizeIndex || fields[primary]?.deftype === 'primary')) { row[primary] = row['_id'] delete row['_id'] } @@ -270,7 +277,7 @@ export class MongoDriver extends Driver { private unpatchVirtual(table: string, row: any) { const { primary, fields } = this.model(table) - if (typeof primary === 'string' && (this.config.optimizeIndex || fields[primary]?.type === 'primary')) { + if (typeof primary === 'string' && (this.config.optimizeIndex || fields[primary]?.deftype === 'primary')) { row['_id'] = row[primary] delete row[primary] } @@ -278,38 +285,39 @@ export class MongoDriver extends Driver { } private transformQuery(sel: Selection.Immutable, query: Query.Expr, table: string) { - return new Transformer(Object.keys(sel.tables), this.getVirtualKey(table)).query(query) + return new Builder(this, Object.keys(sel.tables), this.getVirtualKey(table)).query(query) } async get(sel: Selection.Immutable) { - const transformer = new Transformer(Object.keys(sel.tables)).select(sel) + const transformer = new Builder(this, Object.keys(sel.tables)).select(sel) if (!transformer) return [] this.logger.debug('%s %s', transformer.table, JSON.stringify(transformer.pipeline)) return this.db .collection(transformer.table) .aggregate(transformer.pipeline, { allowDiskUse: true, session: this.session }) - .toArray() + .toArray().then(rows => rows.map(row => this.builder.load(row, sel.model))) } async eval(sel: Selection.Immutable, expr: Eval.Expr) { - const transformer = new Transformer(Object.keys(sel.tables)).select(sel) + const transformer = new Builder(this, Object.keys(sel.tables)).select(sel) if (!transformer) return this.logger.debug('%s %s', transformer.table, JSON.stringify(transformer.pipeline)) const res = await this.db .collection(transformer.table) .aggregate(transformer.pipeline, { allowDiskUse: true, session: this.session }) .toArray() - return res.length ? res[0][transformer.evalKey!] : transformer.aggrDefault + return this.builder.load(res.length ? res[0][transformer.evalKey!] : transformer.aggrDefault, expr) } async set(sel: Selection.Mutable, update: {}) { - const { query, table } = sel + const { query, table, model } = sel const filter = this.transformQuery(sel, query, table) if (!filter) return {} const coll = this.db.collection(table) - const transformer = new Transformer(Object.keys(sel.tables), this.getVirtualKey(table), '$' + tempKey + '.') - const $set = transformer.eval(mapValues(update, (value: any) => typeof value === 'string' && value.startsWith('$') ? { $literal: value } : value)) + const transformer = new Builder(this, Object.keys(sel.tables), this.getVirtualKey(table), '$' + tempKey + '.') + const $set = this.builder.formatUpdateAggr(model.getType(), mapValues(this.builder.dump(update, model), + (value: any) => typeof value === 'string' && value.startsWith('$') ? { $literal: value } : transformer.eval(value))) const $unset = Object.entries($set) .filter(([_, value]) => typeof value === 'object') .map(([key, _]) => key) @@ -335,19 +343,19 @@ export class MongoDriver extends Driver { private shouldEnsurePrimary(table: string) { const model = this.model(table) const { primary, autoInc } = model - return typeof primary === 'string' && autoInc && model.fields[primary]?.type !== 'primary' + return typeof primary === 'string' && autoInc && model.fields[primary]?.deftype !== 'primary' } private shouldFillPrimary(table: string) { const model = this.model(table) const { primary, autoInc } = model - return typeof primary === 'string' && autoInc && model.fields[primary]?.type === 'primary' + return typeof primary === 'string' && autoInc && model.fields[primary]?.deftype === 'primary' } private async ensurePrimary(table: string, data: any[]) { const model = this.model(table) const { primary, autoInc } = model - if (typeof primary === 'string' && autoInc && model.fields[primary]?.type !== 'primary') { + if (typeof primary === 'string' && autoInc && model.fields[primary]?.deftype !== 'primary') { const missing = data.filter(item => !(primary in item)) if (!missing.length) return const doc = await this.db.collection('_fields').findOneAndUpdate( @@ -362,16 +370,14 @@ export class MongoDriver extends Driver { } async create(sel: Selection.Mutable, data: any) { - const { table } = sel + const { table, model } = sel const lastTask = Promise.resolve(this._createTasks[table]).catch(noop) return this._createTasks[table] = lastTask.then(async () => { - const model = this.model(table) const coll = this.db.collection(table) await this.ensurePrimary(table, [data]) try { - data = model.create(data) - const copy = this.unpatchVirtual(table, { ...data }) + const copy = this.unpatchVirtual(table, { ...this.builder.dump(data, model) }) const insertedId = (await coll.insertOne(copy, { session: this.session })).insertedId if (this.shouldFillPrimary(table)) { return { ...data, [model.primary as string]: insertedId } @@ -404,7 +410,7 @@ export class MongoDriver extends Driver { const item = original.find(item => keys.every(key => item[key]?.valueOf() === update[key]?.valueOf())) if (item) { const updateFields = new Set(Object.keys(update).map(key => key.split('.', 1)[0])) - const override = omit(pick(executeUpdate(item, update, ref), updateFields), keys) + const override = this.builder.dump(omit(pick(executeUpdate(item, update, ref), updateFields), keys), model) const query = this.transformQuery(sel, pick(item, keys), table) if (!query) continue bulk.find(query).updateOne({ $set: override }) @@ -414,7 +420,7 @@ export class MongoDriver extends Driver { } await this.ensurePrimary(table, insertion) for (const update of insertion) { - const copy = executeUpdate(model.create(), update, ref) + const copy = this.builder.dump(executeUpdate(model.create(), update, ref), model) bulk.insert(this.unpatchVirtual(table, copy)) } const result = await bulk.execute({ session: this.session }) @@ -426,8 +432,9 @@ export class MongoDriver extends Driver { for (const update of data) { const query = this.transformQuery(sel, pick(update, keys), table)! - const transformer = new Transformer(Object.keys(sel.tables), this.getVirtualKey(table), '$' + tempKey + '.') - const $set = transformer.eval(mapValues(update, (value: any) => typeof value === 'string' && value.startsWith('$') ? { $literal: value } : value)) + const transformer = new Builder(this, Object.keys(sel.tables), this.getVirtualKey(table), '$' + tempKey + '.') + const $set = this.builder.formatUpdateAggr(model.getType(), mapValues(this.builder.dump(update, model), + (value: any) => typeof value === 'string' && value.startsWith('$') ? { $literal: value } : transformer.eval(value))) const $unset = Object.entries($set) .filter(([_, value]) => typeof value === 'object') .map(([key, _]) => key) diff --git a/packages/mongo/tests/index.spec.ts b/packages/mongo/tests/index.spec.ts index 354dbd9d..618fb371 100644 --- a/packages/mongo/tests/index.spec.ts +++ b/packages/mongo/tests/index.spec.ts @@ -25,6 +25,11 @@ describe('@minatojs/driver-mongo', () => { }) test(database, process.argv.includes('--enable-transaction-abort') ? {} : { + model: { + object: { + aggregateNull: false, + } + }, transaction: { abort: false } diff --git a/packages/mongo/tests/migration.spec.ts b/packages/mongo/tests/migration.spec.ts index 1af62ef6..285176c0 100644 --- a/packages/mongo/tests/migration.spec.ts +++ b/packages/mongo/tests/migration.spec.ts @@ -1,8 +1,8 @@ import { $, Database, Primary } from 'minato' import { Context, ForkScope, Logger } from 'cordis' import { expect } from 'chai' -import { } from 'chai-shape' import MongoDriver from '@minatojs/driver-mongo' +import '@minatojs/tests' const logger = new Logger('mongo') @@ -97,21 +97,21 @@ describe('@minatojs/driver-mongo/migrate-virtualKey', () => { })) table.push(await database.create('temp1', { text: 'awesome bar' })) table.push(await database.create('temp1', { text: 'awesome baz' })) - await expect(database.get('temp1', {})).to.eventually.have.shape(table) + await expect(database.get('temp1', {})).to.eventually.deep.eq(table) await resetConfig(true) - await expect(database.get('temp1', {})).to.eventually.have.shape(table) + await expect(database.get('temp1', {})).to.eventually.deep.eq(table) await resetConfig(false) - await expect(database.get('temp1', {})).to.eventually.have.shape(table) + await expect(database.get('temp1', {})).to.eventually.deep.eq(table) await (Object.values(database.drivers)[0] as MongoDriver).drop('_fields') await resetConfig(true) - await expect(database.get('temp1', {})).to.eventually.have.shape(table) + await expect(database.get('temp1', {})).to.eventually.deep.eq(table) await (Object.values(database.drivers)[0] as MongoDriver).drop('_fields') await resetConfig(false) - await expect(database.get('temp1', {})).to.eventually.have.shape(table) + await expect(database.get('temp1', {})).to.eventually.deep.eq(table) }) it('using primary', async () => { @@ -137,20 +137,20 @@ describe('@minatojs/driver-mongo/migrate-virtualKey', () => { })) table.push(await database.create('temp2', { text: 'awesome bar' })) table.push(await database.create('temp2', { text: 'awesome baz' })) - await expect(database.get('temp2', {})).to.eventually.have.shape(table) + await expect(database.get('temp2', {})).to.eventually.deep.eq(table) await (Object.values(database.drivers)[0] as MongoDriver).drop('_fields') await resetConfig(true) - await expect(database.get('temp2', {})).to.eventually.have.shape(table) + await expect(database.get('temp2', {})).to.eventually.deep.eq(table) await (Object.values(database.drivers)[0] as MongoDriver).drop('_fields') await resetConfig(false) - await expect(database.get('temp2', {})).to.eventually.have.shape(table) + await expect(database.get('temp2', {})).to.eventually.deep.eq(table) // query & eval table.push(await database.create('temp2', { foreign: table[0].id })) - await expect(database.get('temp2', {})).to.eventually.have.shape(table) - await expect(database.get('temp2', { foreign: table[0].id })).to.eventually.have.shape([table[3]]) - await expect(database.get('temp2', row => $.eq(row.foreign, table[0].id!))).to.eventually.have.shape([table[3]]) + await expect(database.get('temp2', {})).to.eventually.deep.eq(table) + await expect(database.get('temp2', { foreign: table[0].id })).to.eventually.deep.eq([table[3]]) + await expect(database.get('temp2', row => $.eq(row.foreign, table[0].id!))).to.eventually.deep.eq([table[3]]) }) }) diff --git a/packages/mysql/src/builder.ts b/packages/mysql/src/builder.ts index 3ff13103..a54548d3 100644 --- a/packages/mysql/src/builder.ts +++ b/packages/mysql/src/builder.ts @@ -1,8 +1,6 @@ import { Builder, escapeId, isBracketed } from '@minatojs/sql-utils' -import { Dict, Time } from 'cosmokit' -import { Field, isEvalExpr, Model, randomId, Selection } from 'minato' - -export const DEFAULT_DATE = new Date('1970-01-01') +import { Dict, isNullable, Time } from 'cosmokit' +import { Driver, Field, isEvalExpr, isUint8Array, Model, randomId, Selection, Type, Uint8ArrayFromBase64, Uint8ArrayToBase64, Uint8ArrayToHex } from 'minato' export interface Compat { maria?: boolean @@ -27,57 +25,83 @@ export class MySQLBuilder extends Builder { prequeries: string[] = [] - constructor(tables?: Dict, private compat: Compat = {}) { - super(tables) + constructor(protected driver: Driver, tables?: Dict, private compat: Compat = {}) { + super(driver, tables) this.evalOperators.$sum = (expr) => this.createAggr(expr, value => `ifnull(sum(${value}), 0)`, undefined, value => `ifnull(minato_cfunc_sum(${value}), 0)`) this.evalOperators.$avg = (expr) => this.createAggr(expr, value => `avg(${value})`, undefined, value => `minato_cfunc_avg(${value})`) this.evalOperators.$min = (expr) => this.createAggr(expr, value => `min(${value})`, undefined, value => `minato_cfunc_min(${value})`) this.evalOperators.$max = (expr) => this.createAggr(expr, value => `max(${value})`, undefined, value => `minato_cfunc_max(${value})`) - this.define({ - types: ['list'], - dump: value => value.join(','), - load: value => value ? value.split(',') : [], - }) - - this.define({ - types: ['json'], - dump: value => JSON.stringify(value), - load: value => typeof value === 'string' ? JSON.parse(value) : value, - }) - - this.define({ - types: ['time'], - dump: value => value, - load: (value) => { - if (!value || typeof value === 'object') return value - const time = new Date(DEFAULT_DATE) - const [h, m, s] = value.split(':') - time.setHours(parseInt(h)) - time.setMinutes(parseInt(m)) - time.setSeconds(parseInt(s)) - return time + this.transformers['boolean'] = { + encode: value => `if(${value}=b'1', 1, 0)`, + decode: value => `if(${value}=1, b'1', b'0')`, + load: value => isNullable(value) ? value : !!value, + dump: value => isNullable(value) ? value : value ? 1 : 0, + } + + this.transformers['binary'] = { + encode: value => `to_base64(${value})`, + decode: value => `from_base64(${value})`, + load: value => isNullable(value) ? value : Uint8ArrayFromBase64(value), + dump: value => isNullable(value) ? value : Uint8ArrayToBase64(value), + } + + this.transformers['date'] = { + encode: value => value, + decode: value => `cast(${value} as date)`, + load: value => { + if (isNullable(value) || typeof value === 'object') return value + const parsed = new Date(value), date = new Date() + date.setFullYear(parsed.getFullYear(), parsed.getMonth(), parsed.getDate()) + date.setHours(0, 0, 0, 0) + return date + }, + dump: value => { + if (isNullable(value)) return value + const date = new Date(0) + date.setFullYear(value.getFullYear(), value.getMonth(), value.getDate()) + date.setHours(0, 0, 0, 0) + return Time.template('yyyy-MM-dd hh:mm:ss.SSS', date) + }, + } + + this.transformers['time'] = { + encode: value => value, + decode: value => `cast(${value} as time)`, + load: value => this.driver.types['time'].load(value), + dump: value => isNullable(value) ? value : Time.template('yyyy-MM-dd hh:mm:ss.SSS', value), + } + + this.transformers['timestamp'] = { + encode: value => value, + decode: value => `cast(${value} as datetime)`, + load: value => { + if (isNullable(value) || typeof value === 'object') return value + return new Date(value) }, - }) + dump: value => isNullable(value) ? value : Time.template('yyyy-MM-dd hh:mm:ss.SSS', value), + } } - escape(value: any, field?: Field) { + escapePrimitive(value: any, type?: Type) { if (value instanceof Date) { - value = Time.template('yyyy-MM-dd hh:mm:ss', value) + value = Time.template('yyyy-MM-dd hh:mm:ss.SSS', value) } else if (value instanceof RegExp) { value = value.source - } else if (!field && !!value && typeof value === 'object') { + } else if (isUint8Array(value)) { + return `X'${Uint8ArrayToHex(value)}'` + } else if (!!value && typeof value === 'object') { return `json_extract(${this.quote(JSON.stringify(value))}, '$')` } - return super.escape(value, field) + return super.escapePrimitive(value, type) } - protected jsonQuote(value: string, pure: boolean = false) { - if (pure) return this.compat.maria ? `json_extract(json_object('v', ${value}), '$.v')` : `cast(${value} as json)` - const res = this.state.sqlType === 'raw' ? (this.compat.maria ? `json_extract(json_object('v', ${value}), '$.v')` : `cast(${value} as json)`) : value - this.state.sqlType = 'json' - return res + protected encode(value: string, encoded: boolean, pure: boolean = false, type?: Type) { + return this.asEncoded(encoded === this.isEncoded() && !pure ? value : encoded + ? (this.compat.maria ? `json_extract(json_object('v', ${this.transform(value, type, 'encode')}), '$.v')` + : `cast(${this.transform(value, type, 'encode')} as json)`) + : this.transform(`json_unquote(${value})`, type, 'decode'), pure ? undefined : encoded) } protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string, compat?: (value: string) => string) { @@ -90,10 +114,9 @@ export class MySQLBuilder extends Builder { protected groupArray(value: string) { if (!this.compat.maria) return super.groupArray(value) - const res = this.state.sqlType === 'json' ? `concat('[', group_concat(${value}), ']')` + const res = this.isEncoded() ? `concat('[', group_concat(${value}), ']')` : `concat('[', group_concat(json_extract(json_object('v', ${value}), '$.v')), ']')` - this.state.sqlType = 'json' - return `ifnull(${res}, json_array())` + return this.asEncoded(`ifnull(${res}, json_array())`, true) } protected parseSelection(sel: Selection) { @@ -108,17 +131,17 @@ export class MySQLBuilder extends Builder { if (!(sel.args[0] as any).$) { query = `(SELECT ${output} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''})` } else { - query = `(ifnull((SELECT ${this.groupArray(output)} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), json_array()))` + query = `(ifnull((SELECT ${this.groupArray(this.transform(output, Type.getInner(Type.fromTerm(expr)), 'encode'))} + AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), json_array()))` } if (Object.keys(refFields ?? {}).length) { const funcname = `minato_tfunc_${randomId()}` const decls = Object.values(refFields ?? {}).map(x => `${x} JSON`).join(',') - const args = Object.keys(refFields ?? {}).map(x => this.state.refFields?.[x] ?? x).map(x => this.jsonQuote(x, true)).join(',') - query = this.state.sqlType === 'json' ? `ifnull(${query}, json_array())` : this.jsonQuote(query) + const args = Object.keys(refFields ?? {}).map(x => this.state.refFields?.[x] ?? x).map(x => this.encode(x, true, true)).join(',') + query = this.isEncoded() ? `ifnull(${query}, json_array())` : this.encode(query, true) this.prequeries.push(`DROP FUNCTION IF EXISTS ${funcname}`) this.prequeries.push(`CREATE FUNCTION ${funcname} (${decls}) RETURNS JSON DETERMINISTIC RETURN ${query}`) - this.state.sqlType = 'json' - return `${funcname}(${args})` + return this.asEncoded(`${funcname}(${args})`, true) } else return query } @@ -142,7 +165,7 @@ export class MySQLBuilder extends Builder { if (!prop.startsWith(key + '.')) continue const rest = prop.slice(key.length + 1).split('.') if (rest.length === 1) continue - rest.reduce((obj, k) => obj[k] ??= {}, jsonInit) + rest.slice(0, -1).reduce((obj, k) => obj[k] ??= {}, jsonInit) } // update with json_set @@ -152,13 +175,16 @@ export class MySQLBuilder extends Builder { // json_set cannot create deeply nested property when non-exist // therefore we merge a layout to it if (Object.keys(jsonInit).length !== 0) { - value = `json_merge(${value}, ${this.quote(JSON.stringify(jsonInit))})` + value = `json_merge_patch(${this.escape(jsonInit, 'json')}, ${value})` } for (const prop in item) { if (!prop.startsWith(key + '.')) continue const rest = prop.slice(key.length + 1).split('.') - value = `json_set(${value}, '$${rest.map(key => `."${key}"`).join('')}', ${this.parseEval(item[prop])})` + const type = Type.getInner(field?.type, prop.slice(key.length + 1)) + const v = isEvalExpr(item[prop]) ? this.transform(this.parseEval(item[prop]), item[prop], 'encode') + : this.transform(this.escape(item[prop], type), type, 'encode') + value = `json_set(${value}, '$${rest.map(key => `."${key}"`).join('')}', ${v})` } if (value === valueInit) { diff --git a/packages/mysql/src/index.ts b/packages/mysql/src/index.ts index 988c2ae5..13510869 100644 --- a/packages/mysql/src/index.ts +++ b/packages/mysql/src/index.ts @@ -3,7 +3,7 @@ import type { OkPacket, Pool, PoolConfig, PoolConnection } from 'mysql' import { Dict, difference, makeArray, pick } from 'cosmokit' import { Driver, Eval, executeUpdate, Field, RuntimeError, Selection, z } from 'minato' import { escapeId, isBracketed } from '@minatojs/sql-utils' -import { Compat, DEFAULT_DATE, MySQLBuilder } from './builder' +import { Compat, MySQLBuilder } from './builder' declare module 'mysql' { interface UntypedFieldInfo { @@ -11,6 +11,8 @@ declare module 'mysql' { } } +const timeRegex = /(\d+):(\d+):(\d+)(\.(\d+))?/ + function getIntegerType(length = 4) { if (length <= 1) return 'tinyint' if (length <= 2) return 'smallint' @@ -19,12 +21,12 @@ function getIntegerType(length = 4) { return 'bigint' } -function getTypeDef({ type, length, precision, scale }: Field) { +function getTypeDef({ deftype: type, length, precision, scale }: Field) { switch (type) { case 'float': case 'double': - case 'date': - case 'time': return type + case 'date': return type + case 'time': return 'time(3)' case 'timestamp': return 'datetime(3)' case 'boolean': return 'bit' case 'integer': @@ -34,10 +36,11 @@ function getTypeDef({ type, length, precision, scale }: Field) { case 'unsigned': if ((length || 0) > 8) this.logger.warn(`type ${type}(${length}) exceeds the max supported length`) return `${getIntegerType(length)} unsigned` - case 'decimal': return `decimal(${precision}, ${scale}) unsigned` + case 'decimal': return `decimal(${precision ?? 10}, ${scale ?? 0}) unsigned` case 'char': return `char(${length || 255})` - case 'string': return `varchar(${length || 255})` - case 'text': return `text(${length || 65535})` + case 'string': return (length || 255) > 65536 ? 'longtext' : `varchar(${length || 255})` + case 'text': return (length || 255) > 65536 ? 'longtext' : `text(${length || 65535})` + case 'binary': return (length || 65537) > 65536 ? 'longblob' : `blob` case 'list': return `text(${length || 65535})` case 'json': return `text(${length || 65535})` default: throw new Error(`unsupported type: ${type}`) @@ -48,7 +51,7 @@ function isDefUpdated(field: Field, column: ColumnInfo, def: string) { const typename = def.split(/[ (]/)[0] if (typename === 'text') return !column.DATA_TYPE.endsWith('text') if (typename !== column.DATA_TYPE) return true - switch (field.type) { + switch (field.deftype) { case 'integer': case 'unsigned': case 'char': @@ -93,7 +96,7 @@ export class MySQLDriver extends Driver { static name = 'mysql' public pool!: Pool - public sql = new MySQLBuilder() + public sql: MySQLBuilder = new MySQLBuilder(this) private session?: PoolConnection private _compat: Compat = {} @@ -106,27 +109,8 @@ export class MySQLDriver extends Driver { charset: 'utf8mb4_general_ci', multipleStatements: true, typeCast: (field, next) => { - const { orgName, orgTable } = field.packet - const meta = this.database.tables[orgTable]?.fields[orgName] - - if (Field.string.includes(meta!?.type)) { - return field.string() - } else if (meta?.type === 'json') { - const source = field.string() - return source ? JSON.parse(source) : meta.initial - } else if (meta?.type === 'time') { - const source = field.string() - if (!source) return meta.initial - const time = new Date(DEFAULT_DATE) - const [h, m, s] = source.split(':') - time.setHours(parseInt(h)) - time.setMinutes(parseInt(m)) - time.setSeconds(parseInt(s)) - return time - } - if (field.type === 'BIT') { - return Boolean(field.buffer()?.readUInt8(0)) + return Boolean(field.buffer()?.readUint8(0)) } else { return next() } @@ -145,6 +129,31 @@ export class MySQLDriver extends Driver { if (this._compat.mysql57 || this._compat.maria) { await this._setupCompatFunctions() } + + this.define({ + types: ['json'], + dump: value => value as any, + load: value => typeof value === 'string' ? JSON.parse(value) : value, + }) + + this.define({ + types: ['list'], + dump: value => Array.isArray(value) ? value.join(',') : value, + load: value => value ? value.split(',') : [], + }) + + this.define({ + types: ['time'], + dump: value => value, + load: (value) => { + if (!value || typeof value === 'object') return value + const date = new Date(0) + const parsed = timeRegex.exec(value) + if (!parsed) throw Error(`unexpected time value: ${value}`) + date.setHours(+parsed[1], +parsed[2], +parsed[3], +(parsed[5] ?? 0)) + return date + }, + }) } async stop() { @@ -196,7 +205,7 @@ export class MySQLDriver extends Driver { def += (nullable ? ' ' : ' not ') + 'null' } // blob, text, geometry or json columns cannot have default values - if (initial && !typedef.startsWith('text')) { + if (initial && !typedef.startsWith('text') && !typedef.endsWith('blob')) { def += ' default ' + this.sql.escape(initial, fields[key]) } } @@ -376,28 +385,28 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH async get(sel: Selection.Immutable) { const { model, tables } = sel - const builder = new MySQLBuilder(tables, this._compat) + const builder = new MySQLBuilder(this, tables, this._compat) const sql = builder.get(sel) if (!sql) return [] return Promise.all([...builder.prequeries, sql].map(x => this.queue(x))).then((data) => { - return data.at(-1).map((row) => builder.load(model, row)) + return data.at(-1).map((row) => builder.load(row, model)) }) } async eval(sel: Selection.Immutable, expr: Eval.Expr) { - const builder = new MySQLBuilder(sel.tables, this._compat) + const builder = new MySQLBuilder(this, sel.tables, this._compat) const inner = builder.get(sel.table as Selection, true, true) const output = builder.parseEval(expr, false) const ref = isBracketed(inner) ? sel.ref : '' const sql = `SELECT ${output} AS value FROM ${inner} ${ref}` return Promise.all([...builder.prequeries, sql].map(x => this.queue(x))).then((data) => { - return builder.load(data.at(-1)[0].value) + return builder.load(data.at(-1)[0].value, expr) }) } async set(sel: Selection.Mutable, data: {}) { const { model, query, table, tables, ref } = sel - const builder = new MySQLBuilder(tables, this._compat) + const builder = new MySQLBuilder(this, tables, this._compat) const filter = builder.parseQuery(query) const { fields } = model if (filter === '0') return {} @@ -415,7 +424,7 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH async remove(sel: Selection.Mutable) { const { query, table, tables } = sel - const builder = new MySQLBuilder(tables, this._compat) + const builder = new MySQLBuilder(this, tables, this._compat) const filter = builder.parseQuery(query) if (filter === '0') return {} const result = await this.query(`DELETE FROM ${escapeId(table)} WHERE ` + filter) @@ -425,7 +434,7 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH async create(sel: Selection.Mutable, data: {}) { const { table, model } = sel const { autoInc, primary } = model - const formatted = this.sql.dump(model, data) + const formatted = this.sql.dump(data, model) const keys = Object.keys(formatted) const header = await this.query([ `INSERT INTO ${escapeId(table)} (${keys.map(escapeId).join(', ')})`, @@ -438,7 +447,7 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH async upsert(sel: Selection.Mutable, data: any[], keys: string[]) { if (!data.length) return {} const { model, table, tables, ref } = sel - const builder = new MySQLBuilder(tables, this._compat) + const builder = new MySQLBuilder(this, tables, this._compat) const merged = {} const insertion = data.map((item) => { @@ -489,6 +498,8 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH `ON DUPLICATE KEY UPDATE ${update}`, ].join(' ')) const records = +(/^&Records:\s*(\d+)/.exec(result.message)?.[1] ?? result.affectedRows) + if (!result.message && !result.insertId) return { inserted: 0, matched: result.affectedRows, modified: 0 } + if (!result.message && result.affectedRows > 1) return { inserted: 0, matched: result.affectedRows / 2, modified: result.affectedRows / 2 } return { inserted: records - result.changedRows, matched: result.changedRows, modified: result.affectedRows - records } } diff --git a/packages/postgres/src/builder.ts b/packages/postgres/src/builder.ts index fe8b2395..dbf5f455 100644 --- a/packages/postgres/src/builder.ts +++ b/packages/postgres/src/builder.ts @@ -1,8 +1,9 @@ import { Builder, isBracketed } from '@minatojs/sql-utils' import { Dict, isNullable, Time } from 'cosmokit' -import { Field, isEvalExpr, Model, randomId, Selection } from 'minato' - -const timeRegex = /(\d+):(\d+):(\d+)/ +import { + Driver, Field, isEvalExpr, isUint8Array, Model, randomId, Selection, Type, + Uint8ArrayFromBase64, Uint8ArrayToBase64, Uint8ArrayToHex, unravel, +} from 'minato' export function escapeId(value: string) { return '"' + value.replace(/"/g, '""') + '"' @@ -22,24 +23,15 @@ export function formatTime(time: Date) { } export class PostgresBuilder extends Builder { - // eslint-disable-next-line no-control-regex - protected escapeRegExp = /[\0\b\t\n\r\x1a'\\]/g protected escapeMap = { - '\0': '\\0', - '\b': '\\b', - '\t': '\\t', - '\n': '\\n', - '\r': '\\r', - '\x1a': '\\Z', - '\'': '\'\'', - '\\': '\\\\', + "'": "''", } protected $true = 'TRUE' protected $false = 'FALSE' - constructor(public tables?: Dict) { - super(tables) + constructor(protected driver: Driver, public tables?: Dict) { + super(driver, tables) this.queryOperators = { ...this.queryOperators, @@ -47,7 +39,7 @@ export class PostgresBuilder extends Builder { $regexFor: (key, value) => `${this.escape(value)} ~ ${key}`, $size: (key, value) => { if (!value) return this.logicalNot(key) - if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { + if (this.isJsonQuery(key)) { return `${this.jsonLength(key)} = ${this.escape(value)}` } else { return `${key} IS NOT NULL AND ARRAY_LENGTH(${key}, 1) = ${value}` @@ -84,10 +76,9 @@ export class PostgresBuilder extends Builder { $number: (arg) => { const value = this.parseEval(arg) - const res = this.state.sqlType === 'raw' ? `${value}::double precision` - : `extract(epoch from ${value})::bigint` - this.state.sqlType = 'raw' - return `coalesce(${res}, 0)` + const type = Type.fromTerm(arg) + const res = Field.date.includes(type.type!) ? `extract(epoch from ${value})::bigint` : `${value}::double precision` + return this.asEncoded(`coalesce(${res}, 0)`, false) }, $sum: (expr) => this.createAggr(expr, value => `coalesce(sum(${value})::double precision, 0)`, undefined, 'double precision'), @@ -95,37 +86,63 @@ export class PostgresBuilder extends Builder { $min: (expr) => this.createAggr(expr, value => `min(${value})`, undefined, 'double precision'), $max: (expr) => this.createAggr(expr, value => `max(${value})`, undefined, 'double precision'), $count: (expr) => this.createAggr(expr, value => `count(distinct ${value})::integer`), - $length: (expr) => this.createAggr(expr, value => `count(${value})::integer`, value => { - if (this.state.sqlType === 'json') { - this.state.sqlType = 'raw' - return `${this.jsonLength(value)}` - } else { - this.state.sqlType = 'raw' - return `COALESCE(ARRAY_LENGTH(${value}, 1), 0)` - } - }), + $length: (expr) => this.createAggr(expr, value => `count(${value})::integer`, + value => this.isEncoded() ? this.jsonLength(value) : this.asEncoded(`COALESCE(ARRAY_LENGTH(${value}, 1), 0)`, false), + ), $concat: (args) => `${args.map(arg => this.parseEval(arg, 'text')).join('||')}`, } - this.define({ - types: ['time'], - dump: date => date ? (typeof date === 'string' ? date : formatTime(date)) : null, - load: str => { - if (isNullable(str)) return str - const date = new Date(0) - const parsed = timeRegex.exec(str) - if (!parsed) throw Error(`unexpected time value: ${str}`) - date.setHours(+parsed[1], +parsed[2], +parsed[3]) + this.transformers['boolean'] = { + encode: value => value, + decode: value => `(${value})::boolean`, + load: value => value, + dump: value => value, + } + + this.transformers['decimal'] = { + encode: value => value, + decode: value => `(${value})::double precision`, + load: value => isNullable(value) ? value : +value, + dump: value => value, + } + + this.transformers['binary'] = { + encode: value => `encode(${value}, 'base64')`, + decode: value => `decode(${value}, 'base64')`, + load: value => isNullable(value) ? value : Uint8ArrayFromBase64(value), + dump: value => isNullable(value) ? value : Uint8ArrayToBase64(value), + } + + this.transformers['date'] = { + encode: value => value, + decode: value => `cast(${value} as date)`, + load: value => { + if (isNullable(value) || typeof value === 'object') return value + const parsed = new Date(value), date = new Date() + date.setFullYear(parsed.getFullYear(), parsed.getMonth(), parsed.getDate()) + date.setHours(0, 0, 0, 0) return date }, - }) + dump: value => isNullable(value) ? value : formatTime(value), + } - this.define({ - types: ['list'], - dump: value => '{' + value.join(',') + '}', - load: value => value, - }) + this.transformers['time'] = { + encode: value => value, + decode: value => `cast(${value} as time)`, + load: value => this.driver.types['time'].load(value), + dump: value => this.driver.types['time'].dump(value), + } + + this.transformers['timestamp'] = { + encode: value => value, + decode: value => `cast(${value} as datetime)`, + load: value => { + if (isNullable(value) || typeof value === 'object') return value + return new Date(value) + }, + dump: value => isNullable(value) ? value : formatTime(value), + } } upsert(table: string) { @@ -145,12 +162,13 @@ export class PostgresBuilder extends Builder { else if (typeof expr === 'string') return 'boolean' } - parseEval(expr: any, outtype: boolean | string = false): string { - this.state.sqlType = 'raw' + parseEval(expr: any, outtype: boolean | string = true): string { + this.state.encoded = false if (typeof expr === 'string' || typeof expr === 'number' || typeof expr === 'boolean' || expr instanceof Date || expr instanceof RegExp) { return this.escape(expr) } - return outtype ? this.jsonUnquote(this.parseEvalExpr(expr), false, typeof outtype === 'string' ? outtype : undefined) : this.parseEvalExpr(expr) + return outtype ? `(${this.encode(this.parseEvalExpr(expr), false, false, Type.fromTerm(expr), typeof outtype === 'string' ? outtype : undefined)})` + : this.parseEvalExpr(expr) } protected createRegExpQuery(key: string, value: string | RegExp) { @@ -158,8 +176,8 @@ export class PostgresBuilder extends Builder { } protected createElementQuery(key: string, value: any) { - if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { - return this.jsonContains(key, this.quote(JSON.stringify(value))) + if (this.isJsonQuery(key)) { + return this.jsonContains(key, this.encode(value, true, true)) } else { return `${key} && ARRAY['${value}']::TEXT[]` } @@ -168,56 +186,47 @@ export class PostgresBuilder extends Builder { protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string, eltype?: string) { if (!this.state.group && !nonaggr) { const value = this.parseEval(expr, false) - return `(select ${aggr(this.jsonUnquote(this.escapeId('value'), true, eltype))} from jsonb_array_elements(${value}) ${randomId()})` + return `(select ${aggr(`(${this.encode(this.escapeId('value'), false, true, undefined)})${eltype ? `::${eltype}` : ''}`)} + from jsonb_array_elements(${value}) ${randomId()})` } else { return super.createAggr(expr, aggr, nonaggr) } } protected transformJsonField(obj: string, path: string) { - this.state.sqlType = 'json' - return `jsonb_extract_path(${obj}, ${path.slice(1).replace('.', ',')})` + return this.asEncoded(`jsonb_extract_path(${obj}, ${path.slice(1).replaceAll('.', ',')})`, true) } protected jsonLength(value: string) { - return `jsonb_array_length(${value})` + return this.asEncoded(`jsonb_array_length(${value})`, false) } protected jsonContains(obj: string, value: string) { - return `(${obj} @> ${value})` + return this.asEncoded(`(${obj} @> ${value})`, false) } - protected jsonUnquote(value: string, pure: boolean = false, type?: string) { - if (pure && type) return `(jsonb_build_object('v', ${value})->>'v')::${type}` - if (this.state.sqlType === 'json') { - this.state.sqlType = 'raw' - return `(jsonb_build_object('v', ${value})->>'v')::${type}` - } - return value + protected encode(value: string, encoded: boolean, pure: boolean = false, type?: Type, outtype?: string) { + return this.asEncoded((encoded === this.isEncoded() && !pure) ? value + : encoded ? `to_jsonb(${this.transform(value, type, 'encode')})` + : this.transform(`(jsonb_build_object('v', ${value})->>'v')`, type, 'decode') + `${typeof outtype === 'string' ? `::${outtype}` : ''}` + , pure ? undefined : encoded) } - protected jsonQuote(value: string, pure: boolean = false) { - if (pure) return `to_jsonb(${value})` - if (this.state.sqlType !== 'json') { - this.state.sqlType = 'json' - return `to_jsonb(${value})` - } - return value - } - - protected groupObject(fields: any) { - const parse = (expr) => { - const value = this.parseEval(expr, false) - return this.state.sqlType === 'json' ? `to_jsonb(${value})` : `${value}` + protected groupObject(_fields: any) { + const _groupObject = (fields: any, type?: Type, prefix: string = '') => { + const parse = (expr, key) => { + const value = (!_fields[`${prefix}${key}`] && type && Type.getInner(type, key)?.inner) + ? _groupObject(expr, Type.getInner(type, key), `${prefix}${key}.`) + : this.parseEval(expr, false) + return this.isEncoded() ? this.encode(`to_jsonb(${value})`, true) : this.transform(value, expr, 'encode') + } + return `jsonb_build_object(` + Object.entries(fields).map(([key, expr]) => `'${key}', ${parse(expr, key)}`).join(',') + `)` } - const res = `jsonb_build_object(` + Object.entries(fields).map(([key, expr]) => `'${key}', ${parse(expr)}`).join(',') + `)` - this.state.sqlType = 'json' - return res + return this.asEncoded(_groupObject(unravel(_fields), this.state.type, ''), true) } protected groupArray(value: string) { - this.state.sqlType = 'json' - return `coalesce(jsonb_agg(${value}), '[]'::jsonb)` + return this.asEncoded(`coalesce(jsonb_agg(${value}), '[]'::jsonb)`, true) } protected parseSelection(sel: Selection) { @@ -229,7 +238,8 @@ export class PostgresBuilder extends Builder { if (!(sel.args[0] as any).$) { return `(SELECT ${output} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''})` } else { - return `(coalesce((SELECT ${this.groupArray(output)} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), '[]'::jsonb))` + return `(coalesce((SELECT ${this.groupArray(this.transform(output, Type.getInner(Type.fromTerm(expr)), 'encode'))} + AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), '[]'::jsonb))` } } @@ -239,15 +249,19 @@ export class PostgresBuilder extends Builder { return `'${value}'` } - escape(value: any, field?: Field) { + escapePrimitive(value: any, type?: Type) { if (value instanceof Date) { value = formatTime(value) } else if (value instanceof RegExp) { value = value.source - } else if (!field && !!value && typeof value === 'object') { + } else if (isUint8Array(value)) { + return `'\\x${Uint8ArrayToHex(value)}'::bytea` + } else if (type?.type === 'list' && Array.isArray(value)) { + return `ARRAY[${value.map(x => this.escape(x)).join(', ')}]::TEXT[]` + } else if (!!value && typeof value === 'object') { return `${this.quote(JSON.stringify(value))}::jsonb` } - return super.escape(value, field) + return super.escapePrimitive(value, type) } toUpdateExpr(item: any, key: string, field?: Field, upsert?: boolean) { @@ -279,13 +293,20 @@ export class PostgresBuilder extends Builder { // json_set cannot create deeply nested property when non-exist // therefore we merge a layout to it if (Object.keys(jsonInit).length !== 0) { - value = `(${value} || jsonb ${this.quote(JSON.stringify(jsonInit))})` + value = `(jsonb ${this.escape(jsonInit, 'json')} || ${value})` } for (const prop in item) { if (!prop.startsWith(key + '.')) continue const rest = prop.slice(key.length + 1).split('.') - value = `jsonb_set(${value}, '{${rest.map(key => `"${key}"`).join(',')}}', ${this.jsonQuote(this.parseEval(item[prop]), true)}, true)` + const type = Type.getInner(field?.type, prop.slice(key.length + 1)) + let escaped: string + + const v = isEvalExpr(item[prop]) ? this.encode(this.parseEval(item[prop]), true, true, Type.fromTerm(item[prop])) + : (escaped = this.transform(this.escape(item[prop], type), type, 'encode'), escaped.endsWith('::jsonb') ? escaped + : escaped.startsWith(`'`) ? this.encode(`(${escaped})::text`, true, true) // not passing type to prevent duplicated transform + : this.encode(escaped, true, true)) + value = `jsonb_set(${value}, '{${rest.map(key => `"${key}"`).join(',')}}', ${v}, true)` } if (value === valueInit) { diff --git a/packages/postgres/src/index.ts b/packages/postgres/src/index.ts index f3e62fd2..578554ca 100644 --- a/packages/postgres/src/index.ts +++ b/packages/postgres/src/index.ts @@ -2,7 +2,7 @@ import postgres from 'postgres' import { Dict, difference, isNullable, makeArray, pick } from 'cosmokit' import { Driver, Eval, executeUpdate, Field, Selection, z } from 'minato' import { isBracketed } from '@minatojs/sql-utils' -import { formatTime, PostgresBuilder } from './builder' +import { escapeId, formatTime, PostgresBuilder } from './builder' interface ColumnInfo { table_catalog: string @@ -55,79 +55,52 @@ interface QueryTask { reject: (reason: unknown) => void } -function escapeId(value: string) { - return '"' + value.replace(/"/g, '""') + '"' -} +const timeRegex = /(\d+):(\d+):(\d+)(\.(\d+))?/ function getTypeDef(field: Field & { autoInc?: boolean }) { - let { type, length, precision, scale, initial, autoInc } = field - let def = '' - if (['primary', 'unsigned', 'integer'].includes(type)) { - length ||= 4 - if (precision) def += `numeric(${precision}, ${scale ?? 0})` - else if (length <= 2) def += autoInc ? 'smallserial' : 'smallint' - else if (length <= 4) def += autoInc ? 'serial' : 'integer' - else { - if (length > 8) this.logger.warn(`type ${type}(${length}) exceeds the max supported length`) - def += autoInc ? 'bigserial' : 'bigint' - } - if (!isNullable(initial) && !autoInc) def += ` DEFAULT ${initial}` - } else if (type === 'decimal') { - def += `numeric(${precision}, ${scale})` - if (!isNullable(initial)) def += ` DEFAULT ${initial}` - } else if (type === 'float') { - def += 'real' - if (!isNullable(initial)) def += ` DEFAULT ${initial}` - } else if (type === 'double') { - def += 'double precision' - if (!isNullable(initial)) def += ` DEFAULT ${initial}` - } else if (type === 'char') { - def += `varchar(${length || 64}) ` - if (!isNullable(initial)) def += ` DEFAULT '${initial.replace(/'/g, "''")}'` - } else if (type === 'string') { - def += `varchar(${length || 255})` - if (!isNullable(initial)) def += ` DEFAULT '${initial.replace(/'/g, "''")}'` - } else if (type === 'text') { - def += `text` - if (!isNullable(initial)) def += ` DEFAULT '${initial.replace(/'/g, "''")}'` - } else if (type === 'boolean') { - def += 'boolean' - if (!isNullable(initial)) def += ` DEFAULT ${initial}` - } else if (type === 'list') { - def += 'text[]' - if (initial) { - def += ` DEFAULT ${transformArray(initial)}` - } - } else if (type === 'json') { - def += 'jsonb' - if (initial) def += ` DEFAULT '${JSON.stringify(initial)}'::JSONB` // TODO - } else if (type === 'date') { - def += 'timestamp with time zone' - if (initial) def += ` DEFAULT ${formatTime(initial)}` - } else if (type === 'time') { - def += 'time with time zone' - if (initial) def += ` DEFAULT ${formatTime(initial)}` - } else if (type === 'timestamp') { - def += 'timestamp with time zone' - if (initial) def += ` DEFAULT ${formatTime(initial)}` - } else throw new Error(`unsupported type: ${type}`) - - return def + let { deftype: type, length, precision, scale, autoInc } = field + switch (type) { + case 'primary': + case 'unsigned': + case 'integer': + length ||= 4 + if (precision) return `numeric(${precision}, ${scale ?? 0})` + else if (length <= 2) return autoInc ? 'smallserial' : 'smallint' + else if (length <= 4) return autoInc ? 'serial' : 'integer' + else { + if (length > 8) this.logger.warn(`type ${type}(${length}) exceeds the max supported length`) + return autoInc ? 'bigserial' : 'bigint' + } + case 'decimal': return `numeric(${precision ?? 10}, ${scale ?? 0})` + case 'float': return 'real' + case 'double': return 'double precision' + case 'char': return `varchar(${length || 64}) ` + case 'string': return `varchar(${length || 255})` + case 'text': return `text` + case 'boolean': return 'boolean' + case 'list': return 'text[]' + case 'json': return 'jsonb' + case 'date': return 'timestamp with time zone' + case 'time': return 'time with time zone' + case 'timestamp': return 'timestamp with time zone' + case 'binary': return 'bytea' + default: throw new Error(`unsupported type: ${type}`) + } } function isDefUpdated(field: Field & { autoInc?: boolean }, column: ColumnInfo, def: string) { const typename = def.split(/[ (]/)[0] if (field.autoInc) return false - if (['unsigned', 'integer'].includes(field.type)) { + if (['unsigned', 'integer'].includes(field.deftype!)) { if (column.data_type !== typename) return true } else if (typename === 'text[]') { if (column.data_type !== 'ARRAY') return true - } else if (Field.date.includes(field.type)) { + } else if (Field.date.includes(field.deftype!)) { if (column.data_type !== def) return true } else if (typename === 'varchar') { if (column.data_type !== 'character varying') return true } else if (typename !== column.data_type) return true - switch (field.type) { + switch (field.deftype) { case 'integer': case 'unsigned': case 'char': @@ -147,15 +120,11 @@ function createIndex(keys: string | string[]) { return makeArray(keys).map(escapeId).join(', ') } -function transformArray(arr: any[]) { - return `ARRAY[${arr.map(v => `'${v.replace(/'/g, "''")}'`).join(',')}]::TEXT[]` -} - export class PostgresDriver extends Driver { static name = 'postgres' public postgres!: postgres.Sql - public sql = new PostgresBuilder() + public sql = new PostgresBuilder(this) private session?: postgres.TransactionSql private _counter = 0 @@ -177,6 +146,25 @@ export class PostgresDriver extends Driver { }, ...this.config, }) + + this.define({ + types: ['json'], + dump: value => value, + load: value => value, + }) + + this.define({ + types: ['time'], + dump: date => date ? (typeof date === 'string' ? date : formatTime(date)) : null, + load: str => { + if (isNullable(str)) return str + const date = new Date(0) + const parsed = timeRegex.exec(str) + if (!parsed) throw Error(`unexpected time value: ${str}`) + date.setHours(+parsed[1], +parsed[2], +parsed[3], +(parsed[5] ?? 0)) + return date + }, + }) } async stop() { @@ -235,7 +223,7 @@ export class PostgresDriver extends Driver { // field definitions for (const key in fields) { - const { deprecated } = fields[key]! + const { deprecated, initial, nullable = true } = fields[key]! if (deprecated) continue const legacy = [key, ...fields[key]!.legacy || []] const column = columns.find(info => legacy.includes(info.column_name)) @@ -247,12 +235,13 @@ export class PostgresDriver extends Driver { } if (!column) { - create.push(`${escapeId(key)} ${typedef}`) + create.push(`${escapeId(key)} ${typedef} ${makeArray(primary).includes(key) || !nullable ? 'not null' : 'null'}` + + (initial ? ' DEFAULT ' + this.sql.escape(initial, fields[key]) : '')) } else if (shouldUpdate) { if (column.column_name !== key) rename.push(`RENAME ${escapeId(column.column_name)} TO ${escapeId(key)}`) - const [ctype, cdefault] = typedef.split('DEFAULT') - update.push(`ALTER ${escapeId(key)} TYPE ${ctype}`) - if (cdefault) update.push(`ALTER ${escapeId(key)} SET DEFAULT ${cdefault}`) + update.push(`ALTER ${escapeId(key)} TYPE ${typedef}`) + update.push(`ALTER ${escapeId(key)} ${makeArray(primary).includes(key) || !nullable ? 'SET' : 'DROP'} NOT NULL`) + if (initial) update.push(`ALTER ${escapeId(key)} SET DEFAULT ${this.sql.escape(initial, fields[key])}`) } } @@ -346,26 +335,26 @@ export class PostgresDriver extends Driver { } async get(sel: Selection.Immutable) { - const builder = new PostgresBuilder(sel.tables) + const builder = new PostgresBuilder(this, sel.tables) const query = builder.get(sel) if (!query) return [] return this.queue(query).then(data => { - return data.map(row => builder.load(sel.model, row)) + return data.map(row => builder.load(row, sel.model)) }) } async eval(sel: Selection.Immutable, expr: Eval.Expr) { - const builder = new PostgresBuilder(sel.tables) + const builder = new PostgresBuilder(this, sel.tables) const inner = builder.get(sel.table as Selection, true, true) const output = builder.parseEval(expr, false) const ref = isBracketed(inner) ? sel.ref : '' const [data] = await this.queue(`SELECT ${output} AS value FROM ${inner} ${ref}`) - return builder.load(data?.value) + return builder.load(data?.value, expr) } async set(sel: Selection.Mutable, data: {}) { const { model, query, table, tables, ref } = sel - const builder = new PostgresBuilder(tables) + const builder = new PostgresBuilder(this, tables) const filter = builder.parseQuery(query) const { fields } = model if (filter === '0') return {} @@ -382,7 +371,7 @@ export class PostgresDriver extends Driver { } async remove(sel: Selection.Mutable) { - const builder = new PostgresBuilder(sel.tables) + const builder = new PostgresBuilder(this, sel.tables) const query = builder.parseQuery(sel.query) if (query === 'FALSE') return {} const { count } = await this.query(`DELETE FROM ${sel.table} WHERE ${query}`) @@ -391,21 +380,21 @@ export class PostgresDriver extends Driver { async create(sel: Selection.Mutable, data: any) { const { table, model } = sel - const builder = new PostgresBuilder(sel.tables) - const formatted = builder.dump(model, data) + const builder = new PostgresBuilder(this, sel.tables) + const formatted = builder.dump(data, model) const keys = Object.keys(formatted) const [row] = await this.query([ `INSERT INTO ${builder.escapeId(table)} (${keys.map(builder.escapeId).join(', ')})`, - `VALUES (${keys.map(key => builder.escape(formatted[key])).join(', ')})`, + `VALUES (${keys.map(key => builder.escapePrimitive(formatted[key], model.getType(key))).join(', ')})`, `RETURNING *`, ].join(' ')) - return builder.load(model, row) + return builder.load(row, model) } async upsert(sel: Selection.Mutable, data: any[], keys: string[]) { if (!data.length) return {} const { model, table, tables, ref } = sel - const builder = new PostgresBuilder(tables) + const builder = new PostgresBuilder(this, tables) builder.upsert(table) this._counter = (this._counter + 1) % 256 diff --git a/packages/sql-utils/src/index.ts b/packages/sql-utils/src/index.ts index a8911879..9db36551 100644 --- a/packages/sql-utils/src/index.ts +++ b/packages/sql-utils/src/index.ts @@ -1,5 +1,5 @@ -import { Dict, isNullable } from 'cosmokit' -import { Eval, Field, isComparable, Model, Modifier, Query, randomId, Selection } from 'minato' +import { Dict, isNullable, mapValues } from 'cosmokit' +import { Driver, Eval, Field, isComparable, isEvalExpr, Model, Modifier, Query, randomId, Selection, Type, unravel } from 'minato' export function escapeId(value: string) { return '`' + value + '`' @@ -9,6 +9,10 @@ export function isBracketed(value: string) { return value.startsWith('(') && value.endsWith(')') } +export function isSqlJson(type?: Type) { + return type ? (type.type === 'json' || !!type.inner) : false +} + export type QueryOperators = { [K in keyof Query.FieldExpr]?: (key: string, value: NonNullable) => string } @@ -19,19 +23,31 @@ export type EvalOperators = { [K in keyof Eval.Static as `$${K}`]?: (expr: ExtractUnary>) => string } & { $: (expr: any) => string } -export interface Transformer { - types: Field.Type[] - dump: (value: S) => T | null - load: (value: T, initial?: S) => S | null +interface Transformer { + encode(value: string): string + decode(value: string): string + load(value: S): T + dump(value: T): S } -type SQLType = 'raw' | 'json' | 'list' | 'date' | 'time' | 'timestamp' - interface State { - sqlType?: SQLType - sqlTypes?: Dict + // current table ref in get() + table?: string + + // encode format of last evaluation + encoded?: boolean + encodedMap?: Dict + + // current eval expr type + type?: Type + group?: boolean tables?: Dict + + // joined tables + innerTables?: Dict + + // outter tables and fields within subquery refFields?: Dict refTables?: Dict wrappedSubquery?: boolean @@ -40,7 +56,6 @@ interface State { export class Builder { protected escapeMap = {} protected escapeRegExp?: RegExp - protected types: Dict = {} protected createEqualQuery = this.comparator('=') protected queryOperators: QueryOperators protected evalOperators: EvalOperators @@ -48,10 +63,11 @@ export class Builder { protected $true = '1' protected $false = '0' protected modifiedTable?: string + protected transformers: Dict = Object.create(null) - private readonly _timezone = `+${(new Date()).getTimezoneOffset() / -60}:00`.replace('+-', '-') + protected readonly _timezone = `+${(new Date()).getTimezoneOffset() / -60}:00`.replace('+-', '-') - constructor(tables?: Dict) { + constructor(protected driver: Driver, tables?: Dict) { this.state.tables = tables this.queryOperators = { @@ -97,7 +113,7 @@ export class Builder { }, $size: (key, value) => { if (!value) return this.logicalNot(key) - if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { + if (this.isJsonQuery(key)) { return `${this.jsonLength(key)} = ${this.escape(value)}` } else { return `${key} AND LENGTH(${key}) - LENGTH(REPLACE(${key}, ${this.escape(',')}, ${this.escape('')})) = ${this.escape(value)} - 1` @@ -146,17 +162,17 @@ export class Builder { $lte: this.binary('<='), // membership - $in: ([key, value]) => this.createMemberQuery(this.parseEval(key), value, ''), - $nin: ([key, value]) => this.createMemberQuery(this.parseEval(key), value, ' NOT'), + $in: ([key, value]) => this.asEncoded(this.createMemberQuery(this.parseEval(key), value, ''), false), + $nin: ([key, value]) => this.asEncoded(this.createMemberQuery(this.parseEval(key), value, ' NOT'), false), // typecast + $literal: ([value, type]) => this.escape(value, type as any), $number: (arg) => { const value = this.parseEval(arg) - const res = this.state.sqlType === 'raw' ? `(0+${value})` - : this.state.sqlType === 'time' ? `unix_timestamp(convert_tz(addtime('1970-01-01 00:00:00', ${value}), '${this._timezone}', '+0:00'))` - : `unix_timestamp(convert_tz(${value}, '${this._timezone}', '+0:00'))` - this.state.sqlType = 'raw' - return `ifnull(${res}, 0)` + const type = Type.fromTerm(arg) + const res = type.type === 'time' ? `unix_timestamp(convert_tz(addtime('1970-01-01 00:00:00', ${value}), '${this._timezone}', '+0:00'))` + : ['timestamp', 'date'].includes(type.type!) ? `unix_timestamp(convert_tz(${value}, '${this._timezone}', '+0:00'))` : `(0+${value})` + return this.asEncoded(`ifnull(${res}, 0)`, false) }, // aggregation @@ -165,18 +181,11 @@ export class Builder { $min: (expr) => this.createAggr(expr, value => `min(${value})`), $max: (expr) => this.createAggr(expr, value => `max(${value})`), $count: (expr) => this.createAggr(expr, value => `count(distinct ${value})`), - $length: (expr) => this.createAggr(expr, value => `count(${value})`, value => { - if (this.state.sqlType === 'json') { - this.state.sqlType = 'raw' - return `${this.jsonLength(value)}` - } else { - this.state.sqlType = 'raw' - return `if(${value}, LENGTH(${value}) - LENGTH(REPLACE(${value}, ${this.escape(',')}, ${this.escape('')})) + 1, 0)` - } - }), + $length: (expr) => this.createAggr(expr, value => `count(${value})`, value => this.isEncoded() ? this.jsonLength(value) + : this.asEncoded(`if(${value}, LENGTH(${value}) - LENGTH(REPLACE(${value}, ${this.escape(',')}, ${this.escape('')})) + 1, 0)`, false)), $object: (fields) => this.groupObject(fields), - $array: (expr) => this.groupArray(this.parseEval(expr, false)), + $array: (expr) => this.groupArray(this.transform(this.parseEval(expr, false), expr, 'encode')), $exec: (sel) => this.parseSelection(sel as Selection), } @@ -195,8 +204,7 @@ export class Builder { if (!value.length) return notStr ? this.$true : this.$false return `${key}${notStr} in (${value.map(val => this.escape(val)).join(', ')})` } else { - const res = this.jsonContains(this.parseEval(value, false), this.jsonQuote(key, true)) - this.state.sqlType = 'raw' + const res = this.jsonContains(this.parseEval(value, false), this.encode(key, true, true)) return notStr ? this.logicalNot(res) : res } } @@ -206,13 +214,17 @@ export class Builder { } protected createElementQuery(key: string, value: any) { - if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { - return this.jsonContains(key, this.quote(JSON.stringify(value))) + if (this.isJsonQuery(key)) { + return this.jsonContains(key, this.encode(value, true, true)) } else { return `find_in_set(${this.escape(value)}, ${key})` } } + protected isJsonQuery(key: string) { + return isSqlJson(this.state.tables![this.state.table!].fields![this.unescapeId(key)]?.type) + } + protected comparator(operator: string) { return (key: string, value: any) => { return `${key} ${operator} ${this.escape(value)}` @@ -250,34 +262,32 @@ export class Builder { if (!(sel.args[0] as any).$) { return `(SELECT ${output} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''})` } else { - return `(ifnull((SELECT ${this.groupArray(output)} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), json_array()))` + return `(ifnull((SELECT ${this.groupArray(this.transform(output, Type.getInner(Type.fromTerm(expr)), 'encode'))} + AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), json_array()))` } } protected jsonLength(value: string) { - return `json_length(${value})` + return this.asEncoded(`json_length(${value})`, false) } protected jsonContains(obj: string, value: string) { - return `json_contains(${obj}, ${value})` + return this.asEncoded(`json_contains(${obj}, ${value})`, false) } - protected jsonUnquote(value: string, pure: boolean = false) { - if (pure) return `json_unquote(${value})` - if (this.state.sqlType === 'json') { - this.state.sqlType = 'raw' - return `json_unquote(${value})` - } + protected asEncoded(value: string, encoded: boolean | undefined) { + if (encoded !== undefined) this.state.encoded = encoded return value } - protected jsonQuote(value: string, pure: boolean = false) { - if (pure) return `cast(${value} as json)` - if (this.state.sqlType !== 'json') { - this.state.sqlType = 'json' - return `cast(${value} as json)` - } - return value + protected encode(value: string, encoded: boolean, pure: boolean = false, type?: Type) { + return this.asEncoded((encoded === this.isEncoded() && !pure) ? value + : encoded ? `cast(${this.transform(value, type, 'encode')} as json)` + : `json_unquote(${this.transform(value, type, 'decode')})`, pure ? undefined : encoded) + } + + protected isEncoded(key?: string) { + return key ? this.state.encodedMap?.[key] : this.state.encoded } protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string) { @@ -296,19 +306,30 @@ export class Builder { } } - protected groupObject(fields: any) { - const parse = (expr) => { - const value = this.parseEval(expr, false) - return this.state.sqlType === 'json' ? `json_extract(${value}, '$')` : `${value}` + /** + * Convert value from SQL field to JSON field + */ + protected transform(value: string, type: Type | Eval.Expr | undefined, method: 'encode' | 'decode' | 'load' | 'dump', miss?: any) { + type = Type.isType(type) ? type : Type.fromTerm(type) + const transformer = this.transformers[type.type] ?? this.transformers[this.driver.database.types[type.type]?.deftype!] + return transformer ? transformer[method](value) : (miss ?? value) + } + + protected groupObject(_fields: any) { + const _groupObject = (fields: any, type?: Type, prefix: string = '') => { + const parse = (expr, key) => { + const value = (!_fields[`${prefix}${key}`] && type && Type.getInner(type, key)?.inner) + ? _groupObject(expr, Type.getInner(type, key), `${prefix}${key}.`) + : this.parseEval(expr, false) + return this.isEncoded() ? `json_extract(${value}, '$')` : this.transform(value, expr, 'encode') + } + return `json_object(` + Object.entries(fields).map(([key, expr]) => `'${key}', ${parse(expr, key)}`).join(',') + `)` } - const res = `json_object(` + Object.entries(fields).map(([key, expr]) => `'${key}', ${parse(expr)}`).join(',') + `)` - this.state.sqlType = 'json' - return res + return this.asEncoded(_groupObject(unravel(_fields), this.state.type, ''), true) } protected groupArray(value: string) { - this.state.sqlType = 'json' - return `ifnull(json_arrayagg(${value}), json_array())` + return this.asEncoded(`ifnull(json_arrayagg(${value}), json_array())`, true) } protected parseFieldQuery(key: string, query: Query.FieldExpr) { @@ -357,9 +378,10 @@ export class Builder { } protected parseEvalExpr(expr: any) { - this.state.sqlType = 'raw' + this.state.encoded = false for (const key in expr) { if (key in this.evalOperators) { + this.state.type = Type.fromTerm(expr) return this.evalOperators[key](expr[key]) } } @@ -367,16 +389,12 @@ export class Builder { } protected transformJsonField(obj: string, path: string) { - this.state.sqlType = 'json' - return `json_extract(${obj}, '$${path}')` + return this.asEncoded(`json_extract(${obj}, '$${path}')`, true) } - protected transformKey(key: string, fields: {}, prefix: string, fullKey: string) { + protected transformKey(key: string, fields: Field.Config, prefix: string) { if (key in fields || !key.includes('.')) { - if (this.state.sqlTypes?.[key] || this.state.sqlTypes?.[fullKey]) { - this.state.sqlType = this.state.sqlTypes[key] || this.state.sqlTypes[fullKey] - } - return prefix + this.escapeId(key) + return this.asEncoded(prefix + this.escapeId(key), this.isEncoded(key) ?? isSqlJson(fields[key]?.type)) } const field = Object.keys(fields).find(k => key.startsWith(k + '.')) || key.split('.')[0] const rest = key.slice(field.length + 1).split('.') @@ -404,36 +422,41 @@ export class Builder { // the only table must be the main table || (Object.keys(this.state.tables).length === 1 && table in this.state.tables) ? '' : `${this.escapeId(table)}.`) + if (!(table in (this.state.tables || {})) && (table in (this.state.innerTables || {}))) { + const fields = this.state.innerTables?.[table]?.fields || {} + const res = (fields[key]?.expr) ? this.parseEvalExpr(fields[key]?.expr) + : this.transformKey(key, fields, `${this.escapeId(table)}.`) + return res + } + // field from outer selection if (!(table in (this.state.tables || {})) && (table in (this.state.refTables || {}))) { const fields = this.state.refTables?.[table]?.fields || {} const res = (fields[key]?.expr) ? this.parseEvalExpr(fields[key]?.expr) - : this.transformKey(key, fields, `${this.escapeId(table)}.`, `${table}.${key}`) + : this.transformKey(key, fields, `${this.escapeId(table)}.`) if (this.state.wrappedSubquery) { if (res in (this.state.refFields ?? {})) return this.state.refFields![res] const key = `minato_tvar_${randomId()}` ;(this.state.refFields ??= {})[res] = key - this.state.sqlType = 'json' - return this.escapeId(key) + return this.asEncoded(this.escapeId(key), true) } else return res } - - return this.transformKey(key, fields, prefix, `${table}.${key}`) + return this.transformKey(key, fields, prefix) } parseEval(expr: any, unquote: boolean = true): string { - this.state.sqlType = 'raw' + this.state.encoded = false if (typeof expr === 'string' || typeof expr === 'number' || typeof expr === 'boolean' || expr instanceof Date || expr instanceof RegExp) { return this.escape(expr) } - return unquote ? this.jsonUnquote(this.parseEvalExpr(expr)) : this.parseEvalExpr(expr) + return unquote ? this.encode(this.parseEvalExpr(expr), false, false, Type.fromTerm(expr)) : this.parseEvalExpr(expr) } protected saveState(extra: Partial = {}) { const thisState = this.state this.state = { refTables: { ...(this.state.refTables || {}), ...(this.state.tables || {}) }, ...extra } return () => { - thisState.sqlType = this.state.sqlType + thisState.encoded = this.state.encoded this.state = thisState } } @@ -458,32 +481,24 @@ export class Builder { get(sel: Selection.Immutable, inline = false, group = false, addref = true) { const { args, table, query, ref, model } = sel + this.state.table = ref + // get prefix let prefix: string | undefined if (typeof table === 'string') { prefix = this.escapeId(table) - this.state.sqlTypes = Object.fromEntries(Object.entries(model.fields).map(([key, field]) => { - let sqlType: SQLType = 'raw' - if (field!.type === 'json') sqlType = 'json' - else if (field!.type === 'list') sqlType = 'list' - else if (Field.date.includes(field!.type)) sqlType = field!.type as SQLType - return [key, sqlType] - })) } else if (table instanceof Selection) { prefix = this.get(table, true) if (!prefix) return } else { - const sqlTypes: Dict = {} + this.state.innerTables = Object.fromEntries(Object.values(table).map(t => [t.ref, t.model])) const joins: string[] = Object.entries(table).map(([key, table]) => { - const restore = this.saveState({ tables: table.tables }) - const t = `${this.get(table, true, false, false)} AS ${this.escapeId(key)}` - for (const [fieldKey, fieldType] of Object.entries(this.state.sqlTypes!)) { - sqlTypes[`${key}.${fieldKey}`] = fieldType - } + const restore = this.saveState({ tables: { ...table.tables } }) + const t = `${this.get(table, true, false, false)} AS ${this.escapeId(table.ref)}` restore() return t }) - this.state.sqlTypes = sqlTypes + // the leading space is to prevent from being parsed as bracketed and added ref prefix = ' ' + joins[0] + joins.slice(1, -1).map(join => ` JOIN ${join} ON ${this.$true}`).join(' ') + ` JOIN ` + joins.at(-1) const filter = this.parseEval(args[0].having) @@ -494,20 +509,20 @@ export class Builder { if (filter === this.$false) return this.state.group = group || !!args[0].group - const sqlTypes: Dict = {} + const encodedMap: Dict = {} const fields = args[0].fields ?? Object.fromEntries(Object .entries(model.fields) .filter(([, field]) => !field!.deprecated) - .map(([key]) => [key, { $: [ref, key] }])) + .map(([key, field]) => [key, field!.expr ? field!.expr : Eval('', [ref, key], Type.fromField(field!))])) const keys = Object.entries(fields).map(([key, value]) => { value = this.parseEval(value, false) - sqlTypes[key] = this.state.sqlType! + encodedMap![key] = this.state.encoded! return this.escapeId(key) === value ? this.escapeId(key) : `${value} AS ${this.escapeId(key)}` }).join(', ') // get suffix let suffix = this.suffix(args[0]) - this.state.sqlTypes = sqlTypes + this.state.encodedMap = encodedMap if (filter !== this.$true) { suffix = ` WHERE ${filter}` + suffix @@ -525,39 +540,88 @@ export class Builder { return inline ? `(${result})` : result } - define(converter: Transformer) { - converter.types.forEach(type => this.types[type] = converter) - } + /** + * Convert value from Type to Field.Type. + * @param root indicate whether the context is inside json + */ + dump(value: any, type: Model | Type | Eval.Expr | undefined, root: boolean = true): any { + if (!type) return value + + if (Type.isType(type) || isEvalExpr(type)) { + type = Type.isType(type) ? type : Type.fromTerm(type) + const converter = (type.inner || type.type === 'json') ? (root ? this.driver.types['json'] : undefined) : this.driver.types[type.type] + if (type.inner || type.type === 'json') root = false + let res = value + + if (!isNullable(res) && type.inner) { + if (Type.isArray(type)) { + res = res.map(x => this.dump(x, Type.getInner(type as Type), root)) + } else { + res = mapValues(res, (x, k) => this.dump(x, Type.getInner(type as Type, k), root)) + } + } + res = converter ? converter.dump(res) : res + if (!root) res = this.transform(res, type, 'dump') + return res + } - dump(model: Model, obj: any): any { - obj = model.format(obj) + value = type.format(value) const result = {} - for (const key in obj) { - result[key] = this.stringify(obj[key], model.fields[key]) + for (const key in value) { + const { type: ftype } = type.fields[key]! + result[key] = this.dump(value[key], ftype) } return result } - load(obj: any): any - load(model: Model, obj: any): any - load(model: any, obj?: any) { - if (!obj) { - const converter = this.types[this.state.sqlType!] - return converter ? converter.load(model) : model + /** + * Convert value from Field.Type to Type. + */ + load(value: any, type: Model | Type | Eval.Expr | undefined, root: boolean = true): any { + if (!type) return value + + if (Type.isType(type) || isEvalExpr(type)) { + type = Type.isType(type) ? type : Type.fromTerm(type) + const converter = this.driver.types[(root && value && type.type === 'json') ? 'json' : type.type] + let res = this.transform(value, type, 'load') + res = converter ? converter.load(res) : res + + if (!isNullable(res) && type.inner) { + if (Type.isArray(type)) { + res = res.map(x => this.load(x, Type.getInner(type as Type), false)) + } else { + res = mapValues(res, (x, k) => this.load(x, Type.getInner(type as Type, k), false)) + } + } + return (type.inner && !Type.isArray(type)) ? unravel(res) : res } const result = {} - for (const key in obj) { - if (!(key in model.fields)) continue - const { type, initial } = model.fields[key]! - const converter = (this.state.sqlTypes?.[key] ?? 'raw') === 'raw' ? this.types[type] : this.types[this.state.sqlTypes![key]] - result[key] = converter ? converter.load(obj[key], initial) : obj[key] + for (const key in value) { + if (!(key in type.fields)) continue + result[key] = value[key] + let subroot = root + if (subroot && result[key] && this.isEncoded(key)) { + subroot = false + result[key] = this.driver.types['json'].load(result[key]) + } + result[key] = this.load(result[key], type.fields[key]!.type, subroot) } - return model.parse(result) + return type.parse(result) + } + + /** + * Convert value from Type to SQL. + */ + escape(value: any, type?: Field | Field.Type | Type) { + type &&= (Type.isType(type) ? type : Type.fromField(type)) + return this.escapePrimitive(type ? this.dump(value, type) : value, type) } - escape(value: any, field?: Field) { - value = this.stringify(value, field) + /** + * Convert value from Field.Type to SQL. + */ + escapePrimitive(value: any, type?: Type) { if (isNullable(value)) return 'NULL' switch (typeof value) { @@ -579,11 +643,6 @@ export class Builder { return `"${value}"` } - stringify(value: any, field?: Field) { - const converter = this.types[field!?.type] - return converter ? converter.dump(value) : value - } - quote(value: string) { this.escapeRegExp ??= new RegExp(`[${Object.values(this.escapeMap).join('')}]`, 'g') let chunkIndex = this.escapeRegExp.lastIndex = 0 diff --git a/packages/sqlite/src/builder.ts b/packages/sqlite/src/builder.ts index ebef3f2b..0acd1962 100644 --- a/packages/sqlite/src/builder.ts +++ b/packages/sqlite/src/builder.ts @@ -1,14 +1,14 @@ import { Builder, escapeId } from '@minatojs/sql-utils' import { Dict, isNullable } from 'cosmokit' -import { Field, Model, randomId } from 'minato' +import { Driver, Field, isUint8Array, Model, randomId, Type, Uint8ArrayFromHex, Uint8ArrayToHex } from 'minato' export class SQLiteBuilder extends Builder { protected escapeMap = { "'": "''", } - constructor(tables?: Dict) { - super(tables) + constructor(protected driver: Driver, tables?: Dict) { + super(driver, tables) this.evalOperators.$if = (args) => `iif(${args.map(arg => this.parseEval(arg)).join(', ')})` this.evalOperators.$concat = (args) => `(${args.map(arg => this.parseEval(arg)).join('||')})` @@ -16,72 +16,48 @@ export class SQLiteBuilder extends Builder { this.evalOperators.$log = ([left, right]) => isNullable(right) ? `log(${this.parseEval(left)})` : `log(${this.parseEval(left)}) / log(${this.parseEval(right)})` - this.evalOperators.$length = (expr) => this.createAggr(expr, value => `count(${value})`, value => { - if (this.state.sqlType === 'json') { - this.state.sqlType = 'raw' - return `${this.jsonLength(value)}` - } else { - this.state.sqlType = 'raw' - return `iif(${value}, LENGTH(${value}) - LENGTH(REPLACE(${value}, ${this.escape(',')}, ${this.escape('')})) + 1, 0)` - } - }) + this.evalOperators.$length = (expr) => this.createAggr(expr, value => `count(${value})`, value => this.isEncoded() ? this.jsonLength(value) + : this.asEncoded(`iif(${value}, LENGTH(${value}) - LENGTH(REPLACE(${value}, ${this.escape(',')}, ${this.escape('')})) + 1, 0)`, false)) this.evalOperators.$number = (arg) => { + const type = Type.fromTerm(arg) const value = this.parseEval(arg) - const res = this.state.sqlType === 'raw' ? `cast(${this.parseEval(arg)} as double)` - : `cast(${value} / 1000 as integer)` - this.state.sqlType = 'raw' - return `ifnull(${res}, 0)` + const res = Field.date.includes(type.type!) ? `cast(${value} / 1000 as integer)` : `cast(${this.parseEval(arg)} as double)` + return this.asEncoded(`ifnull(${res}, 0)`, false) } - this.define({ - types: ['boolean'], - dump: value => +value, - load: (value) => !!value, - }) - - this.define({ - types: ['json'], - dump: value => JSON.stringify(value), - load: (value, initial) => value ? JSON.parse(value) : initial, - }) - - this.define({ - types: ['list'], - dump: value => Array.isArray(value) ? value.join(',') : value, - load: (value) => value ? value.split(',') : [], - }) - - this.define({ - types: ['date', 'time', 'timestamp'], - dump: value => value === null ? null : +new Date(value), - load: (value) => value === null ? null : new Date(value), - }) + this.transformers['binary'] = { + encode: value => `hex(${value})`, + decode: value => `unhex(${value})`, + load: value => isNullable(value) ? value : Uint8ArrayFromHex(value), + dump: value => isNullable(value) ? value : Uint8ArrayToHex(value), + } } - escape(value: any, field?: Field) { + escapePrimitive(value: any, type?: Type) { if (value instanceof Date) value = +value else if (value instanceof RegExp) value = value.source - return super.escape(value, field) + else if (isUint8Array(value)) return `X'${Uint8ArrayToHex(value)}'` + return super.escapePrimitive(value, type) } protected createElementQuery(key: string, value: any) { - if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { - return this.jsonContains(key, this.quote(JSON.stringify(value))) + if (this.isJsonQuery(key)) { + return this.jsonContains(key, this.escape(value, 'json')) } else { return `(',' || ${key} || ',') LIKE ${this.escape('%,' + value + ',%')}` } } protected jsonLength(value: string) { - return `json_array_length(${value})` + return this.asEncoded(`json_array_length(${value})`, false) } protected jsonContains(obj: string, value: string) { - return `json_array_contains(${obj}, ${value})` + return this.asEncoded(`json_array_contains(${obj}, ${value})`, false) } - protected jsonUnquote(value: string, pure: boolean = false) { - return value + protected encode(value: string, encoded: boolean, pure: boolean = false, type?: Type) { + return encoded ? super.encode(value, encoded, pure, type) : value } protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string) { @@ -94,13 +70,11 @@ export class SQLiteBuilder extends Builder { } protected groupArray(value: string) { - const res = this.state.sqlType === 'json' ? `('[' || group_concat(${value}) || ']')` : `('[' || group_concat(json_quote(${value})) || ']')` - this.state.sqlType = 'json' - return `ifnull(${res}, json_array())` + const res = this.isEncoded() ? `('[' || group_concat(${value}) || ']')` : `('[' || group_concat(json_quote(${value})) || ']')` + return this.asEncoded(`ifnull(${res}, json_array())`, true) } protected transformJsonField(obj: string, path: string) { - this.state.sqlType = 'raw' - return `json_extract(${obj}, '$${path}')` + return this.asEncoded(`json_extract(${obj}, '$${path}')`, false) } } diff --git a/packages/sqlite/src/index.ts b/packages/sqlite/src/index.ts index 782561f2..7ecfee07 100644 --- a/packages/sqlite/src/index.ts +++ b/packages/sqlite/src/index.ts @@ -1,5 +1,5 @@ -import { clone, deepEqual, Dict, difference, isNullable, makeArray } from 'cosmokit' -import { Driver, Eval, executeUpdate, Field, Selection, z } from 'minato' +import { deepEqual, Dict, difference, isNullable, makeArray } from 'cosmokit' +import { clone, Driver, Eval, executeUpdate, Field, Selection, toLocalUint8Array, z } from 'minato' import { escapeId } from '@minatojs/sql-utils' import { resolve } from 'node:path' import { readFile, writeFile } from 'node:fs/promises' @@ -10,7 +10,7 @@ import zhCN from './locales/zh-CN.yml' import { SQLiteBuilder } from './builder' import { pathToFileURL } from 'node:url' -function getTypeDef({ type }: Field) { +function getTypeDef({ deftype: type }: Field) { switch (type) { case 'primary': case 'boolean': @@ -27,6 +27,7 @@ function getTypeDef({ type }: Field) { case 'text': case 'list': case 'json': return `TEXT` + case 'binary': return `BLOB` } } @@ -43,7 +44,7 @@ export class SQLiteDriver extends Driver { static name = 'sqlite' db!: init.Database - sql = new SQLiteBuilder() + sql = new SQLiteBuilder(this) beforeUnload?: () => void private _transactionTask?: Promise @@ -75,7 +76,7 @@ export class SQLiteDriver extends Driver { } else { def += (nullable ? ' ' : ' NOT ') + 'NULL' if (!isNullable(initial)) { - def += ' DEFAULT ' + this.sql.escape(this.sql.dump(model, { [key]: initial })[key]) + def += ' DEFAULT ' + this.sql.escape(this.sql.dump({ [key]: initial }, model)[key]) } } columnDefs.push(def) @@ -176,6 +177,36 @@ export class SQLiteDriver extends Driver { this.db.create_function('json_array_contains', (array, value) => +(JSON.parse(array) as any[]).includes(JSON.parse(value))) this.db.create_function('modulo', (left, right) => left % right) this.db.create_function('rand', () => Math.random()) + + this.define({ + types: ['boolean'], + dump: value => isNullable(value) ? value : +value, + load: (value) => isNullable(value) ? value : !!value, + }) + + this.define({ + types: ['json'], + dump: value => JSON.stringify(value), + load: value => typeof value === 'string' ? JSON.parse(value) : value, + }) + + this.define({ + types: ['list'], + dump: value => Array.isArray(value) ? value.join(',') : value, + load: value => value ? value.split(',') : [], + }) + + this.define({ + types: ['date', 'time', 'timestamp'], + dump: value => isNullable(value) ? value as any : +new Date(value), + load: value => isNullable(value) ? value : new Date(value), + }) + + this.define({ + types: ['binary'], + dump: value => value, + load: value => value ? toLocalUint8Array(value) : value, + }) } #joinKeys(keys?: string[]) { @@ -262,19 +293,19 @@ export class SQLiteDriver extends Driver { async get(sel: Selection.Immutable) { const { model, tables } = sel - const builder = new SQLiteBuilder(tables) + const builder = new SQLiteBuilder(this, tables) const sql = builder.get(sel) if (!sql) return [] const rows = this.#all(sql) - return rows.map(row => builder.load(model, row)) + return rows.map(row => builder.load(row, model)) } async eval(sel: Selection.Immutable, expr: Eval.Expr) { - const builder = new SQLiteBuilder(sel.tables) + const builder = new SQLiteBuilder(this, sel.tables) const inner = builder.get(sel.table as Selection, true, true) const output = builder.parseEval(expr, false) const { value } = this.#get(`SELECT ${output} AS value FROM ${inner}`) - return builder.load(value) + return builder.load(value, expr) } #update(sel: Selection.Mutable, indexFields: string[], updateFields: string[], update: {}, data: {}) { @@ -282,7 +313,7 @@ export class SQLiteDriver extends Driver { const model = this.model(table) const modified = !deepEqual(clone(data), executeUpdate(data, update, ref)) if (!modified) return 0 - const row = this.sql.dump(model, data) + const row = this.sql.dump(data, model) const assignment = updateFields.map((key) => `${escapeId(key)} = ?`).join(',') const query = Object.fromEntries(indexFields.map(key => [key, row[key]])) const filter = this.sql.parseQuery(query) @@ -307,7 +338,7 @@ export class SQLiteDriver extends Driver { #create(table: string, data: {}) { const model = this.model(table) - data = this.sql.dump(model, data) + data = this.sql.dump(data, model) const keys = Object.keys(data) const sql = `INSERT INTO ${escapeId(table)} (${this.#joinKeys(keys)}) VALUES (${Array(keys.length).fill('?').join(', ')})` return this.#run(sql, keys.map(key => data[key] ?? null), () => this.#get(`SELECT last_insert_rowid() AS id`)) @@ -315,7 +346,6 @@ export class SQLiteDriver extends Driver { async create(sel: Selection.Mutable, data: {}) { const { model, table } = sel - data = model.create(data) const { id } = this.#create(table, data) const { autoInc, primary } = model if (!autoInc || Array.isArray(primary)) return data as any diff --git a/packages/tests/package.json b/packages/tests/package.json index bfa83f72..7fdd53e1 100644 --- a/packages/tests/package.json +++ b/packages/tests/package.json @@ -28,7 +28,8 @@ ], "devDependencies": { "@types/chai": "^4.3.11", - "@types/chai-as-promised": "^7.1.8" + "@types/chai-as-promised": "^7.1.8", + "@types/deep-eql": "^4.0.2" }, "peerDependencies": { "minato": "^3.0.1" @@ -36,7 +37,6 @@ "dependencies": { "chai": "^4.3.10", "chai-as-promised": "^7.1.1", - "chai-shape": "^1.0.0", "cosmokit": "^1.5.2" } } diff --git a/packages/tests/src/index.ts b/packages/tests/src/index.ts index 73337db9..af745f15 100644 --- a/packages/tests/src/index.ts +++ b/packages/tests/src/index.ts @@ -1,4 +1,5 @@ import { Database } from 'minato' +import ModelOperations from './model' import QueryOperators from './query' import UpdateOperators from './update' import ObjectOperations from './object' @@ -61,6 +62,7 @@ function createUnit(target: T, root = false): Unit { } namespace Tests { + export const model = ModelOperations export const query = QueryOperators export const update = UpdateOperators export const object = ObjectOperations diff --git a/packages/tests/src/model.ts b/packages/tests/src/model.ts new file mode 100644 index 00000000..6d5949e0 --- /dev/null +++ b/packages/tests/src/model.ts @@ -0,0 +1,529 @@ +import { isNullable, valueMap } from 'cosmokit' +import { $, Database, Field, Type } from 'minato' +import chai, { expect } from 'chai' + +interface DType { + id: number + text?: string + num?: number + double?: number + decimal?: number + bool?: boolean + list?: string[] + array?: number[] + object?: { + text?: string + num?: number + json?: { + text?: string + num?: number + }, + embed?: { + bool?: boolean + bigint?: bigint + custom?: Custom + } + } + object2?: { + text?: string + num?: number + embed?: { + bool?: boolean + bigint?: bigint + } + } + timestamp?: Date + date?: Date + time?: Date + binary?: Buffer + bigint?: bigint + bnum?: number + bnum2?: number +} + +interface DObject { + id: number + foo?: { + nested: DType + } + bar?: { + nested: DType + } + baz?: { + nested?: DType + }[] +} + +interface Custom { + a: string + b: number +} + +interface Tables { + dtypes: DType + dobjects: DObject +} + +interface Types { + bigint: bigint + custom: Custom +} + +function flatten(type: any, prefix) { + if (typeof type === 'object' && type?.type === 'object') { + const result = {} + for (const key in type.inner) { + Object.assign(result, flatten(type.inner[key]!, `${prefix}.${key}`)) + } + return result + } else { + return { [prefix]: type } as any + } +} + +function ModelOperations(database: Database) { + database.define('bigint', { + type: 'string', + dump: value => value ? value.toString() : value, + load: value => value ? BigInt(value) : value, + initial: 123n + }) + + database.define('custom', { + type: 'string', + dump: value => value ? `${value.a}|${value.b}` : value, + load: value => value ? { a: value.split('|')[0], b: +value.split('|')[1] } : value + }) + + const bnum = database.define({ + type: 'binary', + dump: value => value === undefined ? value : Buffer.from(String(value)), + load: value => value ? +value : value, + initial: 0, + }) + + const baseFields: Field.Extension = { + id: 'unsigned', + text: { + type: 'string', + initial: 'he`l"\'\\lo', + }, + num: { + type: 'integer', + initial: 233, + }, + double: { + type: 'double', + initial: 3.14, + }, + decimal: { + type: 'decimal', + scale: 3, + initial: 12413, + }, + bool: { + type: 'boolean', + initial: true, + }, + list: { + type: 'list', + initial: ['a`a', 'b"b', 'c\'c', 'd\\d'], + }, + array: { + type: 'json', + initial: [1, 2, 3], + }, + object: { + type: 'object', + inner: { + num: 'unsigned', + text: 'string', + json: 'json', + embed: { + type: 'object', + inner: { + bool: { + type: 'boolean', + initial: false, + }, + bigint: 'bigint', + custom: 'custom', + }, + }, + }, + initial: { + num: 1, + text: '2', + json: { + text: '3', + num: 4, + }, + embed: { + bool: true, + bigint: 123n, + }, + }, + }, + // dot defined object + 'object2.num': { + type: 'unsigned', + initial: 1, + }, + 'object2.text': { + type: 'string', + initial: '2', + }, + 'object2.embed.bool': { + type: 'boolean', + initial: true, + }, + 'object2.embed.bigint': 'bigint', + timestamp: { + type: 'timestamp', + initial: new Date('1970-01-01 00:00:00'), + }, + date: { + type: 'date', + initial: new Date('1970-01-01'), + }, + time: { + type: 'time', + initial: new Date('1970-01-01 12:00:00'), + }, + binary: { + type: 'binary', + initial: Buffer.from('initial buffer') + }, + bigint: 'bigint', + bnum, + bnum2: { + type: 'binary', + dump: value => value === undefined ? value : Buffer.from(String(value)), + load: value => value ? +value : value, + initial: 0, + }, + } + + const baseObject = { + type: 'object', + inner: { nested: { type: 'object', inner: baseFields } }, + initial: { nested: { id: 1 } } + } + + database.extend('dtypes', { + ...baseFields + }, { autoInc: true }) + + database.extend('dobjects', { + id: 'unsigned', + foo: baseObject, + ...flatten(baseObject, 'bar'), + baz: { + type: 'array', + inner: baseObject, + initial: [] + }, + }, { autoInc: true }) +} + +function getValue(obj: any, path: string) { + if (path.includes('.')) { + const index = path.indexOf('.') + return getValue(obj[path.slice(0, index)] ??= {}, path.slice(index + 1)) + } else { + return obj[path] + } +} + +namespace ModelOperations { + const magicBorn = new Date('1970/08/17') + + const dtypeTable: DType[] = [ + { id: 1, bool: false }, + { id: 2, text: 'pku' }, + { id: 3, num: 1989 }, + { id: 4, list: ['1', '1', '4'], array: [1, 1, 4] }, + { id: 5, object: { num: 10, text: 'ab', embed: { bool: false, bigint: 90n } } }, + { id: 6, object2: { num: 10, text: 'ab', embed: { bool: false, bigint: 90n } } }, + { id: 7, timestamp: magicBorn }, + { id: 8, date: magicBorn }, + { id: 9, time: new Date('1999-10-01 15:40:00') }, + { id: 10, binary: Buffer.from('hello') }, + { id: 11, bigint: BigInt(1e63) }, + { id: 12, decimal: 2.432 }, + { id: 13, bnum: 114514, bnum2: 12345 }, + { id: 14, object: { embed: { custom: { a: 'abc', b: 123 } } } }, + ] + + const dobjectTable: DObject[] = [ + { id: 1 }, + { id: 2, foo: { nested: { id: 1, list: ['1', '1', '4'], array: [1, 1, 4], object: { num: 10, text: 'ab', embed: { bool: false, bigint: BigInt(1e163) } }, bigint: BigInt(1e63), bnum: 114514, bnum2: 12345 } } }, + { id: 3, bar: { nested: { id: 1, list: ['1', '1', '4'], array: [1, 1, 4], object: { num: 10, text: 'ab', embed: { bool: false, bigint: BigInt(1e163) } }, bigint: BigInt(1e63), bnum: 114514, bnum2: 12345 } } }, + { id: 4, baz: [{ nested: { id: 1, list: ['1', '1', '4'], array: [1, 1, 4], object: { num: 10, text: 'ab', embed: { bool: false, bigint: BigInt(1e163) } }, bigint: BigInt(1e63), bnum: 114514, bnum2: 12345 } }, { nested: { id: 2 } }] }, + { id: 5, foo: { nested: { id: 1, list: ['1', '1', '4'], array: [1, 1, 4], object2: { num: 10, text: 'ab', embed: { bool: false, bigint: BigInt(1e163) } }, bigint: BigInt(1e63), bnum: 114514, bnum2: 12345 } } }, + { id: 6, bar: { nested: { id: 1, list: ['1', '1', '4'], array: [1, 1, 4], object2: { num: 10, text: 'ab', embed: { bool: false, bigint: BigInt(1e163) } }, bigint: BigInt(1e63), bnum: 114514, bnum2: 12345 } } }, + { id: 7, baz: [{ nested: { id: 1, list: ['1', '1', '4'], array: [1, 1, 4], object2: { num: 10, text: 'ab', embed: { bool: false, bigint: BigInt(1e163) } }, bigint: BigInt(1e63), bnum: 114514, bnum2: 12345 } }, { nested: { id: 2 } }] }, + ] + + async function setup(database: Database, name: K, table: Tables[K][]) { + await database.remove(name, {}) + const result: Tables[K][] = [] + for (const item of table) { + result.push(await database.create(name, item as any)) + } + return result + } + + interface ModelOptions { + cast?: boolean + typeModel?: boolean + aggregateNull?: boolean + } + + export const fields = function Fields(database: Database, options: ModelOptions = {}) { + const { cast = true, typeModel = true } = options + + it('basic', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + await expect(database.get('dtypes', {})).to.eventually.have.deep.members(table) + + await database.remove('dtypes', {}) + await database.upsert('dtypes', dtypeTable) + await expect(database.get('dtypes', {})).to.eventually.have.deep.members(table) + }) + + it('modifier', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + await database.remove('dtypes', {}) + await database.upsert('dtypes', dtypeTable.map(({ id }) => ({ id }))) + + await Promise.all(table.map(({ id, ...x }) => database.set('dtypes', id, x))) + await expect(database.get('dtypes', {})).to.eventually.have.deep.members(table) + }) + + it('dot notation in modifier', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + table[0].object = {} + + await database.set('dtypes', table[0].id, row => ({ + object: {} + })) + await expect(database.get('dtypes', table[0].id)).to.eventually.deep.eq([table[0]]) + + table[0].object = { + num: 123, + json: { + num: 456, + }, + embed: { + bool: true, + bigint: 123n, + custom: { + a: 'a', + b: 1, + } + } + } + + await database.set('dtypes', table[0].id, row => ({ + 'object.num': 123, + 'object.json.num': 456, + 'object.embed.bool': true, + 'object.embed.bigint': 123n, + 'object.embed.custom': { a: 'a', b: 1 }, + })) + await expect(database.get('dtypes', table[0].id)).to.eventually.deep.eq([table[0]]) + }) + + it('using expressions in modifier', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + + table[0].object!.json!.num! = 543 + (table[0].object!.json!.num ?? 0) + table[0].object!.embed!.bool! = !table[0].object!.embed!.bool! + table[0].object!.embed!.bigint = 999n + + await database.set('dtypes', table[0].id, row => ({ + 'object.json.num': $.add($.ifNull(row.object.json.num, 0), 543), + 'object.embed.bool': $.not(row.object.embed.bool), + 'object.embed.bigint': 999n, + })) + await expect(database.get('dtypes', {})).to.eventually.have.deep.members(table) + + table[0].object!.embed!.bool! = false + await database.set('dtypes', table[0].id, { + 'object.embed.bool': false, + }) + await expect(database.get('dtypes', {})).to.eventually.have.deep.members(table) + + table[0].object!.embed!.bool! = true + await database.set('dtypes', table[0].id, { + 'object.embed.bool': true, + }) + await expect(database.get('dtypes', {})).to.eventually.have.deep.members(table) + }) + + it('primitive', async () => { + expect(Type.fromTerm($.literal(123)).type).to.equal(Type.Number.type) + expect(Type.fromTerm($.literal('abc')).type).to.equal(Type.String.type) + expect(Type.fromTerm($.literal(true)).type).to.equal(Type.Boolean.type) + expect(Type.fromTerm($.literal(new Date('1970-01-01'))).type).to.equal('timestamp') + expect(Type.fromTerm($.literal(Buffer.from('hello'))).type).to.equal('binary') + expect(Type.fromTerm($.literal([1, 2, 3])).type).to.equal('json') + expect(Type.fromTerm($.literal({ a: 1 })).type).to.equal('json') + }) + + cast && it('cast newtype', async () => { + await setup(database, 'dtypes', dtypeTable) + await expect(database.get('dtypes', row => $.eq(row.bigint as any, $.literal(234n, 'bigint')))).to.eventually.have.length(0) + await expect(database.get('dtypes', row => $.eq(row.bigint as any, $.literal(BigInt(1e63), 'bigint')))).to.eventually.have.length(1) + }) + + typeModel && it('$.object encoding', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + await expect(database.eval('dtypes', row => $.array($.object(row)))).to.eventually.have.deep.members(table) + }) + + typeModel && it('$.object decoding', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + await expect(database.select('dtypes') + .project({ + obj: row => $.object(row) + }) + .project(valueMap(database.tables['dtypes'].fields as any, (field, key) => row => row.obj[key])) + .execute() + ).to.eventually.have.deep.members(table) + }) + + typeModel && it('$.array encoding on cell', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + await expect(database.eval('dtypes', row => $.array(row.object))).to.eventually.have.deep.members(table.map(x => x.object)) + await expect(database.eval('dtypes', row => $.array($.object(row.object2)))).to.eventually.have.deep.members(table.map(x => x.object2)) + }) + + it('$.array encoding', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + await Promise.all(Object.keys(database.tables['dtypes'].fields).map( + key => expect(database.eval('dtypes', row => $.array(row[key]))).to.eventually.have.deep.members(table.map(x => getValue(x, key))) + )) + }) + + it('subquery encoding', async () => { + const table = await setup(database, 'dtypes', dtypeTable) + await Promise.all(Object.keys(database.tables['dtypes'].fields).map( + key => expect(database.select('dtypes', 1) + .project({ + x: row => database.select('dtypes').evaluate(key as any) + }) + .execute() + ).to.eventually.have.shape([{ x: table.map(x => getValue(x, key)) }]) + )) + }) + } + + export const object = function ObjectFields(database: Database, options: ModelOptions = {}) { + const { aggregateNull = true, typeModel = true } = options + + it('basic', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + await expect(database.get('dobjects', {})).to.eventually.have.deep.members(table) + + await database.remove('dobjects', {}) + await database.upsert('dobjects', dobjectTable) + await expect(database.get('dobjects', {})).to.eventually.have.deep.members(table) + }) + + it('modifier', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + await database.remove('dobjects', {}) + await database.upsert('dobjects', dobjectTable.map(({ id }) => ({ id }))) + + await Promise.all(table.map(({ id, ...x }) => database.set('dobjects', id, x))) + await expect(database.get('dobjects', {})).to.eventually.have.deep.members(table) + }) + + it('dot notation in modifier', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + + table[0].foo!.nested = { id: 1 } + await database.set('dobjects', table[0].id, row => ({ + 'foo.nested': { id: 1 } + })) + await expect(database.get('dobjects', table[0].id)).to.eventually.deep.eq([table[0]]) + + table[0].foo!.nested = { + id: 1, + timestamp: new Date('2009/10/01 15:40:00'), + date: new Date('1999/10/01'), + binary: Buffer.from('boom'), + } + table[0].bar!.nested = { + ...table[0].bar?.nested, + id: 9, + timestamp: new Date('2009/10/01 15:40:00'), + date: new Date('1999/10/01'), + binary: Buffer.from('boom'), + } + + await database.set('dobjects', table[0].id, row => ({ + 'foo.nested.timestamp': new Date('2009/10/01 15:40:00'), + 'foo.nested.date': new Date('1999/10/01'), + 'foo.nested.binary': Buffer.from('boom'), + 'bar.nested.id': 9, + 'bar.nested.timestamp': new Date('2009/10/01 15:40:00'), + 'bar.nested.date': new Date('1999/10/01'), + 'bar.nested.binary': Buffer.from('boom'), + } as any)) + await expect(database.get('dobjects', table[0].id)).to.eventually.deep.eq([table[0]]) + + table[0].baz = [{}, {}] + await database.set('dobjects', table[0].id, row => ({ + baz: [{}, {}] + })) + await expect(database.get('dobjects', table[0].id)).to.eventually.deep.eq([table[0]]) + }) + + typeModel && it('$.object encoding', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + await expect(database.eval('dobjects', row => $.array($.object(row)))).to.eventually.have.deep.members(table) + }) + + typeModel && it('$.object decoding', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + await expect(database.select('dobjects') + .project({ + obj: row => $.object(row) + }) + .project(valueMap(database.tables['dobjects'].fields as any, (field, key) => row => row.obj[key])) + .execute() + ).to.eventually.have.deep.members(table) + }) + + aggregateNull && it('$.array encoding', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + await Promise.all(Object.keys(database.tables['dobjects'].fields).map( + key => expect(database.eval('dobjects', row => $.array(row[key]))).to.eventually.have.deep.members(table.map(x => getValue(x, key))) + )) + }) + + it('$.array encoding boxed', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + await Promise.all(Object.keys(database.tables['dobjects'].fields).map( + key => expect(database.eval('dobjects', row => $.array($.object({ x: row[key] })))).to.eventually.have.deep.members(table.map(x => ({ x: getValue(x, key) }))) + )) + }) + + it('subquery encoding', async () => { + const table = await setup(database, 'dobjects', dobjectTable) + await Promise.all(['baz'].map( + key => expect(database.select('dobjects', 1) + .project({ + x: row => database.select('dobjects').evaluate(key as any) + }) + .execute() + ).to.eventually.have.shape([{ x: table.map(x => getValue(x, key)) }]) + )) + }) + } +} + +export default ModelOperations diff --git a/packages/tests/src/object.ts b/packages/tests/src/object.ts index a925245c..e7b79403 100644 --- a/packages/tests/src/object.ts +++ b/packages/tests/src/object.ts @@ -58,6 +58,11 @@ namespace ObjectOperations { { meta: { a: '666', embed: { c: 'world' } } }, ]) }) + + it('selection', async () => { + await setup(database) + await expect(database.select('object', '0').project({ x: row => row.meta.embed.c }).execute()).to.eventually.deep.equal([{ x: 'hello' }]) + }) } export const upsert = function Upsert(database: Database) { diff --git a/packages/tests/src/query.ts b/packages/tests/src/query.ts index ce67dafa..5c29370f 100644 --- a/packages/tests/src/query.ts +++ b/packages/tests/src/query.ts @@ -34,7 +34,13 @@ function QueryOperators(database: Database) { } namespace QueryOperators { - export const comparison = function Comparison(database: Database) { + interface QueryOptions { + nullableComparator?: boolean + } + + export const comparison = function Comparison(database: Database, options: QueryOptions = {}) { + const { nullableComparator = true } = options + before(async () => { await database.remove('temp1', {}) await database.create('temp1', { @@ -81,6 +87,10 @@ namespace QueryOperators { await expect(database.get('temp1', { timestamp: { $lte: new Date('1999-01-01') }, })).eventually.to.have.length(0) + + nullableComparator && await expect(database.get('temp1', + row => $.gt(row.timestamp, new Date('1999-01-01')) + )).eventually.to.have.length(1).with.nested.property('0.text').equal('awesome foo') }) it('date comparisons', async () => { diff --git a/packages/tests/src/selection.ts b/packages/tests/src/selection.ts index f7eb5a47..80815c5e 100644 --- a/packages/tests/src/selection.ts +++ b/packages/tests/src/selection.ts @@ -442,6 +442,14 @@ namespace SelectionTests { ).to.eventually.have.length(6) }) + it('access from join', async () => { + const w = x => database.select('bar').evaluate(row => $.add($.count(row.id), -6, x)) + await expect(database + .join(['foo', 'bar'] as const, (foo, bar) => $.gt(foo.id, w(bar.pid))) + .execute() + ).to.eventually.have.length(9) + }) + it('join selection', async () => { await expect(database .select( diff --git a/packages/tests/src/setup.ts b/packages/tests/src/setup.ts index b6078940..e80e20a8 100644 --- a/packages/tests/src/setup.ts +++ b/packages/tests/src/setup.ts @@ -1,7 +1,63 @@ -import { use } from 'chai' +import chai, { use } from 'chai' import promised from 'chai-as-promised' -import { } from 'chai-shape' import shape from './shape' +import { isNullable } from 'cosmokit' use(shape) use(promised) + +function type(obj) { + if (typeof obj === 'undefined') { + return 'undefined'; + } + + if (obj === null) { + return 'null'; + } + + const stringTag = obj[Symbol.toStringTag]; + if (typeof stringTag === 'string') { + return stringTag; + } + const sliceStart = 8; + const sliceEnd = -1; + return Object.prototype.toString.call(obj).slice(sliceStart, sliceEnd); +} + +function getEnumerableKeys(target) { + var keys: string[] = []; + for (var key in target) { + keys.push(key); + } + return keys; +} + +function getEnumerableSymbols(target) { + var keys: symbol[] = []; + var allKeys = Object.getOwnPropertySymbols(target); + for (var i = 0; i < allKeys.length; i += 1) { + var key = allKeys[i]; + if (Object.getOwnPropertyDescriptor(target, key)?.enumerable) { + keys.push(key); + } + } + return keys; +} + +chai.config.deepEqual = (expected, actual, options) => { + return chai.util.eql(expected, actual, { + comparator: (expected, actual) => { + if (isNullable(expected) && isNullable(actual)) return true + if (type(expected) === 'Object' && type(actual) === 'Object') { + const keys = new Set([ + ...getEnumerableKeys(expected), + ...getEnumerableKeys(actual), + ...getEnumerableSymbols(expected), + ...getEnumerableSymbols(actual), + ]) + return [...keys].every(key => chai.config.deepEqual!(expected[key], actual[key], options)) + } + return null + } + }) +} diff --git a/packages/tests/src/shape.ts b/packages/tests/src/shape.ts index 2da755b8..b6c7a6e6 100644 --- a/packages/tests/src/shape.ts +++ b/packages/tests/src/shape.ts @@ -1,3 +1,4 @@ +import { isNullable } from 'cosmokit' import { inspect } from 'util' function flag(obj, key, value?) { @@ -45,6 +46,8 @@ export = (({ Assertion }) => { return `expected to have ${expect} but got ${actual} at path ${path}` } + if (isNullable(expect) && isNullable(actual)) return + if (!expect || ['string', 'number', 'boolean', 'bigint'].includes(typeof expect)) { return formatError(inspect(expect), inspect(actual)) } diff --git a/packages/tests/src/shims.d.ts b/packages/tests/src/shims.d.ts new file mode 100644 index 00000000..53ea4190 --- /dev/null +++ b/packages/tests/src/shims.d.ts @@ -0,0 +1,46 @@ +/// +/// + +interface DeepEqualOptions { + comparator?: (leftHandOperand: T1, rightHandOperand: T2) => boolean | null; +} + +declare namespace Chai { + interface Config { + deepEqual: (( + leftHandOperand: T1, + rightHandOperand: T2, + options?: DeepEqualOptions, + ) => boolean) | null | undefined + } + + interface ChaiUtils { + eql: ( + leftHandOperand: T1, + rightHandOperand: T2, + options?: DeepEqualOptions, + ) => boolean + } + + interface Assertion { + shape(expected: any, message?: string): Assertion + } + + interface Ordered { + shape(expected: any, message?: string): Assertion + } + + interface Eventually { + shape(expected: any, message?: string): PromisedAssertion + } + + interface PromisedOrdered { + shape(expected: any, message?: string): PromisedAssertion + } +} + +declare module './shape' { + declare const ChaiShape: Chai.ChaiPlugin + + export = ChaiShape +} diff --git a/packages/tests/src/update.ts b/packages/tests/src/update.ts index eccf5d18..f77345f3 100644 --- a/packages/tests/src/update.ts +++ b/packages/tests/src/update.ts @@ -13,6 +13,8 @@ interface Bar { date?: Date time?: Date bigtext?: string + binary?: Buffer + bigint?: bigint } interface Baz { @@ -38,6 +40,12 @@ function OrmOperations(database: Database) { date: 'date', time: 'time', bigtext: 'text', + binary: 'binary', + bigint: { + type: 'string', + dump: value => value ? value.toString() : value, + load: value => value ? BigInt(value) : value, + }, }, { autoInc: true, }) @@ -65,6 +73,9 @@ namespace OrmOperations { { id: 5, timestamp: magicBorn }, { id: 6, date: magicBorn }, { id: 7, time: new Date('1970-01-01 12:00:00') }, + { id: 8, binary: Buffer.from('hello') }, + { id: 9, bigint: BigInt(1e63) }, + { id: 10, text: 'a\b\t\f\n\r\x1a\'\"\\\`b', list: ['a\b\t\f\n\r\x1a\'\"\\\`b'] }, ] const bazTable: Baz[] = [ @@ -95,8 +106,8 @@ namespace OrmOperations { await expect(database.get('temp2', { id: obj.id })).to.eventually.have.shape([obj]) } await expect(database.get('temp2', {})).to.eventually.have.shape(table) - await database.remove('temp2', { id: 7 }) - await expect(database.create('temp2', {})).to.eventually.have.shape({ id: 8 }) + await database.remove('temp2', { id: table.length }) + await expect(database.create('temp2', {})).to.eventually.have.shape({ id: table.length + 1 }) }) it('specify primary key', async () => { @@ -133,6 +144,21 @@ namespace OrmOperations { await database.create('temp2', row) await expect(database.get('temp2', 100)).to.eventually.have.nested.property('0.bigtext', row.bigtext) }) + + it('advanced type', async () => { + await setup(database, 'temp2', barTable) + await expect(database.create('temp2', { binary: Buffer.from('world') })).to.eventually.have.shape({ binary: Buffer.from('world') }) + await expect(database.get('temp2', { binary: { $exists: true } })).to.eventually.have.shape([ + { binary: Buffer.from('hello') }, + { binary: Buffer.from('world') }, + ]) + + await expect(database.create('temp2', { bigint: 1234567891011121314151617181920n })).to.eventually.have.shape({ bigint: 1234567891011121314151617181920n }) + await expect(database.get('temp2', { bigint: { $exists: true } })).to.eventually.have.shape([ + { bigint: BigInt(1e63) }, + { bigint: 1234567891011121314151617181920n }, + ]) + }) } export const set = function Set(database: Database) { @@ -175,7 +201,7 @@ namespace OrmOperations { const table = await setup(database, 'temp2', barTable) table[1].num = table[1].id * 2 table[2].num = table[2].id * 2 - await database.set('temp2', [table[1].id, table[2].id, 9], row => ({ + await database.set('temp2', [table[1].id, table[2].id, 99], row => ({ num: $.multiply(2, row.id), })) await expect(database.get('temp2', {})).to.eventually.have.shape(table) @@ -187,6 +213,15 @@ namespace OrmOperations { await database.set('temp2', row.id, { bigtext: row.bigtext }) await expect(database.get('temp2', row.id)).to.eventually.have.nested.property('0.bigtext', row.bigtext) }) + + it('advanced type', async () => { + const table = await setup(database, 'temp2', barTable) + const data1 = table.find(item => item.id === 1)! + data1.binary = Buffer.from('world') + data1.bigint = 1234567891011121314151617181920n + await database.set('temp2', { id: 1 }, { binary: Buffer.from('world'), bigint: 1234567891011121314151617181920n }) + await expect(database.get('temp2', {})).to.eventually.have.shape(table) + }) } export const upsert = function Upsert(database: Database) { @@ -195,12 +230,15 @@ namespace OrmOperations { const data = [ { id: table[0].id, text: 'thu' }, { id: table[1].id, num: 1911 }, + { id: table[2].id, list: ['2', '3', '3'] }, ] data.forEach(update => { const index = table.findIndex(obj => obj.id === update.id) table[index] = merge(table[index], update) }) - await expect(database.upsert('temp2', data)).to.eventually.have.shape({ inserted: 0, matched: 2 }) + await expect(database.upsert('temp2', data.slice(0, 2))).to.eventually.have.shape({ inserted: 0, matched: 2 }) + await expect(database.upsert('temp2', data.slice(0, 2))).to.eventually.have.shape({ inserted: 0, matched: 2 }) + await expect(database.upsert('temp2', data.slice(2))).to.eventually.have.shape({ inserted: 0, matched: 1 }) await expect(database.get('temp2', {})).to.eventually.have.shape(table) }) @@ -209,9 +247,12 @@ namespace OrmOperations { const data = [ { id: table[table.length - 1].id + 1, text: 'wm"lake' }, { id: table[table.length - 1].id + 2, text: 'by\'tower' }, + { id: table[table.length - 1].id + 3, text: 'over' }, ] table.push(...data.map(bar => merge(database.tables.temp2.create(), bar))) - await expect(database.upsert('temp2', data)).to.eventually.have.shape({ inserted: 2, matched: 0 }) + await expect(database.upsert('temp2', data.slice(0, 2))).to.eventually.have.shape({ inserted: 2, matched: 0 }) + await expect(database.upsert('temp2', data.slice(2))).to.eventually.have.shape({ inserted: 1, matched: 0 }) + await expect(database.upsert('temp2', data.slice(2))).to.eventually.have.shape({ inserted: 0, matched: 1 }) await expect(database.get('temp2', {})).to.eventually.have.shape(table) }) @@ -219,15 +260,15 @@ namespace OrmOperations { const table = await setup(database, 'temp2', barTable) const data2 = table.find(item => item.id === 2)! const data3 = table.find(item => item.id === 3)! - const data9 = table.find(item => item.id === 9) + const data99 = table.find(item => item.id === 99) data2.num = data2.id * 2 data3.num = data3.num! + 3 - expect(data9).to.be.undefined - table.push({ id: 9, num: 999 }) + expect(data99).to.be.undefined + table.push({ id: 99, num: 999 }) await expect(database.upsert('temp2', row => [ { id: 2, num: $.multiply(2, row.id) }, { id: 3, num: $.add(3, row.num) }, - { id: 9, num: 999 }, + { id: 99, num: 999 }, ])).to.eventually.have.shape({ inserted: 1, matched: 2 }) await expect(database.get('temp2', {})).to.eventually.have.shape(table) }) @@ -272,6 +313,19 @@ namespace OrmOperations { { ida: 12, idb: 'c', value: 'd' }, ], ['value'] as any)).to.eventually.have.shape({ inserted: 2, matched: 1 }) }) + + it('advanced type', async () => { + const table = await setup(database, 'temp2', barTable) + const data1 = table.find(item => item.id === 1)! + data1.binary = Buffer.from('world') + data1.bigint = 1234567891011121314151617181920n + table.push({ binary: Buffer.from('foobar'), bigint: 1234567891011121314151617181920212223n } as any) + await database.upsert('temp2', [ + { id: 1, binary: Buffer.from('world'), bigint: 1234567891011121314151617181920n }, + { binary: Buffer.from('foobar'), bigint: 1234567891011121314151617181920212223n }, + ]) + await expect(database.get('temp2', {})).to.eventually.have.shape(table) + }) } export const remove = function Remove(database: Database) { @@ -305,9 +359,18 @@ namespace OrmOperations { export const misc = function Misc(database: Database) { it('date type', async () => { const table = await setup(database, 'temp2', barTable) + await expect(database.get('temp2', {})).to.eventually.have.shape(table) await expect(database.eval('temp2', row => $.max(row.timestamp))).to.eventually.deep.eq(table[4].timestamp) await expect(database.eval('temp2', row => $.max(row.date))).to.eventually.deep.eq(table[5].date) await expect(database.eval('temp2', row => $.max(row.time))).to.eventually.deep.eq(table[6].time) + + table.push(await database.create('temp2', { + text: 'date type', + timestamp: new Date(), + date: new Date(), + time: new Date(), + })) + await expect(database.get('temp2', {})).to.eventually.have.shape(table) }) it('$.number on date types', async () => {