Skip to content

Commit

Permalink
Allow committed offsets refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
rtimush committed May 3, 2018
1 parent 631fc65 commit 9040389
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 44 deletions.
16 changes: 13 additions & 3 deletions core/src/main/scala/akka/kafka/ConsumerSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,13 @@ object ConsumerSettings {
val commitTimeWarning = config.getDuration("commit-time-warning", TimeUnit.MILLISECONDS).millis
val wakeupTimeout = config.getDuration("wakeup-timeout", TimeUnit.MILLISECONDS).millis
val maxWakeups = config.getInt("max-wakeups")
val commitRefreshInterval = if (config.hasPath("commit-refresh-interval"))
Some(config.getDuration("commit-refresh-interval", TimeUnit.MICROSECONDS).millis) else None
val dispatcher = config.getString("use-dispatcher")
val wakeupDebug = config.getBoolean("wakeup-debug")
new ConsumerSettings[K, V](properties, keyDeserializer, valueDeserializer,
pollInterval, pollTimeout, stopTimeout, closeTimeout, commitTimeout, wakeupTimeout, maxWakeups, dispatcher,
commitTimeWarning, wakeupDebug)
pollInterval, pollTimeout, stopTimeout, closeTimeout, commitTimeout, wakeupTimeout, maxWakeups, commitRefreshInterval,
dispatcher, commitTimeWarning, wakeupDebug)
}

