From 71702f16d3d9d6d62b55defc31e623dc0442f433 Mon Sep 17 00:00:00 2001 From: Michele Riva Date: Sun, 28 Jul 2024 12:30:57 -0700 Subject: [PATCH] feat: speeds up insertion for number properties (#756) --- packages/orama/src/components/facets.ts | 12 ++- packages/orama/src/components/index.ts | 27 +++-- packages/orama/src/methods/insert.ts | 18 +++- packages/orama/src/trees/avl.ts | 66 +++++++----- packages/orama/src/types.ts | 4 +- packages/orama/tests/distinct.test.ts | 2 +- packages/orama/tests/enum.test.ts | 20 ++-- packages/orama/tests/insert.test.ts | 60 +++++++++++ packages/orama/tests/search.hybrid.test.ts | 114 ++++++++++----------- packages/orama/tests/tree.avl.test.ts | 4 +- 10 files changed, 212 insertions(+), 115 deletions(-) diff --git a/packages/orama/src/components/facets.ts b/packages/orama/src/components/facets.ts index 7e89a186c..626b04cf7 100644 --- a/packages/orama/src/components/facets.ts +++ b/packages/orama/src/components/facets.ts @@ -96,7 +96,11 @@ export async function getFacets( case 'string[]': { const alreadyInsertedValues = new Set() const innerType = propertyType === 'boolean[]' ? 'boolean' : 'string' - const calculateBooleanStringOrEnumFacet = calculateBooleanStringOrEnumFacetBuilder(facetValues, innerType, alreadyInsertedValues) + const calculateBooleanStringOrEnumFacet = calculateBooleanStringOrEnumFacetBuilder( + facetValues, + innerType, + alreadyInsertedValues + ) for (const v of facetValue as Array) { calculateBooleanStringOrEnumFacet(v) } @@ -140,13 +144,13 @@ function calculateNumberFacetBuilder( if (alreadyInsertedValues?.has(value)) { continue } - + if (facetValue >= range.from && facetValue <= range.to) { if (values[value] === undefined) { values[value] = 1 } else { values[value]++ - + alreadyInsertedValues?.add(value) } } @@ -159,7 +163,7 @@ function calculateBooleanStringOrEnumFacetBuilder( propertyType: 'string' | 'boolean' | 'enum', alreadyInsertedValues?: Set ) { - const defaultValue = (propertyType === 'boolean' ? 'false' : '') + const defaultValue = propertyType === 'boolean' ? 'false' : '' return (facetValue: FacetValue) => { // String or boolean based facets const value = facetValue?.toString() ?? defaultValue diff --git a/packages/orama/src/components/index.ts b/packages/orama/src/components/index.ts index 509994c56..372dc5d04 100644 --- a/packages/orama/src/components/index.ts +++ b/packages/orama/src/components/index.ts @@ -22,6 +22,7 @@ import type { VectorType, WhereCondition } from '../types.js' +import type { InsertOptions } from '../methods/insert.js' import { createError } from '../errors.js' import { create as avlCreate, @@ -293,7 +294,8 @@ function insertScalarBuilder( id: DocumentID, language: string | undefined, tokenizer: Tokenizer, - docsCount: number + docsCount: number, + options?: InsertOptions ) { return async (value: SearchableValue): Promise => { const internalId = getInternalDocumentId(index.sharedInternalDocumentStore, id) @@ -305,7 +307,8 @@ function insertScalarBuilder( break } case 'AVL': { - avlInsert(node, value as number, [internalId]) + const avlRebalanceThreshold = options?.avlRebalanceThreshold ?? 1 + avlInsert(node, value as number, [internalId], avlRebalanceThreshold) break } case 'Radix': { @@ -341,13 +344,14 @@ export async function insert( schemaType: SearchableType, language: string | undefined, tokenizer: Tokenizer, - docsCount: number + docsCount: number, + options?: InsertOptions ): Promise { if (isVectorType(schemaType)) { return insertVector(index, prop, value as number[] | Float32Array, id) } - const insertScalar = insertScalarBuilder(implementation, index, prop, id, language, tokenizer, docsCount) + const insertScalar = insertScalarBuilder(implementation, index, prop, id, language, tokenizer, docsCount, options) if (!isArrayType(schemaType)) { return insertScalar(value) @@ -553,7 +557,10 @@ export async function searchByWhereClause docIDs)) + safeArrayPush( + filtersMap[param], + ids.flatMap(({ docIDs }) => docIDs) + ) } else { const { coordinates, @@ -562,7 +569,10 @@ export async function searchByWhereClause docIDs)) + safeArrayPush( + filtersMap[param], + ids.flatMap(({ docIDs }) => docIDs) + ) } continue @@ -588,7 +598,10 @@ export async function searchByWhereClause( orama: T, doc: PartialSchemaDeep>, language?: string, - skipHooks?: boolean + skipHooks?: boolean, + options?: InsertOptions ): Promise { const errorProperty = await orama.validateSchema(doc, orama.schema) if (errorProperty) { throw createError('SCHEMA_VALIDATION_FAILURE', errorProperty) } - return innerInsert(orama, doc, language, skipHooks) + return innerInsert(orama, doc, language, skipHooks, options) } const ENUM_TYPE = new Set(['enum', 'enum[]']) @@ -26,7 +31,8 @@ async function innerInsert( orama: T, doc: PartialSchemaDeep>, language?: string, - skipHooks?: boolean + skipHooks?: boolean, + options?: InsertOptions ): Promise { const { index, docs } = orama.data @@ -111,7 +117,8 @@ async function innerInsert( expectedType, language, orama.tokenizer, - docsCount + docsCount, + options ) await orama.index.afterInsert?.( orama.data.index, @@ -199,7 +206,8 @@ export async function innerInsertMultiple( for (const doc of batch) { try { - const id = await insert(orama, doc, language, skipHooks) + const options = { avlRebalanceThreshold: batch.length } + const id = await insert(orama, doc, language, skipHooks, options) ids.push(id) } catch (err) { reject(err) diff --git a/packages/orama/src/trees/avl.ts b/packages/orama/src/trees/avl.ts index 1b56e8f46..01cc97f4f 100644 --- a/packages/orama/src/trees/avl.ts +++ b/packages/orama/src/trees/avl.ts @@ -122,14 +122,14 @@ export function rangeSearch(node: RootNode, min: K, max: K): V { export function greaterThan(node: RootNode, key: K, inclusive = false): V { const result: V[] = [] - if (node === null) return result as V; + if (node === null) return result as V const stack: Array>> = [node.root] while (stack.length > 0) { const node = stack.pop() if (!node) { - continue; + continue } if (inclusive && node.k >= key) { @@ -149,14 +149,14 @@ export function greaterThan(node: RootNode, key: K, inclusive = fals export function lessThan(node: RootNode, key: K, inclusive = false): V { const result: V[] = [] - if (node === null) return result as V; + if (node === null) return result as V const stack: Array>> = [node.root] while (stack.length > 0) { const node = stack.pop() if (!node) { - continue; + continue } if (inclusive && node.k <= key) { @@ -198,9 +198,12 @@ export function create(key: K, value: V): RootNode { } } -export function insert(rootNode: RootNode, key: K, newValue: V[]): void { +let insertCount = 0 + +export function insert(rootNode: RootNode, key: K, newValue: V[], rebalanceThreshold = 500): void { function insertNode(node: Nullable>, key: K, newValue: V[]): Node { if (node === null) { + insertCount++ return { k: key, v: newValue, @@ -215,38 +218,49 @@ export function insert(rootNode: RootNode, key: K, newValue: V[]): } else if (key > node.k) { node.r = insertNode(node.r, key, newValue) } else { - for (const value of newValue) { - node.v.push(value) - } + node.v.push(...newValue) return node } - node.h = 1 + Math.max(getHeight(node.l), getHeight(node.r)) + // Rebalance the tree if the insert count reaches the threshold. + // This will improve insertion performance since we won't be rebalancing the tree on every insert. + // When inserting docs using `insertMultiple`, the threshold will be set to the number of docs being inserted. + // We can force rebalancing the tree by setting the threshold to 1 (default). + if (insertCount % rebalanceThreshold === 0) { + console.log(`Rebalancing tree after ${insertCount} inserts...`) + return rebalanceNode(node, key) + } - const balanceFactor = getHeight(node.l) - getHeight(node.r) + return node + } - if (balanceFactor > 1 && key < node.l!.k) { - return rotateRight(node) - } + rootNode.root = insertNode(rootNode.root, key, newValue) +} - if (balanceFactor < -1 && key > node.r!.k) { - return rotateLeft(node) - } +function rebalanceNode(node: Node, key: K): Node { + node.h = 1 + Math.max(getHeight(node.l), getHeight(node.r)) - if (balanceFactor > 1 && key > node.l!.k) { - node.l = rotateLeft(node.l!) - return rotateRight(node) - } + const balanceFactor = getHeight(node.l) - getHeight(node.r) - if (balanceFactor < -1 && key < node.r!.k) { - node.r = rotateRight(node.r!) - return rotateLeft(node) - } + if (balanceFactor > 1 && key < node.l!.k) { + return rotateRight(node) + } - return node + if (balanceFactor < -1 && key > node.r!.k) { + return rotateLeft(node) } - rootNode.root = insertNode(rootNode.root, key, newValue) + if (balanceFactor > 1 && key > node.l!.k) { + node.l = rotateLeft(node.l!) + return rotateRight(node) + } + + if (balanceFactor < -1 && key < node.r!.k) { + node.r = rotateRight(node.r!) + return rotateLeft(node) + } + + return node } function getHeight(node: Nullable>): number { diff --git a/packages/orama/src/types.ts b/packages/orama/src/types.ts index 665e54675..8df8f112b 100644 --- a/packages/orama/src/types.ts +++ b/packages/orama/src/types.ts @@ -1,3 +1,4 @@ +import type { InsertOptions } from './methods/insert.js' import { MODE_FULLTEXT_SEARCH, MODE_HYBRID_SEARCH, MODE_VECTOR_SEARCH } from './constants.js' import { DocumentsStore } from './components/documents-store.js' import { Index } from './components/index.js' @@ -924,7 +925,8 @@ export interface IIndex { schemaType: SearchableType, language: string | undefined, tokenizer: Tokenizer, - docsCount: number + docsCount: number, + options?: InsertOptions ) => SyncOrAsyncValue afterInsert?: IIndexInsertOrRemoveHookFunction diff --git a/packages/orama/tests/distinct.test.ts b/packages/orama/tests/distinct.test.ts index 405f93fc4..567d78cba 100644 --- a/packages/orama/tests/distinct.test.ts +++ b/packages/orama/tests/distinct.test.ts @@ -84,7 +84,7 @@ async function createDb() { color: 'string', rank: 'number', isPromoted: 'boolean' - } + } as const }) const ids = await insertMultiple(db, [ diff --git a/packages/orama/tests/enum.test.ts b/packages/orama/tests/enum.test.ts index 4e4e29cf8..3b1704bb5 100644 --- a/packages/orama/tests/enum.test.ts +++ b/packages/orama/tests/enum.test.ts @@ -17,7 +17,7 @@ t.test('enum', async (t) => { const db = await create({ schema: { categoryId: 'enum' - } + } as const }) const c1 = await insert(db, { @@ -153,7 +153,7 @@ t.test('enum', async (t) => { const db = await create({ schema: { categoryId: 'enum' - } + } as const }) const c1 = await insert(db, { categoryId: 1 }) const c11 = await insert(db, { categoryId: 1 }) @@ -187,7 +187,7 @@ t.test('enum', async (t) => { const db1 = await create({ schema: { categoryId: 'enum' - } + } as const }) const [c1, c11, c2, c3, c5] = await insertMultiple(db1, [ { categoryId: 1 }, @@ -202,7 +202,7 @@ t.test('enum', async (t) => { const db2 = await create({ schema: { categoryId: 'enum' - } + } as const }) await load(db2, dump) @@ -239,7 +239,7 @@ t.test('enum', async (t) => { title: 'string', year: 'number', categoryId: 'enum' - } + } as const }) const [c1] = await insertMultiple(filmDb, [ { title: 'The Shawshank Redemption', year: 1994, categoryId: 1 }, @@ -303,7 +303,7 @@ t.test('enum[]', async (t) => { const db = await create({ schema: { tags: 'enum[]' - } + } as const }) const cGreenBlue = await insert(db, { @@ -395,7 +395,7 @@ t.test('enum[]', async (t) => { const db = await create({ schema: { tags: 'enum[]' - } + } as const }) const c1 = await insert(db, { tags: ['green', 'blue'] }) const c11 = await insert(db, { tags: ['blue', 'green'] }) @@ -429,7 +429,7 @@ t.test('enum[]', async (t) => { const db1 = await create({ schema: { tags: 'enum[]' - } + } as const }) const [c1, c11] = await insertMultiple(db1, [ { tags: ['green'] }, @@ -464,7 +464,7 @@ t.test('enum[]', async (t) => { term: '', where: { tags: { containsAll: [] } - } + } as const }) t.equal(result2.hits.length, 0) t.strictSame( @@ -481,7 +481,7 @@ t.test('enum[]', async (t) => { title: 'string', year: 'number', tags: 'enum[]' - } + } as const }) const [, , , c4] = await insertMultiple(filmDb, [ { title: 'The Shawshank Redemption', year: 1994, tags: ['drama', 'crime'] }, diff --git a/packages/orama/tests/insert.test.ts b/packages/orama/tests/insert.test.ts index 7465ae369..fafb64f58 100644 --- a/packages/orama/tests/insert.test.ts +++ b/packages/orama/tests/insert.test.ts @@ -523,6 +523,66 @@ t.test('insertMultiple method', (t) => { t.equal(after - before > expectedTime, true) }) + t.test('should correctly rebalance AVL tree once the threshold is reached', async (t) => { + t.plan(4) + + const db = await create({ + schema: { + id: 'string', + name: 'string', + number: 'number' + } as const + }) + + function getRandomNumberExcept(n: number): number { + const exceptions = [25, 250] + + if (exceptions.includes(n)) { + return n + } + + let random = Math.floor(Math.random() * 1000) + + while (exceptions.includes(random) || random === n) { + random = Math.floor(Math.random() * 1000) + } + + return random + } + + const docs = Array.from({ length: 1000 }, (_, i) => ({ + id: i.toString(), + name: `name-${i}`, + number: getRandomNumberExcept(i) + })) + + await insertMultiple(db, docs, 200) + + const results25 = await search(db, { + term: 'name-25', + where: { + number: { + eq: 25 + } + } + }) + + const results250 = await search(db, { + term: 'name', + where: { + number: { + eq: 250 + } + } + }) + + t.equal(results25.count, 1) + t.equal(results25.hits[0].document.id, '25') + + t.equal(results250.count, 1) + t.equal(results250.hits[0].document.id, '250') + }) + t.end() }) diff --git a/packages/orama/tests/search.hybrid.test.ts b/packages/orama/tests/search.hybrid.test.ts index 3ae9fde6b..55fb58078 100644 --- a/packages/orama/tests/search.hybrid.test.ts +++ b/packages/orama/tests/search.hybrid.test.ts @@ -201,101 +201,99 @@ t.test('hybrid search', async (t) => { t.test('should correctly paginate the results with a where clause', async (t) => { const db = await create({ schema: { - text: "string", - embedding: "vector[5]", - number: "number", - itemId: "string", - } as const, - }); + text: 'string', + embedding: 'vector[5]', + number: 'number', + itemId: 'string' + } as const + }) await insertMultiple(db, [ - { "text": "hello world", "itemId": "1", "embedding": [1, 2, 3, 4, 5], "number": 1 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 2 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 3 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 4 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 5 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 6 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 7 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 8 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 9 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 10 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 11 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 12 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 13 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 14 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 15 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 16 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 17 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 18 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 19 }, - { "text": "hello there", "itemId": "1", "embedding": [1, 2, 3, 4, 4], "number": 20 } - ] - ); + { text: 'hello world', itemId: '1', embedding: [1, 2, 3, 4, 5], number: 1 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 2 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 3 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 4 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 5 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 6 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 7 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 8 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 9 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 10 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 11 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 12 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 13 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 14 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 15 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 16 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 17 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 18 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 19 }, + { text: 'hello there', itemId: '1', embedding: [1, 2, 3, 4, 4], number: 20 } + ]) const page1 = await search(db, { - term: "hello there", - mode: "hybrid", + term: 'hello there', + mode: 'hybrid', where: { - itemId: "1", + itemId: '1' }, vector: { - property: "embedding", - value: [1, 2, 3, 4, 4], + property: 'embedding', + value: [1, 2, 3, 4, 4] }, similarity: 0.5, limit: 2, - offset: 0, - }); + offset: 0 + }) const page2 = await search(db, { - term: "hello there", - mode: "hybrid", + term: 'hello there', + mode: 'hybrid', where: { - itemId: "1", + itemId: '1' }, vector: { - property: "embedding", - value: [1, 2, 3, 4, 4], + property: 'embedding', + value: [1, 2, 3, 4, 4] }, similarity: 0.5, limit: 2, - offset: 1, - }); + offset: 1 + }) const page3 = await search(db, { - term: "hello there", - mode: "hybrid", + term: 'hello there', + mode: 'hybrid', where: { - itemId: "1", + itemId: '1' }, vector: { - property: "embedding", - value: [1, 2, 3, 4, 4], + property: 'embedding', + value: [1, 2, 3, 4, 4] }, similarity: 0.5, limit: 2, - offset: 2, - }); + offset: 2 + }) const page4 = await search(db, { - term: "hello there", - mode: "hybrid", + term: 'hello there', + mode: 'hybrid', where: { - itemId: "1", + itemId: '1' }, vector: { - property: "embedding", - value: [1, 2, 3, 4, 4], + property: 'embedding', + value: [1, 2, 3, 4, 4] }, similarity: 0.5, limit: 10, - offset: 5, + offset: 5 }) t.equal(page1.hits.length, 2) t.equal(page2.hits.length, 2) t.equal(page3.hits.length, 2) t.equal(page4.hits.length, 10) - t.equal(page1.hits[0].document.number, 2) t.equal(page1.hits[1].document.number, 3) @@ -304,11 +302,9 @@ t.test('should correctly paginate the results with a where clause', async (t) => t.equal(page3.hits[0].document.number, 4) t.equal(page3.hits[1].document.number, 5) - + t.equal(page1.count, 20) t.equal(page2.count, 20) t.equal(page3.count, 20) t.equal(page4.count, 20) }) - - diff --git a/packages/orama/tests/tree.avl.test.ts b/packages/orama/tests/tree.avl.test.ts index 8ee9d5a33..ef4f71e9e 100644 --- a/packages/orama/tests/tree.avl.test.ts +++ b/packages/orama/tests/tree.avl.test.ts @@ -123,7 +123,7 @@ t.test('AVL Tree', (t) => { insert(tree, 20, ['quuz']) insert(tree, 12, ['corge']) - t.same(greaterThan(tree, 10), ['quuz', 'corge', 'qux']) + t.same(greaterThan(tree, 10), ['qux', 'quuz', 'corge']) }) t.test('lessThan', (t) => { @@ -138,6 +138,6 @@ t.test('AVL Tree', (t) => { insert(tree, 20, ['quuz']) insert(tree, 12, ['corge']) - t.same(lessThan(tree, 10), ['bar', 'foo', 'quux']) + t.same(lessThan(tree, 10), ['foo', 'bar', 'quux']) }) })