Skip to content

Commit

Permalink
[BEAM-9008] adds CassandraIO.readAll
Browse files Browse the repository at this point in the history
  • Loading branch information
vmarquez committed Sep 7, 2021
1 parent 3eaa041 commit e12fc33
Show file tree
Hide file tree
Showing 8 changed files with 555 additions and 567 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.cassandra;

import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.Session;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
import org.apache.beam.sdk.options.ValueProvider;

@SuppressWarnings({
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class ConnectionManager {

private static final ConcurrentHashMap<String, Cluster> clusterMap =
new ConcurrentHashMap<String, Cluster>();
private static final ConcurrentHashMap<String, Session> sessionMap =
new ConcurrentHashMap<String, Session>();

static {
Runtime.getRuntime()
.addShutdownHook(
new Thread(
() -> {
for (Session session : sessionMap.values()) {
if (!session.isClosed()) {
session.close();
}
}
}));
}

private static String readToClusterHash(Read<?> read) {
return Objects.requireNonNull(read.hosts()).get().stream().reduce(",", (a, b) -> a + b)
+ Objects.requireNonNull(read.port()).get()
+ safeVPGet(read.localDc())
+ safeVPGet(read.consistencyLevel());
}

private static String readToSessionHash(Read<?> read) {
return readToClusterHash(read) + read.keyspace().get();
}

static Session getSession(Read<?> read) {
Cluster cluster =
clusterMap.computeIfAbsent(
readToClusterHash(read),
k ->
CassandraIO.getCluster(
Objects.requireNonNull(read.hosts()),
Objects.requireNonNull(read.port()),
read.username(),
read.password(),
read.localDc(),
read.consistencyLevel(),
read.connectTimeout(),
read.readTimeout()));
return sessionMap.computeIfAbsent(
readToSessionHash(read),
k -> cluster.connect(Objects.requireNonNull(read.keyspace()).get()));
}

private static String safeVPGet(ValueProvider<String> s) {
return s != null ? s.get() : "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
})
class DefaultObjectMapper<T> implements Mapper<T>, Serializable {

private transient com.datastax.driver.mapping.Mapper<T> mapper;
private final transient com.datastax.driver.mapping.Mapper<T> mapper;

DefaultObjectMapper(com.datastax.driver.mapping.Mapper mapper) {
this.mapper = mapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class DefaultObjectMapperFactory<T> implements SerializableFunction<Session, Mapper> {

private transient MappingManager mappingManager;
Class<T> entity;
final Class<T> entity;

DefaultObjectMapperFactory(Class<T> entity) {
this.entity = entity;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.cassandra;

import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.ColumnMetadata;
import com.datastax.driver.core.PreparedStatement;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.Token;
import java.util.Collections;
import java.util.Iterator;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.sdk.io.cassandra.CassandraIO.Read;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@SuppressWarnings({
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
class ReadFn<T> extends DoFn<Read<T>, T> {

private static final Logger LOG = LoggerFactory.getLogger(ReadFn.class);

@ProcessElement
public void processElement(@Element Read<T> read, OutputReceiver<T> receiver) {
try {
Session session = ConnectionManager.getSession(read);
Mapper<T> mapper = read.mapperFactoryFn().apply(session);
String partitionKey =
session.getCluster().getMetadata().getKeyspace(read.keyspace().get())
.getTable(read.table().get()).getPartitionKey().stream()
.map(ColumnMetadata::getName)
.collect(Collectors.joining(","));

String query = generateRangeQuery(read, partitionKey, read.ringRanges() != null);
PreparedStatement preparedStatement = session.prepare(query);
Set<RingRange> ringRanges =
read.ringRanges() == null ? Collections.emptySet() : read.ringRanges().get();

for (RingRange rr : ringRanges) {
Token startToken = session.getCluster().getMetadata().newToken(rr.getStart().toString());
Token endToken = session.getCluster().getMetadata().newToken(rr.getEnd().toString());
ResultSet rs =
session.execute(preparedStatement.bind().setToken(0, startToken).setToken(1, endToken));
Iterator<T> iter = mapper.map(rs);
while (iter.hasNext()) {
T n = iter.next();
receiver.output(n);
}
}

if (read.ringRanges() == null) {
ResultSet rs = session.execute(preparedStatement.bind());
Iterator<T> iter = mapper.map(rs);
while (iter.hasNext()) {
receiver.output(iter.next());
}
}
} catch (Exception ex) {
LOG.error("error", ex);
}
}

private Session getSession(Read<T> read) {
Cluster cluster =
CassandraIO.getCluster(
read.hosts(),
read.port(),
read.username(),
read.password(),
read.localDc(),
read.consistencyLevel(),
read.connectTimeout(),
read.readTimeout());

return cluster.connect(read.keyspace().get());
}

private static String generateRangeQuery(
Read<?> spec, String partitionKey, Boolean hasRingRange) {
final String rangeFilter =
(hasRingRange)
? Joiner.on(" AND ")
.skipNulls()
.join(
String.format("(token(%s) >= ?)", partitionKey),
String.format("(token(%s) < ?)", partitionKey))
: "";
final String combinedQuery = buildInitialQuery(spec, hasRingRange) + rangeFilter;
LOG.debug("CassandraIO generated query : {}", combinedQuery);
return combinedQuery;
}

private static String buildInitialQuery(Read<?> spec, Boolean hasRingRange) {
return (spec.query() == null)
? String.format("SELECT * FROM %s.%s", spec.keyspace().get(), spec.table().get())
+ " WHERE "
: spec.query().get() + (hasRingRange ? " AND " : "");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,31 @@
*/
package org.apache.beam.sdk.io.cassandra;

import java.io.Serializable;
import java.math.BigInteger;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;

/** Models a Cassandra token range. */
final class RingRange {
@Experimental(Kind.SOURCE_SINK)
@SuppressWarnings({
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public final class RingRange implements Serializable {
private final BigInteger start;
private final BigInteger end;

RingRange(BigInteger start, BigInteger end) {
private RingRange(BigInteger start, BigInteger end) {
this.start = start;
this.end = end;
}

BigInteger getStart() {
public BigInteger getStart() {
return start;
}

BigInteger getEnd() {
public BigInteger getEnd() {
return end;
}

Expand All @@ -55,4 +63,34 @@ public boolean isWrapping() {
public String toString() {
return String.format("(%s,%s]", start.toString(), end.toString());
}

public static RingRange of(BigInteger start, BigInteger end) {
return new RingRange(start, end);
}

@Override
public boolean equals(@Nullable Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}

RingRange ringRange = (RingRange) o;

if (getStart() != null
? !getStart().equals(ringRange.getStart())
: ringRange.getStart() != null) {
return false;
}
return getEnd() != null ? getEnd().equals(ringRange.getEnd()) : ringRange.getEnd() == null;
}

@Override
public int hashCode() {
int result = getStart() != null ? getStart().hashCode() : 0;
result = 31 * result + (getEnd() != null ? getEnd().hashCode() : 0);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ final class SplitGenerator {
this.partitioner = partitioner;
}

private static BigInteger getRangeMin(String partitioner) {
static BigInteger getRangeMin(String partitioner) {
if (partitioner.endsWith("RandomPartitioner")) {
return BigInteger.ZERO;
} else if (partitioner.endsWith("Murmur3Partitioner")) {
return new BigInteger("2").pow(63).negate();
return BigInteger.valueOf(2).pow(63).negate();
} else {
throw new UnsupportedOperationException(
"Unsupported partitioner. " + "Only Random and Murmur3 are supported");
}
}

private static BigInteger getRangeMax(String partitioner) {
static BigInteger getRangeMax(String partitioner) {
if (partitioner.endsWith("RandomPartitioner")) {
return new BigInteger("2").pow(127).subtract(BigInteger.ONE);
return BigInteger.valueOf(2).pow(127).subtract(BigInteger.ONE);
} else if (partitioner.endsWith("Murmur3Partitioner")) {
return new BigInteger("2").pow(63).subtract(BigInteger.ONE);
return BigInteger.valueOf(2).pow(63).subtract(BigInteger.ONE);
} else {
throw new UnsupportedOperationException(
"Unsupported partitioner. " + "Only Random and Murmur3 are supported");
Expand Down Expand Up @@ -84,7 +84,7 @@ List<List<RingRange>> generateSplits(long totalSplitCount, List<BigInteger> ring
BigInteger start = ringTokens.get(i);
BigInteger stop = ringTokens.get((i + 1) % tokenRangeCount);

if (!inRange(start) || !inRange(stop)) {
if (!isInRange(start) || !isInRange(stop)) {
throw new RuntimeException(
String.format("Tokens (%s,%s) not in range of %s", start, stop, partitioner));
}
Expand Down Expand Up @@ -127,7 +127,7 @@ List<List<RingRange>> generateSplits(long totalSplitCount, List<BigInteger> ring

// Append the splits between the endpoints
for (int j = 0; j < splitCount; j++) {
splits.add(new RingRange(endpointTokens.get(j), endpointTokens.get(j + 1)));
splits.add(RingRange.of(endpointTokens.get(j), endpointTokens.get(j + 1)));
LOG.debug("Split #{}: [{},{})", j + 1, endpointTokens.get(j), endpointTokens.get(j + 1));
}
}
Expand All @@ -144,7 +144,7 @@ List<List<RingRange>> generateSplits(long totalSplitCount, List<BigInteger> ring
return coalesceSplits(getTargetSplitSize(totalSplitCount), splits);
}

private boolean inRange(BigInteger token) {
private boolean isInRange(BigInteger token) {
return !(token.compareTo(rangeMin) < 0 || token.compareTo(rangeMax) > 0);
}

Expand Down
Loading

0 comments on commit e12fc33

Please sign in to comment.