Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: thread safe partition assignment handler #1729

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ private[kafka] trait SourceLogicSubscription {

/**
* Opportunity for subclasses to add a different logic to the partition assignment callbacks.
*
* Note: called from consumer actor, returned handler must be thread safe
*/
protected def addToPartitionAssignmentHandler(handler: PartitionAssignmentHandler): PartitionAssignmentHandler =
handler
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/akka/kafka/internal/SubSourceLogic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,14 @@ private class SubSourceLogic[K, V, Msg](
override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
for {
tp <- lastRevoked -- assignedTps
// FIXME subSources is mutable internal state of logic, this is not thread safe
control <- subSources.get(tp)
} control.filterRevokedPartitionsCB.invoke(Set(tp))

override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
for {
tp <- lostTps
// FIXME subSources is mutable internal state of logic, this is not thread safe
control <- subSources.get(tp)
} control.filterRevokedPartitionsCB.invoke(Set(tp))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ private final class TransactionalProducerStageLogic[K, V, P](

override def onCompletionFailure(ex: Throwable): Unit = {
abortTransaction(s"Stage failure ($ex)")
if (commitInProgress)
log.warning(
"Stage onCompleteFailure with commit in flight"
)
batchOffsets.committingFailed()
super.onCompletionFailure(ex)
}
Expand Down
228 changes: 134 additions & 94 deletions core/src/main/scala/akka/kafka/internal/TransactionalSources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import akka.actor.{ActorRef, Status, Terminated}
import akka.actor.Status.Failure
import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts
import akka.pattern.ask
import akka.kafka.ConsumerMessage.{PartitionOffset, TransactionalMessage}
import akka.kafka.internal.KafkaConsumerActor.Internal.Revoked
import akka.kafka.internal.SubSourceLogic._
Expand All @@ -25,8 +26,9 @@ import akka.util.Timeout
import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerGroupMetadata, ConsumerRecord, OffsetAndMetadata}
import org.apache.kafka.common.{IsolationLevel, TopicPartition}

import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise, TimeoutException}
import scala.util.{Success, Try}

/** Internal API */
@InternalApi
Expand Down Expand Up @@ -95,13 +97,16 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour
override protected def logSource: Class[_] = classOf[TransactionalSourceLogic[_, _, _]]

private val inFlightRecords = InFlightRecords.empty
private val onRevokeCB = getAsyncCallback[Revoke](onRevoke).invoke _
private val onRevokeDrainDoneCB = getAsyncCallback[(Revoke, Try[RevokeDrainDone.type])](onRevokeDrainDone).invoke _

override def messageHandling = super.messageHandling.orElse(drainHandling).orElse {
case (_, Revoked(tps)) =>
inFlightRecords.revoke(tps.toSet)
}
override def messageHandling: PartialFunction[(ActorRef, Any), Unit] =
super.messageHandling.orElse(drainHandling).orElse {
case (_, Revoked(tps)) =>
inFlightRecords.revoke(tps.toSet)
}

override def shuttingDownReceive =
override def shuttingDownReceive: PartialFunction[(ActorRef, Any), Unit] =
super.shuttingDownReceive
.orElse(drainHandling)
.orElse {
Expand All @@ -113,18 +118,19 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour

private def drainHandling: PartialFunction[(ActorRef, Any), Unit] = {
case (sender, Committed(offsets)) =>
inFlightRecords.committed(offsets.iterator.map { case (k, v) => k -> (v.offset() - 1L) }.toMap)
inFlightRecords.committed(offsets.view.mapValues(v => v.offset() - 1L).toMap)
sender.tell(Done, sourceActor.ref)
case (_, CommittingFailure) => {
case (_, CommittingFailure) =>
log.info("Committing failed, resetting in flight offsets")
inFlightRecords.reset()
}
case (sender, Drain(partitions, ack, msg)) =>
if (inFlightRecords.empty(partitions)) {
log.debug(s"Partitions drained ${partitions.mkString(",")}")
if (log.isDebugEnabled)
log.debug(s"Partitions drained [{}]", partitions.mkString(","))
ack.getOrElse(sender).tell(msg, sourceActor.ref)
} else {
log.debug(s"Draining partitions {}", partitions)
if (log.isDebugEnabled)
log.debug(s"Draining partitions [{}]", partitions.mkString(", "))
materializer.scheduleOnce(
consumerSettings.drainingCheckInterval,
() => sourceActor.ref.tell(Drain(partitions, ack.orElse(Some(sender)), msg), sourceActor.ref)
Expand All @@ -151,45 +157,45 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour

override protected def addToPartitionAssignmentHandler(
handler: PartitionAssignmentHandler
): PartitionAssignmentHandler = {
// FIXME this touches mutable internal stage fields (sourceActor, stageActor, consumerActor, subSources) from
// another thread (consumer actor) not thread safe
val blockingRevokedCall = new PartitionAssignmentHandler {
override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

// This is invoked in the KafkaConsumerActor thread when doing poll.
override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
if (waitForDraining(revokedTps)) {
sourceActor.ref.tell(Revoked(revokedTps.toList), consumerActor)
} else {
sourceActor.ref.tell(Failure(new Error("Timeout while draining")), consumerActor)
consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), consumerActor)
}

override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
onRevoke(lostTps, consumer)

override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()
): PartitionAssignmentHandler =
new PartitionAssignmentHelpers.Chain(handler,
createBlockingPartitionAssignmentHandler(consumerSettings, onRevokeCB))

def onRevoke(revoke: Revoke): Unit = {
// Tricky chain of async interactions - draining is a timed async wait and both steps
// needs to interact with stage internal mutable state, and finally complete or fail a promise
// whose future the blocking partition assignment handler blocks the consumer on.
// Simplifying is tricky since other logic depends on message-send-drain
if (log.isDebugEnabled)
log.debug("onRevoke [{}]", revoke.partitions.mkString(","))
stageActor.ref
.ask(Drain(revoke.partitions, None, Drained))(consumerSettings.commitTimeout)
.transform(tryDrain => Success((revoke, tryDrain.map(_ => RevokeDrainDone))))(ExecutionContexts.parasitic)
.foreach(onRevokeDrainDoneCB)(ExecutionContexts.parasitic)
}

private def waitForDraining(partitions: Set[TopicPartition]): Boolean = {
import akka.pattern.ask
implicit val timeout = Timeout(consumerSettings.commitTimeout)
try {
Await.result(ask(stageActor.ref, Drain(partitions, None, Drained)), timeout.duration)
true
} catch {
case t: Throwable =>
false
}
}
def onRevokeDrainDone(revokeDrainDone: (Revoke, Try[RevokeDrainDone.type])): Unit = {
revokeDrainDone match {
case (revoke, scala.util.Success(RevokeDrainDone)) =>
if (log.isDebugEnabled)
log.debug("onRevokeDrainDone [{}]", revoke.partitions.mkString(","))
inFlightRecords.revoke(revoke.partitions)
revoke.revokeCompletion.success(Done)
case (revoke, scala.util.Failure(ex)) =>
if (log.isDebugEnabled)
log.debug("onRevokeDrainDone, drain failed [{}] ({})", revoke.partitions.mkString(","), ex)
stageActor.ref.tell(Failure(new TimeoutException(s"Timeout while draining ($ex)")), consumerActor)
consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), consumerActor)
// we don't signal failure back, just completion
revoke.revokeCompletion.success(Done)
}
new PartitionAssignmentHelpers.Chain(handler, blockingRevokedCall)
}

override def requestConsumerGroupMetadata(): Future[ConsumerGroupMetadata] = {
import akka.pattern.ask
implicit val timeout: Timeout = 5.seconds // FIXME specific timeout config for this?
ask(consumerActor, KafkaConsumerActor.Internal.GetConsumerGroupMetadata)(timeout)
def requestConsumerGroupMetadata(): Future[ConsumerGroupMetadata] = {
// use some sensible existing timeout setting for this consumer actor ask
implicit val timeout: Timeout = consumerSettings.metadataRequestTimeout
consumerActor
.ask(KafkaConsumerActor.Internal.GetConsumerGroupMetadata)
.mapTo[ConsumerGroupMetadata]
}
}
Expand Down Expand Up @@ -241,49 +247,56 @@ private[kafka] final class TransactionalSubSource[K, V](

new SubSourceLogic(shape, txConsumerSettings, subscription, subSourceStageLogicFactory = factory) {

private val onRevokeCB = getAsyncCallback[Revoke](onRevoke).invoke _
private val onRevokeDrainDoneCB =
getAsyncCallback[(Revoke, Try[RevokeDrainDone.type])](onRevokeDrainDone).invoke _

override protected def addToPartitionAssignmentHandler(
handler: PartitionAssignmentHandler
): PartitionAssignmentHandler = {
// FIXME this touches mutable internal stage fields (sourceActor, stageActor, consumerActor, subSources) from
// another thread (consumer actor) not thread safe
val blockingRevokedCall = new PartitionAssignmentHandler {
override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

// This is invoked in the KafkaConsumerActor thread when doing poll.
override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
if (revokedTps.isEmpty) ()
else if (waitForDraining(revokedTps)) {
subSources.values
.map(_.controlAndStageActor.stageActor)
.foreach(_.tell(Revoked(revokedTps.toList), stageActor.ref))
} else {
sourceActor.ref.tell(Status.Failure(new Error("Timeout while draining")), stageActor.ref)
consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), stageActor.ref)
}

override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
onRevoke(lostTps, consumer)

override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()
}
new PartitionAssignmentHelpers.Chain(handler, blockingRevokedCall)
): PartitionAssignmentHandler =
new PartitionAssignmentHelpers.Chain(
handler,
TransactionalSourceLogic.createBlockingPartitionAssignmentHandler(consumerSettings, onRevokeCB)
)

def onRevoke(revoke: Revoke): Unit = {
// Tricky chain of async interactions - draining is a timed async wait and both steps
// needs to interact with stage internal mutable state, and finally complete or fail a promise
// whose future the blocking partition assignment handler blocks the consumer on.
// Simplifying is tricky since other logic depends on message-send-drain
implicit val timeout: Timeout = Timeout(txConsumerSettings.commitTimeout)
implicit val ec: ExecutionContext = materializer.executionContext
if (log.isDebugEnabled)
log.debug("onRevoke [{}]", revoke.partitions.mkString(","))
val drainCommandFutures =
Future.sequence(subSources.values.map(_.stageActor.ask(Drain(revoke.partitions, None, Drained))))
drainCommandFutures
.transform(tryDrainAll => Success((revoke, tryDrainAll.map(_ => RevokeDrainDone))))(
ExecutionContexts.parasitic
)
.foreach(onRevokeDrainDoneCB)
}

private def waitForDraining(partitions: Set[TopicPartition]): Boolean = {
import akka.pattern.ask
implicit val timeout = Timeout(txConsumerSettings.commitTimeout)
try {
val drainCommandFutures =
subSources.values.map(_.stageActor).map(ask(_, Drain(partitions, None, Drained)))
implicit val ec = executionContext
Await.result(Future.sequence(drainCommandFutures), timeout.duration)
true
} catch {
case t: Throwable =>
false
def onRevokeDrainDone(revokeDrainDone: (Revoke, Try[RevokeDrainDone.type])): Unit = {
revokeDrainDone match {
case (revoke, scala.util.Success(RevokeDrainDone)) =>
if (log.isDebugEnabled)
log.debug("onRevokeDrainDone [{}]", revoke.partitions.mkString(","))
subSources.values
.map(_.controlAndStageActor.stageActor)
.foreach(_.tell(Revoked(revoke.partitions.toList), stageActor.ref))
revoke.revokeCompletion.success(Done)
case (revoke, scala.util.Failure(ex)) =>
if (log.isDebugEnabled)
log.debug("onRevokeDrainDone, drain failed [{}] ({})", revoke.partitions.mkString(","), ex)
sourceActor.ref.tell(Status.Failure(new TimeoutException("Timeout while draining")), stageActor.ref)
consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), stageActor.ref)
// we don't signal failure back, just completion
revoke.revokeCompletion.success(Done)
}
}
}

}
}

