Skip to content

Commit

Permalink
Track latest completed commit offset per partition
Browse files Browse the repository at this point in the history
By tracking these offsets we can skip awaiting already completed commits from the rebalance listener in #830.

To prevent unbounded memory usage, after a rebalance we remove the committed offset for partitions that are no longer assigned to this consumer.

Note that a commit might complete just after a partition was revoked. This is not a big issue; the offset will still be removed in the next rebalance. When the `rebalanceSafeCommits` feature is available (see #830) commits will complete in the rebalance listener and this cannot happen anymore.
  • Loading branch information
erikvanoosten committed Nov 5, 2023
1 parent 21361c1 commit e269a0a
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package zio.kafka.consumer.internal

import org.apache.kafka.common.TopicPartition
import zio._
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import zio.test._

object RunloopCommitOffsetsSpec extends ZIOSpecDefault {

private val tp10 = new TopicPartition("t1", 0)
private val tp11 = new TopicPartition("t1", 1)
private val tp20 = new TopicPartition("t2", 0)
private val tp21 = new TopicPartition("t2", 1)
private val tp22 = new TopicPartition("t2", 2)

override def spec: Spec[TestEnvironment with Scope, Any] =
suite("Runloop.CommitOffsets spec")(
test("addCommits adds to empty CommitOffsets") {
val s1 = Runloop.CommitOffsets(Map.empty)
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp10 -> 10))))
assertTrue(s2.offsets == Map(tp10 -> 10L))
},
test("addCommits updates offset when it is higher") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 5L))
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp10 -> 10))))
assertTrue(s2.offsets == Map(tp10 -> 10L))
},
test("addCommits ignores an offset when it is lower") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L))
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp10 -> 5))))
assertTrue(s2.offsets == Map(tp10 -> 10L))
},
test("addCommits keeps unrelated partitions") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L))
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp11 -> 11))))
assertTrue(s2.offsets == Map(tp10 -> 10L, tp11 -> 11L))
},
test("addCommits does it all at once") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L, tp20 -> 205L, tp21 -> 210L, tp22 -> 220L))
val s2 = s1.addCommits(Chunk(makeCommit(Map(tp11 -> 11, tp20 -> 206L, tp21 -> 209L, tp22 -> 220L))))
assertTrue(s2.offsets == Map(tp10 -> 10L, tp11 -> 11L, tp20 -> 206L, tp21 -> 210L, tp22 -> 220L))
},
test("addCommits adds multiple commits") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L, tp20 -> 200L, tp21 -> 210L, tp22 -> 220L))
val s2 = s1.addCommits(
Chunk(
makeCommit(Map(tp11 -> 11, tp20 -> 199L, tp21 -> 211L, tp22 -> 219L)),
makeCommit(Map(tp20 -> 198L, tp21 -> 209L, tp22 -> 221L))
)
)
assertTrue(s2.offsets == Map(tp10 -> 10L, tp11 -> 11L, tp20 -> 200L, tp21 -> 211L, tp22 -> 221L))
},
test("keepPartitions removes some partitions") {
val s1 = Runloop.CommitOffsets(Map(tp10 -> 10L, tp20 -> 20L))
val s2 = s1.keepPartitions(Set(tp10))
assertTrue(s2.offsets == Map(tp10 -> 10L))
}
)

