Skip to content

Commit

Permalink
Provide Consumer upon creation of TransactionalProducer
Browse files Browse the repository at this point in the history
  • Loading branch information
svroonland committed Jan 4, 2025
1 parent 901a536 commit 9426c99
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 62 deletions.
100 changes: 52 additions & 48 deletions zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import zio.kafka.consumer.diagnostics.DiagnosticEvent.Finalization.{
SubscriptionFinalized
}
import zio.kafka.consumer.diagnostics.{ DiagnosticEvent, Diagnostics }
import zio.kafka.producer.{ Producer, TransactionalProducer }
import zio.kafka.producer.{ Producer, TransactionalProducer, TransactionalProducerSettings }
import zio.kafka.serde.Serde
import zio.kafka.testkit.KafkaTestUtils._
import zio.kafka.testkit.{ Kafka, KafkaRandom }
Expand Down Expand Up @@ -1240,62 +1240,66 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom {
clientId: String,
fromTopic: String,
toTopic: String,
tProducer: TransactionalProducer,
tProducerSettings: TransactionalProducerSettings,
consumerCreated: Promise[Nothing, Unit]
): ZIO[Kafka, Throwable, Unit] =
ZIO.logAnnotate("consumer", name) {
for {
consumedMessagesCounter <- Ref.make(0)
_ <- consumedMessagesCounter.get
.flatMap(consumed => ZIO.logDebug(s"Consumed so far: $consumed"))
.repeat(Schedule.fixed(1.second))
.fork
tConsumer <-
Consumer
.partitionedStream(Subscription.topics(fromTopic), Serde.string, Serde.string)
.flatMapPar(Int.MaxValue) { case (_, partitionStream) =>
ZStream.fromZIO(consumerCreated.succeed(())) *>
partitionStream.mapChunksZIO { records =>
ZIO.scoped {
for {
t <- tProducer.createTransaction
_ <- t.produceChunkBatch(
records.map(r => new ProducerRecord(toTopic, r.key, r.value)),
Serde.string,
Serde.string,
OffsetBatch(records.map(_.offset))
)
_ <- consumedMessagesCounter.update(_ + records.size)
} yield Chunk.empty
ZIO.scoped {
(for {
consumedMessagesCounter <- Ref.make(0)
_ <- consumedMessagesCounter.get
.flatMap(consumed => ZIO.logDebug(s"Consumed so far: $consumed"))
.repeat(Schedule.fixed(1.second))
.fork

tProducer <-
TransactionalProducer.make(tProducerSettings)

tConsumer <-
Consumer
.partitionedStream(Subscription.topics(fromTopic), Serde.string, Serde.string)
.flatMapPar(Int.MaxValue) { case (_, partitionStream) =>
ZStream.fromZIO(consumerCreated.succeed(())) *>
partitionStream.mapChunksZIO { records =>
ZIO.scoped {
for {
t <- tProducer.createTransaction
_ <- t.produceChunkBatch(
records.map(r => new ProducerRecord(toTopic, r.key, r.value)),
Serde.string,
Serde.string,
OffsetBatch(records.map(_.offset))
)
_ <- consumedMessagesCounter.update(_ + records.size)
} yield Chunk.empty
}
}
}
}
.runDrain
.provideSome[Kafka](
transactionalConsumer(
clientId,
consumerGroupId,
rebalanceSafeCommits = true,
properties = Map(
ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG ->
implicitly[ClassTag[T]].runtimeClass.getName,
ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> "200"
)
}
.runDrain
.tapError(e => ZIO.logError(s"Error: $e")) <* ZIO.logDebug("Done")
} yield tConsumer)
.provideSome[Kafka & Scope](
transactionalConsumer(
clientId,
consumerGroupId,
rebalanceSafeCommits = true,
properties = Map(
ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG ->
implicitly[ClassTag[T]].runtimeClass.getName,
ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> "200"
)
)
.tapError(e => ZIO.logError(s"Error: $e")) <* ZIO.logDebug("Done")
} yield tConsumer
)
}
}

for {
transactionalId <- randomThing("transactional")
tProducerSettings <- transactionalProducerSettings(transactionalId)
tProducer <- TransactionalProducer.make(tProducerSettings)

topicA <- randomTopic
topicB <- randomTopic
_ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topicA, partitions = partitionCount))
_ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topicB, partitions = partitionCount))
topicA <- randomTopic
topicB <- randomTopic
_ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topicA, partitions = partitionCount))
_ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topicB, partitions = partitionCount))

_ <- produceMany(topicA, messagesBeforeRebalance)

