diff --git a/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala b/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala index ca7a47495..8983813be 100644 --- a/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala +++ b/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala @@ -1349,96 +1349,68 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { val allMessages = (1 to messageCount).map(i => s"$i" -> f"msg$i%06d") val (messagesBeforeRebalance, messagesAfterRebalance) = allMessages.splitAt(messageCount / 2) - def transactionalRebalanceListener(streamCompleteOnRebalanceRef: Ref[Option[Promise[Nothing, Unit]]]) = - RebalanceListener( - onAssigned = _ => ZIO.unit, - onRevoked = _ => - streamCompleteOnRebalanceRef.get.flatMap { - case Some(p) => - ZIO.logDebug("onRevoked, awaiting stream completion") *> - p.await.timeoutFail(new InterruptedException("Timed out waiting stream to complete"))(1.minute) - case None => ZIO.unit - }, - onLost = _ => ZIO.logDebug("Lost some partitions") - ) - def makeCopyingTransactionalConsumer( name: String, consumerGroupId: String, clientId: String, fromTopic: String, toTopic: String, - tProducer: TransactionalProducer, - consumerCreated: Promise[Nothing, Unit] + consumerCreated: Promise[Throwable, Unit] ): ZIO[Kafka, Throwable, Unit] = ZIO.logAnnotate("consumer", name) { - for { - consumedMessagesCounter <- Ref.make(0) - _ <- consumedMessagesCounter.get - .flatMap(consumed => ZIO.logDebug(s"Consumed so far: $consumed")) - .repeat(Schedule.fixed(1.second)) - .fork - streamCompleteOnRebalanceRef <- Ref.make[Option[Promise[Nothing, Unit]]](None) - tConsumer <- - Consumer - .partitionedAssignmentStream(Subscription.topics(fromTopic), Serde.string, Serde.string) - .mapZIO { assignedPartitions => - for { - p <- Promise.make[Nothing, Unit] - _ <- streamCompleteOnRebalanceRef.set(Some(p)) - _ <- ZIO.logDebug(s"${assignedPartitions.size} partitions assigned") - _ <- consumerCreated.succeed(()) - partitionStreams = assignedPartitions.map(_._2) - s <- ZStream - .mergeAllUnbounded(64)(partitionStreams: _*) - .mapChunksZIO { records => - ZIO.scoped { - for { - t <- tProducer.createTransaction - _ <- t.produceChunkBatch( - records.map(r => new ProducerRecord(toTopic, r.key, r.value)), - Serde.string, - Serde.string, - OffsetBatch(records.map(_.offset)) - ) - _ <- consumedMessagesCounter.update(_ + records.size) - } yield Chunk.empty - }.uninterruptible - } - .runDrain - .ensuring { - for { - _ <- streamCompleteOnRebalanceRef.set(None) - _ <- p.succeed(()) - c <- consumedMessagesCounter.get - _ <- ZIO.logDebug(s"Consumed $c messages") - } yield () - } - } yield s - } - .runDrain - .provideSome[Kafka]( - transactionalConsumer( - clientId, - consumerGroupId, - restartStreamOnRebalancing = true, - properties = Map( - ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG -> - implicitly[ClassTag[T]].runtimeClass.getName, - ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> "200" - ), - rebalanceListener = transactionalRebalanceListener(streamCompleteOnRebalanceRef) + ZIO.scoped { + (for { + consumer <- ZIO.service[Consumer] + consumedMessagesCounter <- Ref.make(0) + _ <- consumedMessagesCounter.get + .flatMap(consumed => ZIO.logDebug(s"Consumed so far: $consumed")) + .repeat(Schedule.fixed(1.second)) + .fork + + transactionalId <- randomThing("transactional") + tProducerSettings <- transactionalProducerSettings(transactionalId) + tProducer <- + TransactionalProducer.make(tProducerSettings, consumer) + + tConsumer <- + consumer + .partitionedStream(Subscription.topics(fromTopic), Serde.string, Serde.string) + .flatMapPar(Int.MaxValue) { case (_, partitionStream) => + ZStream.fromZIO(consumerCreated.succeed(())) *> + partitionStream.mapChunksZIO { records => + ZIO.scoped { + for { + t <- tProducer.createTransaction + _ <- t.produceChunkBatch( + records.map(r => new ProducerRecord(toTopic, r.key, r.value)), + Serde.string, + Serde.string, + OffsetBatch(records.map(_.offset)) + ) + _ <- consumedMessagesCounter.update(_ + records.size) + } yield Chunk.empty + } + } + } + .runDrain + .tapError(e => ZIO.logError(s"Error: $e") *> consumerCreated.fail(e)) <* ZIO.logDebug("Done") + } yield tConsumer) + .provideSome[Kafka & Scope]( + transactionalConsumer( + clientId, + consumerGroupId, + rebalanceSafeCommits = true, + properties = Map( + ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG -> + implicitly[ClassTag[T]].runtimeClass.getName, + ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> "200" ) ) - .tapError(e => ZIO.logError(s"Error: $e")) <* ZIO.logDebug("Done") - } yield tConsumer + ) + } } for { - transactionalId <- randomThing("transactional") - tProducerSettings <- transactionalProducerSettings(transactionalId) - tProducer <- TransactionalProducer.make(tProducerSettings) - topicA <- randomTopic topicB <- randomTopic _ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topicA, partitions = partitionCount)) @@ -1450,28 +1422,26 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom { _ <- ZIO.logDebug("Starting copier 1") copier1ClientId = copyingGroup + "-1" - copier1Created <- Promise.make[Nothing, Unit] + copier1Created <- Promise.make[Throwable, Unit] copier1 <- makeCopyingTransactionalConsumer( "1", copyingGroup, copier1ClientId, topicA, topicB, - tProducer, copier1Created ).fork _ <- copier1Created.await _ <- ZIO.logDebug("Starting copier 2") copier2ClientId = copyingGroup + "-2" - copier2Created <- Promise.make[Nothing, Unit] + copier2Created <- Promise.make[Throwable, Unit] copier2 <- makeCopyingTransactionalConsumer( "2", copyingGroup, copier2ClientId, topicA, topicB, - tProducer, copier2Created ).fork _ <- ZIO.logDebug("Waiting for copier 2 to start") diff --git a/zio-kafka-testkit/src/main/scala/zio/kafka/testkit/KafkaTestUtils.scala b/zio-kafka-testkit/src/main/scala/zio/kafka/testkit/KafkaTestUtils.scala index 9d6ad7a8a..aa2a926bb 100644 --- a/zio-kafka-testkit/src/main/scala/zio/kafka/testkit/KafkaTestUtils.scala +++ b/zio-kafka-testkit/src/main/scala/zio/kafka/testkit/KafkaTestUtils.scala @@ -53,11 +53,11 @@ object KafkaTestUtils { * Note: to run multiple tests in parallel, you need to use different transactional ids via * `transactionalProducer(transactionalId)`. */ - val transactionalProducer: ZLayer[Kafka, Throwable, TransactionalProducer] = + val transactionalProducer: ZLayer[Kafka with Consumer, Throwable, TransactionalProducer] = transactionalProducer("test-transaction") - def transactionalProducer(transactionalId: String): ZLayer[Kafka, Throwable, TransactionalProducer] = - ZLayer.makeSome[Kafka, TransactionalProducer]( + def transactionalProducer(transactionalId: String): ZLayer[Kafka with Consumer, Throwable, TransactionalProducer] = + ZLayer.makeSome[Kafka with Consumer, TransactionalProducer]( ZLayer(transactionalProducerSettings(transactionalId)), TransactionalProducer.live ) diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/Consumer.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/Consumer.scala index 3f24a6d48..e2fb26498 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/Consumer.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/Consumer.scala @@ -165,6 +165,9 @@ trait Consumer { * This method is useful when you want to use rebalance-safe-commits, but you are not committing to the Kafka brokers, * but to some external system, for example a relational database. * + * When this consumer is used in combination with a [[zio.kafka.producer.TransactionalProducer]], the transactional + * producer calls this method when the transaction is committed. + * * See also [[zio.kafka.consumer.ConsumerSettings.withRebalanceSafeCommits]]. */ def registerExternalCommits(offsetBatch: OffsetBatch): Task[Unit] diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/ConsumerSettings.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/ConsumerSettings.scala index 5db7221cd..922c9aa3e 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/ConsumerSettings.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/ConsumerSettings.scala @@ -234,6 +234,8 @@ final case class ConsumerSettings( * External commits (that is, commits to an external system, e.g. a relational database) must be registered to the * consumer with [[Consumer.registerExternalCommits]]. * + * When this consumer is coupled to a TransactionalProducer, `rebalanceSafeCommits` must be enabled. + * * When `false`, streams for revoked partitions may continue to run even though the rebalance is not held up. Any * offset commits from these streams have a high chance of being delayed (commits are not possible during some phases * of a rebalance). The consumer that takes over the partition will likely not see these delayed commits and will diff --git a/zio-kafka/src/main/scala/zio/kafka/producer/TransactionalProducer.scala b/zio-kafka/src/main/scala/zio/kafka/producer/TransactionalProducer.scala index 1056bd29a..f9146c311 100644 --- a/zio-kafka/src/main/scala/zio/kafka/producer/TransactionalProducer.scala +++ b/zio-kafka/src/main/scala/zio/kafka/producer/TransactionalProducer.scala @@ -7,7 +7,7 @@ import org.apache.kafka.common.errors.InvalidGroupIdException import org.apache.kafka.common.serialization.ByteArraySerializer import zio.Cause.Fail import zio._ -import zio.kafka.consumer.OffsetBatch +import zio.kafka.consumer.{ Consumer, OffsetBatch } import java.util import scala.jdk.CollectionConverters._ @@ -22,7 +22,8 @@ object TransactionalProducer { private final class LiveTransactionalProducer( live: ProducerLive, - semaphore: Semaphore + semaphore: Semaphore, + consumer: Consumer ) extends TransactionalProducer { private val abortTransaction: Task[Unit] = ZIO.attemptBlocking(live.p.abortTransaction()) @@ -48,10 +49,12 @@ object TransactionalProducer { } } - sendOffsetsToTransaction.when(offsetBatch.offsets.nonEmpty) *> ZIO.attemptBlocking(live.p.commitTransaction()) + sendOffsetsToTransaction.when(offsetBatch.offsets.nonEmpty) *> + ZIO.attemptBlocking(live.p.commitTransaction()) *> + consumer.registerExternalCommits(offsetBatch).unit } - private def commitOrAbort(transaction: TransactionImpl, exit: Exit[Any, Any]): UIO[Unit] = + private def commitOrAbort(transaction: TransactionImpl, exit: Exit[Any, Any]): ZIO[Any, Nothing, Unit] = exit match { case Exit.Success(_) => transaction.offsetBatchRef.get @@ -74,15 +77,16 @@ object TransactionalProducer { def createTransaction: ZIO[TransactionalProducer & Scope, Throwable, Transaction] = ZIO.service[TransactionalProducer].flatMap(_.createTransaction) - val live: RLayer[TransactionalProducerSettings, TransactionalProducer] = + val live: RLayer[TransactionalProducerSettings with Consumer, TransactionalProducer] = ZLayer.scoped { for { settings <- ZIO.service[TransactionalProducerSettings] - producer <- make(settings) + consumer <- ZIO.service[Consumer] + producer <- make(settings, consumer) } yield producer } - def make(settings: TransactionalProducerSettings): ZIO[Scope, Throwable, TransactionalProducer] = + def make(settings: TransactionalProducerSettings, consumer: Consumer): ZIO[Scope, Throwable, TransactionalProducer] = for { rawProducer <- ZIO.acquireRelease( ZIO.attempt( @@ -102,5 +106,5 @@ object TransactionalProducer { ) live = new ProducerLive(rawProducer, runtime, sendQueue) _ <- ZIO.blocking(live.sendFromQueue).forkScoped - } yield new LiveTransactionalProducer(live, semaphore) + } yield new LiveTransactionalProducer(live, semaphore, consumer) }