Skip to content

Commit

Permalink
Adds implementation for supporting columnar batch reads from Spark. (#…
Browse files Browse the repository at this point in the history
…198)

This bypasses most of the existing translation code for the following reasons:
1.  I think there might be a memory leak because the existing code doesn't close the allocator.
2.  This avoids continuously recopying the schema.

I didn't delete the old code because it appears the BigQueryRDD still relies on it partially.

I also couldn't find instructions on formatting/testing (I couldn't find explicit unit tests
for existing arrow code, I'll update accordingly if pointers can be provided).
  • Loading branch information
emkornfield authored Jul 7, 2020
1 parent a209789 commit 22b41d3
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.google.cloud.bigquery.connector.common;

import com.google.api.gax.rpc.ServerStream;
import com.google.cloud.bigquery.storage.v1.BigQueryReadClient;
import com.google.cloud.bigquery.storage.v1.ReadRowsRequest;
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
Expand All @@ -29,6 +30,7 @@ public class ReadRowsHelper {
private ReadRowsRequest.Builder request;
private int maxReadRowsRetries;
private BigQueryReadClient client;
private ServerStream<ReadRowsResponse> incomingStream;

public ReadRowsHelper(
BigQueryReadClientFactory bigQueryReadClientFactory,
Expand All @@ -51,7 +53,13 @@ public Iterator<ReadRowsResponse> readRows() {

// In order to enable testing
protected Iterator<ReadRowsResponse> fetchResponses(ReadRowsRequest.Builder readRowsRequest) {
return client.readRowsCallable().call(readRowsRequest.build()).iterator();
incomingStream = client.readRowsCallable().call(readRowsRequest.build());
return incomingStream.iterator();
}

@Override
public String toString() {
return request.toString();
}

// Ported from https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/150
Expand Down Expand Up @@ -89,7 +97,7 @@ public ReadRowsResponse next() {
serverResponses = helper.fetchResponses(helper.request.setOffset(readRowsCount));
retries++;
} else {
helper.client.close();
helper.close();
throw e;
}
}
Expand All @@ -100,6 +108,10 @@ public ReadRowsResponse next() {
}

public void close() {
if (incomingStream != null) {
incomingStream.cancel();
incomingStream = null;
}
if (!client.isShutdown()) {
client.close();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright 2018 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.spark.bigquery.v2;

import com.google.cloud.bigquery.connector.common.ReadRowsHelper;
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
import com.google.cloud.spark.bigquery.ArrowSchemaConverter;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.io.InputStream;
import java.io.SequenceInputStream;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;

class ArrowColumnBatchPartitionColumnBatchReader implements InputPartitionReader<ColumnarBatch> {
private static final long maxAllocation = 500 * 1024 * 1024;

private final ReadRowsHelper readRowsHelper;
private final ArrowStreamReader reader;
private final BufferAllocator allocator;
private final List<String> namesInOrder;
private ColumnarBatch currentBatch;
private boolean closed = false;

static class ReadRowsResponseInputStreamEnumeration
implements java.util.Enumeration<InputStream> {
private Iterator<ReadRowsResponse> responses;
private ReadRowsResponse currentResponse;

ReadRowsResponseInputStreamEnumeration(Iterator<ReadRowsResponse> responses) {
this.responses = responses;
loadNextResponse();
}

public boolean hasMoreElements() {
return currentResponse != null;
}

public InputStream nextElement() {
if (!hasMoreElements()) {
throw new NoSuchElementException("No more responses");
}
ReadRowsResponse ret = currentResponse;
loadNextResponse();
return ret.getArrowRecordBatch().getSerializedRecordBatch().newInput();
}

void loadNextResponse() {
if (responses.hasNext()) {
currentResponse = responses.next();
} else {
currentResponse = null;
}
}
}

ArrowColumnBatchPartitionColumnBatchReader(
Iterator<ReadRowsResponse> readRowsResponses,
ByteString schema,
ReadRowsHelper readRowsHelper,
List<String> namesInOrder) {
this.allocator =
(new RootAllocator(maxAllocation))
.newChildAllocator("ArrowBinaryIterator", 0, maxAllocation);
this.readRowsHelper = readRowsHelper;
this.namesInOrder = namesInOrder;

InputStream batchStream =
new SequenceInputStream(new ReadRowsResponseInputStreamEnumeration(readRowsResponses));
InputStream fullStream = new SequenceInputStream(schema.newInput(), batchStream);

reader = new ArrowStreamReader(fullStream, allocator);
}

@Override
public boolean next() throws IOException {
if (closed) {
return false;
}
closed = !reader.loadNextBatch();
if (closed) {
return false;
}
VectorSchemaRoot root = reader.getVectorSchemaRoot();
if (currentBatch == null) {
// trying to verify from dev@spark but this object
// should only need to get created once. The underlying
// vectors should stay the same.
ColumnVector[] columns =
namesInOrder.stream()
.map(root::getVector)
.map(ArrowSchemaConverter::new)
.toArray(ColumnVector[]::new);

currentBatch = new ColumnarBatch(columns);
}
currentBatch.setNumRows(root.getRowCount());
return true;
}

@Override
public ColumnarBatch get() {
return currentBatch;
}

@Override
public void close() throws IOException {
closed = true;
try {
readRowsHelper.close();
} catch (Exception e) {
throw new IOException("Failure closing stream: " + readRowsHelper, e);
} finally {
try {
AutoCloseables.close(reader, allocator);
} catch (Exception e) {
throw new IOException("Failure closing arrow components. stream: " + readRowsHelper, e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2018 Google Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.cloud.spark.bigquery.v2;

import com.google.cloud.bigquery.connector.common.BigQueryReadClientFactory;
import com.google.cloud.bigquery.connector.common.ReadRowsHelper;
import com.google.cloud.bigquery.connector.common.ReadSessionResponse;
import com.google.cloud.bigquery.storage.v1.ReadRowsRequest;
import com.google.cloud.bigquery.storage.v1.ReadRowsResponse;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.util.Iterator;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
import org.apache.spark.sql.vectorized.ColumnarBatch;

public class ArrowInputPartition implements InputPartition<ColumnarBatch> {

private final BigQueryReadClientFactory bigQueryReadClientFactory;
private final String streamName;
private final int maxReadRowsRetries;
private final ImmutableList<String> selectedFields;
private final ByteString serializedArrowSchema;

public ArrowInputPartition(
BigQueryReadClientFactory bigQueryReadClientFactory,
String name,
int maxReadRowsRetries,
ImmutableList<String> selectedFields,
ReadSessionResponse readSessionResponse) {
this.bigQueryReadClientFactory = bigQueryReadClientFactory;
this.streamName = name;
this.maxReadRowsRetries = maxReadRowsRetries;
this.selectedFields = selectedFields;
this.serializedArrowSchema =
readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema();
}

@Override
public InputPartitionReader<ColumnarBatch> createPartitionReader() {
ReadRowsRequest.Builder readRowsRequest =
ReadRowsRequest.newBuilder().setReadStream(streamName);
ReadRowsHelper readRowsHelper =
new ReadRowsHelper(bigQueryReadClientFactory, readRowsRequest, maxReadRowsRetries);
Iterator<ReadRowsResponse> readRowsResponses = readRowsHelper.readRows();
return new ArrowColumnBatchPartitionColumnBatchReader(
readRowsResponses, serializedArrowSchema, readRowsHelper, selectedFields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.sql.vectorized.ColumnarBatch;

public class BigQueryDataSourceReader
implements DataSourceReader,
SupportsPushDownRequiredColumns,
SupportsPushDownFilters,
SupportsReportStatistics {
SupportsReportStatistics,
SupportsScanColumnarBatch {

private static Statistics UNKNOWN_STATISTICS =
new Statistics() {
Expand Down Expand Up @@ -87,9 +89,14 @@ public StructType readSchema() {
return schema.orElse(SchemaConverters.toSpark(table.getDefinition().getSchema()));
}

@Override
public boolean enableBatchRead() {
return readSessionCreatorConfig.getReadDataFormat() == DataFormat.ARROW && !isEmptySchema();
}

@Override
public List<InputPartition<InternalRow>> planInputPartitions() {
if (schema.map(StructType::isEmpty).orElse(false)) {
if (isEmptySchema()) {
// create empty projection
return createEmptyProjectionPartitions();
}
Expand Down Expand Up @@ -117,10 +124,44 @@ public List<InputPartition<InternalRow>> planInputPartitions() {
.collect(Collectors.toList());
}

@Override
public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() {
if (!enableBatchRead()) {
throw new IllegalStateException("Batch reads should not be enabled");
}
ImmutableList<String> selectedFields =
schema
.map(requiredSchema -> ImmutableList.copyOf(requiredSchema.fieldNames()))
.orElse(ImmutableList.of());
Optional<String> filter =
emptyIfNeeded(
SparkFilterUtils.getCompiledFilter(
readSessionCreatorConfig.getReadDataFormat(), globalFilter, pushedFilters));
ReadSessionResponse readSessionResponse =
readSessionCreator.create(
tableId, selectedFields, filter, readSessionCreatorConfig.getMaxParallelism());
ReadSession readSession = readSessionResponse.getReadSession();
return readSession.getStreamsList().stream()
.map(
stream ->
new ArrowInputPartition(
bigQueryReadClientFactory,
stream.getName(),
readSessionCreatorConfig.getMaxReadRowsRetries(),
selectedFields,
readSessionResponse))
.collect(Collectors.toList());
}

private boolean isEmptySchema() {
return schema.map(StructType::isEmpty).orElse(false);
}

private ReadRowsResponseToInternalRowIteratorConverter createConverter(
ImmutableList<String> selectedFields, ReadSessionResponse readSessionResponse) {
ReadRowsResponseToInternalRowIteratorConverter converter;
if (readSessionCreatorConfig.getReadDataFormat() == DataFormat.AVRO) {
DataFormat format = readSessionCreatorConfig.getReadDataFormat();
if (format == DataFormat.AVRO) {
Schema schema = readSessionResponse.getReadTableInfo().getDefinition().getSchema();
if (selectedFields.isEmpty()) {
// means select *
Expand All @@ -138,11 +179,9 @@ private ReadRowsResponseToInternalRowIteratorConverter createConverter(
}
return ReadRowsResponseToInternalRowIteratorConverter.avro(
schema, selectedFields, readSessionResponse.getReadSession().getAvroSchema().getSchema());
} else {
return ReadRowsResponseToInternalRowIteratorConverter.arrow(
selectedFields,
readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema());
}
throw new IllegalArgumentException(
"No known converted for " + readSessionCreatorConfig.getReadDataFormat());
}

List<InputPartition<InternalRow>> createEmptyProjectionPartitions() {
Expand Down

0 comments on commit 22b41d3

Please sign in to comment.