Skip to content

Commit

Permalink
fix: dataloader mem leak
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Krick <matt.krick@gmail.com>
  • Loading branch information
mattkrick committed Mar 19, 2024
1 parent 49b349b commit bf9a379
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 26 deletions.
14 changes: 8 additions & 6 deletions packages/embedder/JobQueueStream.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {Selectable, sql} from 'kysely'
import ms from 'ms'
import sleep from 'parabol-client/utils/sleep'
import {DataLoaderWorker} from 'parabol-server/graphql/graphql'
import RootDataLoader from 'parabol-server/dataloader/RootDataLoader'
import 'parabol-server/initSentry'
import getKysely from 'parabol-server/postgres/getKysely'
import {DB} from 'parabol-server/postgres/pg'
Expand All @@ -13,10 +13,10 @@ import {updateJobState} from './indexing/updateJobState'

type Job = Selectable<DB['EmbeddingsJobQueue']>
export default class JobQueueStream implements AsyncIterator<Job> {
private dataLoader: DataLoaderWorker
private dataLoader: RootDataLoader
private modelManager: ModelManager

constructor(modelManager: ModelManager, dataLoader: DataLoaderWorker) {
constructor(modelManager: ModelManager, dataLoader: RootDataLoader) {
this.modelManager = modelManager
this.dataLoader = dataLoader
}
Expand Down Expand Up @@ -92,7 +92,7 @@ export default class JobQueueStream implements AsyncIterator<Job> {
await pg
.updateTable('EmbeddingsMetadata')
.set({fullText})
.where('id', '=', metadata.id)
.where('id', '=', embeddingsMetadataId)
.execute()
}
} catch (e) {
Expand Down Expand Up @@ -128,7 +128,9 @@ export default class JobQueueStream implements AsyncIterator<Job> {
// Cannot use summarization strategy if generation model has same context length as embedding model
// We must split the text & not tokens because the endpoint doesn't support decoding input tokens
const chunks = isFullTextTooBig ? embeddingModel.splitText(fullText, 0) : [fullText]

if (isFullTextTooBig) {
console.log('we split!', chunks)
}
await Promise.all(
chunks.map(async (chunk, chunkNumber) => {
const embeddingVector = await embeddingModel.getEmbedding(chunk)
Expand All @@ -150,7 +152,7 @@ export default class JobQueueStream implements AsyncIterator<Job> {
chunkNumber: isFullTextTooBig ? chunkNumber : null
})
.onConflict((oc) =>
oc.doUpdateSet((eb) => ({
oc.column('embeddingsMetadataId').doUpdateSet((eb) => ({
embedText: eb.ref('excluded.embedText'),
embedding: eb.ref('excluded.embedding')
}))
Expand Down
3 changes: 2 additions & 1 deletion packages/embedder/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {establishPrimaryEmbedder} from './establishPrimaryEmbedder'
import {importHistoricalMetadata} from './importHistoricalMetadata'
import {getRootDataLoader} from './indexing/getRootDataLoader'
import {mergeAsyncIterators} from './mergeAsyncIterators'
import RootDataLoader from 'parabol-server/dataloader/RootDataLoader'

tracer.init({
service: `embedder`,
Expand Down Expand Up @@ -81,7 +82,7 @@ const run = async () => {
}

const incomingStream = new RedisStream('embedderStream', 'embedderConsumerGroup', embedderChannel)
const dataLoader = getRootDataLoader()
const dataLoader = new RootDataLoader({maxBatchSize: 1000})
const jobQueueStream = new JobQueueStream(modelManager, dataLoader)

console.log(`\n⚡⚡⚡️️ Server ID: ${SERVER_ID}. Embedder is ready ⚡⚡⚡️️️`)
Expand Down
4 changes: 2 additions & 2 deletions packages/embedder/indexing/createEmbeddingTextFrom.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import {Selectable} from 'kysely'
import {DataLoaderWorker} from 'parabol-server/graphql/graphql'
import {DB} from 'parabol-server/postgres/pg'

import RootDataLoader from 'parabol-server/dataloader/RootDataLoader'
import {createTextFromRetrospectiveDiscussionTopic} from './retrospectiveDiscussionTopic'

export const createEmbeddingTextFrom = async (
embeddingsMetadata: Selectable<DB['EmbeddingsMetadata']>,
dataLoader: DataLoaderWorker
dataLoader: RootDataLoader
) => {
switch (embeddingsMetadata.objectType) {
case 'retrospectiveDiscussionTopic':
Expand Down
10 changes: 0 additions & 10 deletions packages/embedder/indexing/getRootDataLoader.ts

This file was deleted.

13 changes: 6 additions & 7 deletions packages/embedder/indexing/retrospectiveDiscussionTopic.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import getRethink, {RethinkSchema} from 'parabol-server/database/rethinkDriver'
import {RethinkSchema} from 'parabol-server/database/rethinkDriver'
import Comment from 'parabol-server/database/types/Comment'
import {isMeetingRetrospective} from 'parabol-server/database/types/MeetingRetrospective'
import {DataLoaderWorker} from 'parabol-server/graphql/graphql'
import RootDataLoader from 'parabol-server/dataloader/RootDataLoader'
import prettier from 'prettier'

// Here's a generic reprentation of the text generated here:
Expand All @@ -21,14 +21,14 @@ import prettier from 'prettier'

const IGNORE_COMMENT_USER_IDS = ['parabolAIUser']

async function getPreferredNameByUserId(userId: string, dataLoader: DataLoaderWorker) {
async function getPreferredNameByUserId(userId: string, dataLoader: RootDataLoader) {
if (!userId) return 'Unknown'
const user = await dataLoader.get('users').load(userId)
return !user ? 'Unknown' : user.preferredName
}

async function formatThread(
dataLoader: DataLoaderWorker,
dataLoader: RootDataLoader,
comments: Comment[],
parentId: string | null = null,
depth = 0
Expand Down Expand Up @@ -60,13 +60,12 @@ async function formatThread(

export const createTextFromRetrospectiveDiscussionTopic = async (
discussionId: string,
dataLoader: DataLoaderWorker,
dataLoader: RootDataLoader,
textForReranking: boolean = false
) => {
const discussion = await dataLoader.get('discussions').load(discussionId)
if (!discussion) throw new Error(`Discussion not found: ${discussionId}`)
const {discussionTopicId: reflectionGroupId, meetingId, summary: discussionSummary} = discussion
const r = await getRethink()
const [newMeeting, reflectionGroup, reflections] = await Promise.all([
dataLoader.get('newMeetings').load(meetingId),
dataLoader.get('retroReflectionGroups').load(reflectionGroupId),
Expand Down Expand Up @@ -166,7 +165,7 @@ export const createTextFromRetrospectiveDiscussionTopic = async (

export const newRetroDiscussionTopicsFromNewMeeting = async (
newMeeting: RethinkSchema['NewMeeting']['type'],
dataLoader: DataLoaderWorker
dataLoader: RootDataLoader
) => {
const discussPhase = newMeeting.phases.find((phase) => phase.phaseType === 'discuss')
const orgId = (await dataLoader.get('teams').load(newMeeting.teamId))?.orgId
Expand Down

0 comments on commit bf9a379

Please sign in to comment.