From 2182903bcf4c716ff89bc0a5192f67ea2006c496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Andr=C3=A9n?= Date: Thu, 4 Apr 2024 17:02:20 +0200 Subject: [PATCH] wip: Make the PartitionAssignmentHandler thread safe Not passing the TransactionsSourceSpec though, I think something shuts down too early --- .../internal/SourceLogicSubscription.scala | 2 + .../akka/kafka/internal/SubSourceLogic.scala | 2 + .../internal/TransactionalProducerStage.scala | 4 + .../kafka/internal/TransactionalSources.scala | 228 ++++++++++-------- 4 files changed, 142 insertions(+), 94 deletions(-) diff --git a/core/src/main/scala/akka/kafka/internal/SourceLogicSubscription.scala b/core/src/main/scala/akka/kafka/internal/SourceLogicSubscription.scala index 74f3bfc5b..a4cb979c4 100644 --- a/core/src/main/scala/akka/kafka/internal/SourceLogicSubscription.scala +++ b/core/src/main/scala/akka/kafka/internal/SourceLogicSubscription.scala @@ -68,6 +68,8 @@ private[kafka] trait SourceLogicSubscription { /** * Opportunity for subclasses to add a different logic to the partition assignment callbacks. + * + * Note: called from consumer actor, returned handler must be thread safe */ protected def addToPartitionAssignmentHandler(handler: PartitionAssignmentHandler): PartitionAssignmentHandler = handler diff --git a/core/src/main/scala/akka/kafka/internal/SubSourceLogic.scala b/core/src/main/scala/akka/kafka/internal/SubSourceLogic.scala index 64d8bc8e7..67c365880 100644 --- a/core/src/main/scala/akka/kafka/internal/SubSourceLogic.scala +++ b/core/src/main/scala/akka/kafka/internal/SubSourceLogic.scala @@ -305,12 +305,14 @@ private class SubSourceLogic[K, V, Msg]( override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = for { tp <- lastRevoked -- assignedTps + // FIXME subSources is mutable internal state of logic, this is not thread safe control <- subSources.get(tp) } control.filterRevokedPartitionsCB.invoke(Set(tp)) override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = for { tp <- lostTps + // FIXME subSources is mutable internal state of logic, this is not thread safe control <- subSources.get(tp) } control.filterRevokedPartitionsCB.invoke(Set(tp)) diff --git a/core/src/main/scala/akka/kafka/internal/TransactionalProducerStage.scala b/core/src/main/scala/akka/kafka/internal/TransactionalProducerStage.scala index 2c952d1d5..09b3d315a 100644 --- a/core/src/main/scala/akka/kafka/internal/TransactionalProducerStage.scala +++ b/core/src/main/scala/akka/kafka/internal/TransactionalProducerStage.scala @@ -257,6 +257,10 @@ private final class TransactionalProducerStageLogic[K, V, P]( override def onCompletionFailure(ex: Throwable): Unit = { abortTransaction(s"Stage failure ($ex)") + if (commitInProgress) + log.warning( + "Stage onCompleteFailure with commit in flight" + ) batchOffsets.committingFailed() super.onCompletionFailure(ex) } diff --git a/core/src/main/scala/akka/kafka/internal/TransactionalSources.scala b/core/src/main/scala/akka/kafka/internal/TransactionalSources.scala index c741cbc78..958f7c0c1 100644 --- a/core/src/main/scala/akka/kafka/internal/TransactionalSources.scala +++ b/core/src/main/scala/akka/kafka/internal/TransactionalSources.scala @@ -11,6 +11,7 @@ import akka.actor.{ActorRef, Status, Terminated} import akka.actor.Status.Failure import akka.annotation.InternalApi import akka.dispatch.ExecutionContexts +import akka.pattern.ask import akka.kafka.ConsumerMessage.{PartitionOffset, TransactionalMessage} import akka.kafka.internal.KafkaConsumerActor.Internal.Revoked import akka.kafka.internal.SubSourceLogic._ @@ -25,8 +26,9 @@ import akka.util.Timeout import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerGroupMetadata, ConsumerRecord, OffsetAndMetadata} import org.apache.kafka.common.{IsolationLevel, TopicPartition} -import scala.concurrent.duration.{DurationInt, FiniteDuration} -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future, Promise, TimeoutException} +import scala.util.{Success, Try} /** Internal API */ @InternalApi @@ -95,13 +97,16 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour override protected def logSource: Class[_] = classOf[TransactionalSourceLogic[_, _, _]] private val inFlightRecords = InFlightRecords.empty + private val onRevokeCB = getAsyncCallback[Revoke](onRevoke).invoke _ + private val onRevokeDrainDoneCB = getAsyncCallback[(Revoke, Try[RevokeDrainDone.type])](onRevokeDrainDone).invoke _ - override def messageHandling = super.messageHandling.orElse(drainHandling).orElse { - case (_, Revoked(tps)) => - inFlightRecords.revoke(tps.toSet) - } + override def messageHandling: PartialFunction[(ActorRef, Any), Unit] = + super.messageHandling.orElse(drainHandling).orElse { + case (_, Revoked(tps)) => + inFlightRecords.revoke(tps.toSet) + } - override def shuttingDownReceive = + override def shuttingDownReceive: PartialFunction[(ActorRef, Any), Unit] = super.shuttingDownReceive .orElse(drainHandling) .orElse { @@ -113,18 +118,19 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour private def drainHandling: PartialFunction[(ActorRef, Any), Unit] = { case (sender, Committed(offsets)) => - inFlightRecords.committed(offsets.iterator.map { case (k, v) => k -> (v.offset() - 1L) }.toMap) + inFlightRecords.committed(offsets.view.mapValues(v => v.offset() - 1L).toMap) sender.tell(Done, sourceActor.ref) - case (_, CommittingFailure) => { + case (_, CommittingFailure) => log.info("Committing failed, resetting in flight offsets") inFlightRecords.reset() - } case (sender, Drain(partitions, ack, msg)) => if (inFlightRecords.empty(partitions)) { - log.debug(s"Partitions drained ${partitions.mkString(",")}") + if (log.isDebugEnabled) + log.debug(s"Partitions drained [{}]", partitions.mkString(",")) ack.getOrElse(sender).tell(msg, sourceActor.ref) } else { - log.debug(s"Draining partitions {}", partitions) + if (log.isDebugEnabled) + log.debug(s"Draining partitions [{}]", partitions.mkString(", ")) materializer.scheduleOnce( consumerSettings.drainingCheckInterval, () => sourceActor.ref.tell(Drain(partitions, ack.orElse(Some(sender)), msg), sourceActor.ref) @@ -151,45 +157,45 @@ private[internal] abstract class TransactionalSourceLogic[K, V, Msg](shape: Sour override protected def addToPartitionAssignmentHandler( handler: PartitionAssignmentHandler - ): PartitionAssignmentHandler = { - // FIXME this touches mutable internal stage fields (sourceActor, stageActor, consumerActor, subSources) from - // another thread (consumer actor) not thread safe - val blockingRevokedCall = new PartitionAssignmentHandler { - override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = () - - // This is invoked in the KafkaConsumerActor thread when doing poll. - override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = - if (waitForDraining(revokedTps)) { - sourceActor.ref.tell(Revoked(revokedTps.toList), consumerActor) - } else { - sourceActor.ref.tell(Failure(new Error("Timeout while draining")), consumerActor) - consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), consumerActor) - } - - override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = - onRevoke(lostTps, consumer) - - override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = () + ): PartitionAssignmentHandler = + new PartitionAssignmentHelpers.Chain(handler, + createBlockingPartitionAssignmentHandler(consumerSettings, onRevokeCB)) + + def onRevoke(revoke: Revoke): Unit = { + // Tricky chain of async interactions - draining is a timed async wait and both steps + // needs to interact with stage internal mutable state, and finally complete or fail a promise + // whose future the blocking partition assignment handler blocks the consumer on. + // Simplifying is tricky since other logic depends on message-send-drain + if (log.isDebugEnabled) + log.debug("onRevoke [{}]", revoke.partitions.mkString(",")) + stageActor.ref + .ask(Drain(revoke.partitions, None, Drained))(consumerSettings.commitTimeout) + .transform(tryDrain => Success((revoke, tryDrain.map(_ => RevokeDrainDone))))(ExecutionContexts.parasitic) + .foreach(onRevokeDrainDoneCB)(ExecutionContexts.parasitic) + } - private def waitForDraining(partitions: Set[TopicPartition]): Boolean = { - import akka.pattern.ask - implicit val timeout = Timeout(consumerSettings.commitTimeout) - try { - Await.result(ask(stageActor.ref, Drain(partitions, None, Drained)), timeout.duration) - true - } catch { - case t: Throwable => - false - } - } + def onRevokeDrainDone(revokeDrainDone: (Revoke, Try[RevokeDrainDone.type])): Unit = { + revokeDrainDone match { + case (revoke, scala.util.Success(RevokeDrainDone)) => + if (log.isDebugEnabled) + log.debug("onRevokeDrainDone [{}]", revoke.partitions.mkString(",")) + inFlightRecords.revoke(revoke.partitions) + revoke.revokeCompletion.success(Done) + case (revoke, scala.util.Failure(ex)) => + if (log.isDebugEnabled) + log.debug("onRevokeDrainDone, drain failed [{}] ({})", revoke.partitions.mkString(","), ex) + stageActor.ref.tell(Failure(new TimeoutException(s"Timeout while draining ($ex)")), consumerActor) + consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), consumerActor) + // we don't signal failure back, just completion + revoke.revokeCompletion.success(Done) } - new PartitionAssignmentHelpers.Chain(handler, blockingRevokedCall) } - override def requestConsumerGroupMetadata(): Future[ConsumerGroupMetadata] = { - import akka.pattern.ask - implicit val timeout: Timeout = 5.seconds // FIXME specific timeout config for this? - ask(consumerActor, KafkaConsumerActor.Internal.GetConsumerGroupMetadata)(timeout) + def requestConsumerGroupMetadata(): Future[ConsumerGroupMetadata] = { + // use some sensible existing timeout setting for this consumer actor ask + implicit val timeout: Timeout = consumerSettings.metadataRequestTimeout + consumerActor + .ask(KafkaConsumerActor.Internal.GetConsumerGroupMetadata) .mapTo[ConsumerGroupMetadata] } } @@ -241,49 +247,56 @@ private[kafka] final class TransactionalSubSource[K, V]( new SubSourceLogic(shape, txConsumerSettings, subscription, subSourceStageLogicFactory = factory) { + private val onRevokeCB = getAsyncCallback[Revoke](onRevoke).invoke _ + private val onRevokeDrainDoneCB = + getAsyncCallback[(Revoke, Try[RevokeDrainDone.type])](onRevokeDrainDone).invoke _ + override protected def addToPartitionAssignmentHandler( handler: PartitionAssignmentHandler - ): PartitionAssignmentHandler = { - // FIXME this touches mutable internal stage fields (sourceActor, stageActor, consumerActor, subSources) from - // another thread (consumer actor) not thread safe - val blockingRevokedCall = new PartitionAssignmentHandler { - override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = () - - // This is invoked in the KafkaConsumerActor thread when doing poll. - override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = - if (revokedTps.isEmpty) () - else if (waitForDraining(revokedTps)) { - subSources.values - .map(_.controlAndStageActor.stageActor) - .foreach(_.tell(Revoked(revokedTps.toList), stageActor.ref)) - } else { - sourceActor.ref.tell(Status.Failure(new Error("Timeout while draining")), stageActor.ref) - consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), stageActor.ref) - } - - override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = - onRevoke(lostTps, consumer) - - override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = () - } - new PartitionAssignmentHelpers.Chain(handler, blockingRevokedCall) + ): PartitionAssignmentHandler = + new PartitionAssignmentHelpers.Chain( + handler, + TransactionalSourceLogic.createBlockingPartitionAssignmentHandler(consumerSettings, onRevokeCB) + ) + + def onRevoke(revoke: Revoke): Unit = { + // Tricky chain of async interactions - draining is a timed async wait and both steps + // needs to interact with stage internal mutable state, and finally complete or fail a promise + // whose future the blocking partition assignment handler blocks the consumer on. + // Simplifying is tricky since other logic depends on message-send-drain + implicit val timeout: Timeout = Timeout(txConsumerSettings.commitTimeout) + implicit val ec: ExecutionContext = materializer.executionContext + if (log.isDebugEnabled) + log.debug("onRevoke [{}]", revoke.partitions.mkString(",")) + val drainCommandFutures = + Future.sequence(subSources.values.map(_.stageActor.ask(Drain(revoke.partitions, None, Drained)))) + drainCommandFutures + .transform(tryDrainAll => Success((revoke, tryDrainAll.map(_ => RevokeDrainDone))))( + ExecutionContexts.parasitic + ) + .foreach(onRevokeDrainDoneCB) } - private def waitForDraining(partitions: Set[TopicPartition]): Boolean = { - import akka.pattern.ask - implicit val timeout = Timeout(txConsumerSettings.commitTimeout) - try { - val drainCommandFutures = - subSources.values.map(_.stageActor).map(ask(_, Drain(partitions, None, Drained))) - implicit val ec = executionContext - Await.result(Future.sequence(drainCommandFutures), timeout.duration) - true - } catch { - case t: Throwable => - false + def onRevokeDrainDone(revokeDrainDone: (Revoke, Try[RevokeDrainDone.type])): Unit = { + revokeDrainDone match { + case (revoke, scala.util.Success(RevokeDrainDone)) => + if (log.isDebugEnabled) + log.debug("onRevokeDrainDone [{}]", revoke.partitions.mkString(",")) + subSources.values + .map(_.controlAndStageActor.stageActor) + .foreach(_.tell(Revoked(revoke.partitions.toList), stageActor.ref)) + revoke.revokeCompletion.success(Done) + case (revoke, scala.util.Failure(ex)) => + if (log.isDebugEnabled) + log.debug("onRevokeDrainDone, drain failed [{}] ({})", revoke.partitions.mkString(","), ex) + sourceActor.ref.tell(Status.Failure(new TimeoutException("Timeout while draining")), stageActor.ref) + consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), stageActor.ref) + // we don't signal failure back, just completion + revoke.revokeCompletion.success(Done) } } } + } } @@ -299,10 +312,12 @@ private object TransactionalSourceLogic { final case class Committed(offsets: Map[TopicPartition, OffsetAndMetadata]) case object CommittingFailure + final case class Revoke(partitions: Set[TopicPartition], revokeCompletion: Promise[Done]) + case object RevokeDrainDone + private[internal] final case class CommittedMarkerRef(sourceActor: ActorRef, commitTimeout: FiniteDuration) extends CommittedMarker { override def committed(offsets: Map[TopicPartition, OffsetAndMetadata]): Future[Done] = { - import akka.pattern.ask sourceActor .ask(Committed(offsets))(Timeout(commitTimeout)) .map(_ => Done)(ExecutionContexts.parasitic) @@ -352,6 +367,28 @@ private object TransactionalSourceLogic { override def assigned(): Set[TopicPartition] = inFlightRecords.keySet } } + + private[internal] def createBlockingPartitionAssignmentHandler( + consumerSettings: ConsumerSettings[_, _], + revokeCallback: Revoke => Unit + ): PartitionAssignmentHandler = + new PartitionAssignmentHandler { + override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = () + + // This is invoked in the KafkaConsumerActor thread when doing poll. + override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = { + if (revokedTps.nonEmpty) { + val revokeDone = Promise[Done]() + revokeCallback(Revoke(revokedTps, revokeDone)) + Await.result(revokeDone.future, consumerSettings.commitTimeout * 2) + } + } + + override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = + onRevoke(lostTps, consumer) + + override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = () + } } @InternalApi @@ -416,31 +453,34 @@ private final class TransactionalSubSourceStageLogic[K, V]( private def drainHandling: PartialFunction[(ActorRef, Any), Unit] = { case (sender, Committed(offsets)) => - inFlightRecords.committed(offsets.iterator.map { case (k, v) => k -> (v.offset() - 1L) }.toMap) + inFlightRecords.committed(offsets.view.mapValues(v => v.offset() - 1L).toMap) sender ! Done - case (sender, CommittingFailure) => { + case (_, CommittingFailure) => log.info("Committing failed, resetting in flight offsets") inFlightRecords.reset() - } + case (sender, Drain(partitions, ack, msg)) => if (inFlightRecords.empty(partitions)) { - log.debug(s"Partitions drained ${partitions.mkString(",")}") + if (log.isDebugEnabled) + log.debug(s"Partitions drained [{}]", partitions.mkString(",")) ack.getOrElse(sender) ! msg } else { - log.debug(s"Draining partitions {}", partitions) + if (log.isDebugEnabled) + log.debug(s"Draining partitions [{}]", partitions.mkString(",")) materializer.scheduleOnce( consumerSettings.drainingCheckInterval, () => subSourceActor.ref.tell(Drain(partitions, ack.orElse(Some(sender)), msg), stageActor.ref) ) } - case (sender, DrainingComplete) => + case (_, DrainingComplete) => completeStage() } - override def requestConsumerGroupMetadata(): Future[ConsumerGroupMetadata] = { - implicit val timeout: Timeout = 5.seconds // FIXME specific timeout config for this? - akka.pattern - .ask(consumerActor, KafkaConsumerActor.Internal.GetConsumerGroupMetadata)(timeout) + def requestConsumerGroupMetadata(): Future[ConsumerGroupMetadata] = { + // use some sensible existing timeout setting for this consumer actor ask + implicit val timeout: Timeout = consumerSettings.metadataRequestTimeout + consumerActor + .ask(KafkaConsumerActor.Internal.GetConsumerGroupMetadata) .mapTo[ConsumerGroupMetadata] }