private def makeCommit(offsets: Map[TopicPartition, Long]): RunloopCommand.Commit = {
val o = offsets.map { case (tp, offset) => tp -> new OffsetAndMetadata(offset) }
val p = Unsafe.unsafe(implicit unsafe => Promise.unsafe.make[Throwable, Unit](FiberId.None))
RunloopCommand.Commit(o, p)
}
}
50 changes: 41 additions & 9 deletions zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ private[consumer] final class Runloop private (
userRebalanceListener: RebalanceListener,
restartStreamsOnRebalancing: Boolean,
currentStateRef: Ref[State],
completedCommitsRef: Ref[CommitOffsets],
fetchStrategy: FetchStrategy
) {

Expand Down Expand Up @@ -154,8 +155,11 @@ private[consumer] final class Runloop private (
val offsetsWithMetaData = offsets.map { case (tp, offset) =>
tp -> new OffsetAndMetadata(offset.offset + 1, offset.leaderEpoch, offset.metadata)
}
val cont = (e: Exit[Throwable, Unit]) => ZIO.foreachDiscard(commits)(_.cont.done(e))
val onSuccess = cont(Exit.unit) <* diagnostics.emit(DiagnosticEvent.Commit.Success(offsetsWithMetaData))
val cont = (e: Exit[Throwable, Unit]) => ZIO.foreachDiscard(commits)(_.cont.done(e))
val onSuccess =
completedCommitsRef.update(_.addCommits(commits)) *>
cont(Exit.unit) <*
diagnostics.emit(DiagnosticEvent.Commit.Success(offsetsWithMetaData))
val onFailure: Throwable => UIO[Unit] = {
case _: RebalanceInProgressException =>
for {
Expand Down Expand Up @@ -183,7 +187,7 @@ private[consumer] final class Runloop private (
ZIO.succeed(state)
} else {
val (offsets, callback, onFailure) = asyncCommitParameters(commits)
val newState = state.addCommits(commits)
val newState = state.addPendingCommits(commits)
consumer.runloopAccess { c =>
// We don't wait for the completion of the commit here, because it
// will only complete once we poll again.
Expand Down Expand Up @@ -376,6 +380,12 @@ private[consumer] final class Runloop private (
val tp = pendingRequest.tp
!(lostTps.contains(tp) || revokedTps.contains(tp) || endedStreams.exists(_.tp == tp))
}

// Remove completed commits for partitions that are no longer assigned:
// NOTE: the type annotation is needed to keep the IntelliJ compiler happy.
_ <-
completedCommitsRef.update(_.keepPartitions(updatedAssignedStreams.map(_.tp).toSet)): Task[Unit]

} yield Runloop.PollResult(
records = polledRecords,
ignoreRecordsForTps = ignoreRecordsForTps,
Expand Down Expand Up @@ -561,7 +571,7 @@ private[consumer] final class Runloop private (
}
}

private[consumer] object Runloop {
object Runloop {
private implicit final class StreamOps[R, E, A](private val stream: ZStream[R, E, A]) extends AnyVal {

/**
Expand Down Expand Up @@ -627,7 +637,7 @@ private[consumer] object Runloop {
val None: RebalanceEvent = RebalanceEvent(wasInvoked = false, Set.empty, Set.empty, Set.empty, Chunk.empty)
}

def make(
private[consumer] def make(
hasGroupId: Boolean,
consumer: ConsumerAccess,
pollTimeout: Duration,
Expand All @@ -645,8 +655,9 @@ private[consumer] object Runloop {
commandQueue <- ZIO.acquireRelease(Queue.unbounded[RunloopCommand])(_.shutdown)
lastRebalanceEvent <- Ref.Synchronized.make[Runloop.RebalanceEvent](Runloop.RebalanceEvent.None)
initialState = State.initial
currentStateRef <- Ref.make(initialState)
runtime <- ZIO.runtime[Any]
currentStateRef <- Ref.make(initialState)
completedCommitsRef <- Ref.make(CommitOffsets.empty)
runtime <- ZIO.runtime[Any]
runloop = new Runloop(
runtime = runtime,
hasGroupId = hasGroupId,
Expand All @@ -662,6 +673,7 @@ private[consumer] object Runloop {
userRebalanceListener = userRebalanceListener,
restartStreamsOnRebalancing = restartStreamsOnRebalancing,
currentStateRef = currentStateRef,
completedCommitsRef = completedCommitsRef,
fetchStrategy = fetchStrategy
)
_ <- ZIO.logDebug("Starting Runloop")
Expand All @@ -685,8 +697,8 @@ private[consumer] object Runloop {
assignedStreams: Chunk[PartitionStreamControl],
subscriptionState: SubscriptionState
) {
def addCommits(c: Chunk[RunloopCommand.Commit]): State = copy(pendingCommits = pendingCommits ++ c)
def addRequest(r: RunloopCommand.Request): State = copy(pendingRequests = pendingRequests :+ r)
def addPendingCommits(c: Chunk[RunloopCommand.Commit]): State = copy(pendingCommits = pendingCommits ++ c)
def addRequest(r: RunloopCommand.Request): State = copy(pendingRequests = pendingRequests :+ r)

def shouldPoll: Boolean =
subscriptionState.isSubscribed && (pendingRequests.nonEmpty || pendingCommits.nonEmpty || assignedStreams.isEmpty)
Expand All @@ -700,4 +712,24 @@ private[consumer] object Runloop {
subscriptionState = SubscriptionState.NotSubscribed
)
}

// package private for unit testing
private[internal] final case class CommitOffsets(offsets: Map[TopicPartition, Long]) {
def addCommits(c: Chunk[RunloopCommand.Commit]): CommitOffsets =
CommitOffsets(
c.foldLeft(offsets) { case (offsetsAcc, newCommit) =>
newCommit.offsets.foldLeft(offsetsAcc) { case (acc, (tp, offsetAndMetadata)) =>
val newOffset = offsetAndMetadata.offset()
acc.updatedWith(tp)(currentOffset => currentOffset.map(_ max newOffset).orElse(Some(newOffset)))
}
}
)

def keepPartitions(tps: Set[TopicPartition]): CommitOffsets =
CommitOffsets(offsets.view.filterKeys(tps.contains).toMap)
}

private[internal] object CommitOffsets {
val empty: CommitOffsets = CommitOffsets(Map.empty)
}
}

0 comments on commit e269a0a

Please sign in to comment.