Expand All @@ -299,10 +312,12 @@ private object TransactionalSourceLogic {
final case class Committed(offsets: Map[TopicPartition, OffsetAndMetadata])
case object CommittingFailure

final case class Revoke(partitions: Set[TopicPartition], revokeCompletion: Promise[Done])
case object RevokeDrainDone

private[internal] final case class CommittedMarkerRef(sourceActor: ActorRef, commitTimeout: FiniteDuration)
extends CommittedMarker {
override def committed(offsets: Map[TopicPartition, OffsetAndMetadata]): Future[Done] = {
import akka.pattern.ask
sourceActor
.ask(Committed(offsets))(Timeout(commitTimeout))
.map(_ => Done)(ExecutionContexts.parasitic)
Expand Down Expand Up @@ -352,6 +367,28 @@ private object TransactionalSourceLogic {
override def assigned(): Set[TopicPartition] = inFlightRecords.keySet
}
}

private[internal] def createBlockingPartitionAssignmentHandler(
consumerSettings: ConsumerSettings[_, _],
revokeCallback: Revoke => Unit
): PartitionAssignmentHandler =
new PartitionAssignmentHandler {
override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

// This is invoked in the KafkaConsumerActor thread when doing poll.
override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = {
if (revokedTps.nonEmpty) {
val revokeDone = Promise[Done]()
revokeCallback(Revoke(revokedTps, revokeDone))
Await.result(revokeDone.future, consumerSettings.commitTimeout * 2)
}
}

override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
onRevoke(lostTps, consumer)

override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()
}
}

@InternalApi
Expand Down Expand Up @@ -416,31 +453,34 @@ private final class TransactionalSubSourceStageLogic[K, V](

private def drainHandling: PartialFunction[(ActorRef, Any), Unit] = {
case (sender, Committed(offsets)) =>
inFlightRecords.committed(offsets.iterator.map { case (k, v) => k -> (v.offset() - 1L) }.toMap)
inFlightRecords.committed(offsets.view.mapValues(v => v.offset() - 1L).toMap)
sender ! Done
case (sender, CommittingFailure) => {
case (_, CommittingFailure) =>
log.info("Committing failed, resetting in flight offsets")
inFlightRecords.reset()
}

case (sender, Drain(partitions, ack, msg)) =>
if (inFlightRecords.empty(partitions)) {
log.debug(s"Partitions drained ${partitions.mkString(",")}")
if (log.isDebugEnabled)
log.debug(s"Partitions drained [{}]", partitions.mkString(","))
ack.getOrElse(sender) ! msg
} else {
log.debug(s"Draining partitions {}", partitions)
if (log.isDebugEnabled)
log.debug(s"Draining partitions [{}]", partitions.mkString(","))
materializer.scheduleOnce(
consumerSettings.drainingCheckInterval,
() => subSourceActor.ref.tell(Drain(partitions, ack.orElse(Some(sender)), msg), stageActor.ref)
)
}
case (sender, DrainingComplete) =>
case (_, DrainingComplete) =>
completeStage()
}

override def requestConsumerGroupMetadata(): Future[ConsumerGroupMetadata] = {
implicit val timeout: Timeout = 5.seconds // FIXME specific timeout config for this?
akka.pattern
.ask(consumerActor, KafkaConsumerActor.Internal.GetConsumerGroupMetadata)(timeout)
def requestConsumerGroupMetadata(): Future[ConsumerGroupMetadata] = {
// use some sensible existing timeout setting for this consumer actor ask
implicit val timeout: Timeout = consumerSettings.metadataRequestTimeout
consumerActor
.ask(KafkaConsumerActor.Internal.GetConsumerGroupMetadata)
.mapTo[ConsumerGroupMetadata]
}

Expand Down