diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
index e94dc2cb6071..bd919ddbdade 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
@@ -21,13 +21,9 @@
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import com.datastax.driver.core.Cluster;
-import com.datastax.driver.core.ColumnMetadata;
import com.datastax.driver.core.ConsistencyLevel;
import com.datastax.driver.core.PlainTextAuthProvider;
import com.datastax.driver.core.QueryOptions;
-import com.datastax.driver.core.ResultSet;
-import com.datastax.driver.core.ResultSetFuture;
-import com.datastax.driver.core.Row;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.SocketOptions;
import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy;
@@ -36,32 +32,32 @@
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
-import java.util.Iterator;
import java.util.List;
-import java.util.NoSuchElementException;
import java.util.Optional;
+import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
+import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
-import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -89,6 +85,11 @@
*
* }
*
+ *
Alternatively, one may use {@code CassandraIO.readAll()
+ * .withCoder(SerializableCoder.of(Person.class))} to query a subset of the Cassandra database by
+ * creating a PCollection of {@code CassandraIO.Read} each with their own query or
+ * RingRange.
+ *
* Writing to Apache Cassandra
*
* {@code CassandraIO} provides a sink to write a collection of entities to Apache Cassandra.
@@ -137,6 +138,11 @@ public static Read read() {
return new AutoValue_CassandraIO_Read.Builder().build();
}
+ /** Provide a {@link ReadAll} {@link PTransform} to read data from a Cassandra database. */
+ public static ReadAll readAll() {
+ return new AutoValue_CassandraIO_ReadAll.Builder().build();
+ }
+
/** Provide a {@link Write} {@link PTransform} to write data to a Cassandra database. */
public static Write write() {
return Write.builder(MutationType.WRITE).build();
@@ -186,6 +192,9 @@ public abstract static class Read extends PTransform>
abstract @Nullable SerializableFunction mapperFactoryFn();
+ @Nullable
+ abstract ValueProvider> ringRanges();
+
abstract Builder builder();
/** Specify the hosts of the Apache Cassandra instances. */
@@ -371,6 +380,14 @@ public Read withMapperFactoryFn(SerializableFunction mapperF
return builder().setMapperFactoryFn(mapperFactory).build();
}
+ public Read withRingRanges(Set ringRange) {
+ return withRingRanges(ValueProvider.StaticValueProvider.of(ringRange));
+ }
+
+ public Read withRingRanges(ValueProvider> ringRange) {
+ return builder().setRingRanges(ringRange).build();
+ }
+
@Override
public PCollection expand(PBegin input) {
checkArgument((hosts() != null && port() != null), "WithHosts() and withPort() are required");
@@ -379,7 +396,69 @@ public PCollection expand(PBegin input) {
checkArgument(entity() != null, "withEntity() is required");
checkArgument(coder() != null, "withCoder() is required");
- return input.apply(org.apache.beam.sdk.io.Read.from(new CassandraSource<>(this, null)));
+ PCollection> splits =
+ input
+ .apply(Create.of(this))
+ .apply("Create Splits", ParDo.of(new SplitFn()))
+ .setCoder(SerializableCoder.of(new TypeDescriptor>() {}));
+
+ return splits.apply("ReadAll", CassandraIO.readAll().withCoder(coder()));
+ }
+
+ private static class SplitFn extends DoFn, Read> {
+ @ProcessElement
+ public void process(
+ @Element CassandraIO.Read read, OutputReceiver> outputReceiver) {
+ Set ringRanges = getRingRanges(read);
+ for (RingRange rr : ringRanges) {
+ Set subset = ImmutableSet.of(rr);
+ outputReceiver.output(read.withRingRanges(ImmutableSet.of(rr)));
+ }
+ }
+
+ private static Set getRingRanges(Read read) {
+ try (Cluster cluster =
+ getCluster(
+ read.hosts(),
+ read.port(),
+ read.username(),
+ read.password(),
+ read.localDc(),
+ read.consistencyLevel(),
+ read.connectTimeout(),
+ read.readTimeout())) {
+ if (isMurmur3Partitioner(cluster)) {
+ LOG.info("Murmur3Partitioner detected, splitting");
+ Integer splitCount;
+ if (read.minNumberOfSplits() != null && read.minNumberOfSplits().get() != null) {
+ splitCount = read.minNumberOfSplits().get();
+ } else {
+ splitCount = cluster.getMetadata().getAllHosts().size();
+ }
+ List tokens =
+ cluster.getMetadata().getTokenRanges().stream()
+ .map(tokenRange -> new BigInteger(tokenRange.getEnd().getValue().toString()))
+ .collect(Collectors.toList());
+ SplitGenerator splitGenerator =
+ new SplitGenerator(cluster.getMetadata().getPartitioner());
+
+ return splitGenerator.generateSplits(splitCount, tokens).stream()
+ .flatMap(List::stream)
+ .collect(Collectors.toSet());
+
+ } else {
+ LOG.warn(
+ "Only Murmur3Partitioner is supported for splitting, using an unique source for "
+ + "the read");
+ String partitioner = cluster.getMetadata().getPartitioner();
+ RingRange totalRingRange =
+ RingRange.of(
+ SplitGenerator.getRangeMin(partitioner),
+ SplitGenerator.getRangeMax(partitioner));
+ return Collections.singleton(totalRingRange);
+ }
+ }
+ }
}
@AutoValue.Builder
@@ -418,6 +497,8 @@ abstract static class Builder {
abstract Optional> mapperFactoryFn();
+ abstract Builder setRingRanges(ValueProvider> ringRange);
+
abstract Read autoBuild();
public Read build() {
@@ -429,390 +510,6 @@ public Read build() {
}
}
- @VisibleForTesting
- static class CassandraSource extends BoundedSource {
- final Read spec;
- final List splitQueries;
- // split source ached size - can't be calculated when already split
- Long estimatedSize;
- private static final String MURMUR3PARTITIONER = "org.apache.cassandra.dht.Murmur3Partitioner";
-
- CassandraSource(Read spec, List splitQueries) {
- this(spec, splitQueries, null);
- }
-
- private CassandraSource(Read spec, List splitQueries, Long estimatedSize) {
- this.estimatedSize = estimatedSize;
- this.spec = spec;
- this.splitQueries = splitQueries;
- }
-
- @Override
- public Coder getOutputCoder() {
- return spec.coder();
- }
-
- @Override
- public BoundedReader createReader(PipelineOptions pipelineOptions) {
- return new CassandraReader(this);
- }
-
- @Override
- public List> split(
- long desiredBundleSizeBytes, PipelineOptions pipelineOptions) {
- try (Cluster cluster =
- getCluster(
- spec.hosts(),
- spec.port(),
- spec.username(),
- spec.password(),
- spec.localDc(),
- spec.consistencyLevel(),
- spec.connectTimeout(),
- spec.readTimeout())) {
- if (isMurmur3Partitioner(cluster)) {
- LOG.info("Murmur3Partitioner detected, splitting");
- return splitWithTokenRanges(
- spec, desiredBundleSizeBytes, getEstimatedSizeBytes(pipelineOptions), cluster);
- } else {
- LOG.warn(
- "Only Murmur3Partitioner is supported for splitting, using a unique source for "
- + "the read");
- return Collections.singletonList(
- new CassandraIO.CassandraSource<>(spec, Collections.singletonList(buildQuery(spec))));
- }
- }
- }
-
- private static String buildQuery(Read spec) {
- return (spec.query() == null)
- ? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get())
- : spec.query().get().toString();
- }
-
- /**
- * Compute the number of splits based on the estimated size and the desired bundle size, and
- * create several sources.
- */
- private List> splitWithTokenRanges(
- CassandraIO.Read spec,
- long desiredBundleSizeBytes,
- long estimatedSizeBytes,
- Cluster cluster) {
- long numSplits =
- getNumSplits(desiredBundleSizeBytes, estimatedSizeBytes, spec.minNumberOfSplits());
- LOG.info("Number of desired splits is {}", numSplits);
-
- SplitGenerator splitGenerator = new SplitGenerator(cluster.getMetadata().getPartitioner());
- List tokens =
- cluster.getMetadata().getTokenRanges().stream()
- .map(tokenRange -> new BigInteger(tokenRange.getEnd().getValue().toString()))
- .collect(Collectors.toList());
- List> splits = splitGenerator.generateSplits(numSplits, tokens);
- LOG.info("{} splits were actually generated", splits.size());
-
- final String partitionKey =
- cluster.getMetadata().getKeyspace(spec.keyspace().get()).getTable(spec.table().get())
- .getPartitionKey().stream()
- .map(ColumnMetadata::getName)
- .collect(Collectors.joining(","));
-
- List tokenRanges =
- getTokenRanges(cluster, spec.keyspace().get(), spec.table().get());
- final long estimatedSize = getEstimatedSizeBytesFromTokenRanges(tokenRanges) / splits.size();
-
- List> sources = new ArrayList<>();
- for (List split : splits) {
- List queries = new ArrayList<>();
- for (RingRange range : split) {
- if (range.isWrapping()) {
- // A wrapping range is one that overlaps from the end of the partitioner range and its
- // start (ie : when the start token of the split is greater than the end token)
- // We need to generate two queries here : one that goes from the start token to the end
- // of
- // the partitioner range, and the other from the start of the partitioner range to the
- // end token of the split.
- queries.add(generateRangeQuery(spec, partitionKey, range.getStart(), null));
- // Generation of the second query of the wrapping range
- queries.add(generateRangeQuery(spec, partitionKey, null, range.getEnd()));
- } else {
- queries.add(generateRangeQuery(spec, partitionKey, range.getStart(), range.getEnd()));
- }
- }
- sources.add(new CassandraIO.CassandraSource<>(spec, queries, estimatedSize));
- }
- return sources;
- }
-
- private static String generateRangeQuery(
- Read spec, String partitionKey, BigInteger rangeStart, BigInteger rangeEnd) {
- final String rangeFilter =
- Joiner.on(" AND ")
- .skipNulls()
- .join(
- rangeStart == null
- ? null
- : String.format("(token(%s) >= %d)", partitionKey, rangeStart),
- rangeEnd == null
- ? null
- : String.format("(token(%s) < %d)", partitionKey, rangeEnd));
- final String query =
- (spec.query() == null)
- ? buildQuery(spec) + " WHERE " + rangeFilter
- : buildQuery(spec) + " AND " + rangeFilter;
- LOG.debug("CassandraIO generated query : {}", query);
- return query;
- }
-
- private static long getNumSplits(
- long desiredBundleSizeBytes,
- long estimatedSizeBytes,
- @Nullable ValueProvider minNumberOfSplits) {
- long numSplits =
- desiredBundleSizeBytes > 0 ? (estimatedSizeBytes / desiredBundleSizeBytes) : 1;
- if (numSplits <= 0) {
- LOG.warn("Number of splits is less than 0 ({}), fallback to 1", numSplits);
- numSplits = 1;
- }
- return minNumberOfSplits != null ? Math.max(numSplits, minNumberOfSplits.get()) : numSplits;
- }
-
- /**
- * Returns cached estimate for split or if missing calculate size for whole table. Highly
- * innacurate if query is specified.
- *
- * @param pipelineOptions
- * @return
- */
- @Override
- public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) {
- if (estimatedSize != null) {
- return estimatedSize;
- } else {
- try (Cluster cluster =
- getCluster(
- spec.hosts(),
- spec.port(),
- spec.username(),
- spec.password(),
- spec.localDc(),
- spec.consistencyLevel(),
- spec.connectTimeout(),
- spec.readTimeout())) {
- if (isMurmur3Partitioner(cluster)) {
- try {
- List tokenRanges =
- getTokenRanges(cluster, spec.keyspace().get(), spec.table().get());
- this.estimatedSize = getEstimatedSizeBytesFromTokenRanges(tokenRanges);
- return this.estimatedSize;
- } catch (Exception e) {
- LOG.warn("Can't estimate the size", e);
- return 0L;
- }
- } else {
- LOG.warn("Only Murmur3 partitioner is supported, can't estimate the size");
- return 0L;
- }
- }
- }
- }
-
- @VisibleForTesting
- static long getEstimatedSizeBytesFromTokenRanges(List tokenRanges) {
- long size = 0L;
- for (TokenRange tokenRange : tokenRanges) {
- size = size + tokenRange.meanPartitionSize * tokenRange.partitionCount;
- }
- return Math.round(size / getRingFraction(tokenRanges));
- }
-
- @Override
- public void populateDisplayData(DisplayData.Builder builder) {
- super.populateDisplayData(builder);
- if (spec.hosts() != null) {
- builder.add(DisplayData.item("hosts", spec.hosts().toString()));
- }
- if (spec.port() != null) {
- builder.add(DisplayData.item("port", spec.port()));
- }
- builder.addIfNotNull(DisplayData.item("keyspace", spec.keyspace()));
- builder.addIfNotNull(DisplayData.item("table", spec.table()));
- builder.addIfNotNull(DisplayData.item("username", spec.username()));
- builder.addIfNotNull(DisplayData.item("localDc", spec.localDc()));
- builder.addIfNotNull(DisplayData.item("consistencyLevel", spec.consistencyLevel()));
- }
- // ------------- CASSANDRA SOURCE UTIL METHODS ---------------//
-
- /**
- * Gets the list of token ranges that a table occupies on a give Cassandra node.
- *
- * NB: This method is compatible with Cassandra 2.1.5 and greater.
- */
- private static List getTokenRanges(Cluster cluster, String keyspace, String table) {
- try (Session session = cluster.newSession()) {
- ResultSet resultSet =
- session.execute(
- "SELECT range_start, range_end, partitions_count, mean_partition_size FROM "
- + "system.size_estimates WHERE keyspace_name = ? AND table_name = ?",
- keyspace,
- table);
-
- ArrayList tokenRanges = new ArrayList<>();
- for (Row row : resultSet) {
- TokenRange tokenRange =
- new TokenRange(
- row.getLong("partitions_count"),
- row.getLong("mean_partition_size"),
- new BigInteger(row.getString("range_start")),
- new BigInteger(row.getString("range_end")));
- tokenRanges.add(tokenRange);
- }
- // The table may not contain the estimates yet
- // or have partitions_count and mean_partition_size fields = 0
- // if the data was just inserted and the amount of data in the table was small.
- // This is very common situation during tests,
- // when we insert a few rows and immediately query them.
- // However, for tiny data sets the lack of size estimates is not a problem at all,
- // because we don't want to split tiny data anyways.
- // Therefore, we're not issuing a warning if the result set was empty
- // or mean_partition_size and partitions_count = 0.
- return tokenRanges;
- }
- }
-
- /** Compute the percentage of token addressed compared with the whole tokens in the cluster. */
- @VisibleForTesting
- static double getRingFraction(List tokenRanges) {
- double ringFraction = 0;
- for (TokenRange tokenRange : tokenRanges) {
- ringFraction =
- ringFraction
- + (distance(tokenRange.rangeStart, tokenRange.rangeEnd).doubleValue()
- / SplitGenerator.getRangeSize(MURMUR3PARTITIONER).doubleValue());
- }
- return ringFraction;
- }
-
- /**
- * Check if the current partitioner is the Murmur3 (default in Cassandra version newer than 2).
- */
- @VisibleForTesting
- static boolean isMurmur3Partitioner(Cluster cluster) {
- return MURMUR3PARTITIONER.equals(cluster.getMetadata().getPartitioner());
- }
-
- /** Measure distance between two tokens. */
- @VisibleForTesting
- static BigInteger distance(BigInteger left, BigInteger right) {
- return (right.compareTo(left) > 0)
- ? right.subtract(left)
- : right.subtract(left).add(SplitGenerator.getRangeSize(MURMUR3PARTITIONER));
- }
-
- /**
- * Represent a token range in Cassandra instance, wrapping the partition count, size and token
- * range.
- */
- @VisibleForTesting
- static class TokenRange {
- private final long partitionCount;
- private final long meanPartitionSize;
- private final BigInteger rangeStart;
- private final BigInteger rangeEnd;
-
- TokenRange(
- long partitionCount, long meanPartitionSize, BigInteger rangeStart, BigInteger rangeEnd) {
- this.partitionCount = partitionCount;
- this.meanPartitionSize = meanPartitionSize;
- this.rangeStart = rangeStart;
- this.rangeEnd = rangeEnd;
- }
- }
-
- private class CassandraReader extends BoundedSource.BoundedReader {
- private final CassandraIO.CassandraSource source;
- private Cluster cluster;
- private Session session;
- private Iterator iterator;
- private T current;
-
- CassandraReader(CassandraSource source) {
- this.source = source;
- }
-
- @Override
- public boolean start() {
- LOG.debug("Starting Cassandra reader");
- cluster =
- getCluster(
- source.spec.hosts(),
- source.spec.port(),
- source.spec.username(),
- source.spec.password(),
- source.spec.localDc(),
- source.spec.consistencyLevel(),
- source.spec.connectTimeout(),
- source.spec.readTimeout());
- session = cluster.connect(source.spec.keyspace().get());
- LOG.debug("Queries: " + source.splitQueries);
- List futures = new ArrayList<>();
- for (String query : source.splitQueries) {
- futures.add(session.executeAsync(query));
- }
-
- final Mapper mapper = getMapper(session, source.spec.entity());
-
- for (ResultSetFuture result : futures) {
- if (iterator == null) {
- iterator = mapper.map(result.getUninterruptibly());
- } else {
- iterator = Iterators.concat(iterator, mapper.map(result.getUninterruptibly()));
- }
- }
-
- return advance();
- }
-
- @Override
- public boolean advance() {
- if (iterator.hasNext()) {
- current = iterator.next();
- return true;
- }
- current = null;
- return false;
- }
-
- @Override
- public void close() {
- LOG.debug("Closing Cassandra reader");
- if (session != null) {
- session.close();
- }
- if (cluster != null) {
- cluster.close();
- }
- }
-
- @Override
- public T getCurrent() throws NoSuchElementException {
- if (current == null) {
- throw new NoSuchElementException();
- }
- return current;
- }
-
- @Override
- public CassandraIO.CassandraSource getCurrentSource() {
- return source;
- }
-
- private Mapper getMapper(Session session, Class enitity) {
- return source.spec.mapperFactoryFn().apply(session);
- }
- }
- }
-
/** Specify the mutation type: either write or delete. */
public enum MutationType {
WRITE,
@@ -1179,7 +876,7 @@ public void teardown() throws Exception {
}
/** Get a Cassandra cluster using hosts and port. */
- private static Cluster getCluster(
+ static Cluster getCluster(
ValueProvider> hosts,
ValueProvider port,
ValueProvider username,
@@ -1301,4 +998,53 @@ private void waitForFuturesToFinish() throws ExecutionException, InterruptedExce
}
}
}
+
+ /**
+ * A {@link PTransform} to read data from Apache Cassandra. See {@link CassandraIO} for more
+ * information on usage and configuration.
+ */
+ @AutoValue
+ public abstract static class ReadAll extends PTransform>, PCollection> {
+ @AutoValue.Builder
+ abstract static class Builder {
+
+ abstract Builder setCoder(Coder coder);
+
+ abstract ReadAll autoBuild();
+
+ public ReadAll build() {
+ return autoBuild();
+ }
+ }
+
+ @Nullable
+ abstract Coder coder();
+
+ abstract Builder builder();
+
+ /** Specify the {@link Coder} used to serialize the entity in the {@link PCollection}. */
+ public ReadAll withCoder(Coder coder) {
+ checkArgument(coder != null, "coder can not be null");
+ return builder().setCoder(coder).build();
+ }
+
+ @Override
+ public PCollection expand(PCollection> input) {
+ checkArgument(coder() != null, "withCoder() is required");
+ return input
+ .apply("Reshuffle", Reshuffle.viaRandomKey())
+ .apply("Read", ParDo.of(new ReadFn<>()))
+ .setCoder(this.coder());
+ }
+ }
+
+ /**
+ * Check if the current partitioner is the Murmur3 (default in Cassandra version newer than 2).
+ */
+ @VisibleForTesting
+ private static boolean isMurmur3Partitioner(Cluster cluster) {
+ return MURMUR3PARTITIONER.equals(cluster.getMetadata().getPartitioner());
+ }
+
+ private static final String MURMUR3PARTITIONER = "org.apache.cassandra.dht.Murmur3Partitioner";
}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java
new file mode 100644
index 000000000000..5091ac40b936
--- /dev/null
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.Session;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
+import org.apache.beam.sdk.options.ValueProvider;
+
+@SuppressWarnings({
+ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public class ConnectionManager {
+
+ private static final ConcurrentHashMap clusterMap =
+ new ConcurrentHashMap();
+ private static final ConcurrentHashMap sessionMap =
+ new ConcurrentHashMap();
+
+ static {
+ Runtime.getRuntime()
+ .addShutdownHook(
+ new Thread(
+ () -> {
+ for (Session session : sessionMap.values()) {
+ if (!session.isClosed()) {
+ session.close();
+ }
+ }
+ }));
+ }
+
+ private static String readToClusterHash(Read> read) {
+ return Objects.requireNonNull(read.hosts()).get().stream().reduce(",", (a, b) -> a + b)
+ + Objects.requireNonNull(read.port()).get()
+ + safeVPGet(read.localDc())
+ + safeVPGet(read.consistencyLevel());
+ }
+
+ private static String readToSessionHash(Read> read) {
+ return readToClusterHash(read) + read.keyspace().get();
+ }
+
+ static Session getSession(Read> read) {
+ Cluster cluster =
+ clusterMap.computeIfAbsent(
+ readToClusterHash(read),
+ k ->
+ CassandraIO.getCluster(
+ Objects.requireNonNull(read.hosts()),
+ Objects.requireNonNull(read.port()),
+ read.username(),
+ read.password(),
+ read.localDc(),
+ read.consistencyLevel(),
+ read.connectTimeout(),
+ read.readTimeout()));
+ return sessionMap.computeIfAbsent(
+ readToSessionHash(read),
+ k -> cluster.connect(Objects.requireNonNull(read.keyspace()).get()));
+ }
+
+ private static String safeVPGet(ValueProvider s) {
+ return s != null ? s.get() : "";
+ }
+}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
index 8f6d5781eac6..92ec2c58d8b2 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapper.java
@@ -34,7 +34,7 @@
})
class DefaultObjectMapper implements Mapper, Serializable {
- private transient com.datastax.driver.mapping.Mapper mapper;
+ private final transient com.datastax.driver.mapping.Mapper mapper;
DefaultObjectMapper(com.datastax.driver.mapping.Mapper mapper) {
this.mapper = mapper;
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
index 7976665905b7..ef75ff312aca 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/DefaultObjectMapperFactory.java
@@ -34,7 +34,7 @@
class DefaultObjectMapperFactory implements SerializableFunction {
private transient MappingManager mappingManager;
- Class entity;
+ final Class entity;
DefaultObjectMapperFactory(Class entity) {
this.entity = entity;
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java
new file mode 100644
index 000000000000..193cdf0a3d8c
--- /dev/null
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ReadFn.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.ColumnMetadata;
+import com.datastax.driver.core.PreparedStatement;
+import com.datastax.driver.core.ResultSet;
+import com.datastax.driver.core.Session;
+import com.datastax.driver.core.Token;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+@SuppressWarnings({
+ "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
+ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+class ReadFn extends DoFn, T> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ReadFn.class);
+
+ @ProcessElement
+ public void processElement(@Element Read read, OutputReceiver receiver) {
+ try {
+ Session session = ConnectionManager.getSession(read);
+ Mapper mapper = read.mapperFactoryFn().apply(session);
+ String partitionKey =
+ session.getCluster().getMetadata().getKeyspace(read.keyspace().get())
+ .getTable(read.table().get()).getPartitionKey().stream()
+ .map(ColumnMetadata::getName)
+ .collect(Collectors.joining(","));
+
+ String query = generateRangeQuery(read, partitionKey, read.ringRanges() != null);
+ PreparedStatement preparedStatement = session.prepare(query);
+ Set ringRanges =
+ read.ringRanges() == null ? Collections.emptySet() : read.ringRanges().get();
+
+ for (RingRange rr : ringRanges) {
+ Token startToken = session.getCluster().getMetadata().newToken(rr.getStart().toString());
+ Token endToken = session.getCluster().getMetadata().newToken(rr.getEnd().toString());
+ ResultSet rs =
+ session.execute(preparedStatement.bind().setToken(0, startToken).setToken(1, endToken));
+ Iterator iter = mapper.map(rs);
+ while (iter.hasNext()) {
+ T n = iter.next();
+ receiver.output(n);
+ }
+ }
+
+ if (read.ringRanges() == null) {
+ ResultSet rs = session.execute(preparedStatement.bind());
+ Iterator iter = mapper.map(rs);
+ while (iter.hasNext()) {
+ receiver.output(iter.next());
+ }
+ }
+ } catch (Exception ex) {
+ LOG.error("error", ex);
+ }
+ }
+
+ private Session getSession(Read read) {
+ Cluster cluster =
+ CassandraIO.getCluster(
+ read.hosts(),
+ read.port(),
+ read.username(),
+ read.password(),
+ read.localDc(),
+ read.consistencyLevel(),
+ read.connectTimeout(),
+ read.readTimeout());
+
+ return cluster.connect(read.keyspace().get());
+ }
+
+ private static String generateRangeQuery(
+ Read> spec, String partitionKey, Boolean hasRingRange) {
+ final String rangeFilter =
+ (hasRingRange)
+ ? Joiner.on(" AND ")
+ .skipNulls()
+ .join(
+ String.format("(token(%s) >= ?)", partitionKey),
+ String.format("(token(%s) < ?)", partitionKey))
+ : "";
+ final String combinedQuery = buildInitialQuery(spec, hasRingRange) + rangeFilter;
+ LOG.debug("CassandraIO generated query : {}", combinedQuery);
+ return combinedQuery;
+ }
+
+ private static String buildInitialQuery(Read> spec, Boolean hasRingRange) {
+ return (spec.query() == null)
+ ? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get())
+ + " WHERE "
+ : spec.query().get() + (hasRingRange ? " AND " : "");
+ }
+}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
index b5f94d7c0c9b..c83e47fcac9a 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
@@ -17,23 +17,31 @@
*/
package org.apache.beam.sdk.io.cassandra;
+import java.io.Serializable;
import java.math.BigInteger;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Experimental.Kind;
/** Models a Cassandra token range. */
-final class RingRange {
+@Experimental(Kind.SOURCE_SINK)
+@SuppressWarnings({
+ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public final class RingRange implements Serializable {
private final BigInteger start;
private final BigInteger end;
- RingRange(BigInteger start, BigInteger end) {
+ private RingRange(BigInteger start, BigInteger end) {
this.start = start;
this.end = end;
}
- BigInteger getStart() {
+ public BigInteger getStart() {
return start;
}
- BigInteger getEnd() {
+ public BigInteger getEnd() {
return end;
}
@@ -55,4 +63,34 @@ public boolean isWrapping() {
public String toString() {
return String.format("(%s,%s]", start.toString(), end.toString());
}
+
+ public static RingRange of(BigInteger start, BigInteger end) {
+ return new RingRange(start, end);
+ }
+
+ @Override
+ public boolean equals(@Nullable Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ RingRange ringRange = (RingRange) o;
+
+ if (getStart() != null
+ ? !getStart().equals(ringRange.getStart())
+ : ringRange.getStart() != null) {
+ return false;
+ }
+ return getEnd() != null ? getEnd().equals(ringRange.getEnd()) : ringRange.getEnd() == null;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = getStart() != null ? getStart().hashCode() : 0;
+ result = 31 * result + (getEnd() != null ? getEnd().hashCode() : 0);
+ return result;
+ }
}
diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
index de494212d560..bc1205a28797 100644
--- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
+++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
@@ -39,22 +39,22 @@ final class SplitGenerator {
this.partitioner = partitioner;
}
- private static BigInteger getRangeMin(String partitioner) {
+ static BigInteger getRangeMin(String partitioner) {
if (partitioner.endsWith("RandomPartitioner")) {
return BigInteger.ZERO;
} else if (partitioner.endsWith("Murmur3Partitioner")) {
- return new BigInteger("2").pow(63).negate();
+ return BigInteger.valueOf(2).pow(63).negate();
} else {
throw new UnsupportedOperationException(
"Unsupported partitioner. " + "Only Random and Murmur3 are supported");
}
}
- private static BigInteger getRangeMax(String partitioner) {
+ static BigInteger getRangeMax(String partitioner) {
if (partitioner.endsWith("RandomPartitioner")) {
- return new BigInteger("2").pow(127).subtract(BigInteger.ONE);
+ return BigInteger.valueOf(2).pow(127).subtract(BigInteger.ONE);
} else if (partitioner.endsWith("Murmur3Partitioner")) {
- return new BigInteger("2").pow(63).subtract(BigInteger.ONE);
+ return BigInteger.valueOf(2).pow(63).subtract(BigInteger.ONE);
} else {
throw new UnsupportedOperationException(
"Unsupported partitioner. " + "Only Random and Murmur3 are supported");
@@ -84,7 +84,7 @@ List> generateSplits(long totalSplitCount, List ring
BigInteger start = ringTokens.get(i);
BigInteger stop = ringTokens.get((i + 1) % tokenRangeCount);
- if (!inRange(start) || !inRange(stop)) {
+ if (!isInRange(start) || !isInRange(stop)) {
throw new RuntimeException(
String.format("Tokens (%s,%s) not in range of %s", start, stop, partitioner));
}
@@ -127,7 +127,7 @@ List> generateSplits(long totalSplitCount, List ring
// Append the splits between the endpoints
for (int j = 0; j < splitCount; j++) {
- splits.add(new RingRange(endpointTokens.get(j), endpointTokens.get(j + 1)));
+ splits.add(RingRange.of(endpointTokens.get(j), endpointTokens.get(j + 1)));
LOG.debug("Split #{}: [{},{})", j + 1, endpointTokens.get(j), endpointTokens.get(j + 1));
}
}
@@ -144,7 +144,7 @@ List> generateSplits(long totalSplitCount, List ring
return coalesceSplits(getTargetSplitSize(totalSplitCount), splits);
}
- private boolean inRange(BigInteger token) {
+ private boolean isInRange(BigInteger token) {
return !(token.compareTo(rangeMin) < 0 || token.compareTo(rangeMax) > 0);
}
diff --git a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
index a52808b0e0a1..131ce83b48dd 100644
--- a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
+++ b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOTest.java
@@ -18,23 +18,18 @@
package org.apache.beam.sdk.io.cassandra;
import static junit.framework.TestCase.assertTrue;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.distance;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.getEstimatedSizeBytesFromTokenRanges;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.getRingFraction;
-import static org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.isMurmur3Partitioner;
-import static org.apache.beam.sdk.testing.SourceTestUtils.readFromSource;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.Metadata;
+import com.datastax.driver.core.ProtocolVersion;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Row;
import com.datastax.driver.core.Session;
+import com.datastax.driver.core.TypeCodec;
import com.datastax.driver.core.exceptions.NoHostAvailableException;
-import com.datastax.driver.core.querybuilder.QueryBuilder;
+import com.datastax.driver.mapping.annotations.ClusteringColumn;
import com.datastax.driver.mapping.annotations.Column;
import com.datastax.driver.mapping.annotations.Computed;
import com.datastax.driver.mapping.annotations.PartitionKey;
@@ -44,10 +39,14 @@
import java.io.IOException;
import java.io.Serializable;
import java.math.BigInteger;
+import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
@@ -61,13 +60,8 @@
import javax.management.remote.JMXConnectorFactory;
import javax.management.remote.JMXServiceURL;
import org.apache.beam.sdk.coders.SerializableCoder;
-import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.io.cassandra.CassandraIO.CassandraSource.TokenRange;
import org.apache.beam.sdk.io.common.NetworkTestHelper;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.testing.SourceTestUtils;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
@@ -82,7 +76,6 @@
import org.apache.cassandra.service.StorageServiceMBean;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.AfterClass;
-import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Rule;
@@ -99,7 +92,7 @@
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
})
public class CassandraIOTest implements Serializable {
- private static final long NUM_ROWS = 20L;
+ private static final long NUM_ROWS = 22L;
private static final String CASSANDRA_KEYSPACE = "beam_ks";
private static final String CASSANDRA_HOST = "127.0.0.1";
private static final String CASSANDRA_TABLE = "scientist";
@@ -190,39 +183,44 @@ private static void insertData() throws Exception {
LOG.info("Create Cassandra tables");
session.execute(
String.format(
- "CREATE TABLE IF NOT EXISTS %s.%s(person_id int, person_name text, PRIMARY KEY"
- + "(person_id));",
+ "CREATE TABLE IF NOT EXISTS %s.%s(person_department text, person_id int, person_name text, PRIMARY KEY"
+ + "((person_department), person_id));",
CASSANDRA_KEYSPACE, CASSANDRA_TABLE));
session.execute(
String.format(
- "CREATE TABLE IF NOT EXISTS %s.%s(person_id int, person_name text, PRIMARY KEY"
- + "(person_id));",
+ "CREATE TABLE IF NOT EXISTS %s.%s(person_department text, person_id int, person_name text, PRIMARY KEY"
+ + "((person_department), person_id));",
CASSANDRA_KEYSPACE, CASSANDRA_TABLE_WRITE));
LOG.info("Insert records");
- String[] scientists = {
- "Einstein",
- "Darwin",
- "Copernicus",
- "Pasteur",
- "Curie",
- "Faraday",
- "Newton",
- "Bohr",
- "Galilei",
- "Maxwell"
+ String[][] scientists = {
+ new String[] {"phys", "Einstein"},
+ new String[] {"bio", "Darwin"},
+ new String[] {"phys", "Copernicus"},
+ new String[] {"bio", "Pasteur"},
+ new String[] {"bio", "Curie"},
+ new String[] {"phys", "Faraday"},
+ new String[] {"math", "Newton"},
+ new String[] {"phys", "Bohr"},
+ new String[] {"phys", "Galileo"},
+ new String[] {"math", "Maxwell"},
+ new String[] {"logic", "Russel"},
};
for (int i = 0; i < NUM_ROWS; i++) {
int index = i % scientists.length;
- session.execute(
+ String insertStr =
String.format(
- "INSERT INTO %s.%s(person_id, person_name) values("
+ "INSERT INTO %s.%s(person_department, person_id, person_name) values("
+ + "'"
+ + scientists[index][0]
+ + "', "
+ i
+ ", '"
- + scientists[index]
+ + scientists[index][1]
+ "');",
CASSANDRA_KEYSPACE,
- CASSANDRA_TABLE));
+ CASSANDRA_TABLE);
+ session.execute(insertStr);
}
flushMemTablesAndRefreshSizeEstimates();
}
@@ -276,25 +274,6 @@ private static void disableAutoCompaction() throws Exception {
Thread.sleep(JMX_CONF_TIMEOUT);
}
- @Test
- public void testEstimatedSizeBytes() throws Exception {
- PipelineOptions pipelineOptions = PipelineOptionsFactory.create();
- CassandraIO.Read read =
- CassandraIO.read()
- .withHosts(Collections.singletonList(CASSANDRA_HOST))
- .withPort(cassandraPort)
- .withKeyspace(CASSANDRA_KEYSPACE)
- .withTable(CASSANDRA_TABLE);
- CassandraIO.CassandraSource source = new CassandraIO.CassandraSource<>(read, null);
- long estimatedSizeBytes = source.getEstimatedSizeBytes(pipelineOptions);
- // the size is non determanistic in Cassandra backend: checks that estimatedSizeBytes >= 12960L
- // -20% && estimatedSizeBytes <= 12960L +20%
- assertThat(
- "wrong estimated size in " + CASSANDRA_KEYSPACE + "/" + CASSANDRA_TABLE,
- estimatedSizeBytes,
- greaterThan(0L));
- }
-
@Test
public void testRead() throws Exception {
PCollection output =
@@ -304,6 +283,7 @@ public void testRead() throws Exception {
.withPort(cassandraPort)
.withKeyspace(CASSANDRA_KEYSPACE)
.withTable(CASSANDRA_TABLE)
+ .withMinNumberOfSplits(50)
.withCoder(SerializableCoder.of(Scientist.class))
.withEntity(Scientist.class));
@@ -321,9 +301,113 @@ public KV apply(Scientist scientist) {
PAssert.that(mapped.apply("Count occurrences per scientist", Count.perKey()))
.satisfies(
input -> {
+ int count = 0;
for (KV element : input) {
+ count++;
assertEquals(element.getKey(), NUM_ROWS / 10, element.getValue().longValue());
}
+ assertEquals(11, count);
+ return null;
+ });
+
+ pipeline.run();
+ }
+
+ private CassandraIO.Read getReadWithRingRange(RingRange... rr) {
+ return CassandraIO.read()
+ .withHosts(Collections.singletonList(CASSANDRA_HOST))
+ .withPort(cassandraPort)
+ .withRingRanges(new HashSet<>(Arrays.asList(rr)))
+ .withKeyspace(CASSANDRA_KEYSPACE)
+ .withTable(CASSANDRA_TABLE)
+ .withCoder(SerializableCoder.of(Scientist.class))
+ .withEntity(Scientist.class);
+ }
+
+ private CassandraIO.Read getReadWithQuery(String query) {
+ return CassandraIO.read()
+ .withHosts(Collections.singletonList(CASSANDRA_HOST))
+ .withPort(cassandraPort)
+ .withQuery(query)
+ .withKeyspace(CASSANDRA_KEYSPACE)
+ .withTable(CASSANDRA_TABLE)
+ .withCoder(SerializableCoder.of(Scientist.class))
+ .withEntity(Scientist.class);
+ }
+
+ @Test
+ public void testReadAllQuery() {
+ String physQuery =
+ String.format(
+ "SELECT * From %s.%s WHERE person_department='phys' AND person_id=0;",
+ CASSANDRA_KEYSPACE, CASSANDRA_TABLE);
+
+ String mathQuery =
+ String.format(
+ "SELECT * From %s.%s WHERE person_department='math' AND person_id=6;",
+ CASSANDRA_KEYSPACE, CASSANDRA_TABLE);
+
+ PCollection output =
+ pipeline
+ .apply(Create.of(getReadWithQuery(physQuery), getReadWithQuery(mathQuery)))
+ .apply(
+ CassandraIO.readAll().withCoder(SerializableCoder.of(Scientist.class)));
+
+ PCollection mapped =
+ output.apply(
+ MapElements.via(
+ new SimpleFunction() {
+ @Override
+ public String apply(Scientist scientist) {
+ return scientist.name;
+ }
+ }));
+ PAssert.that(mapped).containsInAnyOrder("Einstein", "Newton");
+ PAssert.thatSingleton(output.apply("count", Count.globally())).isEqualTo(2L);
+ pipeline.run();
+ }
+
+ @Test
+ public void testReadAllRingRange() {
+ RingRange physRR =
+ fromEncodedKey(
+ cluster.getMetadata(), TypeCodec.varchar().serialize("phys", ProtocolVersion.V3));
+
+ RingRange mathRR =
+ fromEncodedKey(
+ cluster.getMetadata(), TypeCodec.varchar().serialize("math", ProtocolVersion.V3));
+
+ RingRange logicRR =
+ fromEncodedKey(
+ cluster.getMetadata(), TypeCodec.varchar().serialize("logic", ProtocolVersion.V3));
+
+ PCollection output =
+ pipeline
+ .apply(Create.of(getReadWithRingRange(physRR), getReadWithRingRange(mathRR, logicRR)))
+ .apply(
+ CassandraIO.readAll().withCoder(SerializableCoder.of(Scientist.class)));
+
+ PCollection> mapped =
+ output.apply(
+ MapElements.via(
+ new SimpleFunction>() {
+ @Override
+ public KV apply(Scientist scientist) {
+ return KV.of(scientist.department, scientist.id);
+ }
+ }));
+
+ PAssert.that(mapped.apply("Count occurrences per department", Count.perKey()))
+ .satisfies(
+ input -> {
+ HashMap map = new HashMap<>();
+ for (KV element : input) {
+ map.put(element.getKey(), element.getValue());
+ }
+ assertEquals(3, map.size()); // do we have all three departments
+ assertEquals(map.get("phys"), 10L, 0L);
+ assertEquals(map.get("math"), 4L, 0L);
+ assertEquals(map.get("logic"), 2L, 0L);
return null;
});
@@ -339,8 +423,9 @@ public void testReadWithQuery() throws Exception {
.withPort(cassandraPort)
.withKeyspace(CASSANDRA_KEYSPACE)
.withTable(CASSANDRA_TABLE)
+ .withMinNumberOfSplits(20)
.withQuery(
- "select person_id, writetime(person_name) from beam_ks.scientist where person_id=10")
+ "select person_id, writetime(person_name) from beam_ks.scientist where person_id=10 AND person_department='logic'")
.withCoder(SerializableCoder.of(Scientist.class))
.withEntity(Scientist.class));
@@ -365,6 +450,7 @@ public void testWrite() {
ScientistWrite scientist = new ScientistWrite();
scientist.id = i;
scientist.name = "Name " + i;
+ scientist.department = "bio";
data.add(scientist);
}
@@ -485,52 +571,6 @@ public void testCustomMapperImplDelete() {
assertEquals(1, counter.intValue());
}
- @Test
- public void testSplit() throws Exception {
- PipelineOptions options = PipelineOptionsFactory.create();
- CassandraIO.Read read =
- CassandraIO.read()
- .withHosts(Collections.singletonList(CASSANDRA_HOST))
- .withPort(cassandraPort)
- .withKeyspace(CASSANDRA_KEYSPACE)
- .withTable(CASSANDRA_TABLE)
- .withEntity(Scientist.class)
- .withCoder(SerializableCoder.of(Scientist.class));
-
- // initialSource will be read without splitting (which does not happen in production)
- // so we need to provide splitQueries to avoid NPE in source.reader.start()
- String splitQuery = QueryBuilder.select().from(CASSANDRA_KEYSPACE, CASSANDRA_TABLE).toString();
- CassandraIO.CassandraSource initialSource =
- new CassandraIO.CassandraSource<>(read, Collections.singletonList(splitQuery));
- int desiredBundleSizeBytes = 2048;
- long estimatedSize = initialSource.getEstimatedSizeBytes(options);
- List> splits = initialSource.split(desiredBundleSizeBytes, options);
- SourceTestUtils.assertSourcesEqualReferenceSource(initialSource, splits, options);
- float expectedNumSplitsloat =
- (float) initialSource.getEstimatedSizeBytes(options) / desiredBundleSizeBytes;
- long sum = 0;
-
- for (BoundedSource subSource : splits) {
- sum += subSource.getEstimatedSizeBytes(options);
- }
-
- // due to division and cast estimateSize != sum but will be close. Exact equals checked below
- assertEquals((long) (estimatedSize / splits.size()) * splits.size(), sum);
-
- int expectedNumSplits = (int) Math.ceil(expectedNumSplitsloat);
- assertEquals("Wrong number of splits", expectedNumSplits, splits.size());
- int emptySplits = 0;
- for (BoundedSource subSource : splits) {
- if (readFromSource(subSource, options).isEmpty()) {
- emptySplits += 1;
- }
- }
- assertThat(
- "There are too many empty splits, parallelism is sub-optimal",
- emptySplits,
- lessThan((int) (ACCEPTABLE_EMPTY_SPLITS_PERCENTAGE * splits.size())));
- }
-
private List getRows(String table) {
ResultSet result =
session.execute(
@@ -545,6 +585,7 @@ public void testDelete() throws Exception {
Scientist einstein = new Scientist();
einstein.id = 0;
+ einstein.department = "phys";
einstein.name = "Einstein";
pipeline
.apply(Create.of(einstein))
@@ -561,7 +602,8 @@ public void testDelete() throws Exception {
// re-insert suppressed doc to make the test autonomous
session.execute(
String.format(
- "INSERT INTO %s.%s(person_id, person_name) values("
+ "INSERT INTO %s.%s(person_department, person_id, person_name) values("
+ + "'phys', "
+ einstein.id
+ ", '"
+ einstein.name
@@ -570,58 +612,6 @@ public void testDelete() throws Exception {
CASSANDRA_TABLE));
}
- @Test
- public void testValidPartitioner() {
- Assert.assertTrue(isMurmur3Partitioner(cluster));
- }
-
- @Test
- public void testDistance() {
- BigInteger distance = distance(new BigInteger("10"), new BigInteger("100"));
- assertEquals(BigInteger.valueOf(90), distance);
-
- distance = distance(new BigInteger("100"), new BigInteger("10"));
- assertEquals(new BigInteger("18446744073709551526"), distance);
- }
-
- @Test
- public void testRingFraction() {
- // simulate a first range taking "half" of the available tokens
- List tokenRanges = new ArrayList<>();
- tokenRanges.add(new TokenRange(1, 1, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("0")));
- assertEquals(0.5, getRingFraction(tokenRanges), 0);
-
- // add a second range to cover all tokens available
- tokenRanges.add(new TokenRange(1, 1, new BigInteger("0"), BigInteger.valueOf(Long.MAX_VALUE)));
- assertEquals(1.0, getRingFraction(tokenRanges), 0);
- }
-
- @Test
- public void testEstimatedSizeBytesFromTokenRanges() {
- List tokenRanges = new ArrayList<>();
- // one partition containing all tokens, the size is actually the size of the partition
- tokenRanges.add(
- new TokenRange(
- 1, 1000, BigInteger.valueOf(Long.MIN_VALUE), BigInteger.valueOf(Long.MAX_VALUE)));
- assertEquals(1000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
-
- // one partition with half of the tokens, we estimate the size to the double of this partition
- tokenRanges = new ArrayList<>();
- tokenRanges.add(
- new TokenRange(1, 1000, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("0")));
- assertEquals(2000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
-
- // we have three partitions covering all tokens, the size is the sum of partition size *
- // partition count
- tokenRanges = new ArrayList<>();
- tokenRanges.add(
- new TokenRange(1, 1000, BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("-3")));
- tokenRanges.add(new TokenRange(1, 1000, new BigInteger("-2"), new BigInteger("10000")));
- tokenRanges.add(
- new TokenRange(2, 3000, new BigInteger("10001"), BigInteger.valueOf(Long.MAX_VALUE)));
- assertEquals(8000, getEstimatedSizeBytesFromTokenRanges(tokenRanges));
- }
-
/** Simple Cassandra entity used in read tests. */
@Table(name = CASSANDRA_TABLE, keyspace = CASSANDRA_KEYSPACE)
static class Scientist implements Serializable {
@@ -632,10 +622,14 @@ static class Scientist implements Serializable {
@Computed("writetime(person_name)")
Long nameTs;
- @PartitionKey()
+ @ClusteringColumn()
@Column(name = "person_id")
int id;
+ @PartitionKey
+ @Column(name = "person_department")
+ String department;
+
@Override
public String toString() {
return id + ":" + name;
@@ -650,7 +644,9 @@ public boolean equals(@Nullable Object o) {
return false;
}
Scientist scientist = (Scientist) o;
- return id == scientist.id && Objects.equal(name, scientist.name);
+ return id == scientist.id
+ && Objects.equal(name, scientist.name)
+ && Objects.equal(department, scientist.department);
}
@Override
@@ -659,6 +655,11 @@ public int hashCode() {
}
}
+ private static RingRange fromEncodedKey(Metadata metadata, ByteBuffer... bb) {
+ BigInteger bi = BigInteger.valueOf((long) metadata.newToken(bb).getValue());
+ return RingRange.of(bi, bi.add(BigInteger.valueOf(1L)));
+ }
+
private static final String CASSANDRA_TABLE_WRITE = "scientist_write";
/** Simple Cassandra entity used in write tests. */
@Table(name = CASSANDRA_TABLE_WRITE, keyspace = CASSANDRA_KEYSPACE)