Skip to content

Commit

Permalink
Load CDK: Remove ScopedTask interface, simplify TaskScopeProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt committed Jan 11, 2025
1 parent 09132c7 commit 16e598a
Show file tree
Hide file tree
Showing 21 changed files with 167 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ abstract class DestinationConfiguration : Configuration {

/**
* The amount of time given to implementor tasks (e.g. open, processBatch) to complete their
* current work after a failure.
* current work after a failure. Input consuming will stop right away, so this will give the
* tasks time to persist the messages already read.
*/
open val gracefulCancellationTimeoutMs: Long = 60 * 1000L // 1 minutes
open val gracefulCancellationTimeoutMs: Long = 10 * 60 * 1000L // 10 minutes

open val numProcessRecordsWorkers: Int = 2
open val processRecordsIsIO: Boolean = false
open val numProcessBatchWorkers: Int = 5
open val numProcessBatchWorkersForFileTransfer: Int = 3
open val batchQueueDepth: Int = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,13 @@ class DefaultDestinationTaskLauncher(

private val closeStreamHasRun = ConcurrentHashMap<DestinationStream.Descriptor, AtomicBoolean>()

inner class TaskWrapper(
override val innerTask: ScopedTask,
) : WrappedTask<ScopedTask> {
inner class WrappedTask(
private val innerTask: Task,
) : Task {
override val isIO = innerTask.isIO
override val cancelAtEndOfSync = innerTask.cancelAtEndOfSync
override val killOnSyncFailure = innerTask.killOnSyncFailure

override suspend fun execute() {
try {
innerTask.execute()
Expand All @@ -161,16 +165,8 @@ class DefaultDestinationTaskLauncher(
}
}

inner class NoopWrapper(
override val innerTask: ScopedTask,
) : WrappedTask<ScopedTask> {
override suspend fun execute() {
innerTask.execute()
}
}

private suspend fun enqueue(task: ScopedTask, withExceptionHandling: Boolean = true) {
val wrapped = if (withExceptionHandling) TaskWrapper(task) else NoopWrapper(task)
private suspend fun launch(task: Task, withExceptionHandling: Boolean = true) {
val wrapped = if (withExceptionHandling) WrappedTask(task) else task
taskScopeProvider.launch(wrapped)
}

Expand All @@ -186,56 +182,56 @@ class DefaultDestinationTaskLauncher(
fileTransferQueue = fileTransferQueue,
destinationTaskLauncher = this,
)
enqueue(inputConsumerTask)
launch(inputConsumerTask)

// Launch the client interface setup task
log.info { "Starting startup task" }
val setupTask = setupTaskFactory.make(this)
enqueue(setupTask)
launch(setupTask)

// TODO: pluggable file transfer
if (!fileTransferEnabled) {
// Start a spill-to-disk task for each record stream
catalog.streams.forEach { stream ->
log.info { "Starting spill-to-disk task for $stream" }
val spillTask = spillToDiskTaskFactory.make(this, stream.descriptor)
enqueue(spillTask)
launch(spillTask)
}

repeat(config.numProcessRecordsWorkers) {
log.info { "Launching process records task $it" }
val task = processRecordsTaskFactory.make(this)
enqueue(task)
launch(task)
}

repeat(config.numProcessBatchWorkers) {
log.info { "Launching process batch task $it" }
val task = processBatchTaskFactory.make(this)
enqueue(task)
launch(task)
}
} else {
repeat(config.numProcessRecordsWorkers) {
log.info { "Launching process file task $it" }
enqueue(processFileTaskFactory.make(this))
launch(processFileTaskFactory.make(this))
}

repeat(config.numProcessBatchWorkersForFileTransfer) {
log.info { "Launching process batch task $it" }
val task = processBatchTaskFactory.make(this)
enqueue(task)
launch(task)
}
}

// Start flush task
log.info { "Starting timed file aggregate flush task " }
enqueue(flushTickTask)
launch(flushTickTask)

// Start the checkpoint management tasks
log.info { "Starting timed checkpoint flush task" }
enqueue(timedCheckpointFlushTask)
launch(timedCheckpointFlushTask)

log.info { "Starting checkpoint update task" }
enqueue(updateCheckpointsTask)
launch(updateCheckpointsTask)

// Await completion
if (succeeded.receive()) {
Expand All @@ -250,7 +246,7 @@ class DefaultDestinationTaskLauncher(
catalog.streams.forEach {
log.info { "Starting open stream task for $it" }
val task = openStreamTaskFactory.make(this, it)
enqueue(task)
launch(task)
}
}

Expand All @@ -276,14 +272,14 @@ class DefaultDestinationTaskLauncher(
log.info {
"Batch $wrapped is persisted: Starting flush checkpoints task for $stream"
}
enqueue(flushCheckpointsTaskFactory.make())
launch(flushCheckpointsTaskFactory.make())
}

if (streamManager.isBatchProcessingComplete()) {
if (closeStreamHasRun.getOrPut(stream) { AtomicBoolean(false) }.setOnce()) {
log.info { "Batch processing complete: Starting close stream task for $stream" }
val task = closeStreamTaskFactory.make(this, stream)
enqueue(task)
launch(task)
} else {
log.info { "Close stream task has already run, skipping." }
}
Expand All @@ -296,7 +292,7 @@ class DefaultDestinationTaskLauncher(
/** Called when a stream is closed. */
override suspend fun handleStreamClosed(stream: DestinationStream.Descriptor) {
if (teardownIsEnqueued.setOnce()) {
enqueue(teardownTaskFactory.make(this))
launch(teardownTaskFactory.make(this))
} else {
log.info { "Teardown task already enqueued, not enqueuing another one" }
}
Expand All @@ -305,15 +301,15 @@ class DefaultDestinationTaskLauncher(
override suspend fun handleException(e: Exception) {
catalog.streams
.map { failStreamTaskFactory.make(this, e, it.descriptor) }
.forEach { enqueue(it, withExceptionHandling = false) }
.forEach { launch(it, withExceptionHandling = false) }
}

override suspend fun handleFailStreamComplete(
stream: DestinationStream.Descriptor,
e: Exception
) {
if (failSyncIsEnqueued.setOnce()) {
enqueue(failSyncTaskFactory.make(this, e))
launch(failSyncTaskFactory.make(this, e))
} else {
log.info { "Teardown task already enqueued, not enqueuing another one" }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
package io.airbyte.cdk.load.task

interface Task {
/**
* If the task performs any blocking io, even writing to local disk, it should set
* [isIO] = true. [cancelAtEndOfSync] is for long-running tasks that will otherwise
* not close. [killOnSyncFailure] is for tasks that close normally under success conditions
* but should be halted immediately on failure to permit shutdown (like input consuming).
*
* TODO: simplify this further.
*/
val isIO: Boolean
val cancelAtEndOfSync: Boolean
val killOnSyncFailure: Boolean

suspend fun execute()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,145 +6,82 @@ package io.airbyte.cdk.load.task

import io.airbyte.cdk.load.command.DestinationConfiguration
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.AtomicReference
import kotlin.system.measureTimeMillis
import kotlinx.coroutines.CompletableJob
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
import org.apache.mina.util.ConcurrentHashSet

/**
* The scope in which a task should run
* - [InternalScope]:
* ```
* - internal to the task launcher
* - should not be blockable by implementor errors
* - killable w/o side effects
* ```
* - [ImplementorScope]: implemented by the destination
* ```
* - calls implementor interface
* - should not block internal tasks (esp reading from stdin)
* - should complete if possible even when failing the sync
* ```
*/
sealed interface ScopedTask : Task

interface InternalScope : ScopedTask

interface ImplementorScope : ScopedTask

/**
* Some tasks should be immediately cancelled upon any failure (for example, reading from stdin, the
* every-15-minutes flush). Those tasks should be placed into the fail-fast scope.
*/
interface KillableScope : ScopedTask

interface WrappedTask<T : Task> : Task {
val innerTask: T
interface WrappedTask : Task {
val innerTask: Task
}

@Singleton
@Secondary
class TaskScopeProvider(config: DestinationConfiguration) {
private val log = KotlinLogging.logger {}

private val timeoutMs = config.gracefulCancellationTimeoutMs

data class ControlScope(
val name: String,
val job: CompletableJob,
val dispatcher: CoroutineDispatcher
) {
val scope: CoroutineScope = CoroutineScope(dispatcher + job)
val runningJobs: AtomicLong = AtomicLong(0)
}

private val internalScope = ControlScope("internal", Job(), Dispatchers.IO)

private val implementorScope =
ControlScope(
"implementor",
Job(),
Executors.newFixedThreadPool(config.maxNumImplementorTaskThreads)
.asCoroutineDispatcher()
)

private val failFastScope = ControlScope("input", Job(), Dispatchers.IO)

suspend fun launch(task: WrappedTask<ScopedTask>) {
val scope =
when (task.innerTask) {
is InternalScope -> internalScope
is ImplementorScope -> implementorScope
is KillableScope -> failFastScope
private val supervisor = Job()
private val ioScope = CoroutineScope(Dispatchers.IO + supervisor)
private val defaultScope = CoroutineScope(Dispatchers.Default + supervisor)
private val killOnSyncFailure = ConcurrentHashSet<Job>()
private val cancelAtEndOfSync = ConcurrentHashSet<Job>()

suspend fun launch(task: Task) {
val scope = if (task.isIO) ioScope else defaultScope
val job =
scope.launch {
log.info { "Launching $task" }
task.execute()
log.info { "Task $task completed" }
}
scope.scope.launch {
var nJobs = scope.runningJobs.incrementAndGet()
log.info { "Launching task $task in scope ${scope.name} ($nJobs now running)" }
val elapsed = measureTimeMillis { task.execute() }
nJobs = scope.runningJobs.decrementAndGet()
log.info { "Task $task completed in $elapsed ms ($nJobs now running)" }
if (task.cancelAtEndOfSync) {
cancelAtEndOfSync.add(job)
}
if (task.killOnSyncFailure) {
killOnSyncFailure.add(job)
}
}

suspend fun close() {
// Under normal operation, all tasks should be complete
// (except things like force flush, which loop). So
// - it's safe to force cancel the internal tasks
// - implementor scope should join immediately
log.info { "Closing task scopes (${implementorScope.runningJobs.get()} remaining)" }
log.info { "Closing normally, canceling long-running tasks" }
cancelAtEndOfSync.forEach { it.cancel() }

val uncaughtExceptions = AtomicReference<Throwable>()
implementorScope.job.children.forEach {
log.info { "Verifying task completion" }
supervisor.children.forEach {
it.invokeOnCompletion { cause ->
if (cause != null) {
log.error { "Uncaught exception in implementor task: $cause" }
log.error { "Uncaught exception in task: $cause" }
uncaughtExceptions.set(cause)
}
}
}
implementorScope.job.complete()
implementorScope.job.join()
if (uncaughtExceptions.get() != null) {
throw IllegalStateException(
"Uncaught exceptions in implementor tasks",
uncaughtExceptions.get()
)
}
log.info {
"Implementor tasks completed, cancelling internal tasks (${internalScope.runningJobs.get()} remaining)."
throw uncaughtExceptions.get()
}
internalScope.job.cancel()
}

suspend fun kill() {
log.info { "Killing task scopes" }
// Terminate tasks which should be immediately terminated
failFastScope.job.cancel()
log.info { "Failing, killing input tasks and canceling long-running tasks" }
killOnSyncFailure.forEach { it.cancel() }
cancelAtEndOfSync.forEach { it.cancel() }

// Give the implementor tasks a chance to fail gracefully
withTimeoutOrNull(timeoutMs) {
log.info {
"Cancelled internal tasks, waiting ${timeoutMs}ms for implementor tasks to complete"
}
implementorScope.job.complete()
implementorScope.job.join()
supervisor.complete()
log.info { "Implementor tasks completed" }
}
?: run {
log.error { "Implementor tasks did not complete within ${timeoutMs}ms, cancelling" }
implementorScope.job.cancel()
supervisor.cancel()
}

log.info { "Cancelling internal tasks" }
internalScope.job.cancel()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ package io.airbyte.cdk.load.task.implementor
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.task.Task
import io.airbyte.cdk.load.write.StreamLoader
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton

interface CloseStreamTask : ImplementorScope
interface CloseStreamTask : Task

/**
* Wraps @[StreamLoader.close] and marks the stream as closed in the stream manager. Also starts the
Expand All @@ -24,6 +24,9 @@ class DefaultCloseStreamTask(
val streamDescriptor: DestinationStream.Descriptor,
private val taskLauncher: DestinationTaskLauncher
) : CloseStreamTask {
override val isIO = true
override val cancelAtEndOfSync = false
override val killOnSyncFailure = false

override suspend fun execute() {
val streamLoader = syncManager.getOrAwaitStreamLoader(streamDescriptor)
Expand Down
Loading

0 comments on commit 16e598a

Please sign in to comment.