From cfd4ea8923a347ee15debe1d80b945f34c2f05b9 Mon Sep 17 00:00:00 2001 From: jules Ivanic Date: Sat, 17 Jun 2023 18:43:56 +0400 Subject: [PATCH] Starting a new consumption session with a subscription which is invalid with the previous ones should fail with `InvalidSubscriptionUnion` error --- .../zio/kafka/consumer/internal/Runloop.scala | 19 +++++++++++++------ .../consumer/internal/RunloopAccess.scala | 12 ++++++------ .../consumer/internal/RunloopCommand.scala | 8 ++++++-- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala index a42688aa5d..ba7a10552d 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala @@ -51,8 +51,15 @@ private[consumer] final class Runloop private ( ) .unit - private[internal] def addSubscription(subscription: Subscription): UIO[Unit] = - commandQueue.offer(RunloopCommand.AddSubscription(subscription)).unit + private[internal] def addSubscription(subscription: Subscription): IO[InvalidSubscriptionUnion, Unit] = + for { + _ <- ZIO.logDebug(s"Add subscription $subscription") + promise <- Promise.make[InvalidSubscriptionUnion, Unit] + _ <- commandQueue.offer(RunloopCommand.AddSubscription(subscription, promise)) + _ <- ZIO.logDebug(s"Waiting for subscription $subscription") + _ <- promise.await + _ <- ZIO.logDebug(s"Done for subscription $subscription") + } yield () private[internal] def removeSubscription(subscription: Subscription): UIO[Unit] = commandQueue.offer(RunloopCommand.RemoveSubscription(subscription)).unit @@ -401,24 +408,24 @@ private[consumer] final class Runloop private ( cmd match { case req: RunloopCommand.Request => ZIO.succeed(state.addRequest(req)) case cmd: RunloopCommand.Commit => doCommit(cmd).as(state.addCommit(cmd)) - case RunloopCommand.AddSubscription(newSubscription) => + case cmd @ RunloopCommand.AddSubscription(newSubscription, _) => state.subscriptionState match { case SubscriptionState.NotSubscribed => val newSubState = SubscriptionState.Subscribed(subscriptions = Set(newSubscription), union = newSubscription) - doChangeSubscription(newSubState) + cmd.succeed *> doChangeSubscription(newSubState) case SubscriptionState.Subscribed(existingSubscriptions, _) => val subs = NonEmptyChunk.fromIterable(newSubscription, existingSubscriptions) Subscription.unionAll(subs) match { - case None => ZIO.fail(InvalidSubscriptionUnion(subs)) + case None => cmd.fail(InvalidSubscriptionUnion(subs)).as(state) case Some(union) => val newSubState = SubscriptionState.Subscribed( subscriptions = existingSubscriptions + newSubscription, union = union ) - doChangeSubscription(newSubState) + cmd.succeed *> doChangeSubscription(newSubState) } } case RunloopCommand.RemoveSubscription(subscription) => diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/RunloopAccess.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/RunloopAccess.scala index b54e72e2f8..e5f8399c15 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/RunloopAccess.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/RunloopAccess.scala @@ -5,9 +5,9 @@ import zio.kafka.consumer.diagnostics.DiagnosticEvent.Finalization import zio.kafka.consumer.diagnostics.Diagnostics import zio.kafka.consumer.internal.Runloop.ByteArrayCommittableRecord import zio.kafka.consumer.internal.RunloopAccess.PartitionAssignment -import zio.kafka.consumer.{ ConsumerSettings, Subscription } +import zio.kafka.consumer.{ ConsumerSettings, InvalidSubscriptionUnion, Subscription } import zio.stream.{ Stream, Take, UStream, ZStream } -import zio.{ durationInt, Hub, Ref, Scope, UIO, ZIO, ZLayer } +import zio.{ durationInt, Hub, IO, Ref, Scope, UIO, ZIO, ZLayer } private[internal] sealed trait RunloopState private[internal] object RunloopState { @@ -33,10 +33,10 @@ private[consumer] final class RunloopAccess private ( runloopStateRef.updateSomeAndGetZIO { case RunloopState.NotStarted if shouldStartIfNot => makeRunloop.map(RunloopState.Started.apply) } - private def withRunloopZIO[A](shouldStartIfNot: Boolean)(f: Runloop => UIO[A]): UIO[A] = + private def withRunloopZIO[E, A](shouldStartIfNot: Boolean)(f: Runloop => IO[E, A]): IO[E, A] = runloop(shouldStartIfNot).flatMap { - case RunloopState.Stopped => ZIO.unit.asInstanceOf[UIO[A]] - case RunloopState.NotStarted => ZIO.unit.asInstanceOf[UIO[A]] + case RunloopState.Stopped => ZIO.unit.asInstanceOf[IO[E, A]] + case RunloopState.NotStarted => ZIO.unit.asInstanceOf[IO[E, A]] case RunloopState.Started(runloop) => f(runloop) } @@ -71,7 +71,7 @@ private[consumer] final class RunloopAccess private ( */ def subscribe( subscription: Subscription - ): ZIO[Scope, Throwable, UStream[Take[Throwable, PartitionAssignment]]] = + ): ZIO[Scope, InvalidSubscriptionUnion, UStream[Take[Throwable, PartitionAssignment]]] = for { stream <- ZStream.fromHubScoped(partitionHub) // starts the Runloop if not already started diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/RunloopCommand.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/RunloopCommand.scala index 4bfc575eef..f737c58641 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/RunloopCommand.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/RunloopCommand.scala @@ -2,7 +2,7 @@ package zio.kafka.consumer.internal import org.apache.kafka.common.TopicPartition import zio._ -import zio.kafka.consumer.Subscription +import zio.kafka.consumer.{ InvalidSubscriptionUnion, Subscription } sealed trait RunloopCommand object RunloopCommand { @@ -27,7 +27,11 @@ object RunloopCommand { /** Used by a stream to request more records. */ final case class Request(tp: TopicPartition) extends StreamCommand - final case class AddSubscription(subscription: Subscription) extends StreamCommand + final case class AddSubscription(subscription: Subscription, cont: Promise[InvalidSubscriptionUnion, Unit]) + extends StreamCommand { + @inline def succeed: UIO[Unit] = cont.succeed(()).unit + @inline def fail(e: InvalidSubscriptionUnion): UIO[Unit] = cont.fail(e).unit + } final case class RemoveSubscription(subscription: Subscription) extends StreamCommand case object RemoveAllSubscriptions extends StreamCommand }