/**
Expand Down Expand Up @@ -291,6 +293,7 @@ class ConsumerSettings[K, V](
val commitTimeout: FiniteDuration,
val wakeupTimeout: FiniteDuration,
val maxWakeups: Int,
val commitRefreshInterval: Option[FiniteDuration],
val dispatcher: String,
val commitTimeWarning: FiniteDuration = 1.second,
val wakeupDebug: Boolean = true
Expand Down Expand Up @@ -365,6 +368,12 @@ class ConsumerSettings[K, V](
def withMaxWakeups(maxWakeups: Int): ConsumerSettings[K, V] =
copy(maxWakeups = maxWakeups)

def withCommitRefreshInterval(commitRefreshInterval: FiniteDuration): ConsumerSettings[K, V] =
copy(commitRefreshInterval = Some(commitRefreshInterval))

def withoutCommitRefresh(): ConsumerSettings[K, V] =
copy(commitRefreshInterval = None)

def withWakeupDebug(wakeupDebug: Boolean): ConsumerSettings[K, V] =
copy(wakeupDebug = wakeupDebug)

Expand All @@ -380,12 +389,13 @@ class ConsumerSettings[K, V](
commitTimeWarning: FiniteDuration = commitTimeWarning,
wakeupTimeout: FiniteDuration = wakeupTimeout,
maxWakeups: Int = maxWakeups,
commitRefreshInterval: Option[FiniteDuration] = commitRefreshInterval,
dispatcher: String = dispatcher,
wakeupDebug: Boolean = wakeupDebug
): ConsumerSettings[K, V] =
new ConsumerSettings[K, V](properties, keyDeserializer, valueDeserializer,
pollInterval, pollTimeout, stopTimeout, closeTimeout, commitTimeout, wakeupTimeout,
maxWakeups, dispatcher, commitTimeWarning, wakeupDebug)
maxWakeups, commitRefreshInterval, dispatcher, commitTimeWarning, wakeupDebug)

/**
* Create a `KafkaConsumer` instance from the settings.
Expand Down
122 changes: 86 additions & 36 deletions core/src/main/scala/akka/kafka/KafkaConsumerActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,19 @@ import java.io.{PrintWriter, StringWriter}
import java.util
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.LockSupport
import java.util.regex.Pattern

import akka.Done
import akka.actor.{Actor, ActorLogging, ActorRef, Cancellable, DeadLetterSuppression, NoSerializationVerificationNeeded, Props, Status, Terminated}
import akka.event.LoggingReceive
import org.apache.kafka.clients.consumer._
import org.apache.kafka.common.{Metric, MetricName, TopicPartition}
import org.apache.kafka.common.errors.WakeupException
import org.apache.kafka.common.{Metric, MetricName, TopicPartition}

import java.util.concurrent.locks.LockSupport

import akka.Done

import scala.util.control.{NoStackTrace, NonFatal}
import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.util.control.{NoStackTrace, NonFatal}

object KafkaConsumerActor {
case class StoppingException() extends RuntimeException("Kafka consumer is stopping")
Expand Down Expand Up @@ -58,6 +56,12 @@ object KafkaConsumerActor {
private[KafkaConsumerActor] final case class Poll[K, V](
target: KafkaConsumerActor[K, V], periodic: Boolean
) extends DeadLetterSuppression with NoSerializationVerificationNeeded
private[KafkaConsumerActor] final case class PartitionAssigned(
partition: TopicPartition, offset: Long
) extends DeadLetterSuppression with NoSerializationVerificationNeeded
private[KafkaConsumerActor] final case class PartitionRevoked(
partition: TopicPartition
) extends DeadLetterSuppression with NoSerializationVerificationNeeded
private val number = new AtomicInteger()
def nextNumber(): Int = {
number.incrementAndGet()
Expand All @@ -70,14 +74,21 @@ object KafkaConsumerActor {
private[kafka] def rebalanceListener(onAssign: Set[TopicPartition] => Unit, onRevoke: Set[TopicPartition] => Unit): ListenerCallbacks =
ListenerCallbacks(onAssign, onRevoke)

private class WrappedAutoPausedListener(client: Consumer[_, _], listener: ListenerCallbacks) extends ConsumerRebalanceListener with NoSerializationVerificationNeeded {
private class WrappedAutoPausedListener(client: Consumer[_, _], caller: ActorRef, listener: ListenerCallbacks) extends ConsumerRebalanceListener with NoSerializationVerificationNeeded {
import KafkaConsumerActor.Internal._
override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]): Unit = {
client.pause(partitions)
partitions.asScala.foreach { tp =>
caller ! PartitionAssigned(tp, client.position(tp))
}
listener.onAssign(partitions.asScala.toSet)
}

override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]): Unit = {
listener.onRevoke(partitions.asScala.toSet)
partitions.asScala.foreach { tp =>
caller ! PartitionRevoked(tp)
}
}
}
}
Expand All @@ -99,6 +110,9 @@ private[kafka] class KafkaConsumerActor[K, V](settings: ConsumerSettings[K, V])
var consumer: Consumer[K, V] = _
var subscriptions = Set.empty[SubscriptionRequest]
var commitsInProgress = 0
var commitRequestedOffsets = Map.empty[TopicPartition, Long]
var committedOffsets = Map.empty[TopicPartition, Long]
var commitRefreshDeadline: Option[Deadline] = None
var wakeups = 0
var stopInProgress = false
var delayedPollInFlight = false
Expand All @@ -109,13 +123,18 @@ private[kafka] class KafkaConsumerActor[K, V](settings: ConsumerSettings[K, V])
checkOverlappingRequests("Assign", sender(), tps)
val previousAssigned = consumer.assignment()
consumer.assign((tps.toSeq ++ previousAssigned.asScala).asJava)
tps.foreach { tp =>
self ! PartitionAssigned(tp, consumer.position(tp))
}
case AssignWithOffset(tps) =>
scheduleFirstPollTask()
checkOverlappingRequests("AssignWithOffset", sender(), tps.keySet)
val previousAssigned = consumer.assignment()
consumer.assign((tps.keys.toSeq ++ previousAssigned.asScala).asJava)
tps.foreach {
case (tp, offset) => consumer.seek(tp, offset)
case (tp, offset) =>
consumer.seek(tp, offset)
self ! PartitionAssigned(tp, offset)
}
case AssignOffsetsForTimes(timestampsToSearch) =>
scheduleFirstPollTask()
Expand All @@ -129,34 +148,12 @@ private[kafka] class KafkaConsumerActor[K, V](settings: ConsumerSettings[K, V])
val ts = oat.timestamp()
log.debug("Get offset {} from topic {} with timestamp {}", offset, tp, ts)
consumer.seek(tp, offset)
self ! PartitionAssigned(tp, offset)
}

case Commit(offsets) =>
val commitMap = offsets.mapValues(new OffsetAndMetadata(_))
val reply = sender()
commitsInProgress += 1
val startTime = System.nanoTime()
consumer.commitAsync(commitMap.asJava, new OffsetCommitCallback {
override def onComplete(offsets: util.Map[TopicPartition, OffsetAndMetadata], exception: Exception): Unit = {
// this is invoked on the thread calling consumer.poll which will always be the actor, so it is safe
val duration = FiniteDuration(System.nanoTime() - startTime, NANOSECONDS)
if (duration > settings.commitTimeWarning) {
log.warning("Kafka commit took longer than `commit-time-warning`: {} ms", duration.toMillis)
}
commitsInProgress -= 1
if (exception != null) reply ! Status.Failure(exception)
else reply ! Committed(offsets.asScala.toMap)
}
})
// When many requestors, e.g. many partitions with committablePartitionedSource the
// performance is much by collecting more requests/commits before performing the poll.
// That is done by sending a message to self, and thereby collect pending messages in mailbox.
if (requestors.size == 1)
poll()
else if (!delayedPollInFlight) {
delayedPollInFlight = true
self ! delayedPollMsg
}
commitRequestedOffsets ++= offsets
commit(offsets, sender())

case s: SubscriptionRequest =>
subscriptions = subscriptions + s
Expand Down Expand Up @@ -184,6 +181,18 @@ private[kafka] class KafkaConsumerActor[K, V](settings: ConsumerSettings[K, V])
self ! delayedPollMsg
}

case PartitionAssigned(partition, offset) =>
commitRequestedOffsets += partition -> commitRequestedOffsets.getOrElse(partition, offset)
committedOffsets += partition -> committedOffsets.getOrElse(partition, offset)
commitRefreshDeadline = settings.commitRefreshInterval.map(_.fromNow)

case PartitionRevoked(partition) =>
commitRequestedOffsets -= partition
committedOffsets -= partition

case Committed(offsets) =>
committedOffsets ++= offsets.mapValues(_.offset())

case Stop =>
if (commitsInProgress == 0) {
context.stop(self)
Expand All @@ -207,9 +216,9 @@ private[kafka] class KafkaConsumerActor[K, V](settings: ConsumerSettings[K, V])

subscription match {
case Subscribe(topics, listener) =>
consumer.subscribe(topics.toList.asJava, new WrappedAutoPausedListener(consumer, listener))
consumer.subscribe(topics.toList.asJava, new WrappedAutoPausedListener(consumer, self, listener))
case SubscribePattern(pattern, listener) =>
consumer.subscribe(Pattern.compile(pattern), new WrappedAutoPausedListener(consumer, listener))
consumer.subscribe(Pattern.compile(pattern), new WrappedAutoPausedListener(consumer, self, listener))
}
}

Expand Down Expand Up @@ -264,6 +273,14 @@ private[kafka] class KafkaConsumerActor[K, V](settings: ConsumerSettings[K, V])

private def receivePoll(p: Poll[_, _]): Unit = {
if (p.target == this) {
if (commitRefreshDeadline.exists(_.isOverdue())) {
val refreshOffsets = committedOffsets.filter {
case (tp, offset) =>
commitRequestedOffsets.get(tp).contains(offset)
}
log.debug("Refreshing committed offsets: {}", refreshOffsets)
commit(refreshOffsets, context.system.deadLetters)
}
poll()
if (p.periodic)
currentPollTask = schedulePollTask()
Expand All @@ -278,7 +295,7 @@ private[kafka] class KafkaConsumerActor[K, V](settings: ConsumerSettings[K, V])

def poll(): Unit = {
val wakeupTask = context.system.scheduler.scheduleOnce(settings.wakeupTimeout) {
log.warning("KafkaConsumer poll has exceeded wake up timeout ({}ms). Waking up consumer to avoid thread starvation.", settings.wakeupTimeout.toMillis)
log.warning("KafkaConsumer poll has exceeded wake up timeout ({}ms). Waking up consumer to avoid thread starvation.", settings.wakeupTimeout.toMillis)
if (settings.wakeupDebug) {
val stacks = Thread.getAllStackTraces.asScala.map { case (k, v) => s"$k\n ${v.mkString("\n")}" }.mkString("\n\n")
log.warning("Wake up has been triggered. Dumping stacks: {}", stacks)
Expand Down Expand Up @@ -410,6 +427,39 @@ private[kafka] class KafkaConsumerActor[K, V](settings: ConsumerSettings[K, V])
}
}

private def commit(offsets: Map[TopicPartition, Long], reply: ActorRef): Unit = {
commitRefreshDeadline = settings.commitRefreshInterval.map(_.fromNow)
val commitMap = offsets.mapValues(new OffsetAndMetadata(_))
val reply = sender()
commitsInProgress += 1
val startTime = System.nanoTime()
consumer.commitAsync(commitMap.asJava, new OffsetCommitCallback {
override def onComplete(offsets: util.Map[TopicPartition, OffsetAndMetadata], exception: Exception): Unit = {
// this is invoked on the thread calling consumer.poll which will always be the actor, so it is safe
val duration = FiniteDuration(System.nanoTime() - startTime, NANOSECONDS)
if (duration > settings.commitTimeWarning) {
log.warning("Kafka commit took longer than `commit-time-warning`: {} ms", duration.toMillis)
}
commitsInProgress -= 1
if (exception != null) reply ! Status.Failure(exception)
else {
val committed = Committed(offsets.asScala.toMap)
self ! committed
reply ! committed
}
}
})
// When many requestors, e.g. many partitions with committablePartitionedSource the
// performance is much by collecting more requests/commits before performing the poll.
// That is done by sending a message to self, and thereby collect pending messages in mailbox.
if (requestors.size == 1)
poll()
else if (!delayedPollInFlight) {
delayedPollInFlight = true
self ! delayedPollMsg
}
}

private def processResult(partitionsToFetch: Set[TopicPartition], rawResult: ConsumerRecords[K, V]): Unit = {
if (!rawResult.isEmpty) {
//check the we got only requested partitions and did not drop any messages
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/akka/kafka/internal/ConsumerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ConsumerTest(_system: ActorSystem)

def testSource(mock: ConsumerMock[K, V], groupId: String = "group1", topics: Set[String] = Set("topic")): Source[CommittableMessage[K, V], Control] = {
val settings = new ConsumerSettings(Map(ConsumerConfig.GROUP_ID_CONFIG -> groupId), Some(new StringDeserializer), Some(new StringDeserializer),
1.milli, 1.milli, 1.second, closeTimeout, 1.second, 5.seconds, 3, "akka.kafka.default-dispatcher", 1.second, true) {
1.milli, 1.milli, 1.second, closeTimeout, 1.second, 5.seconds, 3, None, "akka.kafka.default-dispatcher", 1.second, true) {
override def createKafkaConsumer(): KafkaConsumer[K, V] = {
mock.mock
}
Expand Down
80 changes: 76 additions & 4 deletions core/src/test/scala/akka/kafka/scaladsl/IntegrationSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ import java.util.concurrent.{ConcurrentLinkedQueue, TimeUnit}
import akka.actor.ActorSystem
import akka.kafka.ConsumerMessage.CommittableOffsetBatch
import akka.kafka.ProducerMessage.Message
import akka.kafka.Subscriptions.TopicSubscription
import akka.kafka.test.Utils._
import akka.kafka._
import akka.kafka.test.Utils._
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.{Keep, Sink, Source}
import akka.stream.testkit.TestSubscriber
Expand All @@ -23,7 +22,6 @@ import akka.{Done, NotUsed}
import net.manub.embeddedkafka.{EmbeddedKafka, EmbeddedKafkaConfig}
import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer, StringDeserializer, StringSerializer}
import org.scalactic.TypeCheckedTripleEquals
import org.scalatest._
Expand All @@ -41,7 +39,11 @@ class IntegrationSpec extends TestKit(ActorSystem("IntegrationSpec"))
implicit val stageStoppingTimeout = StageStoppingTimeout(15.seconds)
implicit val mat = ActorMaterializer()(system)
implicit val ec = system.dispatcher
implicit val embeddedKafkaConfig = EmbeddedKafkaConfig(9092, 2181, Map("offsets.topic.replication.factor" -> "1"))
implicit val embeddedKafkaConfig = EmbeddedKafkaConfig(9092, 2181, Map(
"offsets.topic.replication.factor" -> "1",
"offsets.retention.minutes" -> "1",
"offsets.retention.check.interval.ms" -> "100"
))
val bootstrapServers = s"localhost:${embeddedKafkaConfig.kafkaPort}"
val InitialMsg = "initial msg in topic, required to create the topic before any consumer subscribes to it"

Expand Down Expand Up @@ -230,6 +232,76 @@ class IntegrationSpec extends TestKit(ActorSystem("IntegrationSpec"))
Await.result(control.isShutdown, remainingOrDefault)
}

"resume consumer from committed offset after retention period" in assertAllStagesStopped {
val topic1 = createTopic(1)
val group1 = createGroup(1)
val group2 = createGroup(2)

givenInitializedTopic(topic1)

// NOTE: If no partition is specified but a key is present a partition will be chosen
// using a hash of the key. If neither key nor partition is present a partition
// will be assigned in a round-robin fashion.

Source(1 to 100)
.map(n => new ProducerRecord(topic1, partition0, null: Array[Byte], n.toString))
.runWith(Producer.plainSink(producerSettings))

val committedElements = new ConcurrentLinkedQueue[Int]()

val consumerSettings = createConsumerSettings(group1).withCommitRefreshInterval(5.seconds)

val (control, probe1) = Consumer.committableSource(consumerSettings, Subscriptions.topics(topic1))
.filterNot(_.record.value == InitialMsg)
.mapAsync(10) { elem =>
elem.committableOffset.commitScaladsl().map { _ =>
committedElements.add(elem.record.value.toInt)
Done
}
}
.toMat(TestSink.probe)(Keep.both)
.run()

probe1
.request(25)
.expectNextN(25).toSet should be(Set(Done))

Thread.sleep(70000)

probe1.cancel()
Await.result(control.isShutdown, remainingOrDefault)

val probe2 = Consumer.committableSource(consumerSettings, Subscriptions.topics(topic1))
.map(_.record.value)
.runWith(TestSink.probe)

// Note that due to buffers and mapAsync(10) the committed offset is more
// than 26, and that is not wrong

// some concurrent publish
Source(101 to 200)
.map(n => new ProducerRecord(topic1, partition0, null: Array[Byte], n.toString))
.runWith(Producer.plainSink(producerSettings))

probe2
.request(100)
.expectNextN(((committedElements.asScala.max + 1) to 100).map(_.toString))

Thread.sleep(70000)

probe2.cancel()

val probe3 = Consumer.committableSource(consumerSettings, Subscriptions.topics(topic1))
.map(_.record.value)
.runWith(TestSink.probe)

probe3
.request(100)
.expectNextN(((committedElements.asScala.max + 1) to 100).map(_.toString))

probe3.cancel()
}

"handle commit without demand" in assertAllStagesStopped {
val topic1 = createTopic(1)
val group1 = createGroup(1)
Expand Down

0 comments on commit 9040389

Please sign in to comment.