diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index 6ee3d9d96ef6..75b9cfc9a74c 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -32,6 +32,7 @@ import org.apache.beam.runners.core.metrics.DistributionCell; import org.apache.beam.runners.core.metrics.DistributionData; import org.apache.beam.runners.core.metrics.MetricsContainerImpl; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.SerializableCoder; @@ -40,6 +41,9 @@ import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.TransformHierarchy.Node; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -62,6 +66,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.MockConsumer; @@ -526,6 +531,18 @@ public void testProcessElementWhenTopicPartitionIsRemoved() throws Exception { assertEquals(ProcessContinuation.stop(), result); } + @Test + public void testSDFCommitOffsetEnabled() { + OffSetsVisitor visitor = testCommittingOffsets(true); + Assert.assertEquals(true, visitor.foundOffsetTransform); + } + + @Test + public void testSDFCommitOffsetNotEnabled() { + OffSetsVisitor visitor = testCommittingOffsets(false); + Assert.assertNotEquals(true, visitor.foundOffsetTransform); + } + @Test public void testProcessElementWhenTopicPartitionIsStopped() throws Exception { MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); @@ -688,4 +705,47 @@ public void visitValue(PValue value, Node producer) { } } } + + private OffSetsVisitor testCommittingOffsets(boolean enableOffsets) { + + // Force Kafka read to use SDF implementation + PipelineOptions pipelineOptions = PipelineOptionsFactory.create(); + ExperimentalOptions.addExperiment( + pipelineOptions.as(ExperimentalOptions.class), "use_sdf_read"); + + Pipeline p = Pipeline.create(pipelineOptions); + KafkaIO.Read read = + KafkaIO.read() + .withKeyDeserializer(StringDeserializer.class) + .withValueDeserializer(StringDeserializer.class) + .withConsumerConfigUpdates( + new ImmutableMap.Builder() + .put(ConsumerConfig.GROUP_ID_CONFIG, "group_id_1") + .build()) + .withBootstrapServers("bootstrap_server") + .withTopic("test-topic"); + + if (enableOffsets) { + read = read.commitOffsetsInFinalize(); + } + + p.apply(read.withoutMetadata()); + OffSetsVisitor visitor = new OffSetsVisitor(); + p.traverseTopologically(visitor); + return visitor; + } + + static class OffSetsVisitor extends PipelineVisitor.Defaults { + boolean foundOffsetTransform = false; + + @Override + public void visitValue(PValue value, Node producer) { + if (value instanceof PCollection) { + PCollection pc = (PCollection) value; + if (pc.getName().contains("KafkaCommitOffset")) { + foundOffsetTransform = true; + } + } + } + } }