Expand All @@ -1310,7 +1314,7 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom {
copier1ClientId,
topicA,
topicB,
tProducer,
tProducerSettings,
copier1Created
).fork
_ <- copier1Created.await
Expand All @@ -1324,7 +1328,7 @@ object ConsumerSpec extends ZIOSpecDefaultSlf4j with KafkaRandom {
copier2ClientId,
topicA,
topicB,
tProducer,
tProducerSettings,
copier2Created
).fork
_ <- ZIO.logDebug("Waiting for copier 2 to start")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ object KafkaTestUtils {
* Note: to run multiple tests in parallel, you need to use different transactional ids via
* `transactionalProducer(transactionalId)`.
*/
val transactionalProducer: ZLayer[Kafka, Throwable, TransactionalProducer] =
val transactionalProducer: ZLayer[Kafka with Consumer, Throwable, TransactionalProducer] =
transactionalProducer("test-transaction")

def transactionalProducer(transactionalId: String): ZLayer[Kafka, Throwable, TransactionalProducer] =
ZLayer.makeSome[Kafka, TransactionalProducer](
def transactionalProducer(transactionalId: String): ZLayer[Kafka with Consumer, Throwable, TransactionalProducer] =
ZLayer.makeSome[Kafka with Consumer, TransactionalProducer](
ZLayer(transactionalProducerSettings(transactionalId)),
TransactionalProducer.live
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import java.util
import scala.jdk.CollectionConverters._

trait TransactionalProducer {
def createTransaction: ZIO[Scope with Consumer, Throwable, Transaction]
def createTransaction: ZIO[Scope, Throwable, Transaction]
}

object TransactionalProducer {
Expand All @@ -22,11 +22,12 @@ object TransactionalProducer {

private final class LiveTransactionalProducer(
live: ProducerLive,
semaphore: Semaphore
semaphore: Semaphore,
consumer: Consumer
) extends TransactionalProducer {
private val abortTransaction: Task[Unit] = ZIO.attemptBlocking(live.p.abortTransaction())

private def commitTransactionWithOffsets(offsetBatch: OffsetBatch): ZIO[Consumer, Throwable, Unit] = {
private def commitTransactionWithOffsets(offsetBatch: OffsetBatch): ZIO[Any, Throwable, Unit] = {
val sendOffsetsToTransaction: Task[Unit] =
ZIO.suspend {
@inline def invalidGroupIdException: IO[InvalidGroupIdException, Nothing] =
Expand All @@ -50,10 +51,10 @@ object TransactionalProducer {

sendOffsetsToTransaction.when(offsetBatch.offsets.nonEmpty) *>
ZIO.attemptBlocking(live.p.commitTransaction()) *>
ZIO.serviceWithZIO[Consumer](_.registerOffsetsCommittedInTransaction(offsetBatch)).unit
consumer.registerOffsetsCommittedInTransaction(offsetBatch).unit
}

private def commitOrAbort(transaction: TransactionImpl, exit: Exit[Any, Any]): ZIO[Consumer, Nothing, Unit] =
private def commitOrAbort(transaction: TransactionImpl, exit: Exit[Any, Any]): ZIO[Any, Nothing, Unit] =
exit match {
case Exit.Success(_) =>
transaction.offsetBatchRef.get
Expand All @@ -62,7 +63,7 @@ object TransactionalProducer {
case Exit.Failure(_) => abortTransaction.retryN(5).orDie
}

override def createTransaction: ZIO[Scope with Consumer, Throwable, Transaction] =
override def createTransaction: ZIO[Scope, Throwable, Transaction] =
semaphore.withPermitScoped *> {
ZIO.acquireReleaseExit {
for {
Expand All @@ -74,18 +75,18 @@ object TransactionalProducer {
}
}

def createTransaction: ZIO[TransactionalProducer & Scope & Consumer, Throwable, Transaction] =
def createTransaction: ZIO[TransactionalProducer & Scope, Throwable, Transaction] =
ZIO.service[TransactionalProducer].flatMap(_.createTransaction)

val live: RLayer[TransactionalProducerSettings, TransactionalProducer] =
val live: RLayer[TransactionalProducerSettings with Consumer, TransactionalProducer] =
ZLayer.scoped {
for {
settings <- ZIO.service[TransactionalProducerSettings]
producer <- make(settings)
} yield producer
}

def make(settings: TransactionalProducerSettings): ZIO[Scope, Throwable, TransactionalProducer] =
def make(settings: TransactionalProducerSettings): ZIO[Scope with Consumer, Throwable, TransactionalProducer] =
for {
rawProducer <- ZIO.acquireRelease(
ZIO.attempt(
Expand All @@ -104,6 +105,7 @@ object TransactionalProducer {
settings.producerSettings.sendBufferSize
)
live = new ProducerLive(rawProducer, runtime, sendQueue)
_ <- ZIO.blocking(live.sendFromQueue).forkScoped
} yield new LiveTransactionalProducer(live, semaphore)
_ <- ZIO.blocking(live.sendFromQueue).forkScoped
consumer <- ZIO.service[Consumer]
} yield new LiveTransactionalProducer(live, semaphore, consumer)
}

0 comments on commit 9426c99

Please sign in to comment.