From 5ff9a751946ddc497b16283dde9746b7e5237da2 Mon Sep 17 00:00:00 2001 From: David Rabinowitz Date: Tue, 23 Jun 2020 13:52:52 -0700 Subject: [PATCH 1/9] Adding acceptance test on Dataproc (#193) In order to run the test: `sbt package acceptance:test` --- build.sbt | 82 +++++-- .../common/BigQueryCredentialsSupplier.java | 2 +- .../v2/BigQueryInputPartitionReaderTest.java | 5 +- .../acceptance/AcceptanceTestContext.java | 39 ++++ .../acceptance/AcceptanceTestUtils.java | 136 ++++++++++++ .../DataprocAcceptanceTestBase.java | 204 ++++++++++++++++++ .../DataprocImage13AcceptanceTest.java | 38 ++++ .../DataprocImage14AcceptanceTest.java | 38 ++++ .../DataprocImage15AcceptanceTest.java | 38 ++++ .../resources/acceptance/read_shakespeare.py | 37 ++++ 10 files changed, 600 insertions(+), 19 deletions(-) create mode 100644 published/src/test/java/com/google/spark/bigquery/acceptance/AcceptanceTestContext.java create mode 100644 published/src/test/java/com/google/spark/bigquery/acceptance/AcceptanceTestUtils.java create mode 100644 published/src/test/java/com/google/spark/bigquery/acceptance/DataprocAcceptanceTestBase.java create mode 100644 published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage13AcceptanceTest.java create mode 100644 published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage14AcceptanceTest.java create mode 100644 published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage15AcceptanceTest.java create mode 100644 published/src/test/resources/acceptance/read_shakespeare.py diff --git a/build.sbt b/build.sbt index 0c5d2f3e82..5817ac6fe2 100644 --- a/build.sbt +++ b/build.sbt @@ -1,3 +1,18 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ lazy val scala211Version = "2.11.12" lazy val scala212Version = "2.12.10" lazy val sparkVersion = "2.4.0" @@ -9,14 +24,29 @@ lazy val commonSettings = Seq( crossScalaVersions := Seq(scala211Version, scala212Version) ) +// scalastyle:off // For https://github.com/GoogleCloudPlatform/spark-bigquery-connector/issues/72 // Based on // https://github.com/sbt/sbt-assembly/#q-despite-the-concerned-friends-i-still-want-publish-fat-jars-what-advice-do-you-have +// scalastyle:on lazy val root = (project in file(".")) .disablePlugins(AssemblyPlugin) .settings(commonSettings, skip in publish := true) .aggregate(connector, fatJar, published) +lazy val commonTestDependencies = Seq( + "io.grpc" % "grpc-alts" % "1.30.0", + "io.grpc" % "grpc-netty-shaded" % "1.30.0", + "com.google.api" % "gax-grpc" % "1.57.0", + "com.google.guava" % "guava" % "29.0-jre", + + "org.scalatest" %% "scalatest" % "3.1.0" % "test", + "org.mockito" %% "mockito-scala-scalatest" % "1.10.0" % "test", + "junit" % "junit" % "4.13" % "test", + "com.novocode" % "junit-interface" % "0.11" % "test", + "com.google.truth" % "truth" % "1.0.1" % "test" +) + lazy val connector = (project in file("connector")) .enablePlugins(BuildInfoPlugin) .configs(ITest) @@ -36,10 +66,11 @@ lazy val connector = (project in file("connector")) IO.write(file, s"scala.version=${scalaVersion.value}\n") Seq(file) }.taskValue, - libraryDependencies ++= Seq( + libraryDependencies ++= (commonTestDependencies ++ Seq( "org.apache.spark" %% "spark-core" % sparkVersion % "provided", "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", - "org.slf4j" % "slf4j-api" % "1.7.25" % "provided", + "org.slf4j" % "slf4j-api" % "1.7.16" % "provided", + "aopalliance" % "aopalliance" % "1.0" % "provided", "org.codehaus.jackson" % "jackson-core-asl" % "1.9.13" % "provided", "org.codehaus.jackson" % "jackson-mapper-asl" % "1.9.13" % "provided", "org.apache.arrow" % "arrow-vector" % "0.16.0", @@ -49,10 +80,6 @@ lazy val connector = (project in file("connector")) "com.google.cloud" % "google-cloud-bigquery" % "1.116.1", "com.google.cloud" % "google-cloud-bigquerystorage" % "0.133.2-beta", // Keep in sync with com.google.cloud - "io.grpc" % "grpc-alts" % "1.29.0", - "io.grpc" % "grpc-netty-shaded" % "1.29.0", - "com.google.api" % "gax-grpc" % "1.56.0", - "com.google.guava" % "guava" % "29.0-jre", "com.fasterxml.jackson.core" % "jackson-databind" % "2.11.0", "com.fasterxml.jackson.module" % "jackson-module-paranamer" % "2.11.0", "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.11.0", @@ -60,14 +87,13 @@ lazy val connector = (project in file("connector")) // runtime // scalastyle:off - "com.google.cloud.bigdataoss" % "gcs-connector" % "hadoop2-2.0.0" % "runtime" classifier("shaded"), + "com.google.cloud.bigdataoss" % "gcs-connector" % "hadoop2-2.0.0" % "runtime" classifier("shaded") + exclude("com.google.cloud.bigdataoss", "util-hadoop"), // scalastyle:on // test - "org.scalatest" %% "scalatest" % "3.1.0" % "test", - "org.mockito" %% "mockito-scala-scalatest" % "1.10.0" % "test", - "org.apache.spark" %% "spark-avro" % sparkVersion % "test", - "com.google.truth" % "truth" % "1.0.1" % "test") + "org.apache.spark" %% "spark-avro" % sparkVersion % "test" + )) .map(_.excludeAll(excludedOrgs.map(ExclusionRule(_)): _*)) ) @@ -88,8 +114,9 @@ lazy val fatJar = project case PathList(ps@_*) if ps.last.endsWith(".proto") => MergeStrategy.discard // Relocate netty-tcnative.so. This is necessary even though gRPC shades it, because we shade // gRPC. - case PathList("META-INF", "native", f) if f.contains("netty_tcnative") => RelocationMergeStrategy( - path => path.replace("native/lib", s"native/lib${relocationPrefix.replace('.', '_')}_")) + case PathList("META-INF", "native", f) if f.contains("netty_tcnative") => + RelocationMergeStrategy(path => + path.replace("native/lib", s"native/lib${relocationPrefix.replace('.', '_')}_")) // Relocate GRPC service registries case PathList("META-INF", "services", _) => ServiceResourceMergeStrategy(renamed, @@ -99,13 +126,29 @@ lazy val fatJar = project ) .dependsOn(connector) - +val publishVerified = taskKey[Seq[String]]("Published signed artifact after acceptance test") lazy val published = project + .configs(AcceptanceTest) .settings( commonSettings, publishSettings, name := "spark-bigquery-with-dependencies", - packageBin in Compile := (assembly in(fatJar, Compile)).value + packageBin in Compile := (assembly in(fatJar, Compile)).value, + test in AcceptanceTest := (test in AcceptanceTest).dependsOn(packageBin in Compile).value, + // publishSigned in SbtPgp := publishSigned.dependsOn(test in AcceptanceTest).value, + inConfig(AcceptanceTest)(Defaults.testTasks), + testOptions in Test := Seq(Tests.Filter(unitFilter)), + testOptions in AcceptanceTest := Seq(Tests.Filter(acceptanceFilter)), + publishVerified := { Seq( + (packageBin in Compile).value.toString, + (test in AcceptanceTest).value.toString + ) }, + libraryDependencies ++= (commonTestDependencies ++ Seq( + "com.google.cloud" % "google-cloud-dataproc" % "1.0.0" % "test", + "com.google.cloud" % "google-cloud-storage" % "1.109.1" % "test" + )) + .map(_.excludeAll(excludedOrgs.map(ExclusionRule(_)): _*)) + ) lazy val myPackage = "com.google.cloud.spark.bigquery" @@ -156,10 +199,17 @@ lazy val ITest = config("it") extend Test (test in Test) := ((test in Test) dependsOn scalastyle.in(Test).toTask("")).value parallelExecution in ITest := false -def unitFilter(name: String): Boolean = (name endsWith "Suite") && !itFilter(name) +// Default IntegrationTest config uses separate test directory, build files +lazy val AcceptanceTest = config("acceptance") extend Test +parallelExecution in AcceptanceTest := false + +def unitFilter(name: String): Boolean = + (name.endsWith("Suite") || name.endsWith("Test")) && !itFilter(name) && !acceptanceFilter(name) def itFilter(name: String): Boolean = name endsWith "ITSuite" +def acceptanceFilter(name: String): Boolean = name endsWith "AcceptanceTest" + lazy val publishSettings = Seq( homepage := Some(url("https://github.com/GoogleCloudPlatform/spark-bigquery-connector")), scmInfo := Some(ScmInfo(url("https://github.com/GoogleCloudPlatform/spark-bigquery-connector"), diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryCredentialsSupplier.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryCredentialsSupplier.java index 131d7344f4..e035b3343e 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryCredentialsSupplier.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryCredentialsSupplier.java @@ -69,7 +69,7 @@ private static Credentials createCredentialsFromFile(String file) { } } - private static Credentials createDefaultCredentials() { + public static Credentials createDefaultCredentials() { try { return GoogleCredentials.getApplicationDefault(); } catch (IOException e) { diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReaderTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReaderTest.java index a3340803f1..c42c34c4c4 100644 --- a/connector/src/test/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReaderTest.java +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReaderTest.java @@ -18,9 +18,11 @@ import com.google.cloud.bigquery.*; import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; +import com.google.cloud.bigquery.storage.v1.StreamStats; import com.google.cloud.spark.bigquery.ReadRowsResponseToInternalRowIteratorConverter; import com.google.common.collect.ImmutableList; import com.google.protobuf.TextFormat; +import org.apache.log4j.Logger; import org.apache.spark.sql.catalyst.InternalRow; import org.junit.Test; @@ -55,8 +57,7 @@ public class BigQueryInputPartitionReaderTest { ); private static final String ALL_TYPES_TABLE_READ_ROWS_RESPONSE_STR = - "status {\n" + - " fraction_consumed: 0.5\n" + + "stats {\n" + " progress {\n" + " at_response_end: 0.5\n" + " }\n" + diff --git a/published/src/test/java/com/google/spark/bigquery/acceptance/AcceptanceTestContext.java b/published/src/test/java/com/google/spark/bigquery/acceptance/AcceptanceTestContext.java new file mode 100644 index 0000000000..c2ce6db35b --- /dev/null +++ b/published/src/test/java/com/google/spark/bigquery/acceptance/AcceptanceTestContext.java @@ -0,0 +1,39 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.spark.bigquery.acceptance; + +public class AcceptanceTestContext { + + final String testId; + final String clusterId; + final String connectorJarUri; + final String testBaseGcsDir; + + public AcceptanceTestContext(String testId, String clusterId) { + this.testId = testId; + this.clusterId = clusterId; + this.testBaseGcsDir = AcceptanceTestUtils.createTestBaseGcsDir(testId); + this.connectorJarUri = this.testBaseGcsDir + "/connector.jar"; + } + + public String getScriptUri(String testName) { + return testBaseGcsDir + "/" + testName + "/script.py"; + } + + public String getResultsDirUri(String testName) { + return testBaseGcsDir + "/" + testName + "/results"; + } +} diff --git a/published/src/test/java/com/google/spark/bigquery/acceptance/AcceptanceTestUtils.java b/published/src/test/java/com/google/spark/bigquery/acceptance/AcceptanceTestUtils.java new file mode 100644 index 0000000000..87ad483469 --- /dev/null +++ b/published/src/test/java/com/google/spark/bigquery/acceptance/AcceptanceTestUtils.java @@ -0,0 +1,136 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.spark.bigquery.acceptance; + +import com.google.cloud.WriteChannel; +import com.google.cloud.storage.*; +import com.google.common.io.ByteStreams; + +import java.io.*; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.attribute.FileTime; +import java.util.Comparator; +import java.util.stream.StreamSupport; + +public class AcceptanceTestUtils { + + // must be set in order to run the acceptance test + private static final String BUCKET = System.getenv("ACCEPTANCE_TEST_BUCKET"); + + static Storage storage = + new StorageOptions.DefaultStorageFactory().create(StorageOptions.getDefaultInstance()); + + public static Path getAssemblyJar(Path targetDir) { + try { + return Files.list(targetDir) + .filter(Files::isRegularFile) + .filter(AcceptanceTestUtils::isAssemblyJar) + .max(Comparator.comparing(AcceptanceTestUtils::lastModifiedTime)) + .get(); + } catch (IOException e) { + throw new UncheckedIOException(e.getMessage(), e); + } + } + + private static boolean isAssemblyJar(Path path) { + String name = path.toFile().getName(); + return name.endsWith(".jar") && name.contains("-assembly-"); + } + + private static FileTime lastModifiedTime(Path path) { + try { + return Files.getLastModifiedTime(path); + } catch (IOException e) { + throw new UncheckedIOException(e.getMessage(), e); + } + } + + public static BlobId copyToGcs(Path source, String destinationUri, String contentType) + throws Exception { + File sourceFile = source.toFile(); + try (FileInputStream sourceInputStream = new FileInputStream(sourceFile)) { + FileChannel sourceFileChannel = sourceInputStream.getChannel(); + MappedByteBuffer sourceContent = + sourceFileChannel.map(FileChannel.MapMode.READ_ONLY, 0, sourceFile.length()); + return uploadToGcs(sourceContent, destinationUri, contentType); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Failed to write '%s' to '%s'", source, destinationUri), e); + } + } + + public static BlobId uploadToGcs(InputStream source, String destinationUri, String contentType) + throws Exception { + try { + ByteBuffer sourceContent = ByteBuffer.wrap(ByteStreams.toByteArray(source)); + return uploadToGcs(sourceContent, destinationUri, contentType); + } catch (IOException e) { + throw new UncheckedIOException(String.format("Failed to write to '%s'", destinationUri), e); + } + } + + public static BlobId uploadToGcs(ByteBuffer content, String destinationUri, String contentType) + throws Exception { + URI uri = new URI(destinationUri); + BlobId blobId = BlobId.of(uri.getHost(), uri.getPath().substring(1)); + BlobInfo blobInfo = BlobInfo.newBuilder(blobId).setContentType(contentType).build(); + try (WriteChannel writer = storage.writer(blobInfo)) { + writer.write(content); + } catch (IOException e) { + throw new UncheckedIOException(String.format("Failed to write to '%s'", destinationUri), e); + } + return blobId; + } + + public static String createTestBaseGcsDir(String testId) { + return String.format("gs://%s/tests/%s", BUCKET, testId); + } + + public static String getCsv(String resultsDirUri) throws Exception { + URI uri = new URI(resultsDirUri); + Blob csvBlob = + StreamSupport.stream( + storage + .list(uri.getHost(), Storage.BlobListOption.prefix(uri.getPath().substring(1))) + .iterateAll() + .spliterator(), + false) + .filter(blob -> blob.getName().endsWith("csv")) + .findFirst() + .get(); + return new String(storage.readAllBytes(csvBlob.getBlobId()), StandardCharsets.UTF_8); + } + + protected static void deleteGcsDir(String testBaseGcsDir) throws Exception { + URI uri = new URI(testBaseGcsDir); + BlobId[] blobIds = + StreamSupport.stream( + storage + .list(uri.getHost(), Storage.BlobListOption.prefix(uri.getPath().substring(1))) + .iterateAll() + .spliterator(), + false) + .map(Blob::getBlobId) + .toArray(BlobId[]::new); + storage.delete(blobIds); + } +} diff --git a/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocAcceptanceTestBase.java b/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocAcceptanceTestBase.java new file mode 100644 index 0000000000..9c37ebc77d --- /dev/null +++ b/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocAcceptanceTestBase.java @@ -0,0 +1,204 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.spark.bigquery.acceptance; + +import com.google.cloud.dataproc.v1.*; +import org.junit.AssumptionViolatedException; +import org.junit.Test; +import scala.util.Properties; + +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import static com.google.common.truth.Truth.assertThat; + +public class DataprocAcceptanceTestBase { + + public static final String US_CENTRAL_1_DATAPROC_ENDPOINT = + "us-central1-dataproc.googleapis.com:443"; + private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT"); + private static final String REGION = "us-central1"; + private AcceptanceTestContext context; + + protected DataprocAcceptanceTestBase(AcceptanceTestContext context) { + this.context = context; + } + + protected static AcceptanceTestContext setup(String scalaVersion, String dataprocImageVersion) + throws Exception { + // this line will abort the test for the wrong scala version + String runtimeScalaVersion = Properties.versionNumberString(); + if (!runtimeScalaVersion.startsWith(scalaVersion)) { + throw new AssumptionViolatedException( + String.format( + "Test is for scala %s, Runtime is scala %s", scalaVersion, runtimeScalaVersion)); + } + String testId = + String.format( + "%s-%s%s", + System.currentTimeMillis(), + dataprocImageVersion.charAt(0), + dataprocImageVersion.charAt(2)); + String clusterName = createClusterIfNeeded(dataprocImageVersion, testId); + AcceptanceTestContext acceptanceTestContext = new AcceptanceTestContext(testId, clusterName); + uploadConnectorJar(scalaVersion, acceptanceTestContext.connectorJarUri); + return acceptanceTestContext; + } + + protected static void tearDown(AcceptanceTestContext context) throws Exception { + if (context != null) { + terminateCluster(context.clusterId); + AcceptanceTestUtils.deleteGcsDir(context.testBaseGcsDir); + } + } + + protected static String createClusterIfNeeded(String dataprocImageVersion, String testId) + throws Exception { + String clusterName = generateClusterName(dataprocImageVersion, testId); + cluster( + client -> + client + .createClusterAsync( + PROJECT_ID, REGION, createCluster(clusterName, dataprocImageVersion)) + .get()); + return clusterName; + } + + protected static void terminateCluster(String clusterName) throws Exception { + cluster(client -> client.deleteClusterAsync(PROJECT_ID, REGION, clusterName).get()); + } + + private static void cluster(ThrowingConsumer command) throws Exception { + try (ClusterControllerClient clusterControllerClient = + ClusterControllerClient.create( + ClusterControllerSettings.newBuilder() + .setEndpoint("us-central1-dataproc.googleapis.com:443") + .build())) { + command.accept(clusterControllerClient); + } + } + + private static String generateClusterName(String dataprocImageVersion, String testId) { + return String.format("spark-bigquery-it-%s", testId); + } + + private static Cluster createCluster(String clusterName, String dataprocImageVersion) { + return Cluster.newBuilder() + .setClusterName(clusterName) + .setProjectId(PROJECT_ID) + .setConfig( + ClusterConfig.newBuilder() + .setGceClusterConfig( + GceClusterConfig.newBuilder() + .setNetworkUri("default") + .setZoneUri(REGION + "-a")) + .setMasterConfig( + InstanceGroupConfig.newBuilder() + .setNumInstances(1) + .setMachineTypeUri("n1-standard-4") + .setDiskConfig( + DiskConfig.newBuilder() + .setBootDiskType("pd-standard") + .setBootDiskSizeGb(300) + .setNumLocalSsds(0))) + .setWorkerConfig( + InstanceGroupConfig.newBuilder() + .setNumInstances(2) + .setMachineTypeUri("n1-standard-4") + .setDiskConfig( + DiskConfig.newBuilder() + .setBootDiskType("pd-standard") + .setBootDiskSizeGb(300) + .setNumLocalSsds(0))) + .setSoftwareConfig( + SoftwareConfig.newBuilder().setImageVersion(dataprocImageVersion))) + .build(); + } + + private static void uploadConnectorJar(String scalaVersion, String connectorJarUri) + throws Exception { + Path targetDir = Paths.get(String.format("fatJar/target/scala-%s/", scalaVersion)); + Path assemblyJar = AcceptanceTestUtils.getAssemblyJar(targetDir); + AcceptanceTestUtils.copyToGcs(assemblyJar, connectorJarUri, "application/java-archive"); + } + + @Test + public void testRead() throws Exception { + String testName = "test-read"; + AcceptanceTestUtils.uploadToGcs( + getClass().getResourceAsStream("/acceptance/read_shakespeare.py"), + context.getScriptUri(testName), + "text/x-python"); + Job job = + Job.newBuilder() + .setPlacement(JobPlacement.newBuilder().setClusterName(context.clusterId)) + .setPysparkJob( + PySparkJob.newBuilder() + .setMainPythonFileUri(context.getScriptUri(testName)) + .addJarFileUris(context.connectorJarUri) + .addArgs(context.getResultsDirUri(testName))) + .build(); + Job result = runAndWait(job, Duration.ofSeconds(60)); + assertThat(result.getStatus().getState()).isEqualTo(JobStatus.State.DONE); + String output = AcceptanceTestUtils.getCsv(context.getResultsDirUri(testName)); + assertThat(output.trim()).isEqualTo("spark,10"); + } + + private Job runAndWait(Job job, Duration timeout) throws Exception { + try (JobControllerClient jobControllerClient = + JobControllerClient.create( + JobControllerSettings.newBuilder() + .setEndpoint(US_CENTRAL_1_DATAPROC_ENDPOINT) + .build())) { + Job request = jobControllerClient.submitJob(PROJECT_ID, REGION, job); + String jobId = request.getReference().getJobId(); + CompletableFuture finishedJobFuture = + CompletableFuture.supplyAsync( + () -> waitForJobCompletion(jobControllerClient, PROJECT_ID, REGION, jobId)); + Job jobInfo = finishedJobFuture.get(timeout.getSeconds(), TimeUnit.SECONDS); + return jobInfo; + } + } + + Job waitForJobCompletion( + JobControllerClient jobControllerClient, String projectId, String region, String jobId) { + while (true) { + // Poll the service periodically until the Job is in a finished state. + Job jobInfo = jobControllerClient.getJob(projectId, region, jobId); + switch (jobInfo.getStatus().getState()) { + case DONE: + case CANCELLED: + case ERROR: + return jobInfo; + default: + try { + // Wait a second in between polling attempts. + TimeUnit.SECONDS.sleep(1); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + } + + @FunctionalInterface + private interface ThrowingConsumer { + void accept(T t) throws Exception; + } +} diff --git a/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage13AcceptanceTest.java b/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage13AcceptanceTest.java new file mode 100644 index 0000000000..2f69785fdd --- /dev/null +++ b/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage13AcceptanceTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.spark.bigquery.acceptance; + +import org.junit.AfterClass; +import org.junit.BeforeClass; + +public class DataprocImage13AcceptanceTest extends DataprocAcceptanceTestBase { + + private static AcceptanceTestContext context; + + public DataprocImage13AcceptanceTest() { + super(context); + } + + @BeforeClass + public static void setup() throws Exception { + context = DataprocAcceptanceTestBase.setup("2.11", "1.3-debian9"); + } + + @AfterClass + public static void tearDown() throws Exception { + DataprocAcceptanceTestBase.tearDown(context); + } +} diff --git a/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage14AcceptanceTest.java b/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage14AcceptanceTest.java new file mode 100644 index 0000000000..95d9fd639e --- /dev/null +++ b/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage14AcceptanceTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.spark.bigquery.acceptance; + +import org.junit.AfterClass; +import org.junit.BeforeClass; + +public class DataprocImage14AcceptanceTest extends DataprocAcceptanceTestBase { + + private static AcceptanceTestContext context; + + public DataprocImage14AcceptanceTest() { + super(context); + } + + @BeforeClass + public static void setup() throws Exception { + context = DataprocAcceptanceTestBase.setup("2.11", "1.4-debian9"); + } + + @AfterClass + public static void tearDown() throws Exception { + DataprocAcceptanceTestBase.tearDown(context); + } +} diff --git a/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage15AcceptanceTest.java b/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage15AcceptanceTest.java new file mode 100644 index 0000000000..3852083558 --- /dev/null +++ b/published/src/test/java/com/google/spark/bigquery/acceptance/DataprocImage15AcceptanceTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.spark.bigquery.acceptance; + +import org.junit.AfterClass; +import org.junit.BeforeClass; + +public class DataprocImage15AcceptanceTest extends DataprocAcceptanceTestBase { + + private static AcceptanceTestContext context; + + public DataprocImage15AcceptanceTest() { + super(context); + } + + @BeforeClass + public static void setup() throws Exception { + context = DataprocAcceptanceTestBase.setup("2.12", "1.5-debian10"); + } + + @AfterClass + public static void tearDown() throws Exception { + DataprocAcceptanceTestBase.tearDown(context); + } +} diff --git a/published/src/test/resources/acceptance/read_shakespeare.py b/published/src/test/resources/acceptance/read_shakespeare.py new file mode 100644 index 0000000000..7d2372726c --- /dev/null +++ b/published/src/test/resources/acceptance/read_shakespeare.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# Copyright 2018 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pyspark.sql import SparkSession + +spark = SparkSession.builder.appName('Shakespeare on Spark').getOrCreate() + +table = 'bigquery-public-data.samples.shakespeare' +df = spark.read.format('bigquery').load(table) +# Only these columns will be read +df = df.select('word', 'word_count') +# The filters that are allowed will be automatically pushed down. +# Those that are not will be computed client side +df = df.where("word_count > 0 AND word='spark'") +# Further processing is done inside Spark +df = df.groupBy('word').sum('word_count') + +print('The resulting schema is') +df.printSchema() + +print('Spark mentions in Shakespeare') +df.show() + +df.coalesce(1).write.csv(sys.argv[1]) \ No newline at end of file From beaed86db475550d39575c48ecf3a4d8c439fc07 Mon Sep 17 00:00:00 2001 From: Gaurangi94 Date: Tue, 23 Jun 2020 15:51:05 -0700 Subject: [PATCH 2/9] Added support for materialized views (#192) --- cloudbuild/cloudbuild.yaml | 2 ++ .../bigquery/connector/common/BigQueryClient.java | 4 ++-- .../bigquery/connector/common/ReadSessionCreator.java | 2 +- .../spark/bigquery/BigQueryRelationProvider.scala | 4 ++-- .../bigquery/direct/DirectBigQueryRelation.scala | 8 ++++++-- .../bigquery/it/SparkBigQueryEndToEndITSuite.scala | 11 ++++++++++- 6 files changed, 23 insertions(+), 8 deletions(-) diff --git a/cloudbuild/cloudbuild.yaml b/cloudbuild/cloudbuild.yaml index c36af8d7fa..bb15d2eaac 100644 --- a/cloudbuild/cloudbuild.yaml +++ b/cloudbuild/cloudbuild.yaml @@ -15,6 +15,8 @@ steps: id: 'integration-tests' entrypoint: 'sbt' args: ['it:test'] + env: + - 'GOOGLE_CLOUD_PROJECT=${_GOOGLE_CLOUD_PROJECT}' # Tests take around 13 mins in general. timeout: 1200s diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java index 8236f5c265..9de36bc4a6 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java @@ -73,7 +73,7 @@ public TableInfo getSupportedTable(TableId tableId, boolean viewsEnabled, String if (TableDefinition.Type.TABLE == tableType) { return table; } - if (TableDefinition.Type.VIEW == tableType) { + if (TableDefinition.Type.VIEW == tableType || TableDefinition.Type.MATERIALIZED_VIEW == tableType) { if (viewsEnabled) { return table; } else { @@ -167,7 +167,7 @@ public long calculateTableSize(TableInfo tableInfo, Optional filter) { TableDefinition.Type type = tableInfo.getDefinition().getType(); if (type == TableDefinition.Type.TABLE && !filter.isPresent()) { return tableInfo.getNumRows().longValue(); - } else if (type == TableDefinition.Type.VIEW || + } else if (type == TableDefinition.Type.VIEW || type == TableDefinition.Type.MATERIALIZED_VIEW || (type == TableDefinition.Type.TABLE && filter.isPresent())) { // run a query String table = fullTableName(tableInfo.getTableId()); diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreator.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreator.java index 9c03bc6d65..9b53516b5f 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreator.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreator.java @@ -126,7 +126,7 @@ TableInfo getActualTable( if (TableDefinition.Type.TABLE == tableType) { return table; } - if (TableDefinition.Type.VIEW == tableType) { + if (TableDefinition.Type.VIEW == tableType || TableDefinition.Type.MATERIALIZED_VIEW == tableType) { if (!config.viewsEnabled) { throw new BigQueryConnectorException(UNSUPPORTED, format( "Views are not enabled. You can enable views by setting '%s' to true. Notice additional cost may occur.", diff --git a/connector/src/main/scala/com/google/cloud/spark/bigquery/BigQueryRelationProvider.scala b/connector/src/main/scala/com/google/cloud/spark/bigquery/BigQueryRelationProvider.scala index c0a10eaffe..49fd0b9581 100644 --- a/connector/src/main/scala/com/google/cloud/spark/bigquery/BigQueryRelationProvider.scala +++ b/connector/src/main/scala/com/google/cloud/spark/bigquery/BigQueryRelationProvider.scala @@ -16,7 +16,7 @@ package com.google.cloud.spark.bigquery import com.google.auth.Credentials -import com.google.cloud.bigquery.TableDefinition.Type.{TABLE, VIEW} +import com.google.cloud.bigquery.TableDefinition.Type.{MATERIALIZED_VIEW, TABLE, VIEW} import com.google.cloud.bigquery.{BigQuery, BigQueryOptions, TableDefinition} import com.google.cloud.spark.bigquery.direct.DirectBigQueryRelation import org.apache.spark.sql.sources._ @@ -57,7 +57,7 @@ class BigQueryRelationProvider( .getOrElse(sys.error(s"Table $tableName not found")) table.getDefinition[TableDefinition].getType match { case TABLE => new DirectBigQueryRelation(opts, table)(sqlContext) - case VIEW => if (opts.viewsEnabled) { + case VIEW | MATERIALIZED_VIEW => if (opts.viewsEnabled) { new DirectBigQueryRelation(opts, table)(sqlContext) } else { sys.error( diff --git a/connector/src/main/scala/com/google/cloud/spark/bigquery/direct/DirectBigQueryRelation.scala b/connector/src/main/scala/com/google/cloud/spark/bigquery/direct/DirectBigQueryRelation.scala index 55ad03dc0f..a901afb042 100644 --- a/connector/src/main/scala/com/google/cloud/spark/bigquery/direct/DirectBigQueryRelation.scala +++ b/connector/src/main/scala/com/google/cloud/spark/bigquery/direct/DirectBigQueryRelation.scala @@ -193,7 +193,9 @@ private[bigquery] class DirectBigQueryRelation( ): TableInfo = { val tableDefinition = table.getDefinition[TableDefinition] val tableType = tableDefinition.getType - if(options.viewsEnabled && TableDefinition.Type.VIEW == tableType) { + if(options.viewsEnabled && + (TableDefinition.Type.VIEW == tableType || + TableDefinition.Type.MATERIALIZED_VIEW == tableType)) { // get it from the view val querySql = createSql(tableDefinition.getSchema, requiredColumns, filtersString) logDebug(s"querySql is $querySql") @@ -274,7 +276,9 @@ private[bigquery] class DirectBigQueryRelation( def getNumBytes(tableDefinition: TableDefinition): Long = { val tableType = tableDefinition.getType - if (options.viewsEnabled && TableDefinition.Type.VIEW == tableType) { + if (options.viewsEnabled && + (TableDefinition.Type.VIEW == tableType || + TableDefinition.Type.MATERIALIZED_VIEW == tableType)) { sqlContext.sparkSession.sessionState.conf.defaultSizeInBytes } else { tableDefinition.asInstanceOf[StandardTableDefinition].getNumBytes diff --git a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndITSuite.scala b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndITSuite.scala index 28fced0568..be99d16f01 100644 --- a/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndITSuite.scala +++ b/connector/src/test/scala/com/google/cloud/spark/bigquery/it/SparkBigQueryEndToEndITSuite.scala @@ -346,7 +346,7 @@ class SparkBigQueryEndToEndITSuite extends FunSuite countResults should equal(countAfterCollect) } */ - + test("read data types. DataSource %s".format(dataSourceFormat)) { val allTypesTable = readAllTypesTable(dataSourceFormat) val expectedRow = spark.range(1).select(TestConstants.ALL_TYPES_TABLE_COLS: _*).head.toSeq @@ -542,6 +542,15 @@ class SparkBigQueryEndToEndITSuite extends FunSuite assert(df.schema == allTypesTable.schema) } + test("query materialized view") { + var df = spark.read.format("bigquery") + .option("table", "bigquery-public-data:ethereum_blockchain.live_logs") + .option("viewsEnabled", "true") + .option("viewMaterializationProject", System.getenv("GOOGLE_CLOUD_PROJECT")) + .option("viewMaterializationDataset", testDataset) + .load() + } + test("write to bq - adding the settings to spark.conf" ) { spark.conf.set("temporaryGcsBucket", temporaryGcsBucket) val df = initialData From a2097892a5201bdadc62ddad509dd86c67a5784d Mon Sep 17 00:00:00 2001 From: David Rabinowitz Date: Mon, 6 Jul 2020 17:14:35 -0700 Subject: [PATCH 3/9] Applying Google Java format on compile (#203) --- .../connector/common/BigQueryClient.java | 298 ++++---- .../common/BigQueryClientModule.java | 93 +-- .../connector/common/BigQueryConfig.java | 14 +- .../common/BigQueryConnectorException.java | 36 +- .../common/BigQueryCredentialsSupplier.java | 88 +-- .../connector/common/BigQueryErrorCode.java | 27 +- .../common/BigQueryReadClientFactory.java | 45 +- .../connector/common/BigQueryUtil.java | 112 ++- .../connector/common/ReadRowsHelper.java | 137 ++-- .../connector/common/ReadSessionCreator.java | 311 ++++---- .../common/ReadSessionCreatorConfig.java | 112 +-- .../connector/common/ReadSessionResponse.java | 24 +- .../common/UserAgentHeaderProvider.java | 16 +- .../connector/common/UserAgentProvider.java | 3 +- .../connector/common/VersionProvider.java | 3 +- .../spark/bigquery/ArrowBinaryIterator.java | 169 ++--- .../spark/bigquery/AvroBinaryIterator.java | 75 +- ...esponseToInternalRowIteratorConverter.java | 83 ++- .../spark/bigquery/SchemaConverters.java | 298 ++++---- .../spark/bigquery/SparkBigQueryConfig.java | 667 +++++++++--------- ...arkBigQueryConnectorUserAgentProvider.java | 118 ++-- ...SparkBigQueryConnectorVersionProvider.java | 26 +- .../spark/bigquery/SparkFilterUtils.java | 339 ++++----- .../bigquery/examples/JavaShakespeare.java | 50 +- .../bigquery/v2/BigQueryDataSourceReader.java | 303 ++++---- .../bigquery/v2/BigQueryDataSourceV2.java | 46 +- ...BigQueryEmptyProjectionInputPartition.java | 16 +- ...ryEmptyProjectionInputPartitionReader.java | 46 +- .../bigquery/v2/BigQueryInputPartition.java | 44 +- .../v2/BigQueryInputPartitionReader.java | 64 +- .../v2/SparkBigQueryConnectorModule.java | 88 +-- project/plugins.sbt | 4 +- 32 files changed, 1915 insertions(+), 1840 deletions(-) diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java index 9de36bc4a6..716bcec555 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClient.java @@ -40,146 +40,160 @@ // presto converts the dataset and table names to lower case, while BigQuery is case sensitive // the mappings here keep the mappings public class BigQueryClient { - private final BigQuery bigQuery; - private final Optional materializationProject; - private final Optional materializationDataset; - - BigQueryClient(BigQuery bigQuery, Optional materializationProject, Optional materializationDataset) { - this.bigQuery = bigQuery; - this.materializationProject = materializationProject; - this.materializationDataset = materializationDataset; - } - - // return empty if no filters are used - private static Optional createWhereClause(String[] filters) { - if (filters.length == 0) { - return Optional.empty(); - } - return Optional.of(Stream.of(filters).collect(Collectors.joining(") AND (", "(", ")"))); - } - - public TableInfo getTable(TableId tableId) { - return bigQuery.getTable(tableId); - } - - public TableInfo getSupportedTable(TableId tableId, boolean viewsEnabled, String viewEnabledParamName) { - TableInfo table = getTable(tableId); - if (table == null) { - return null; - } - - TableDefinition tableDefinition = table.getDefinition(); - TableDefinition.Type tableType = tableDefinition.getType(); - if (TableDefinition.Type.TABLE == tableType) { - return table; - } - if (TableDefinition.Type.VIEW == tableType || TableDefinition.Type.MATERIALIZED_VIEW == tableType) { - if (viewsEnabled) { - return table; - } else { - throw new BigQueryConnectorException(UNSUPPORTED, format( - "Views are not enabled. You can enable views by setting '%s' to true. Notice additional cost may occur.", - viewEnabledParamName)); - } - } - // not regular table or a view - throw new BigQueryConnectorException(UNSUPPORTED, format("Table type '%s' of table '%s.%s' is not supported", - tableType, table.getTableId().getDataset(), table.getTableId().getTable())); - } - - DatasetId toDatasetId(TableId tableId) { - return DatasetId.of(tableId.getProject(), tableId.getDataset()); - } - - String getProjectId() { - return bigQuery.getOptions().getProjectId(); - } - - Iterable listDatasets(String projectId) { - return bigQuery.listDatasets(projectId).iterateAll(); - } - - Iterable listTables(DatasetId datasetId, TableDefinition.Type... types) { - Set allowedTypes = ImmutableSet.copyOf(types); - Iterable
allTables = bigQuery.listTables(datasetId).iterateAll(); - return StreamSupport.stream(allTables.spliterator(), false) - .filter(table -> allowedTypes.contains(table.getDefinition().getType())) - .collect(toImmutableList()); - } - - TableId createDestinationTable(TableId tableId) { - String project = materializationProject.orElse(tableId.getProject()); - String dataset = materializationDataset.orElse(tableId.getDataset()); - DatasetId datasetId = DatasetId.of(project, dataset); - String name = format("_bqc_%s", randomUUID().toString().toLowerCase(ENGLISH).replace("-", "")); - return TableId.of(datasetId.getProject(), datasetId.getDataset(), name); - } - - Table update(TableInfo table) { - return bigQuery.update(table); - } - - Job create(JobInfo jobInfo) { - return bigQuery.create(jobInfo); - } - - TableResult query(String sql) { - try { - return bigQuery.query(QueryJobConfiguration.of(sql)); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new BigQueryException(BaseHttpServiceException.UNKNOWN_CODE, format("Failed to run the query [%s]", sql), e); - } - } - - String createSql(TableId table, ImmutableList requiredColumns, String[] filters) { - String columns = requiredColumns.isEmpty() ? "*" : - requiredColumns.stream().map(column -> format("`%s`", column)).collect(joining(",")); - - String whereClause = createWhereClause(filters) - .map(clause -> "WHERE " + clause) - .orElse(""); - - return createSql(table, columns, filters); - } - - // assuming the SELECT part is properly formatted, can be used to call functions such as COUNT and SUM - String createSql(TableId table, String formattedQuery, String[] filters) { - String tableName = fullTableName(table); - - String whereClause = createWhereClause(filters) - .map(clause -> "WHERE " + clause) - .orElse(""); - - return format("SELECT %s FROM `%s` %s", formattedQuery, tableName, whereClause); - } - - String fullTableName(TableId tableId) { - return format("%s.%s.%s", tableId.getProject(), tableId.getDataset(), tableId.getTable()); - } - - public long calculateTableSize(TableId tableId, Optional filter) { - return calculateTableSize(getTable(tableId), filter); - } - - public long calculateTableSize(TableInfo tableInfo, Optional filter) { - try { - TableDefinition.Type type = tableInfo.getDefinition().getType(); - if (type == TableDefinition.Type.TABLE && !filter.isPresent()) { - return tableInfo.getNumRows().longValue(); - } else if (type == TableDefinition.Type.VIEW || type == TableDefinition.Type.MATERIALIZED_VIEW || - (type == TableDefinition.Type.TABLE && filter.isPresent())) { - // run a query - String table = fullTableName(tableInfo.getTableId()); - String sql = format("SELECT COUNT(*) from `%s` WHERE %s", table, filter.get()); - TableResult result = bigQuery.query(QueryJobConfiguration.of(sql)); - return result.iterateAll().iterator().next().get(0).getLongValue(); - } else { - throw new IllegalArgumentException(format("Unsupported table type %s for table %s", - type, fullTableName(tableInfo.getTableId()))); - } - } catch (InterruptedException e) { - throw new BigQueryConnectorException("Querying table size was interrupted on the client side", e); - } - } + private final BigQuery bigQuery; + private final Optional materializationProject; + private final Optional materializationDataset; + + BigQueryClient( + BigQuery bigQuery, + Optional materializationProject, + Optional materializationDataset) { + this.bigQuery = bigQuery; + this.materializationProject = materializationProject; + this.materializationDataset = materializationDataset; + } + + // return empty if no filters are used + private static Optional createWhereClause(String[] filters) { + if (filters.length == 0) { + return Optional.empty(); + } + return Optional.of(Stream.of(filters).collect(Collectors.joining(") AND (", "(", ")"))); + } + + public TableInfo getTable(TableId tableId) { + return bigQuery.getTable(tableId); + } + + public TableInfo getSupportedTable( + TableId tableId, boolean viewsEnabled, String viewEnabledParamName) { + TableInfo table = getTable(tableId); + if (table == null) { + return null; + } + + TableDefinition tableDefinition = table.getDefinition(); + TableDefinition.Type tableType = tableDefinition.getType(); + if (TableDefinition.Type.TABLE == tableType) { + return table; + } + if (TableDefinition.Type.VIEW == tableType + || TableDefinition.Type.MATERIALIZED_VIEW == tableType) { + if (viewsEnabled) { + return table; + } else { + throw new BigQueryConnectorException( + UNSUPPORTED, + format( + "Views are not enabled. You can enable views by setting '%s' to true. Notice additional cost may occur.", + viewEnabledParamName)); + } + } + // not regular table or a view + throw new BigQueryConnectorException( + UNSUPPORTED, + format( + "Table type '%s' of table '%s.%s' is not supported", + tableType, table.getTableId().getDataset(), table.getTableId().getTable())); + } + + DatasetId toDatasetId(TableId tableId) { + return DatasetId.of(tableId.getProject(), tableId.getDataset()); + } + + String getProjectId() { + return bigQuery.getOptions().getProjectId(); + } + + Iterable listDatasets(String projectId) { + return bigQuery.listDatasets(projectId).iterateAll(); + } + + Iterable
listTables(DatasetId datasetId, TableDefinition.Type... types) { + Set allowedTypes = ImmutableSet.copyOf(types); + Iterable
allTables = bigQuery.listTables(datasetId).iterateAll(); + return StreamSupport.stream(allTables.spliterator(), false) + .filter(table -> allowedTypes.contains(table.getDefinition().getType())) + .collect(toImmutableList()); + } + + TableId createDestinationTable(TableId tableId) { + String project = materializationProject.orElse(tableId.getProject()); + String dataset = materializationDataset.orElse(tableId.getDataset()); + DatasetId datasetId = DatasetId.of(project, dataset); + String name = format("_bqc_%s", randomUUID().toString().toLowerCase(ENGLISH).replace("-", "")); + return TableId.of(datasetId.getProject(), datasetId.getDataset(), name); + } + + Table update(TableInfo table) { + return bigQuery.update(table); + } + + Job create(JobInfo jobInfo) { + return bigQuery.create(jobInfo); + } + + TableResult query(String sql) { + try { + return bigQuery.query(QueryJobConfiguration.of(sql)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new BigQueryException( + BaseHttpServiceException.UNKNOWN_CODE, format("Failed to run the query [%s]", sql), e); + } + } + + String createSql(TableId table, ImmutableList requiredColumns, String[] filters) { + String columns = + requiredColumns.isEmpty() + ? "*" + : requiredColumns.stream().map(column -> format("`%s`", column)).collect(joining(",")); + + String whereClause = createWhereClause(filters).map(clause -> "WHERE " + clause).orElse(""); + + return createSql(table, columns, filters); + } + + // assuming the SELECT part is properly formatted, can be used to call functions such as COUNT and + // SUM + String createSql(TableId table, String formattedQuery, String[] filters) { + String tableName = fullTableName(table); + + String whereClause = createWhereClause(filters).map(clause -> "WHERE " + clause).orElse(""); + + return format("SELECT %s FROM `%s` %s", formattedQuery, tableName, whereClause); + } + + String fullTableName(TableId tableId) { + return format("%s.%s.%s", tableId.getProject(), tableId.getDataset(), tableId.getTable()); + } + + public long calculateTableSize(TableId tableId, Optional filter) { + return calculateTableSize(getTable(tableId), filter); + } + + public long calculateTableSize(TableInfo tableInfo, Optional filter) { + try { + TableDefinition.Type type = tableInfo.getDefinition().getType(); + if (type == TableDefinition.Type.TABLE && !filter.isPresent()) { + return tableInfo.getNumRows().longValue(); + } else if (type == TableDefinition.Type.VIEW + || type == TableDefinition.Type.MATERIALIZED_VIEW + || (type == TableDefinition.Type.TABLE && filter.isPresent())) { + // run a query + String table = fullTableName(tableInfo.getTableId()); + String sql = format("SELECT COUNT(*) from `%s` WHERE %s", table, filter.get()); + TableResult result = bigQuery.query(QueryJobConfiguration.of(sql)); + return result.iterateAll().iterator().next().get(0).getLongValue(); + } else { + throw new IllegalArgumentException( + format( + "Unsupported table type %s for table %s", + type, fullTableName(tableInfo.getTableId()))); + } + } catch (InterruptedException e) { + throw new BigQueryConnectorException( + "Querying table size was interrupted on the client side", e); + } + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClientModule.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClientModule.java index 3edd6c1357..d8196da639 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClientModule.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryClientModule.java @@ -24,51 +24,62 @@ public class BigQueryClientModule implements Module { - @Provides - @Singleton - public static UserAgentHeaderProvider createUserAgentHeaderProvider(UserAgentProvider versionProvider) { - return new UserAgentHeaderProvider(versionProvider.getUserAgent()); - } + @Provides + @Singleton + public static UserAgentHeaderProvider createUserAgentHeaderProvider( + UserAgentProvider versionProvider) { + return new UserAgentHeaderProvider(versionProvider.getUserAgent()); + } - // Note that at this point the config has been validated, which means that option 2 or option 3 will always be valid - static String calculateBillingProjectId(Optional configParentProjectId, Credentials credentials) { - // 1. Get from configuration - if (configParentProjectId.isPresent()) { - return configParentProjectId.get(); - } - // 2. Get from the provided credentials, but only ServiceAccountCredentials contains the project id. - // All other credentials types (User, AppEngine, GCE, CloudShell, etc.) take it from the environment - if (credentials instanceof ServiceAccountCredentials) { - return ((ServiceAccountCredentials) credentials).getProjectId(); - } - // 3. No configuration was provided, so get the default from the environment - return BigQueryOptions.getDefaultProjectId(); + // Note that at this point the config has been validated, which means that option 2 or option 3 + // will always be valid + static String calculateBillingProjectId( + Optional configParentProjectId, Credentials credentials) { + // 1. Get from configuration + if (configParentProjectId.isPresent()) { + return configParentProjectId.get(); } - - @Override - public void configure(Binder binder) { - // BigQuery related - binder.bind(BigQueryReadClientFactory.class).in(Scopes.SINGLETON); + // 2. Get from the provided credentials, but only ServiceAccountCredentials contains the project + // id. + // All other credentials types (User, AppEngine, GCE, CloudShell, etc.) take it from the + // environment + if (credentials instanceof ServiceAccountCredentials) { + return ((ServiceAccountCredentials) credentials).getProjectId(); } + // 3. No configuration was provided, so get the default from the environment + return BigQueryOptions.getDefaultProjectId(); + } - @Provides - @Singleton - public BigQueryCredentialsSupplier provideBigQueryCredentialsSupplier(BigQueryConfig config) { - return new BigQueryCredentialsSupplier(config.getAccessToken(), config.getCredentialsKey(), config.getCredentialsFile()); - } + @Override + public void configure(Binder binder) { + // BigQuery related + binder.bind(BigQueryReadClientFactory.class).in(Scopes.SINGLETON); + } - @Provides - @Singleton - public BigQueryClient provideBigQueryClient(BigQueryConfig config, UserAgentHeaderProvider userAgentHeaderProvider, BigQueryCredentialsSupplier bigQueryCredentialsSupplier) { - String billingProjectId = calculateBillingProjectId(config.getParentProjectId(), bigQueryCredentialsSupplier.getCredentials()); - BigQueryOptions.Builder options = BigQueryOptions.newBuilder() - .setHeaderProvider(userAgentHeaderProvider) - .setProjectId(billingProjectId) - .setCredentials(bigQueryCredentialsSupplier.getCredentials()); - return new BigQueryClient( - options.build().getService(), - config.getMaterializationProject(), - config.getMaterializationDataset()); - } + @Provides + @Singleton + public BigQueryCredentialsSupplier provideBigQueryCredentialsSupplier(BigQueryConfig config) { + return new BigQueryCredentialsSupplier( + config.getAccessToken(), config.getCredentialsKey(), config.getCredentialsFile()); + } + @Provides + @Singleton + public BigQueryClient provideBigQueryClient( + BigQueryConfig config, + UserAgentHeaderProvider userAgentHeaderProvider, + BigQueryCredentialsSupplier bigQueryCredentialsSupplier) { + String billingProjectId = + calculateBillingProjectId( + config.getParentProjectId(), bigQueryCredentialsSupplier.getCredentials()); + BigQueryOptions.Builder options = + BigQueryOptions.newBuilder() + .setHeaderProvider(userAgentHeaderProvider) + .setProjectId(billingProjectId) + .setCredentials(bigQueryCredentialsSupplier.getCredentials()); + return new BigQueryClient( + options.build().getService(), + config.getMaterializationProject(), + config.getMaterializationDataset()); + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryConfig.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryConfig.java index 54be69bd9a..a5fb40c66c 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryConfig.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryConfig.java @@ -19,17 +19,17 @@ public interface BigQueryConfig { - Optional getCredentialsKey(); + Optional getCredentialsKey(); - Optional getCredentialsFile(); + Optional getCredentialsFile(); - Optional getAccessToken(); + Optional getAccessToken(); - Optional getParentProjectId(); + Optional getParentProjectId(); - boolean isViewsEnabled(); + boolean isViewsEnabled(); - Optional getMaterializationProject(); + Optional getMaterializationProject(); - Optional getMaterializationDataset(); + Optional getMaterializationDataset(); } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryConnectorException.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryConnectorException.java index 7f4256efc0..0f57f83e88 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryConnectorException.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryConnectorException.java @@ -19,27 +19,27 @@ public class BigQueryConnectorException extends RuntimeException { - final BigQueryErrorCode errorCode; + final BigQueryErrorCode errorCode; - public BigQueryConnectorException(String message) { - this(UNKNOWN, message); - } + public BigQueryConnectorException(String message) { + this(UNKNOWN, message); + } - public BigQueryConnectorException(String message, Throwable cause) { - this(UNKNOWN, message, cause); - } + public BigQueryConnectorException(String message, Throwable cause) { + this(UNKNOWN, message, cause); + } - public BigQueryConnectorException(BigQueryErrorCode errorCode, String message) { - super(message); - this.errorCode = errorCode; - } + public BigQueryConnectorException(BigQueryErrorCode errorCode, String message) { + super(message); + this.errorCode = errorCode; + } - public BigQueryConnectorException(BigQueryErrorCode errorCode, String message, Throwable cause) { - super(message, cause); - this.errorCode = errorCode; - } + public BigQueryConnectorException(BigQueryErrorCode errorCode, String message, Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + } - public BigQueryErrorCode getErrorCode() { - return errorCode; - } + public BigQueryErrorCode getErrorCode() { + return errorCode; + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryCredentialsSupplier.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryCredentialsSupplier.java index e035b3343e..276fac3c7e 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryCredentialsSupplier.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryCredentialsSupplier.java @@ -29,55 +29,59 @@ import static com.google.cloud.bigquery.connector.common.BigQueryUtil.firstPresent; public class BigQueryCredentialsSupplier { - private final Optional accessToken; - private final Optional credentialsKey; - private final Optional credentialsFile; - private final Credentials credentials; + private final Optional accessToken; + private final Optional credentialsKey; + private final Optional credentialsFile; + private final Credentials credentials; - public BigQueryCredentialsSupplier( - Optional accessToken, - Optional credentialsKey, - Optional credentialsFile) { - this.accessToken = accessToken; - this.credentialsKey = credentialsKey; - this.credentialsFile = credentialsFile; - // lazy creation, cache once it's created - Optional credentialsFromAccessToken = credentialsKey.map(BigQueryCredentialsSupplier::createCredentialsFromAccessToken); - Optional credentialsFromKey = credentialsKey.map(BigQueryCredentialsSupplier::createCredentialsFromKey); - Optional credentialsFromFile = credentialsFile.map(BigQueryCredentialsSupplier::createCredentialsFromFile); - this.credentials = firstPresent(credentialsFromAccessToken, credentialsFromKey, credentialsFromFile) - .orElse(createDefaultCredentials()); - } + public BigQueryCredentialsSupplier( + Optional accessToken, + Optional credentialsKey, + Optional credentialsFile) { + this.accessToken = accessToken; + this.credentialsKey = credentialsKey; + this.credentialsFile = credentialsFile; + // lazy creation, cache once it's created + Optional credentialsFromAccessToken = + credentialsKey.map(BigQueryCredentialsSupplier::createCredentialsFromAccessToken); + Optional credentialsFromKey = + credentialsKey.map(BigQueryCredentialsSupplier::createCredentialsFromKey); + Optional credentialsFromFile = + credentialsFile.map(BigQueryCredentialsSupplier::createCredentialsFromFile); + this.credentials = + firstPresent(credentialsFromAccessToken, credentialsFromKey, credentialsFromFile) + .orElse(createDefaultCredentials()); + } - private static Credentials createCredentialsFromAccessToken(String accessToken) { - return GoogleCredentials.create(new AccessToken(accessToken, null)); - } + private static Credentials createCredentialsFromAccessToken(String accessToken) { + return GoogleCredentials.create(new AccessToken(accessToken, null)); + } - private static Credentials createCredentialsFromKey(String key) { - try { - return GoogleCredentials.fromStream(new ByteArrayInputStream(Base64.decodeBase64(key))); - } catch (IOException e) { - throw new UncheckedIOException("Failed to create Credentials from key", e); - } + private static Credentials createCredentialsFromKey(String key) { + try { + return GoogleCredentials.fromStream(new ByteArrayInputStream(Base64.decodeBase64(key))); + } catch (IOException e) { + throw new UncheckedIOException("Failed to create Credentials from key", e); } + } - private static Credentials createCredentialsFromFile(String file) { - try { - return GoogleCredentials.fromStream(new FileInputStream(file)); - } catch (IOException e) { - throw new UncheckedIOException("Failed to create Credentials from file", e); - } + private static Credentials createCredentialsFromFile(String file) { + try { + return GoogleCredentials.fromStream(new FileInputStream(file)); + } catch (IOException e) { + throw new UncheckedIOException("Failed to create Credentials from file", e); } + } - public static Credentials createDefaultCredentials() { - try { - return GoogleCredentials.getApplicationDefault(); - } catch (IOException e) { - throw new UncheckedIOException("Failed to create default Credentials", e); - } + public static Credentials createDefaultCredentials() { + try { + return GoogleCredentials.getApplicationDefault(); + } catch (IOException e) { + throw new UncheckedIOException("Failed to create default Credentials", e); } + } - Credentials getCredentials() { - return credentials; - } + Credentials getCredentials() { + return credentials; + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryErrorCode.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryErrorCode.java index 7bf4ebf73f..07a13e13aa 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryErrorCode.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryErrorCode.java @@ -16,21 +16,20 @@ package com.google.cloud.bigquery.connector.common; public enum BigQueryErrorCode { - BIGQUERY_VIEW_DESTINATION_TABLE_CREATION_FAILED(0), - BIGQUERY_DATETIME_PARSING_ERROR(1), - BIGQUERY_FAILED_TO_EXECUTE_QUERY(2), - // Should be last - UNSUPPORTED(9998), - UNKNOWN(9999); + BIGQUERY_VIEW_DESTINATION_TABLE_CREATION_FAILED(0), + BIGQUERY_DATETIME_PARSING_ERROR(1), + BIGQUERY_FAILED_TO_EXECUTE_QUERY(2), + // Should be last + UNSUPPORTED(9998), + UNKNOWN(9999); - final int code; + final int code; - BigQueryErrorCode(int code) { - this.code = code; - } - - public int getCode() { - return code; - } + BigQueryErrorCode(int code) { + this.code = code; + } + public int getCode() { + return code; + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryReadClientFactory.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryReadClientFactory.java index 27b44df24e..c1df2cbac4 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryReadClientFactory.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryReadClientFactory.java @@ -31,28 +31,31 @@ * short lived clients that can be closed independently. */ public class BigQueryReadClientFactory implements Serializable { - private final Credentials credentials; - // using the user agent as HeaderProvider is not serializable - private final UserAgentHeaderProvider userAgentHeaderProvider; + private final Credentials credentials; + // using the user agent as HeaderProvider is not serializable + private final UserAgentHeaderProvider userAgentHeaderProvider; - @Inject - public BigQueryReadClientFactory(BigQueryCredentialsSupplier bigQueryCredentialsSupplier, UserAgentHeaderProvider userAgentHeaderProvider) { - // using Guava's optional as it is serializable - this.credentials = bigQueryCredentialsSupplier.getCredentials(); - this.userAgentHeaderProvider = userAgentHeaderProvider; - } + @Inject + public BigQueryReadClientFactory( + BigQueryCredentialsSupplier bigQueryCredentialsSupplier, + UserAgentHeaderProvider userAgentHeaderProvider) { + // using Guava's optional as it is serializable + this.credentials = bigQueryCredentialsSupplier.getCredentials(); + this.userAgentHeaderProvider = userAgentHeaderProvider; + } - BigQueryReadClient createBigQueryReadClient() { - try { - BigQueryReadSettings.Builder clientSettings = BigQueryReadSettings.newBuilder() - .setTransportChannelProvider( - BigQueryReadSettings.defaultGrpcTransportProviderBuilder() - .setHeaderProvider(userAgentHeaderProvider) - .build()) - .setCredentialsProvider(FixedCredentialsProvider.create(credentials)); - return BigQueryReadClient.create(clientSettings.build()); - } catch (IOException e) { - throw new UncheckedIOException("Error creating BigQueryStorageClient", e); - } + BigQueryReadClient createBigQueryReadClient() { + try { + BigQueryReadSettings.Builder clientSettings = + BigQueryReadSettings.newBuilder() + .setTransportChannelProvider( + BigQueryReadSettings.defaultGrpcTransportProviderBuilder() + .setHeaderProvider(userAgentHeaderProvider) + .build()) + .setCredentialsProvider(FixedCredentialsProvider.create(credentials)); + return BigQueryReadClient.create(clientSettings.build()); + } catch (IOException e) { + throw new UncheckedIOException("Error creating BigQueryStorageClient", e); } + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryUtil.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryUtil.java index e536e11e90..47a236e9ed 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryUtil.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/BigQueryUtil.java @@ -33,69 +33,67 @@ import static java.lang.String.format; public class BigQueryUtil { - static final ImmutableSet INTERNAL_ERROR_MESSAGES = ImmutableSet.of( - "HTTP/2 error code: INTERNAL_ERROR", - "Connection closed with unknown cause", - "Received unexpected EOS on DATA frame from server"); - private static final String PROJECT_PATTERN = "\\S+"; - private static final String DATASET_PATTERN = "\\w+"; - // Allow all non-whitespace beside ':' and '.'. - // These confuse the qualified table parsing. - private static final String TABLE_PATTERN = "[\\S&&[^.:]]+"; - /** - * Regex for an optionally fully qualified table. - *

- * Must match 'project.dataset.table' OR the legacy 'project:dataset.table' OR 'dataset.table' - * OR 'table'. - */ - private static final Pattern QUALIFIED_TABLE_REGEX = - Pattern.compile(format("^(((%s)[:.])?(%s)\\.)?(%s)$$", PROJECT_PATTERN, DATASET_PATTERN, TABLE_PATTERN)); + static final ImmutableSet INTERNAL_ERROR_MESSAGES = + ImmutableSet.of( + "HTTP/2 error code: INTERNAL_ERROR", + "Connection closed with unknown cause", + "Received unexpected EOS on DATA frame from server"); + private static final String PROJECT_PATTERN = "\\S+"; + private static final String DATASET_PATTERN = "\\w+"; + // Allow all non-whitespace beside ':' and '.'. + // These confuse the qualified table parsing. + private static final String TABLE_PATTERN = "[\\S&&[^.:]]+"; + /** + * Regex for an optionally fully qualified table. + * + *

Must match 'project.dataset.table' OR the legacy 'project:dataset.table' OR 'dataset.table' + * OR 'table'. + */ + private static final Pattern QUALIFIED_TABLE_REGEX = + Pattern.compile( + format("^(((%s)[:.])?(%s)\\.)?(%s)$$", PROJECT_PATTERN, DATASET_PATTERN, TABLE_PATTERN)); - private BigQueryUtil() { - } - - static boolean isRetryable(Throwable cause) { - return getCausalChain(cause).stream().anyMatch(BigQueryUtil::isRetryableInternalError); - } + private BigQueryUtil() {} - static boolean isRetryableInternalError(Throwable t) { - if (t instanceof StatusRuntimeException) { - StatusRuntimeException statusRuntimeException = (StatusRuntimeException) t; - return statusRuntimeException.getStatus().getCode() == Status.Code.INTERNAL && - INTERNAL_ERROR_MESSAGES.stream() - .anyMatch(message -> statusRuntimeException.getMessage().contains(message)); - } - return false; - } + static boolean isRetryable(Throwable cause) { + return getCausalChain(cause).stream().anyMatch(BigQueryUtil::isRetryableInternalError); + } - static BigQueryException convertToBigQueryException(BigQueryError error) { - return new BigQueryException(UNKNOWN_CODE, error.getMessage(), error); + static boolean isRetryableInternalError(Throwable t) { + if (t instanceof StatusRuntimeException) { + StatusRuntimeException statusRuntimeException = (StatusRuntimeException) t; + return statusRuntimeException.getStatus().getCode() == Status.Code.INTERNAL + && INTERNAL_ERROR_MESSAGES.stream() + .anyMatch(message -> statusRuntimeException.getMessage().contains(message)); } + return false; + } - // returns the first present optional, empty if all parameters are empty - public static Optional firstPresent(Optional... optionals) { - return Stream.of(optionals) - .flatMap(Streams::stream) - .findFirst(); - } + static BigQueryException convertToBigQueryException(BigQueryError error) { + return new BigQueryException(UNKNOWN_CODE, error.getMessage(), error); + } - public static TableId parseTableId( - String rawTable, - Optional dataset, - Optional project) { - Matcher matcher = QUALIFIED_TABLE_REGEX.matcher(rawTable); - if (!matcher.matches()) { - throw new IllegalArgumentException( - format("Invalid Table ID '%s'. Must match '%s'", rawTable, QUALIFIED_TABLE_REGEX)); - } - String table = matcher.group(5); - Optional parsedDataset = Optional.ofNullable(matcher.group(4)); - Optional parsedProject = Optional.ofNullable(matcher.group(3)); + // returns the first present optional, empty if all parameters are empty + public static Optional firstPresent(Optional... optionals) { + return Stream.of(optionals).flatMap(Streams::stream).findFirst(); + } - String actualDataset = firstPresent(parsedDataset, dataset).orElseThrow(() -> - new IllegalArgumentException("'dataset' not parsed or provided.")); - return firstPresent(parsedProject, project) - .map(p -> TableId.of(p, actualDataset, table)) - .orElse(TableId.of(actualDataset, table)); + public static TableId parseTableId( + String rawTable, Optional dataset, Optional project) { + Matcher matcher = QUALIFIED_TABLE_REGEX.matcher(rawTable); + if (!matcher.matches()) { + throw new IllegalArgumentException( + format("Invalid Table ID '%s'. Must match '%s'", rawTable, QUALIFIED_TABLE_REGEX)); } + String table = matcher.group(5); + Optional parsedDataset = Optional.ofNullable(matcher.group(4)); + Optional parsedProject = Optional.ofNullable(matcher.group(3)); + + String actualDataset = + firstPresent(parsedDataset, dataset) + .orElseThrow(() -> new IllegalArgumentException("'dataset' not parsed or provided.")); + return firstPresent(parsedProject, project) + .map(p -> TableId.of(p, actualDataset, table)) + .orElse(TableId.of(actualDataset, table)); + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadRowsHelper.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadRowsHelper.java index 6cd0e94c2f..1c99a5c889 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadRowsHelper.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadRowsHelper.java @@ -25,86 +25,83 @@ import static java.util.Objects.requireNonNull; public class ReadRowsHelper { - private BigQueryReadClientFactory bigQueryReadClientFactory; - private ReadRowsRequest.Builder request; - private int maxReadRowsRetries; - private BigQueryReadClient client; + private BigQueryReadClientFactory bigQueryReadClientFactory; + private ReadRowsRequest.Builder request; + private int maxReadRowsRetries; + private BigQueryReadClient client; - public ReadRowsHelper( - BigQueryReadClientFactory bigQueryReadClientFactory, - ReadRowsRequest.Builder request, - int maxReadRowsRetries) { - this.bigQueryReadClientFactory = requireNonNull(bigQueryReadClientFactory, "bigQueryReadClientFactory cannot be null"); - this.request = requireNonNull(request, "request cannot be null"); - this.maxReadRowsRetries = maxReadRowsRetries; - } + public ReadRowsHelper( + BigQueryReadClientFactory bigQueryReadClientFactory, + ReadRowsRequest.Builder request, + int maxReadRowsRetries) { + this.bigQueryReadClientFactory = + requireNonNull(bigQueryReadClientFactory, "bigQueryReadClientFactory cannot be null"); + this.request = requireNonNull(request, "request cannot be null"); + this.maxReadRowsRetries = maxReadRowsRetries; + } - public Iterator readRows() { - if (client != null) { - client.close(); - } - client = bigQueryReadClientFactory.createBigQueryReadClient(); - Iterator serverResponses = fetchResponses(request); - return new ReadRowsIterator(this, serverResponses); + public Iterator readRows() { + if (client != null) { + client.close(); } + client = bigQueryReadClientFactory.createBigQueryReadClient(); + Iterator serverResponses = fetchResponses(request); + return new ReadRowsIterator(this, serverResponses); + } - // In order to enable testing - protected Iterator fetchResponses(ReadRowsRequest.Builder readRowsRequest) { - return client.readRowsCallable() - .call(readRowsRequest.build()) - .iterator(); - } + // In order to enable testing + protected Iterator fetchResponses(ReadRowsRequest.Builder readRowsRequest) { + return client.readRowsCallable().call(readRowsRequest.build()).iterator(); + } - // Ported from https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/150 - static class ReadRowsIterator implements Iterator { - ReadRowsHelper helper; - Iterator serverResponses; - long readRowsCount; - int retries; - - public ReadRowsIterator( - ReadRowsHelper helper, - Iterator serverResponses) { - this.helper = helper; - this.serverResponses = serverResponses; - } + // Ported from https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/150 + static class ReadRowsIterator implements Iterator { + ReadRowsHelper helper; + Iterator serverResponses; + long readRowsCount; + int retries; - @Override - public boolean hasNext() { - boolean hasNext = serverResponses.hasNext(); - if (!hasNext && !helper.client.isShutdown()) { - helper.client.close(); - } - return hasNext; - } + public ReadRowsIterator(ReadRowsHelper helper, Iterator serverResponses) { + this.helper = helper; + this.serverResponses = serverResponses; + } - @Override - public ReadRowsResponse next() { - do { - try { - ReadRowsResponse response = serverResponses.next(); - readRowsCount += response.getRowCount(); - //logDebug(s"read ${response.getSerializedSize} bytes"); - return response; - } catch (Exception e) { - // if relevant, retry the read, from the last read position - if (BigQueryUtil.isRetryable(e) && retries < helper.maxReadRowsRetries) { - serverResponses = helper.fetchResponses(helper.request.setOffset(readRowsCount)); - retries++; - } else { - helper.client.close(); - throw e; - } - } - } while (serverResponses.hasNext()); + @Override + public boolean hasNext() { + boolean hasNext = serverResponses.hasNext(); + if (!hasNext && !helper.client.isShutdown()) { + helper.client.close(); + } + return hasNext; + } - throw new NoSuchElementException("No more server responses"); + @Override + public ReadRowsResponse next() { + do { + try { + ReadRowsResponse response = serverResponses.next(); + readRowsCount += response.getRowCount(); + // logDebug(s"read ${response.getSerializedSize} bytes"); + return response; + } catch (Exception e) { + // if relevant, retry the read, from the last read position + if (BigQueryUtil.isRetryable(e) && retries < helper.maxReadRowsRetries) { + serverResponses = helper.fetchResponses(helper.request.setOffset(readRowsCount)); + retries++; + } else { + helper.client.close(); + throw e; + } } + } while (serverResponses.hasNext()); + + throw new NoSuchElementException("No more server responses"); } + } - public void close() { - if (!client.isShutdown()) { - client.close(); - } + public void close() { + if (!client.isShutdown()) { + client.close(); } + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreator.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreator.java index 9b53516b5f..d0beb4d150 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreator.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreator.java @@ -40,163 +40,176 @@ // A helper class, also handles view materialization public class ReadSessionCreator { - /** - * Default parallelism to 1 reader per 400MB, which should be about the maximum allowed by the - * BigQuery Storage API. The number of partitions returned may be significantly less depending - * on a number of factors. - */ - private static final int DEFAULT_BYTES_PER_PARTITION = 400 * 1000 * 1000; - - private static final Logger log = LoggerFactory.getLogger(ReadSessionCreator.class); - - private static Cache destinationTableCache = - CacheBuilder.newBuilder() - .expireAfterWrite(15, TimeUnit.MINUTES) - .maximumSize(1000) - .build(); - - private final ReadSessionCreatorConfig config; - private final BigQueryClient bigQueryClient; - private final BigQueryReadClientFactory bigQueryReadClientFactory; - - public ReadSessionCreator( - ReadSessionCreatorConfig config, - BigQueryClient bigQueryClient, - BigQueryReadClientFactory bigQueryReadClientFactory) { - this.config = config; - this.bigQueryClient = bigQueryClient; - this.bigQueryReadClientFactory = bigQueryReadClientFactory; + /** + * Default parallelism to 1 reader per 400MB, which should be about the maximum allowed by the + * BigQuery Storage API. The number of partitions returned may be significantly less depending on + * a number of factors. + */ + private static final int DEFAULT_BYTES_PER_PARTITION = 400 * 1000 * 1000; + + private static final Logger log = LoggerFactory.getLogger(ReadSessionCreator.class); + + private static Cache destinationTableCache = + CacheBuilder.newBuilder().expireAfterWrite(15, TimeUnit.MINUTES).maximumSize(1000).build(); + + private final ReadSessionCreatorConfig config; + private final BigQueryClient bigQueryClient; + private final BigQueryReadClientFactory bigQueryReadClientFactory; + + public ReadSessionCreator( + ReadSessionCreatorConfig config, + BigQueryClient bigQueryClient, + BigQueryReadClientFactory bigQueryReadClientFactory) { + this.config = config; + this.bigQueryClient = bigQueryClient; + this.bigQueryReadClientFactory = bigQueryReadClientFactory; + } + + static int getMaxNumPartitionsRequested( + OptionalInt maxParallelism, StandardTableDefinition tableDefinition) { + return maxParallelism.orElse( + Math.max((int) (tableDefinition.getNumBytes() / DEFAULT_BYTES_PER_PARTITION), 1)); + } + + public ReadSessionResponse create( + TableId table, + ImmutableList selectedFields, + Optional filter, + OptionalInt maxParallelism) { + TableInfo tableDetails = bigQueryClient.getTable(table); + + TableInfo actualTable = getActualTable(tableDetails, selectedFields, filter); + StandardTableDefinition tableDefinition = actualTable.getDefinition(); + + try (BigQueryReadClient bigQueryReadClient = + bigQueryReadClientFactory.createBigQueryReadClient()) { + ReadSession.TableReadOptions.Builder readOptions = + ReadSession.TableReadOptions.newBuilder().addAllSelectedFields(selectedFields); + filter.ifPresent(readOptions::setRowRestriction); + + String tablePath = toTablePath(actualTable.getTableId()); + + ReadSession readSession = + bigQueryReadClient.createReadSession( + CreateReadSessionRequest.newBuilder() + .setParent("projects/" + bigQueryClient.getProjectId()) + .setReadSession( + ReadSession.newBuilder() + .setDataFormat(config.readDataFormat) + .setReadOptions(readOptions) + .setTable(tablePath)) + .setMaxStreamCount(getMaxNumPartitionsRequested(maxParallelism, tableDefinition)) + .build()); + + return new ReadSessionResponse(readSession, actualTable); } - - static int getMaxNumPartitionsRequested(OptionalInt maxParallelism, StandardTableDefinition tableDefinition) { - return maxParallelism.orElse(Math.max( - (int) (tableDefinition.getNumBytes() / DEFAULT_BYTES_PER_PARTITION), 1)); + } + + String toTablePath(TableId tableId) { + return format( + "projects/%s/datasets/%s/tables/%s", + tableId.getProject(), tableId.getDataset(), tableId.getTable()); + } + + TableInfo getActualTable( + TableInfo table, ImmutableList requiredColumns, Optional filter) { + String[] filters = filter.map(Stream::of).orElseGet(Stream::empty).toArray(String[]::new); + return getActualTable(table, requiredColumns, filters); + } + + TableInfo getActualTable( + TableInfo table, ImmutableList requiredColumns, String[] filters) { + TableDefinition tableDefinition = table.getDefinition(); + TableDefinition.Type tableType = tableDefinition.getType(); + if (TableDefinition.Type.TABLE == tableType) { + return table; } - - public ReadSessionResponse create( - TableId table, - ImmutableList selectedFields, - Optional filter, - OptionalInt maxParallelism) { - TableInfo tableDetails = bigQueryClient.getTable(table); - - TableInfo actualTable = getActualTable(tableDetails, selectedFields, filter); - StandardTableDefinition tableDefinition = actualTable.getDefinition(); - - try (BigQueryReadClient bigQueryReadClient = bigQueryReadClientFactory.createBigQueryReadClient()) { - ReadSession.TableReadOptions.Builder readOptions = ReadSession.TableReadOptions.newBuilder() - .addAllSelectedFields(selectedFields); - filter.ifPresent(readOptions::setRowRestriction); - - String tablePath = toTablePath(actualTable.getTableId()); - - ReadSession readSession = bigQueryReadClient.createReadSession( - CreateReadSessionRequest.newBuilder() - .setParent("projects/" + bigQueryClient.getProjectId()) - .setReadSession(ReadSession.newBuilder() - .setDataFormat(config.readDataFormat) - .setReadOptions(readOptions) - .setTable(tablePath)) - .setMaxStreamCount(getMaxNumPartitionsRequested(maxParallelism, tableDefinition)) - .build()); - - return new ReadSessionResponse(readSession, actualTable); - } + if (TableDefinition.Type.VIEW == tableType + || TableDefinition.Type.MATERIALIZED_VIEW == tableType) { + if (!config.viewsEnabled) { + throw new BigQueryConnectorException( + UNSUPPORTED, + format( + "Views are not enabled. You can enable views by setting '%s' to true. Notice additional cost may occur.", + config.viewEnabledParamName)); + } + // get it from the view + String querySql = bigQueryClient.createSql(table.getTableId(), requiredColumns, filters); + log.debug("querySql is %s", querySql); + try { + return destinationTableCache.get( + querySql, + new DestinationTableBuilder(bigQueryClient, config, querySql, table.getTableId())); + } catch (ExecutionException e) { + throw new BigQueryConnectorException( + BIGQUERY_VIEW_DESTINATION_TABLE_CREATION_FAILED, "Error creating destination table", e); + } + } else { + // not regular table or a view + throw new BigQueryConnectorException( + UNSUPPORTED, + format( + "Table type '%s' of table '%s.%s' is not supported", + tableType, table.getTableId().getDataset(), table.getTableId().getTable())); } - - String toTablePath(TableId tableId) { - return format("projects/%s/datasets/%s/tables/%s", - tableId.getProject(), tableId.getDataset(), tableId.getTable()); + } + + static class DestinationTableBuilder implements Callable { + final BigQueryClient bigQueryClient; + final ReadSessionCreatorConfig config; + final String querySql; + final TableId table; + + DestinationTableBuilder( + BigQueryClient bigQueryClient, + ReadSessionCreatorConfig config, + String querySql, + TableId table) { + this.bigQueryClient = bigQueryClient; + this.config = config; + this.querySql = querySql; + this.table = table; } - TableInfo getActualTable( - TableInfo table, - ImmutableList requiredColumns, - Optional filter) { - String[] filters = filter.map(Stream::of).orElseGet(Stream::empty).toArray(String[]::new); - return getActualTable(table, requiredColumns, filters); + @Override + public TableInfo call() { + return createTableFromQuery(); } - TableInfo getActualTable( - TableInfo table, - ImmutableList requiredColumns, - String[] filters) { - TableDefinition tableDefinition = table.getDefinition(); - TableDefinition.Type tableType = tableDefinition.getType(); - if (TableDefinition.Type.TABLE == tableType) { - return table; - } - if (TableDefinition.Type.VIEW == tableType || TableDefinition.Type.MATERIALIZED_VIEW == tableType) { - if (!config.viewsEnabled) { - throw new BigQueryConnectorException(UNSUPPORTED, format( - "Views are not enabled. You can enable views by setting '%s' to true. Notice additional cost may occur.", - config.viewEnabledParamName)); - } - // get it from the view - String querySql = bigQueryClient.createSql(table.getTableId(), requiredColumns, filters); - log.debug("querySql is %s", querySql); - try { - return destinationTableCache.get(querySql, new DestinationTableBuilder(bigQueryClient, config, querySql, table.getTableId())); - } catch (ExecutionException e) { - throw new BigQueryConnectorException(BIGQUERY_VIEW_DESTINATION_TABLE_CREATION_FAILED, "Error creating destination table", e); - } - } else { - // not regular table or a view - throw new BigQueryConnectorException(UNSUPPORTED, format("Table type '%s' of table '%s.%s' is not supported", - tableType, table.getTableId().getDataset(), table.getTableId().getTable())); - } + TableInfo createTableFromQuery() { + TableId destinationTable = bigQueryClient.createDestinationTable(table); + log.debug("destinationTable is %s", destinationTable); + JobInfo jobInfo = + JobInfo.of( + QueryJobConfiguration.newBuilder(querySql) + .setDestinationTable(destinationTable) + .build()); + log.debug("running query %s", jobInfo); + Job job = waitForJob(bigQueryClient.create(jobInfo)); + log.debug("job has finished. %s", job); + if (job.getStatus().getError() != null) { + throw convertToBigQueryException(job.getStatus().getError()); + } + // add expiration time to the table + TableInfo createdTable = bigQueryClient.getTable(destinationTable); + long expirationTime = + createdTable.getCreationTime() + + TimeUnit.HOURS.toMillis(config.viewExpirationTimeInHours); + Table updatedTable = + bigQueryClient.update(createdTable.toBuilder().setExpirationTime(expirationTime).build()); + return updatedTable; } - static class DestinationTableBuilder - implements Callable { - final BigQueryClient bigQueryClient; - final ReadSessionCreatorConfig config; - final String querySql; - final TableId table; - - DestinationTableBuilder(BigQueryClient bigQueryClient, ReadSessionCreatorConfig config, String querySql, TableId table) { - this.bigQueryClient = bigQueryClient; - this.config = config; - this.querySql = querySql; - this.table = table; - } - - @Override - public TableInfo call() { - return createTableFromQuery(); - } - - TableInfo createTableFromQuery() { - TableId destinationTable = bigQueryClient.createDestinationTable(table); - log.debug("destinationTable is %s", destinationTable); - JobInfo jobInfo = JobInfo.of( - QueryJobConfiguration - .newBuilder(querySql) - .setDestinationTable(destinationTable) - .build()); - log.debug("running query %s", jobInfo); - Job job = waitForJob(bigQueryClient.create(jobInfo)); - log.debug("job has finished. %s", job); - if (job.getStatus().getError() != null) { - throw convertToBigQueryException(job.getStatus().getError()); - } - // add expiration time to the table - TableInfo createdTable = bigQueryClient.getTable(destinationTable); - long expirationTime = createdTable.getCreationTime() + - TimeUnit.HOURS.toMillis(config.viewExpirationTimeInHours); - Table updatedTable = bigQueryClient.update(createdTable.toBuilder() - .setExpirationTime(expirationTime) - .build()); - return updatedTable; - } - - Job waitForJob(Job job) { - try { - return job.waitFor(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new BigQueryException(BaseServiceException.UNKNOWN_CODE, format("Job %s has been interrupted", job.getJobId()), e); - } - } + Job waitForJob(Job job) { + try { + return job.waitFor(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new BigQueryException( + BaseServiceException.UNKNOWN_CODE, + format("Job %s has been interrupted", job.getJobId()), + e); + } } + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreatorConfig.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreatorConfig.java index 5cdb348884..81440ba378 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreatorConfig.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionCreatorConfig.java @@ -21,70 +21,70 @@ import java.util.OptionalInt; public class ReadSessionCreatorConfig { - final boolean viewsEnabled; - final Optional materializationProject; - final Optional materializationDataset; - final String viewEnabledParamName; - final int viewExpirationTimeInHours; - final DataFormat readDataFormat; - final int maxReadRowsRetries; - final OptionalInt maxParallelism; - final int defaultParallelism; + final boolean viewsEnabled; + final Optional materializationProject; + final Optional materializationDataset; + final String viewEnabledParamName; + final int viewExpirationTimeInHours; + final DataFormat readDataFormat; + final int maxReadRowsRetries; + final OptionalInt maxParallelism; + final int defaultParallelism; - public ReadSessionCreatorConfig( - boolean viewsEnabled, - Optional materializationProject, - Optional materializationDataset, - int viewExpirationTimeInHours, - DataFormat readDataFormat, - int maxReadRowsRetries, - String viewEnabledParamName, - OptionalInt maxParallelism, - int defaultParallelism) { - this.viewsEnabled = viewsEnabled; - this.materializationProject = materializationProject; - this.materializationDataset = materializationDataset; - this.viewEnabledParamName = viewEnabledParamName; - this.viewExpirationTimeInHours = viewExpirationTimeInHours; - this.readDataFormat = readDataFormat; - this.maxReadRowsRetries = maxReadRowsRetries; - this.maxParallelism = maxParallelism; - this.defaultParallelism = defaultParallelism; - } + public ReadSessionCreatorConfig( + boolean viewsEnabled, + Optional materializationProject, + Optional materializationDataset, + int viewExpirationTimeInHours, + DataFormat readDataFormat, + int maxReadRowsRetries, + String viewEnabledParamName, + OptionalInt maxParallelism, + int defaultParallelism) { + this.viewsEnabled = viewsEnabled; + this.materializationProject = materializationProject; + this.materializationDataset = materializationDataset; + this.viewEnabledParamName = viewEnabledParamName; + this.viewExpirationTimeInHours = viewExpirationTimeInHours; + this.readDataFormat = readDataFormat; + this.maxReadRowsRetries = maxReadRowsRetries; + this.maxParallelism = maxParallelism; + this.defaultParallelism = defaultParallelism; + } - public boolean isViewsEnabled() { - return viewsEnabled; - } + public boolean isViewsEnabled() { + return viewsEnabled; + } - public Optional getMaterializationProject() { - return materializationProject; - } + public Optional getMaterializationProject() { + return materializationProject; + } - public Optional getMaterializationDataset() { - return materializationDataset; - } + public Optional getMaterializationDataset() { + return materializationDataset; + } - public String getViewEnabledParamName() { - return viewEnabledParamName; - } + public String getViewEnabledParamName() { + return viewEnabledParamName; + } - public int getViewExpirationTimeInHours() { - return viewExpirationTimeInHours; - } + public int getViewExpirationTimeInHours() { + return viewExpirationTimeInHours; + } - public DataFormat getReadDataFormat() { - return readDataFormat; - } + public DataFormat getReadDataFormat() { + return readDataFormat; + } - public int getMaxReadRowsRetries() { - return maxReadRowsRetries; - } + public int getMaxReadRowsRetries() { + return maxReadRowsRetries; + } - public OptionalInt getMaxParallelism() { - return maxParallelism; - } + public OptionalInt getMaxParallelism() { + return maxParallelism; + } - public int getDefaultParallelism() { - return defaultParallelism; - } + public int getDefaultParallelism() { + return defaultParallelism; + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionResponse.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionResponse.java index ff1befe126..0bf956a372 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionResponse.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadSessionResponse.java @@ -20,19 +20,19 @@ public class ReadSessionResponse { - private final ReadSession readSession; - private final TableInfo readTableInfo; + private final ReadSession readSession; + private final TableInfo readTableInfo; - public ReadSessionResponse(ReadSession readSession, TableInfo readTableInfo) { - this.readSession = readSession; - this.readTableInfo = readTableInfo; - } + public ReadSessionResponse(ReadSession readSession, TableInfo readTableInfo) { + this.readSession = readSession; + this.readTableInfo = readTableInfo; + } - public ReadSession getReadSession() { - return readSession; - } + public ReadSession getReadSession() { + return readSession; + } - public TableInfo getReadTableInfo() { - return readTableInfo; - } + public TableInfo getReadTableInfo() { + return readTableInfo; + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/UserAgentHeaderProvider.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/UserAgentHeaderProvider.java index e0d27484fc..9cbe208958 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/UserAgentHeaderProvider.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/UserAgentHeaderProvider.java @@ -23,14 +23,14 @@ public class UserAgentHeaderProvider implements HeaderProvider, Serializable { - private final String userAgent; + private final String userAgent; - public UserAgentHeaderProvider(String userAgent) { - this.userAgent = userAgent; - } + public UserAgentHeaderProvider(String userAgent) { + this.userAgent = userAgent; + } - @Override - public Map getHeaders() { - return ImmutableMap.of("user-agent", userAgent); - } + @Override + public Map getHeaders() { + return ImmutableMap.of("user-agent", userAgent); + } } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/UserAgentProvider.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/UserAgentProvider.java index 85aab8d455..83bac2d65a 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/UserAgentProvider.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/UserAgentProvider.java @@ -18,6 +18,5 @@ @FunctionalInterface public interface UserAgentProvider { - String getUserAgent(); - + String getUserAgent(); } diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/VersionProvider.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/VersionProvider.java index 0c5cbaf26f..a07fec010c 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/VersionProvider.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/VersionProvider.java @@ -18,6 +18,5 @@ @FunctionalInterface public interface VersionProvider { - String getVersion(); - + String getVersion(); } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ArrowBinaryIterator.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ArrowBinaryIterator.java index 0d5f4f17a7..e118fc8035 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/ArrowBinaryIterator.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ArrowBinaryIterator.java @@ -38,99 +38,104 @@ public class ArrowBinaryIterator implements Iterator { - private static long maxAllocation = Long.MAX_VALUE; - ArrowReaderIterator arrowReaderIterator; - Iterator currentIterator; - List columnsInOrder; - - public ArrowBinaryIterator(List columnsInOrder, ByteString schema, ByteString rowsInBytes) { - BufferAllocator allocator = (new RootAllocator(maxAllocation)).newChildAllocator("ArrowBinaryIterator", - 0, maxAllocation); - - SequenceInputStream bytesWithSchemaStream = new SequenceInputStream( - new ByteArrayInputStream(schema.toByteArray()), - new ByteArrayInputStream(rowsInBytes.toByteArray())); - - ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bytesWithSchemaStream, allocator); - arrowReaderIterator = new ArrowReaderIterator(arrowStreamReader); - currentIterator = ImmutableList.of().iterator(); - this.columnsInOrder = columnsInOrder; + private static long maxAllocation = Long.MAX_VALUE; + ArrowReaderIterator arrowReaderIterator; + Iterator currentIterator; + List columnsInOrder; + + public ArrowBinaryIterator( + List columnsInOrder, ByteString schema, ByteString rowsInBytes) { + BufferAllocator allocator = + (new RootAllocator(maxAllocation)) + .newChildAllocator("ArrowBinaryIterator", 0, maxAllocation); + + SequenceInputStream bytesWithSchemaStream = + new SequenceInputStream( + new ByteArrayInputStream(schema.toByteArray()), + new ByteArrayInputStream(rowsInBytes.toByteArray())); + + ArrowStreamReader arrowStreamReader = new ArrowStreamReader(bytesWithSchemaStream, allocator); + arrowReaderIterator = new ArrowReaderIterator(arrowStreamReader); + currentIterator = ImmutableList.of().iterator(); + this.columnsInOrder = columnsInOrder; + } + + @Override + public boolean hasNext() { + while (!currentIterator.hasNext()) { + if (!arrowReaderIterator.hasNext()) { + return false; + } + currentIterator = toArrowRows(arrowReaderIterator.next(), columnsInOrder); } - @Override - public boolean hasNext() { - while (!currentIterator.hasNext()) { - if (!arrowReaderIterator.hasNext()) { - return false; - } - currentIterator = toArrowRows(arrowReaderIterator.next(), columnsInOrder); - } - - return currentIterator.hasNext(); - } - - @Override - public InternalRow next() { - return currentIterator.next(); - } - - private Iterator toArrowRows(VectorSchemaRoot root, List namesInOrder) { - ColumnVector[] columns = namesInOrder.stream() - .map(name -> root.getVector(name)) - .map(vector -> new ArrowSchemaConverter(vector)) - .collect(Collectors.toList()).toArray(new ColumnVector[0]); - - ColumnarBatch batch = new ColumnarBatch(columns); - batch.setNumRows(root.getRowCount()); - return batch.rowIterator(); - } + return currentIterator.hasNext(); + } + + @Override + public InternalRow next() { + return currentIterator.next(); + } + + private Iterator toArrowRows(VectorSchemaRoot root, List namesInOrder) { + ColumnVector[] columns = + namesInOrder.stream() + .map(name -> root.getVector(name)) + .map(vector -> new ArrowSchemaConverter(vector)) + .collect(Collectors.toList()) + .toArray(new ColumnVector[0]); + + ColumnarBatch batch = new ColumnarBatch(columns); + batch.setNumRows(root.getRowCount()); + return batch.rowIterator(); + } } class ArrowReaderIterator implements Iterator { - private static final Logger log = LoggerFactory.getLogger(AvroBinaryIterator.class); - boolean closed = false; - VectorSchemaRoot current = null; - ArrowReader reader; + private static final Logger log = LoggerFactory.getLogger(AvroBinaryIterator.class); + boolean closed = false; + VectorSchemaRoot current = null; + ArrowReader reader; - public ArrowReaderIterator(ArrowReader reader) { - this.reader = reader; - } + public ArrowReaderIterator(ArrowReader reader) { + this.reader = reader; + } - @Override - public boolean hasNext() { - if (current != null) { - return true; - } - - if (closed) { - return false; - } - - try { - boolean res = reader.loadNextBatch(); - if (res) { - current = reader.getVectorSchemaRoot(); - } else { - ensureClosed(); - } - return res; - } catch (IOException e) { - throw new UncheckedIOException("Failed to load the next arrow batch", e); - } + @Override + public boolean hasNext() { + if (current != null) { + return true; } - @Override - public VectorSchemaRoot next() { - VectorSchemaRoot res = current; - current = null; - return res; + if (closed) { + return false; } - private void ensureClosed() throws IOException { - if (!closed) { - reader.close(); - closed = true; - } + try { + boolean res = reader.loadNextBatch(); + if (res) { + current = reader.getVectorSchemaRoot(); + } else { + ensureClosed(); + } + return res; + } catch (IOException e) { + throw new UncheckedIOException("Failed to load the next arrow batch", e); + } + } + + @Override + public VectorSchemaRoot next() { + VectorSchemaRoot res = current; + current = null; + return res; + } + + private void ensureClosed() throws IOException { + if (!closed) { + reader.close(); + closed = true; } + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/AvroBinaryIterator.java b/connector/src/main/java/com/google/cloud/spark/bigquery/AvroBinaryIterator.java index e434014f22..5ae447941f 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/AvroBinaryIterator.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/AvroBinaryIterator.java @@ -32,46 +32,47 @@ public class AvroBinaryIterator implements Iterator { - private static final Logger log = LoggerFactory.getLogger(AvroBinaryIterator.class); - GenericDatumReader reader; - List columnsInOrder; - BinaryDecoder in; - Schema bqSchema; + private static final Logger log = LoggerFactory.getLogger(AvroBinaryIterator.class); + GenericDatumReader reader; + List columnsInOrder; + BinaryDecoder in; + Schema bqSchema; - /** - * An iterator for scanning over rows serialized in Avro format - * - * @param bqSchema Schema of underlying BigQuery source - * @param columnsInOrder Sequence of columns in the schema - * @param schema Schema in avro format - * @param rowsInBytes Rows serialized in binary format for Avro - */ - public AvroBinaryIterator(Schema bqSchema, - List columnsInOrder, - org.apache.avro.Schema schema, - ByteString rowsInBytes) { - reader = new GenericDatumReader(schema); - this.bqSchema = bqSchema; - this.columnsInOrder = columnsInOrder; - in = new DecoderFactory().binaryDecoder(rowsInBytes.toByteArray(), null); - } + /** + * An iterator for scanning over rows serialized in Avro format + * + * @param bqSchema Schema of underlying BigQuery source + * @param columnsInOrder Sequence of columns in the schema + * @param schema Schema in avro format + * @param rowsInBytes Rows serialized in binary format for Avro + */ + public AvroBinaryIterator( + Schema bqSchema, + List columnsInOrder, + org.apache.avro.Schema schema, + ByteString rowsInBytes) { + reader = new GenericDatumReader(schema); + this.bqSchema = bqSchema; + this.columnsInOrder = columnsInOrder; + in = new DecoderFactory().binaryDecoder(rowsInBytes.toByteArray(), null); + } - @Override - public boolean hasNext() { - try { - return !in.isEnd(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + @Override + public boolean hasNext() { + try { + return !in.isEnd(); + } catch (IOException e) { + throw new UncheckedIOException(e); } + } - @Override - public InternalRow next() { - try { - return SchemaConverters.convertToInternalRow(bqSchema, - columnsInOrder, (GenericRecord) reader.read(null, in)); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + @Override + public InternalRow next() { + try { + return SchemaConverters.convertToInternalRow( + bqSchema, columnsInOrder, (GenericRecord) reader.read(null, in)); + } catch (IOException e) { + throw new UncheckedIOException(e); } + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ReadRowsResponseToInternalRowIteratorConverter.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ReadRowsResponseToInternalRowIteratorConverter.java index d319c7830d..20e62106eb 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/ReadRowsResponseToInternalRowIteratorConverter.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ReadRowsResponseToInternalRowIteratorConverter.java @@ -26,59 +26,56 @@ public interface ReadRowsResponseToInternalRowIteratorConverter { - static ReadRowsResponseToInternalRowIteratorConverter avro( - final com.google.cloud.bigquery.Schema bqSchema, - final List columnsInOrder, - final String rawAvroSchema) { - return new Avro(bqSchema, columnsInOrder, rawAvroSchema); - } + static ReadRowsResponseToInternalRowIteratorConverter avro( + final com.google.cloud.bigquery.Schema bqSchema, + final List columnsInOrder, + final String rawAvroSchema) { + return new Avro(bqSchema, columnsInOrder, rawAvroSchema); + } - static ReadRowsResponseToInternalRowIteratorConverter arrow( - final List columnsInOrder, - final ByteString arrowSchema) { - return new Arrow(columnsInOrder, arrowSchema); - } + static ReadRowsResponseToInternalRowIteratorConverter arrow( + final List columnsInOrder, final ByteString arrowSchema) { + return new Arrow(columnsInOrder, arrowSchema); + } - Iterator convert(ReadRowsResponse response); + Iterator convert(ReadRowsResponse response); - class Avro implements ReadRowsResponseToInternalRowIteratorConverter, Serializable { + class Avro implements ReadRowsResponseToInternalRowIteratorConverter, Serializable { - private final com.google.cloud.bigquery.Schema bqSchema; - private final List columnsInOrder; - private final String rawAvroSchema; + private final com.google.cloud.bigquery.Schema bqSchema; + private final List columnsInOrder; + private final String rawAvroSchema; - public Avro(Schema bqSchema, List columnsInOrder, String rawAvroSchema) { - this.bqSchema = bqSchema; - this.columnsInOrder = columnsInOrder; - this.rawAvroSchema = rawAvroSchema; - } + public Avro(Schema bqSchema, List columnsInOrder, String rawAvroSchema) { + this.bqSchema = bqSchema; + this.columnsInOrder = columnsInOrder; + this.rawAvroSchema = rawAvroSchema; + } - @Override - public Iterator convert(ReadRowsResponse response) { - return new AvroBinaryIterator( - bqSchema, - columnsInOrder, - new org.apache.avro.Schema.Parser().parse(rawAvroSchema), - response.getAvroRows().getSerializedBinaryRows()); - } + @Override + public Iterator convert(ReadRowsResponse response) { + return new AvroBinaryIterator( + bqSchema, + columnsInOrder, + new org.apache.avro.Schema.Parser().parse(rawAvroSchema), + response.getAvroRows().getSerializedBinaryRows()); } + } - class Arrow implements ReadRowsResponseToInternalRowIteratorConverter, Serializable { + class Arrow implements ReadRowsResponseToInternalRowIteratorConverter, Serializable { - private final List columnsInOrder; - private final ByteString arrowSchema; + private final List columnsInOrder; + private final ByteString arrowSchema; - public Arrow(List columnsInOrder, ByteString arrowSchema) { - this.columnsInOrder = columnsInOrder; - this.arrowSchema = arrowSchema; - } + public Arrow(List columnsInOrder, ByteString arrowSchema) { + this.columnsInOrder = columnsInOrder; + this.arrowSchema = arrowSchema; + } - @Override - public Iterator convert(ReadRowsResponse response) { - return new ArrowBinaryIterator( - columnsInOrder, - arrowSchema, - response.getArrowRecordBatch().getSerializedRecordBatch()); - } + @Override + public Iterator convert(ReadRowsResponse response) { + return new ArrowBinaryIterator( + columnsInOrder, arrowSchema, response.getArrowRecordBatch().getSerializedRecordBatch()); } + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java index 19f9e0648b..4aa646aa5d 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java @@ -36,170 +36,172 @@ import java.util.stream.Collectors; public class SchemaConverters { - // Numeric is a fixed precision Decimal Type with 38 digits of precision and 9 digits of scale. - // See https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#numeric-type - private final static int BQ_NUMERIC_PRECISION = 38; - private final static int BQ_NUMERIC_SCALE = 9; - private final static DecimalType NUMERIC_SPARK_TYPE = DataTypes.createDecimalType( - BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); - - /** - * Convert a BigQuery schema to a Spark schema - */ - public static StructType toSpark(Schema schema) { - List fieldList = schema.getFields().stream() - .map(SchemaConverters::convert).collect(Collectors.toList()); - StructType structType = new StructType(fieldList.toArray(new StructField[0])); - - return structType; + // Numeric is a fixed precision Decimal Type with 38 digits of precision and 9 digits of scale. + // See https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#numeric-type + private static final int BQ_NUMERIC_PRECISION = 38; + private static final int BQ_NUMERIC_SCALE = 9; + private static final DecimalType NUMERIC_SPARK_TYPE = + DataTypes.createDecimalType(BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); + + /** Convert a BigQuery schema to a Spark schema */ + public static StructType toSpark(Schema schema) { + List fieldList = + schema.getFields().stream().map(SchemaConverters::convert).collect(Collectors.toList()); + StructType structType = new StructType(fieldList.toArray(new StructField[0])); + + return structType; + } + + public static InternalRow convertToInternalRow( + Schema schema, List namesInOrder, GenericRecord record) { + return convertAll(schema.getFields(), record, namesInOrder); + } + + static Object convert(Field field, Object value) { + if (value == null) { + return null; } - public static InternalRow convertToInternalRow(Schema schema, List namesInOrder, GenericRecord record) { - return convertAll(schema.getFields(), record, namesInOrder); + if (field.getMode() == Field.Mode.REPEATED) { + // rather than recurring down we strip off the repeated mode + // Due to serialization issues, reconstruct the type using reflection: + // See: https://github.com/googleapis/google-cloud-java/issues/3942 + LegacySQLTypeName fType = LegacySQLTypeName.valueOfStrict(field.getType().name()); + Field nestedField = + Field.newBuilder(field.getName(), fType, field.getSubFields()) + // As long as this is not repeated it works, but technically arrays cannot contain + // nulls, so select required instead of nullable. + .setMode(Field.Mode.REQUIRED) + .build(); + + List valueList = (List) value; + + return new GenericArrayData( + valueList.stream().map(v -> convert(nestedField, v)).collect(Collectors.toList())); } - static Object convert(Field field, Object value) { - if (value == null) { - return null; - } - - if (field.getMode() == Field.Mode.REPEATED) { - // rather than recurring down we strip off the repeated mode - // Due to serialization issues, reconstruct the type using reflection: - // See: https://github.com/googleapis/google-cloud-java/issues/3942 - LegacySQLTypeName fType = LegacySQLTypeName.valueOfStrict(field.getType().name()); - Field nestedField = Field.newBuilder(field.getName(), fType, field.getSubFields()) - // As long as this is not repeated it works, but technically arrays cannot contain - // nulls, so select required instead of nullable. - .setMode(Field.Mode.REQUIRED) - .build(); - - List valueList = (List) value; - - return new GenericArrayData(valueList.stream().map(v -> convert(nestedField, v)).collect(Collectors.toList())); - } - - if (LegacySQLTypeName.INTEGER.equals(field.getType()) || - LegacySQLTypeName.FLOAT.equals(field.getType()) || - LegacySQLTypeName.BOOLEAN.equals(field.getType()) || - LegacySQLTypeName.DATE.equals(field.getType()) || - LegacySQLTypeName.TIME.equals(field.getType()) || - LegacySQLTypeName.TIMESTAMP.equals(field.getType())) { - return value; - } - - if (LegacySQLTypeName.STRING.equals(field.getType()) || - LegacySQLTypeName.DATETIME.equals(field.getType()) || - LegacySQLTypeName.GEOGRAPHY.equals(field.getType())) { - return UTF8String.fromBytes(((Utf8) value).getBytes()); - } - - if (LegacySQLTypeName.BYTES.equals(field.getType())) { - return getBytes((ByteBuffer) value); - } - - if (LegacySQLTypeName.NUMERIC.equals(field.getType())) { - byte[] bytes = getBytes((ByteBuffer) value); - BigDecimal b = new BigDecimal(new BigInteger(bytes), BQ_NUMERIC_SCALE); - Decimal d = Decimal.apply(b, BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); - - return d; - } - - if (LegacySQLTypeName.RECORD.equals(field.getType())) { - return convertAll(field.getSubFields(), - (GenericRecord) value, - field.getSubFields().stream().map(f -> f.getName()).collect(Collectors.toList())); - } - - throw new IllegalStateException("Unexpected type: " + field.getType()); + if (LegacySQLTypeName.INTEGER.equals(field.getType()) + || LegacySQLTypeName.FLOAT.equals(field.getType()) + || LegacySQLTypeName.BOOLEAN.equals(field.getType()) + || LegacySQLTypeName.DATE.equals(field.getType()) + || LegacySQLTypeName.TIME.equals(field.getType()) + || LegacySQLTypeName.TIMESTAMP.equals(field.getType())) { + return value; } - private static byte[] getBytes(ByteBuffer buf) { - byte[] bytes = new byte[buf.remaining()]; - buf.get(bytes); + if (LegacySQLTypeName.STRING.equals(field.getType()) + || LegacySQLTypeName.DATETIME.equals(field.getType()) + || LegacySQLTypeName.GEOGRAPHY.equals(field.getType())) { + return UTF8String.fromBytes(((Utf8) value).getBytes()); + } + + if (LegacySQLTypeName.BYTES.equals(field.getType())) { + return getBytes((ByteBuffer) value); + } + + if (LegacySQLTypeName.NUMERIC.equals(field.getType())) { + byte[] bytes = getBytes((ByteBuffer) value); + BigDecimal b = new BigDecimal(new BigInteger(bytes), BQ_NUMERIC_SCALE); + Decimal d = Decimal.apply(b, BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); + + return d; + } - return bytes; + if (LegacySQLTypeName.RECORD.equals(field.getType())) { + return convertAll( + field.getSubFields(), + (GenericRecord) value, + field.getSubFields().stream().map(f -> f.getName()).collect(Collectors.toList())); } - // Schema is not recursive so add helper for sequence of fields - static GenericInternalRow convertAll(FieldList fieldList, - GenericRecord record, - List namesInOrder) { + throw new IllegalStateException("Unexpected type: " + field.getType()); + } + + private static byte[] getBytes(ByteBuffer buf) { + byte[] bytes = new byte[buf.remaining()]; + buf.get(bytes); + + return bytes; + } - Map fieldMap = new HashMap<>(); + // Schema is not recursive so add helper for sequence of fields + static GenericInternalRow convertAll( + FieldList fieldList, GenericRecord record, List namesInOrder) { - fieldList.stream().forEach(field -> - fieldMap.put(field.getName(), convert(field, record.get(field.getName())))); + Map fieldMap = new HashMap<>(); - Object[] values = new Object[namesInOrder.size()]; - for (int i = 0; i < namesInOrder.size(); i++) { - values[i] = fieldMap.get(namesInOrder.get(i)); - } + fieldList.stream() + .forEach( + field -> fieldMap.put(field.getName(), convert(field, record.get(field.getName())))); + + Object[] values = new Object[namesInOrder.size()]; + for (int i = 0; i < namesInOrder.size(); i++) { + values[i] = fieldMap.get(namesInOrder.get(i)); + } - return new GenericInternalRow(values); + return new GenericInternalRow(values); + } + + /** + * Create a function that converts an Avro row with the given BigQuery schema to a Spark SQL row + * + *

The conversion is based on the BigQuery schema, not Avro Schema, because the Avro schema is + * very painful to use. + * + *

Not guaranteed to be stable across all versions of Spark. + */ + private static StructField convert(Field field) { + DataType dataType = getDataType(field); + boolean nullable = true; + + if (field.getMode() == Field.Mode.REQUIRED) { + nullable = false; + } else if (field.getMode() == Field.Mode.REPEATED) { + dataType = new ArrayType(dataType, true); } - /** - * Create a function that converts an Avro row with the given BigQuery schema to a Spark SQL row - *

- * The conversion is based on the BigQuery schema, not Avro Schema, because the Avro schema is - * very painful to use. - *

- * Not guaranteed to be stable across all versions of Spark. - */ - - private static StructField convert(Field field) { - DataType dataType = getDataType(field); - boolean nullable = true; - - if (field.getMode() == Field.Mode.REQUIRED) { - nullable = false; - } else if (field.getMode() == Field.Mode.REPEATED) { - dataType = new ArrayType(dataType, true); - } - - MetadataBuilder metadata = new MetadataBuilder(); - if (field.getDescription() != null) { - metadata.putString("description", field.getDescription()); - } - - return new StructField(field.getName(), dataType, nullable, metadata.build()); + MetadataBuilder metadata = new MetadataBuilder(); + if (field.getDescription() != null) { + metadata.putString("description", field.getDescription()); } - private static DataType getDataType(Field field) { - - if (LegacySQLTypeName.INTEGER.equals(field.getType())) { - return DataTypes.LongType; - } else if (LegacySQLTypeName.FLOAT.equals(field.getType())) { - return DataTypes.DoubleType; - } else if (LegacySQLTypeName.NUMERIC.equals(field.getType())) { - return NUMERIC_SPARK_TYPE; - } else if (LegacySQLTypeName.STRING.equals(field.getType())) { - return DataTypes.StringType; - } else if (LegacySQLTypeName.BOOLEAN.equals(field.getType())) { - return DataTypes.BooleanType; - } else if (LegacySQLTypeName.BYTES.equals(field.getType())) { - return DataTypes.BinaryType; - } else if (LegacySQLTypeName.DATE.equals(field.getType())) { - return DataTypes.DateType; - } else if (LegacySQLTypeName.TIMESTAMP.equals(field.getType())) { - return DataTypes.TimestampType; - } else if (LegacySQLTypeName.TIME.equals(field.getType())) { - return DataTypes.LongType; - // TODO(#5): add a timezone to allow parsing to timestamp - // This can be safely cast to TimestampType, but doing so causes the date to be inferred - // as the current date. It's safer to leave as a stable string and give the user the - // option of casting themselves. - } else if (LegacySQLTypeName.DATETIME.equals(field.getType())) { - return DataTypes.StringType; - } else if (LegacySQLTypeName.RECORD.equals(field.getType())) { - List structFields = field.getSubFields().stream().map(SchemaConverters::convert).collect(Collectors.toList()); - return new StructType(structFields.toArray(new StructField[0])); - } else if (LegacySQLTypeName.GEOGRAPHY.equals(field.getType())) { - return DataTypes.StringType; - } else { - throw new IllegalStateException("Unexpected type: " + field.getType()); - } + return new StructField(field.getName(), dataType, nullable, metadata.build()); + } + + private static DataType getDataType(Field field) { + + if (LegacySQLTypeName.INTEGER.equals(field.getType())) { + return DataTypes.LongType; + } else if (LegacySQLTypeName.FLOAT.equals(field.getType())) { + return DataTypes.DoubleType; + } else if (LegacySQLTypeName.NUMERIC.equals(field.getType())) { + return NUMERIC_SPARK_TYPE; + } else if (LegacySQLTypeName.STRING.equals(field.getType())) { + return DataTypes.StringType; + } else if (LegacySQLTypeName.BOOLEAN.equals(field.getType())) { + return DataTypes.BooleanType; + } else if (LegacySQLTypeName.BYTES.equals(field.getType())) { + return DataTypes.BinaryType; + } else if (LegacySQLTypeName.DATE.equals(field.getType())) { + return DataTypes.DateType; + } else if (LegacySQLTypeName.TIMESTAMP.equals(field.getType())) { + return DataTypes.TimestampType; + } else if (LegacySQLTypeName.TIME.equals(field.getType())) { + return DataTypes.LongType; + // TODO(#5): add a timezone to allow parsing to timestamp + // This can be safely cast to TimestampType, but doing so causes the date to be inferred + // as the current date. It's safer to leave as a stable string and give the user the + // option of casting themselves. + } else if (LegacySQLTypeName.DATETIME.equals(field.getType())) { + return DataTypes.StringType; + } else if (LegacySQLTypeName.RECORD.equals(field.getType())) { + List structFields = + field.getSubFields().stream().map(SchemaConverters::convert).collect(Collectors.toList()); + return new StructType(structFields.toArray(new StructField[0])); + } else if (LegacySQLTypeName.GEOGRAPHY.equals(field.getType())) { + return DataTypes.StringType; + } else { + throw new IllegalStateException("Unexpected type: " + field.getType()); } + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConfig.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConfig.java index 41318bb06e..b16d76e036 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConfig.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConfig.java @@ -41,337 +41,344 @@ import static java.lang.String.format; import static java.util.stream.Collectors.joining; - public class SparkBigQueryConfig implements BigQueryConfig { - public static final String VIEWS_ENABLED_OPTION = "viewsEnabled"; - @VisibleForTesting - static final DataFormat DEFAULT_READ_DATA_FORMAT = DataFormat.AVRO; - @VisibleForTesting - static final FormatOptions DEFAULT_INTERMEDIATE_FORMAT = FormatOptions.parquet(); - private static final String GCS_CONFIG_CREDENTIALS_FILE_PROPERTY = "google.cloud.auth.service.account.json.keyfile"; - private static final String GCS_CONFIG_PROJECT_ID_PROPERTY = "fs.gs.project.id"; - private static final String INTERMEDIATE_FORMAT_OPTION = "intermediateFormat"; - private static final String READ_DATA_FORMAT_OPTION = "readDataFormat"; - private static final ImmutableList PERMITTED_READ_DATA_FORMATS = ImmutableList.of( - DataFormat.ARROW.toString(), DataFormat.AVRO.toString()); - private static final ImmutableList PERMITTED_INTERMEDIATE_FORMATS = ImmutableList.of( - FormatOptions.orc(), FormatOptions.parquet()); - - private static final Supplier> DEFAULT_FALLBACK = () -> Optional.empty(); - - TableId tableId; - Optional parentProjectId; - Optional credentialsKey; - Optional credentialsFile; - Optional accessToken; - Optional filter = Optional.empty(); - Optional schema = Optional.empty(); - OptionalInt maxParallelism = OptionalInt.empty(); - int defaultParallelism = 1; - Optional temporaryGcsBucket = Optional.empty(); - FormatOptions intermediateFormat = DEFAULT_INTERMEDIATE_FORMAT; - DataFormat readDataFormat = DEFAULT_READ_DATA_FORMAT; - boolean combinePushedDownFilters = true; - boolean viewsEnabled = false; - Optional materializationProject = Optional.empty(); - Optional materializationDataset = Optional.empty(); - Optional partitionField = Optional.empty(); - OptionalLong partitionExpirationMs = OptionalLong.empty(); - Optional partitionRequireFilter = Optional.empty(); - Optional partitionType = Optional.empty(); - Optional clusteredFields = Optional.empty(); - Optional createDisposition = Optional.empty(); - boolean optimizedEmptyProjection = true; - ImmutableList loadSchemaUpdateOptions = ImmutableList.of(); - int viewExpirationTimeInHours = 24; - int maxReadRowsRetries = 3; - - private SparkBigQueryConfig() { - // empty - } - - public static SparkBigQueryConfig from( - DataSourceOptions options, - ImmutableMap globalOptions, - Configuration hadoopConfiguration, - int defaultParallelism) { - SparkBigQueryConfig config = new SparkBigQueryConfig(); - - String tableParam = getRequiredOption(options, "table"); - Optional datasetParam = getOption(options, "dataset"); - Optional projectParam = firstPresent(getOption(options, "project"), - Optional.ofNullable(hadoopConfiguration.get(GCS_CONFIG_PROJECT_ID_PROPERTY))); - config.tableId = parseTableId(tableParam, datasetParam, projectParam); - config.parentProjectId = getAnyOption(globalOptions, options, "parentProject"); - config.credentialsKey = getAnyOption(globalOptions, options, "credentials"); - config.credentialsFile = firstPresent(getAnyOption(globalOptions, options, "credentialsFile"), - Optional.ofNullable(hadoopConfiguration.get(GCS_CONFIG_CREDENTIALS_FILE_PROPERTY))); - config.accessToken = getAnyOption(globalOptions, options, "gcpAccessToken"); - config.filter = getOption(options, "filter"); - config.maxParallelism = toOptionalInt(getOptionFromMultipleParams( - options, ImmutableList.of("maxParallelism", "parallelism"), DEFAULT_FALLBACK) + public static final String VIEWS_ENABLED_OPTION = "viewsEnabled"; + @VisibleForTesting static final DataFormat DEFAULT_READ_DATA_FORMAT = DataFormat.AVRO; + + @VisibleForTesting + static final FormatOptions DEFAULT_INTERMEDIATE_FORMAT = FormatOptions.parquet(); + + private static final String GCS_CONFIG_CREDENTIALS_FILE_PROPERTY = + "google.cloud.auth.service.account.json.keyfile"; + private static final String GCS_CONFIG_PROJECT_ID_PROPERTY = "fs.gs.project.id"; + private static final String INTERMEDIATE_FORMAT_OPTION = "intermediateFormat"; + private static final String READ_DATA_FORMAT_OPTION = "readDataFormat"; + private static final ImmutableList PERMITTED_READ_DATA_FORMATS = + ImmutableList.of(DataFormat.ARROW.toString(), DataFormat.AVRO.toString()); + private static final ImmutableList PERMITTED_INTERMEDIATE_FORMATS = + ImmutableList.of(FormatOptions.orc(), FormatOptions.parquet()); + + private static final Supplier> DEFAULT_FALLBACK = () -> Optional.empty(); + + TableId tableId; + Optional parentProjectId; + Optional credentialsKey; + Optional credentialsFile; + Optional accessToken; + Optional filter = Optional.empty(); + Optional schema = Optional.empty(); + OptionalInt maxParallelism = OptionalInt.empty(); + int defaultParallelism = 1; + Optional temporaryGcsBucket = Optional.empty(); + FormatOptions intermediateFormat = DEFAULT_INTERMEDIATE_FORMAT; + DataFormat readDataFormat = DEFAULT_READ_DATA_FORMAT; + boolean combinePushedDownFilters = true; + boolean viewsEnabled = false; + Optional materializationProject = Optional.empty(); + Optional materializationDataset = Optional.empty(); + Optional partitionField = Optional.empty(); + OptionalLong partitionExpirationMs = OptionalLong.empty(); + Optional partitionRequireFilter = Optional.empty(); + Optional partitionType = Optional.empty(); + Optional clusteredFields = Optional.empty(); + Optional createDisposition = Optional.empty(); + boolean optimizedEmptyProjection = true; + ImmutableList loadSchemaUpdateOptions = ImmutableList.of(); + int viewExpirationTimeInHours = 24; + int maxReadRowsRetries = 3; + + private SparkBigQueryConfig() { + // empty + } + + public static SparkBigQueryConfig from( + DataSourceOptions options, + ImmutableMap globalOptions, + Configuration hadoopConfiguration, + int defaultParallelism) { + SparkBigQueryConfig config = new SparkBigQueryConfig(); + + String tableParam = getRequiredOption(options, "table"); + Optional datasetParam = getOption(options, "dataset"); + Optional projectParam = + firstPresent( + getOption(options, "project"), + Optional.ofNullable(hadoopConfiguration.get(GCS_CONFIG_PROJECT_ID_PROPERTY))); + config.tableId = parseTableId(tableParam, datasetParam, projectParam); + config.parentProjectId = getAnyOption(globalOptions, options, "parentProject"); + config.credentialsKey = getAnyOption(globalOptions, options, "credentials"); + config.credentialsFile = + firstPresent( + getAnyOption(globalOptions, options, "credentialsFile"), + Optional.ofNullable(hadoopConfiguration.get(GCS_CONFIG_CREDENTIALS_FILE_PROPERTY))); + config.accessToken = getAnyOption(globalOptions, options, "gcpAccessToken"); + config.filter = getOption(options, "filter"); + config.maxParallelism = + toOptionalInt( + getOptionFromMultipleParams( + options, ImmutableList.of("maxParallelism", "parallelism"), DEFAULT_FALLBACK) .map(Integer::valueOf)); - config.defaultParallelism = defaultParallelism; - config.temporaryGcsBucket = getAnyOption(globalOptions, options, "temporaryGcsBucket"); - config.intermediateFormat = getAnyOption(globalOptions, options, INTERMEDIATE_FORMAT_OPTION) - .map(String::toUpperCase) - .map(FormatOptions::of) - .orElse(DEFAULT_INTERMEDIATE_FORMAT); - if (!PERMITTED_INTERMEDIATE_FORMATS.contains(config.intermediateFormat)) { - throw new IllegalArgumentException( - format("Intermediate format '%s' is not supported. Supported formats are %s", - config.intermediateFormat.getType(), PERMITTED_INTERMEDIATE_FORMATS.stream().map(FormatOptions::getType).collect(joining(",")) - )); - } - String readDataFormatParam = getAnyOption(globalOptions, options, READ_DATA_FORMAT_OPTION) - .map(String::toUpperCase) - .orElse(DEFAULT_READ_DATA_FORMAT.toString()); - if (!PERMITTED_READ_DATA_FORMATS.contains(readDataFormatParam)) { - throw new IllegalArgumentException( - format("Data read format '%s' is not supported. Supported formats are '%s'", readDataFormatParam, String.join(",", PERMITTED_READ_DATA_FORMATS)) - ); - } - config.readDataFormat = DataFormat.valueOf(readDataFormatParam); - config.combinePushedDownFilters = getAnyBooleanOption( - globalOptions, options, "combinePushedDownFilters", true); - config.viewsEnabled = getAnyBooleanOption( - globalOptions, options, VIEWS_ENABLED_OPTION, false); - config.materializationProject = - getAnyOption(globalOptions, options, - ImmutableList.of("materializationProject", "viewMaterializationProject")); - config.materializationDataset = - getAnyOption(globalOptions, options, - ImmutableList.of("materializationDataset", "viewMaterializationDataset")); - - config.partitionField = getOption(options, "partitionField"); - config.partitionExpirationMs = toOptionalLong(getOption(options, "partitionExpirationMs").map(Long::valueOf)); - config.partitionRequireFilter = getOption(options, "partitionRequireFilter").map(Boolean::valueOf); - config.partitionType = getOption(options, "partitionType"); - config.clusteredFields = getOption(options, "clusteredFields").map(s -> s.split(",")); - - config.createDisposition = getOption(options, "createDisposition") - .map(String::toUpperCase) - .map(JobInfo.CreateDisposition::valueOf); - - config.optimizedEmptyProjection = getAnyBooleanOption( - globalOptions, options, "optimizedEmptyProjection", true); - - boolean allowFieldAddition = getAnyBooleanOption( - globalOptions, options, "allowFieldAddition", false); - boolean allowFieldRelaxation = getAnyBooleanOption( - globalOptions, options, "allowFieldRelaxation", false); - ImmutableList.Builder loadSchemaUpdateOptions = ImmutableList.builder(); - if (allowFieldAddition) { - loadSchemaUpdateOptions.add(JobInfo.SchemaUpdateOption.ALLOW_FIELD_ADDITION); - } - if (allowFieldRelaxation) { - loadSchemaUpdateOptions.add(JobInfo.SchemaUpdateOption.ALLOW_FIELD_RELAXATION); - } - config.loadSchemaUpdateOptions = loadSchemaUpdateOptions.build(); - - return config; - } - - private static OptionalInt toOptionalInt(Optional o) { - return o.map(Stream::of).orElse(Stream.empty()).mapToInt(Integer::intValue).findFirst(); - } - - private static OptionalLong toOptionalLong(Optional o) { - return o.map(Stream::of).orElse(Stream.empty()).mapToLong(Long::longValue).findFirst(); - } - - private static Supplier defaultBilledProject() { - return () -> BigQueryOptions.getDefaultInstance().getProjectId(); - } - - private static String getRequiredOption( - DataSourceOptions options, - String name) { - return getOption(options, name, DEFAULT_FALLBACK) - .orElseThrow(() -> new IllegalArgumentException(format("Option %s required.", name))); - } - - private static String getRequiredOption( - DataSourceOptions options, - String name, - Supplier fallback) { - return getOption(options, name, DEFAULT_FALLBACK).orElseGet(fallback); - } - - private static Optional getOption( - DataSourceOptions options, - String name) { - return getOption(options, name, DEFAULT_FALLBACK); - } - - private static Optional getOption( - DataSourceOptions options, - String name, - Supplier> fallback) { - return firstPresent(options.get(name), fallback.get()); - } - - private static Optional getOptionFromMultipleParams( - DataSourceOptions options, - Collection names, - Supplier> fallback) { - return names.stream().map(name -> getOption(options, name)) - .filter(Optional::isPresent) - .findFirst() - .orElseGet(fallback); - } - - private static Optional getAnyOption( - ImmutableMap globalOptions, - DataSourceOptions options, - String name) { - return Optional.ofNullable(options.get(name).orElse(globalOptions.get(name))); - } - - // gives the option to support old configurations as fallback - // Used to provide backward compatibility - private static Optional getAnyOption( - ImmutableMap globalOptions, - DataSourceOptions options, - Collection names) { - return names.stream() - .map(name -> getAnyOption(globalOptions, options, name)) - .filter(optional -> optional.isPresent()) - .findFirst() - .orElse(Optional.empty()); - } - - private static boolean getAnyBooleanOption(ImmutableMap globalOptions, - DataSourceOptions options, - String name, - boolean defaultValue) { - return getAnyOption(globalOptions, options, name) - .map(Boolean::valueOf) - .orElse(defaultValue); - } - - public TableId getTableId() { - return tableId; - } - - @Override - public Optional getParentProjectId() { - return parentProjectId; - } - - @Override - public Optional getCredentialsKey() { - return credentialsKey; - } - - @Override - public Optional getCredentialsFile() { - return credentialsKey; - } - - @Override - public Optional getAccessToken() { - return accessToken; - - } - - public Optional getFilter() { - return filter; - } - - public Optional getSchema() { - return schema; - } - - public OptionalInt getMaxParallelism() { - return maxParallelism; - } - - public int getDefaultParallelism() { - return defaultParallelism; - } - - public Optional getTemporaryGcsBucket() { - return temporaryGcsBucket; - } - - public FormatOptions getIntermediateFormat() { - return intermediateFormat; - } - - public DataFormat getReadDataFormat() { - return readDataFormat; - } - - public boolean isCombinePushedDownFilters() { - return combinePushedDownFilters; - } - - public boolean isViewsEnabled() { - return viewsEnabled; - } - - @Override - public Optional getMaterializationProject() { - return materializationProject; - } - - @Override - public Optional getMaterializationDataset() { - return materializationDataset; - } - - public Optional getPartitionField() { - return partitionField; - } - - public OptionalLong getPartitionExpirationMs() { - return partitionExpirationMs; - } - - public Optional getPartitionRequireFilter() { - return partitionRequireFilter; - } - - public Optional getPartitionType() { - return partitionType; - } - - public Optional getClusteredFields() { - return clusteredFields; - } - - public Optional getCreateDisposition() { - return createDisposition; - } - - public boolean isOptimizedEmptyProjection() { - return optimizedEmptyProjection; - } - - public ImmutableList getLoadSchemaUpdateOptions() { - return loadSchemaUpdateOptions; - } - - public int getViewExpirationTimeInHours() { - return viewExpirationTimeInHours; - } - - public int getMaxReadRowsRetries() { - return maxReadRowsRetries; - } - - public ReadSessionCreatorConfig toReadSessionCreatorConfig() { - return new ReadSessionCreatorConfig( - viewsEnabled, - materializationProject, - materializationDataset, - viewExpirationTimeInHours, - readDataFormat, - maxReadRowsRetries, - VIEWS_ENABLED_OPTION, - maxParallelism, - defaultParallelism); - } + config.defaultParallelism = defaultParallelism; + config.temporaryGcsBucket = getAnyOption(globalOptions, options, "temporaryGcsBucket"); + config.intermediateFormat = + getAnyOption(globalOptions, options, INTERMEDIATE_FORMAT_OPTION) + .map(String::toUpperCase) + .map(FormatOptions::of) + .orElse(DEFAULT_INTERMEDIATE_FORMAT); + if (!PERMITTED_INTERMEDIATE_FORMATS.contains(config.intermediateFormat)) { + throw new IllegalArgumentException( + format( + "Intermediate format '%s' is not supported. Supported formats are %s", + config.intermediateFormat.getType(), + PERMITTED_INTERMEDIATE_FORMATS.stream() + .map(FormatOptions::getType) + .collect(joining(",")))); + } + String readDataFormatParam = + getAnyOption(globalOptions, options, READ_DATA_FORMAT_OPTION) + .map(String::toUpperCase) + .orElse(DEFAULT_READ_DATA_FORMAT.toString()); + if (!PERMITTED_READ_DATA_FORMATS.contains(readDataFormatParam)) { + throw new IllegalArgumentException( + format( + "Data read format '%s' is not supported. Supported formats are '%s'", + readDataFormatParam, String.join(",", PERMITTED_READ_DATA_FORMATS))); + } + config.readDataFormat = DataFormat.valueOf(readDataFormatParam); + config.combinePushedDownFilters = + getAnyBooleanOption(globalOptions, options, "combinePushedDownFilters", true); + config.viewsEnabled = getAnyBooleanOption(globalOptions, options, VIEWS_ENABLED_OPTION, false); + config.materializationProject = + getAnyOption( + globalOptions, + options, + ImmutableList.of("materializationProject", "viewMaterializationProject")); + config.materializationDataset = + getAnyOption( + globalOptions, + options, + ImmutableList.of("materializationDataset", "viewMaterializationDataset")); + + config.partitionField = getOption(options, "partitionField"); + config.partitionExpirationMs = + toOptionalLong(getOption(options, "partitionExpirationMs").map(Long::valueOf)); + config.partitionRequireFilter = + getOption(options, "partitionRequireFilter").map(Boolean::valueOf); + config.partitionType = getOption(options, "partitionType"); + config.clusteredFields = getOption(options, "clusteredFields").map(s -> s.split(",")); + + config.createDisposition = + getOption(options, "createDisposition") + .map(String::toUpperCase) + .map(JobInfo.CreateDisposition::valueOf); + + config.optimizedEmptyProjection = + getAnyBooleanOption(globalOptions, options, "optimizedEmptyProjection", true); + + boolean allowFieldAddition = + getAnyBooleanOption(globalOptions, options, "allowFieldAddition", false); + boolean allowFieldRelaxation = + getAnyBooleanOption(globalOptions, options, "allowFieldRelaxation", false); + ImmutableList.Builder loadSchemaUpdateOptions = + ImmutableList.builder(); + if (allowFieldAddition) { + loadSchemaUpdateOptions.add(JobInfo.SchemaUpdateOption.ALLOW_FIELD_ADDITION); + } + if (allowFieldRelaxation) { + loadSchemaUpdateOptions.add(JobInfo.SchemaUpdateOption.ALLOW_FIELD_RELAXATION); + } + config.loadSchemaUpdateOptions = loadSchemaUpdateOptions.build(); + + return config; + } + + private static OptionalInt toOptionalInt(Optional o) { + return o.map(Stream::of).orElse(Stream.empty()).mapToInt(Integer::intValue).findFirst(); + } + + private static OptionalLong toOptionalLong(Optional o) { + return o.map(Stream::of).orElse(Stream.empty()).mapToLong(Long::longValue).findFirst(); + } + + private static Supplier defaultBilledProject() { + return () -> BigQueryOptions.getDefaultInstance().getProjectId(); + } + + private static String getRequiredOption(DataSourceOptions options, String name) { + return getOption(options, name, DEFAULT_FALLBACK) + .orElseThrow(() -> new IllegalArgumentException(format("Option %s required.", name))); + } + + private static String getRequiredOption( + DataSourceOptions options, String name, Supplier fallback) { + return getOption(options, name, DEFAULT_FALLBACK).orElseGet(fallback); + } + + private static Optional getOption(DataSourceOptions options, String name) { + return getOption(options, name, DEFAULT_FALLBACK); + } + + private static Optional getOption( + DataSourceOptions options, String name, Supplier> fallback) { + return firstPresent(options.get(name), fallback.get()); + } + + private static Optional getOptionFromMultipleParams( + DataSourceOptions options, Collection names, Supplier> fallback) { + return names.stream() + .map(name -> getOption(options, name)) + .filter(Optional::isPresent) + .findFirst() + .orElseGet(fallback); + } + + private static Optional getAnyOption( + ImmutableMap globalOptions, DataSourceOptions options, String name) { + return Optional.ofNullable(options.get(name).orElse(globalOptions.get(name))); + } + + // gives the option to support old configurations as fallback + // Used to provide backward compatibility + private static Optional getAnyOption( + ImmutableMap globalOptions, + DataSourceOptions options, + Collection names) { + return names.stream() + .map(name -> getAnyOption(globalOptions, options, name)) + .filter(optional -> optional.isPresent()) + .findFirst() + .orElse(Optional.empty()); + } + + private static boolean getAnyBooleanOption( + ImmutableMap globalOptions, + DataSourceOptions options, + String name, + boolean defaultValue) { + return getAnyOption(globalOptions, options, name).map(Boolean::valueOf).orElse(defaultValue); + } + + public TableId getTableId() { + return tableId; + } + + @Override + public Optional getParentProjectId() { + return parentProjectId; + } + + @Override + public Optional getCredentialsKey() { + return credentialsKey; + } + + @Override + public Optional getCredentialsFile() { + return credentialsKey; + } + + @Override + public Optional getAccessToken() { + return accessToken; + } + + public Optional getFilter() { + return filter; + } + + public Optional getSchema() { + return schema; + } + + public OptionalInt getMaxParallelism() { + return maxParallelism; + } + + public int getDefaultParallelism() { + return defaultParallelism; + } + + public Optional getTemporaryGcsBucket() { + return temporaryGcsBucket; + } + + public FormatOptions getIntermediateFormat() { + return intermediateFormat; + } + + public DataFormat getReadDataFormat() { + return readDataFormat; + } + + public boolean isCombinePushedDownFilters() { + return combinePushedDownFilters; + } + + public boolean isViewsEnabled() { + return viewsEnabled; + } + + @Override + public Optional getMaterializationProject() { + return materializationProject; + } + + @Override + public Optional getMaterializationDataset() { + return materializationDataset; + } + + public Optional getPartitionField() { + return partitionField; + } + + public OptionalLong getPartitionExpirationMs() { + return partitionExpirationMs; + } + + public Optional getPartitionRequireFilter() { + return partitionRequireFilter; + } + + public Optional getPartitionType() { + return partitionType; + } + + public Optional getClusteredFields() { + return clusteredFields; + } + + public Optional getCreateDisposition() { + return createDisposition; + } + + public boolean isOptimizedEmptyProjection() { + return optimizedEmptyProjection; + } + + public ImmutableList getLoadSchemaUpdateOptions() { + return loadSchemaUpdateOptions; + } + + public int getViewExpirationTimeInHours() { + return viewExpirationTimeInHours; + } + + public int getMaxReadRowsRetries() { + return maxReadRowsRetries; + } + + public ReadSessionCreatorConfig toReadSessionCreatorConfig() { + return new ReadSessionCreatorConfig( + viewsEnabled, + materializationProject, + materializationDataset, + viewExpirationTimeInHours, + readDataFormat, + maxReadRowsRetries, + VIEWS_ENABLED_OPTION, + maxParallelism, + defaultParallelism); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConnectorUserAgentProvider.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConnectorUserAgentProvider.java index c0bc716e89..c57b4697a5 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConnectorUserAgentProvider.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConnectorUserAgentProvider.java @@ -33,67 +33,73 @@ import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; -/** - * Provides the versions of the client environment in an anonymous way. - */ +/** Provides the versions of the client environment in an anonymous way. */ public class SparkBigQueryConnectorUserAgentProvider implements UserAgentProvider { - @VisibleForTesting - static String GCP_REGION_PART = getGcpRegion().map(region -> " region/" + region).orElse(""); - @VisibleForTesting - static String DATAPROC_IMAGE_PART = Optional.ofNullable(System.getenv("DATAPROC_IMAGE_VERSION")) - .map(image -> " dataproc-image/" + image) - .orElse(""); - private static String CONNECTOR_VERSION = BuildInfo.version(); - // In order to avoid using SparkContext or SparkSession, we are going directly to the source - private static String SPARK_VERSION = org.apache.spark.package$.MODULE$.SPARK_VERSION(); - private static String JAVA_VERSION = System.getProperty("java.runtime.version"); - private static String SCALA_VERSION = Properties.versionNumberString(); - static final String USER_AGENT = format("spark-bigquery-connector/%s spark/%s java/%s scala/%s%s%s", - CONNECTOR_VERSION, - SPARK_VERSION, - JAVA_VERSION, - SCALA_VERSION, - GCP_REGION_PART, - DATAPROC_IMAGE_PART - ); + @VisibleForTesting + static String GCP_REGION_PART = getGcpRegion().map(region -> " region/" + region).orElse(""); - private String dataSourceVersion; + @VisibleForTesting + static String DATAPROC_IMAGE_PART = + Optional.ofNullable(System.getenv("DATAPROC_IMAGE_VERSION")) + .map(image -> " dataproc-image/" + image) + .orElse(""); - public SparkBigQueryConnectorUserAgentProvider(String dataSourceVersion) { - this.dataSourceVersion = dataSourceVersion; - } + private static String CONNECTOR_VERSION = BuildInfo.version(); + // In order to avoid using SparkContext or SparkSession, we are going directly to the source + private static String SPARK_VERSION = org.apache.spark.package$.MODULE$.SPARK_VERSION(); + private static String JAVA_VERSION = System.getProperty("java.runtime.version"); + private static String SCALA_VERSION = Properties.versionNumberString(); + static final String USER_AGENT = + format( + "spark-bigquery-connector/%s spark/%s java/%s scala/%s%s%s", + CONNECTOR_VERSION, + SPARK_VERSION, + JAVA_VERSION, + SCALA_VERSION, + GCP_REGION_PART, + DATAPROC_IMAGE_PART); - // Queries the GCE metadata server - @VisibleForTesting - static Optional getGcpRegion() { - RequestConfig config = RequestConfig.custom() - .setConnectTimeout(100) - .setConnectionRequestTimeout(100) - .setSocketTimeout(100).build(); - CloseableHttpClient httpClient = HttpClients.custom().setDefaultRequestConfig(config).build(); - HttpGet httpGet = new HttpGet("http://metadata.google.internal/computeMetadata/v1/instance/zone"); - httpGet.addHeader("Metadata-Flavor", "Google"); - try (CloseableHttpResponse response = httpClient.execute(httpGet)) { - if (response.getStatusLine().getStatusCode() == 200) { - String body = CharStreams.toString(new InputStreamReader(response.getEntity().getContent(), UTF_8)); - return Optional.of(body.substring(body.lastIndexOf('/') + 1)); - } else { - return Optional.empty(); - } - } catch (Exception e) { - return Optional.empty(); - } finally { - try { - Closeables.close(httpClient, true); - } catch (IOException e) { - // nothing to do - } - } - } + private String dataSourceVersion; - @Override - public String getUserAgent() { - return USER_AGENT + " datasource/" + dataSourceVersion; + public SparkBigQueryConnectorUserAgentProvider(String dataSourceVersion) { + this.dataSourceVersion = dataSourceVersion; + } + + // Queries the GCE metadata server + @VisibleForTesting + static Optional getGcpRegion() { + RequestConfig config = + RequestConfig.custom() + .setConnectTimeout(100) + .setConnectionRequestTimeout(100) + .setSocketTimeout(100) + .build(); + CloseableHttpClient httpClient = HttpClients.custom().setDefaultRequestConfig(config).build(); + HttpGet httpGet = + new HttpGet("http://metadata.google.internal/computeMetadata/v1/instance/zone"); + httpGet.addHeader("Metadata-Flavor", "Google"); + try (CloseableHttpResponse response = httpClient.execute(httpGet)) { + if (response.getStatusLine().getStatusCode() == 200) { + String body = + CharStreams.toString(new InputStreamReader(response.getEntity().getContent(), UTF_8)); + return Optional.of(body.substring(body.lastIndexOf('/') + 1)); + } else { + return Optional.empty(); + } + } catch (Exception e) { + return Optional.empty(); + } finally { + try { + Closeables.close(httpClient, true); + } catch (IOException e) { + // nothing to do + } } + } + + @Override + public String getUserAgent() { + return USER_AGENT + " datasource/" + dataSourceVersion; + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConnectorVersionProvider.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConnectorVersionProvider.java index 2fdcd666a8..74e5b0cd14 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConnectorVersionProvider.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkBigQueryConnectorVersionProvider.java @@ -23,19 +23,19 @@ public class SparkBigQueryConnectorVersionProvider implements VersionProvider { - private SparkContext sparkContext; + private SparkContext sparkContext; - public SparkBigQueryConnectorVersionProvider(SparkContext sparkContext) { - this.sparkContext = sparkContext; - } + public SparkBigQueryConnectorVersionProvider(SparkContext sparkContext) { + this.sparkContext = sparkContext; + } - @Override - public String getVersion() { - return format("spark-bigquery-connector/%s spark/%s java/%s scala/%s", - BuildInfo.version(), - sparkContext.version(), - System.getProperty("java.runtime.version"), - Properties.versionNumberString() - ); - } + @Override + public String getVersion() { + return format( + "spark-bigquery-connector/%s spark/%s java/%s scala/%s", + BuildInfo.version(), + sparkContext.version(), + System.getProperty("java.runtime.version"), + Properties.versionNumberString()); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkFilterUtils.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkFilterUtils.java index b434bf9c58..fe62528663 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SparkFilterUtils.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SparkFilterUtils.java @@ -31,174 +31,177 @@ public class SparkFilterUtils { - private SparkFilterUtils() { - } - - public static boolean isHandled(Filter filter, DataFormat readDataFormat) { - if (filter instanceof EqualTo || - filter instanceof GreaterThan || - filter instanceof GreaterThanOrEqual || - filter instanceof LessThan || - filter instanceof LessThanOrEqual || - filter instanceof In || - filter instanceof IsNull || - filter instanceof IsNotNull || - filter instanceof StringStartsWith || - filter instanceof StringEndsWith || - filter instanceof StringContains) { - return true; - } - // There is no direct equivalent of EqualNullSafe in Google standard SQL. - if (filter instanceof EqualNullSafe) { - return false; - } - if (filter instanceof And) { - And and = (And) filter; - return isHandled(and.left(), readDataFormat) && isHandled(and.right(), readDataFormat); - } - if (filter instanceof Or) { - Or or = (Or) filter; - return readDataFormat == DataFormat.AVRO - && isHandled(or.left(), readDataFormat) - && isHandled(or.right(), readDataFormat); - } - if (filter instanceof Not) { - return isHandled(((Not) filter).child(), readDataFormat); - } - return false; - } - - public static Iterable handledFilters(DataFormat readDataFormat, Filter... filters) { - return handledFilters(readDataFormat, ImmutableList.copyOf(filters)); - } - - public static Iterable handledFilters(DataFormat readDataFormat, Iterable filters) { - return StreamSupport.stream(filters.spliterator(), false) - .filter(f -> isHandled(f, readDataFormat)) - .collect(Collectors.toList()); - } - - public static Iterable unhandledFilters(DataFormat readDataFormat, Filter... filters) { - return unhandledFilters(readDataFormat, ImmutableList.copyOf(filters)); - } - - public static Iterable unhandledFilters(DataFormat readDataFormat, Iterable filters) { - return StreamSupport.stream(filters.spliterator(), false) - .filter(f -> !isHandled(f, readDataFormat)) - .collect(Collectors.toList()); - } - - public static String getCompiledFilter( - DataFormat readDataFormat, - Optional configFilter, - Filter... pushedFilters) { - String compiledPushedFilter = compileFilters(handledFilters( - readDataFormat, ImmutableList.copyOf(pushedFilters))); - return Stream.of( - configFilter, - compiledPushedFilter.length() == 0 ? Optional.empty() : Optional.of(compiledPushedFilter)) - .filter(Optional::isPresent) - .map(filter -> "(" + filter.get() + ")") - .collect(Collectors.joining(" AND ")); - - } - - // Mostly copied from JDBCRDD.scala - public static String compileFilter(Filter filter) { - if (filter instanceof EqualTo) { - EqualTo equalTo = (EqualTo) filter; - return format("%s = %s", quote(equalTo.attribute()), compileValue(equalTo.value())); - } - if (filter instanceof GreaterThan) { - GreaterThan greaterThan = (GreaterThan) filter; - return format("%s > %s", quote(greaterThan.attribute()), compileValue(greaterThan.value())); - } - if (filter instanceof GreaterThanOrEqual) { - GreaterThanOrEqual greaterThanOrEqual = (GreaterThanOrEqual) filter; - return format("%s >= %s", quote(greaterThanOrEqual.attribute()), compileValue(greaterThanOrEqual.value())); - } - if (filter instanceof LessThan) { - LessThan lessThan = (LessThan) filter; - return format("%s < %s", quote(lessThan.attribute()), compileValue(lessThan.value())); - } - if (filter instanceof LessThanOrEqual) { - LessThanOrEqual lessThanOrEqual = (LessThanOrEqual) filter; - return format("%s <>>= %s", quote(lessThanOrEqual.attribute()), compileValue(lessThanOrEqual.value())); - } - if (filter instanceof In) { - In in = (In) filter; - return format("%s IN UNNEST(%s)", quote(in.attribute()), compileValue(in.values())); - } - if (filter instanceof IsNull) { - IsNull isNull = (IsNull) filter; - return format("%s IS NULL", quote(isNull.attribute())); - } - if (filter instanceof IsNotNull) { - IsNotNull isNotNull = (IsNotNull) filter; - return format("%s IS NOT NULL", quote(isNotNull.attribute())); - } - if (filter instanceof And) { - And and = (And) filter; - return format("(%s) AND (%s)", compileFilter(and.left()), compileFilter(and.right())); - } - if (filter instanceof Or) { - Or or = (Or) filter; - return format("(%s) OR (%s)", compileFilter(or.left()), compileFilter(or.right())); - } - if (filter instanceof Not) { - Not not = (Not) filter; - return format("(NOT (%s))", compileFilter(not.child())); - } - if (filter instanceof StringStartsWith) { - StringStartsWith stringStartsWith = (StringStartsWith) filter; - return format("%s LIKE '%s%%'", quote(stringStartsWith.attribute()), escape(stringStartsWith.value())); - } - if (filter instanceof StringEndsWith) { - StringEndsWith stringEndsWith = (StringEndsWith) filter; - return format("%s LIKE '%%%s'", quote(stringEndsWith.attribute()), escape(stringEndsWith.value())); - } - if (filter instanceof StringContains) { - StringContains stringContains = (StringContains) filter; - return format("%s LIKE '%%%s%%'", quote(stringContains.attribute()), escape(stringContains.value())); - } - - throw new IllegalArgumentException(format("Invalid filter: %s", filter)); - } - - public static String compileFilters(Iterable filters) { - return StreamSupport.stream(filters.spliterator(), false) - .map(SparkFilterUtils::compileFilter) - .sorted() - .collect(Collectors.joining(" AND ")); - } - - /** - * Converts value to SQL expression. - */ - static String compileValue(Object value) { - if (value == null) { - return null; - } - if (value instanceof String) { - return "'" + escape((String) value) + "'"; - } - if (value instanceof Timestamp || value instanceof Date) { - return "'" + value + "'"; - } - if (value instanceof Object[]) { - return Arrays.stream((Object[]) value) - .map(SparkFilterUtils::compileValue) - .collect(Collectors.joining(", ", "[", "]")); - } - return value.toString(); - } - - static String escape(String value) { - return value.replace("'", "\\'"); - } - - static String quote(String value) { - return "`" + value + "`"; + private SparkFilterUtils() {} + + public static boolean isHandled(Filter filter, DataFormat readDataFormat) { + if (filter instanceof EqualTo + || filter instanceof GreaterThan + || filter instanceof GreaterThanOrEqual + || filter instanceof LessThan + || filter instanceof LessThanOrEqual + || filter instanceof In + || filter instanceof IsNull + || filter instanceof IsNotNull + || filter instanceof StringStartsWith + || filter instanceof StringEndsWith + || filter instanceof StringContains) { + return true; } + // There is no direct equivalent of EqualNullSafe in Google standard SQL. + if (filter instanceof EqualNullSafe) { + return false; + } + if (filter instanceof And) { + And and = (And) filter; + return isHandled(and.left(), readDataFormat) && isHandled(and.right(), readDataFormat); + } + if (filter instanceof Or) { + Or or = (Or) filter; + return readDataFormat == DataFormat.AVRO + && isHandled(or.left(), readDataFormat) + && isHandled(or.right(), readDataFormat); + } + if (filter instanceof Not) { + return isHandled(((Not) filter).child(), readDataFormat); + } + return false; + } + + public static Iterable handledFilters(DataFormat readDataFormat, Filter... filters) { + return handledFilters(readDataFormat, ImmutableList.copyOf(filters)); + } + + public static Iterable handledFilters( + DataFormat readDataFormat, Iterable filters) { + return StreamSupport.stream(filters.spliterator(), false) + .filter(f -> isHandled(f, readDataFormat)) + .collect(Collectors.toList()); + } + + public static Iterable unhandledFilters(DataFormat readDataFormat, Filter... filters) { + return unhandledFilters(readDataFormat, ImmutableList.copyOf(filters)); + } + + public static Iterable unhandledFilters( + DataFormat readDataFormat, Iterable filters) { + return StreamSupport.stream(filters.spliterator(), false) + .filter(f -> !isHandled(f, readDataFormat)) + .collect(Collectors.toList()); + } + + public static String getCompiledFilter( + DataFormat readDataFormat, Optional configFilter, Filter... pushedFilters) { + String compiledPushedFilter = + compileFilters(handledFilters(readDataFormat, ImmutableList.copyOf(pushedFilters))); + return Stream.of( + configFilter, + compiledPushedFilter.length() == 0 + ? Optional.empty() + : Optional.of(compiledPushedFilter)) + .filter(Optional::isPresent) + .map(filter -> "(" + filter.get() + ")") + .collect(Collectors.joining(" AND ")); + } + + // Mostly copied from JDBCRDD.scala + public static String compileFilter(Filter filter) { + if (filter instanceof EqualTo) { + EqualTo equalTo = (EqualTo) filter; + return format("%s = %s", quote(equalTo.attribute()), compileValue(equalTo.value())); + } + if (filter instanceof GreaterThan) { + GreaterThan greaterThan = (GreaterThan) filter; + return format("%s > %s", quote(greaterThan.attribute()), compileValue(greaterThan.value())); + } + if (filter instanceof GreaterThanOrEqual) { + GreaterThanOrEqual greaterThanOrEqual = (GreaterThanOrEqual) filter; + return format( + "%s >= %s", + quote(greaterThanOrEqual.attribute()), compileValue(greaterThanOrEqual.value())); + } + if (filter instanceof LessThan) { + LessThan lessThan = (LessThan) filter; + return format("%s < %s", quote(lessThan.attribute()), compileValue(lessThan.value())); + } + if (filter instanceof LessThanOrEqual) { + LessThanOrEqual lessThanOrEqual = (LessThanOrEqual) filter; + return format( + "%s <>>= %s", quote(lessThanOrEqual.attribute()), compileValue(lessThanOrEqual.value())); + } + if (filter instanceof In) { + In in = (In) filter; + return format("%s IN UNNEST(%s)", quote(in.attribute()), compileValue(in.values())); + } + if (filter instanceof IsNull) { + IsNull isNull = (IsNull) filter; + return format("%s IS NULL", quote(isNull.attribute())); + } + if (filter instanceof IsNotNull) { + IsNotNull isNotNull = (IsNotNull) filter; + return format("%s IS NOT NULL", quote(isNotNull.attribute())); + } + if (filter instanceof And) { + And and = (And) filter; + return format("(%s) AND (%s)", compileFilter(and.left()), compileFilter(and.right())); + } + if (filter instanceof Or) { + Or or = (Or) filter; + return format("(%s) OR (%s)", compileFilter(or.left()), compileFilter(or.right())); + } + if (filter instanceof Not) { + Not not = (Not) filter; + return format("(NOT (%s))", compileFilter(not.child())); + } + if (filter instanceof StringStartsWith) { + StringStartsWith stringStartsWith = (StringStartsWith) filter; + return format( + "%s LIKE '%s%%'", quote(stringStartsWith.attribute()), escape(stringStartsWith.value())); + } + if (filter instanceof StringEndsWith) { + StringEndsWith stringEndsWith = (StringEndsWith) filter; + return format( + "%s LIKE '%%%s'", quote(stringEndsWith.attribute()), escape(stringEndsWith.value())); + } + if (filter instanceof StringContains) { + StringContains stringContains = (StringContains) filter; + return format( + "%s LIKE '%%%s%%'", quote(stringContains.attribute()), escape(stringContains.value())); + } + + throw new IllegalArgumentException(format("Invalid filter: %s", filter)); + } + + public static String compileFilters(Iterable filters) { + return StreamSupport.stream(filters.spliterator(), false) + .map(SparkFilterUtils::compileFilter) + .sorted() + .collect(Collectors.joining(" AND ")); + } + + /** Converts value to SQL expression. */ + static String compileValue(Object value) { + if (value == null) { + return null; + } + if (value instanceof String) { + return "'" + escape((String) value) + "'"; + } + if (value instanceof Timestamp || value instanceof Date) { + return "'" + value + "'"; + } + if (value instanceof Object[]) { + return Arrays.stream((Object[]) value) + .map(SparkFilterUtils::compileValue) + .collect(Collectors.joining(", ", "[", "]")); + } + return value.toString(); + } + + static String escape(String value) { + return value.replace("'", "\\'"); + } + static String quote(String value) { + return "`" + value + "`"; + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/examples/JavaShakespeare.java b/connector/src/main/java/com/google/cloud/spark/bigquery/examples/JavaShakespeare.java index e2bb2fa925..6fe8dd02f9 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/examples/JavaShakespeare.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/examples/JavaShakespeare.java @@ -21,30 +21,36 @@ public class JavaShakespeare { - public static void main(String[] args) { - SparkSession spark = SparkSession.builder() - .appName("spark-bigquery-demo") - .getOrCreate(); + public static void main(String[] args) { + SparkSession spark = SparkSession.builder().appName("spark-bigquery-demo").getOrCreate(); - // Use the Cloud Storage bucket for temporary BigQuery export data used - // by the connector. This assumes the Cloud Storage connector for - // Hadoop is configured. - String bucket = spark.sparkContext().hadoopConfiguration().get("fs.gs.system.bucket"); - spark.conf().set("temporaryGcsBucket", bucket); + // Use the Cloud Storage bucket for temporary BigQuery export data used + // by the connector. This assumes the Cloud Storage connector for + // Hadoop is configured. + String bucket = spark.sparkContext().hadoopConfiguration().get("fs.gs.system.bucket"); + spark.conf().set("temporaryGcsBucket", bucket); - // Load data in from BigQuery. - Dataset wordsDF = spark.read().format("bigquery") - .option("table", "bigquery-public-data.samples.shakespeare").load().cache(); - wordsDF.show(); - wordsDF.printSchema(); - wordsDF.createOrReplaceTempView("words"); + // Load data in from BigQuery. + Dataset wordsDF = + spark + .read() + .format("bigquery") + .option("table", "bigquery-public-data.samples.shakespeare") + .load() + .cache(); + wordsDF.show(); + wordsDF.printSchema(); + wordsDF.createOrReplaceTempView("words"); - // Perform word count. - Dataset wordCountDF = spark.sql( - "SELECT word, SUM(word_count) AS word_count FROM words GROUP BY word"); + // Perform word count. + Dataset wordCountDF = + spark.sql("SELECT word, SUM(word_count) AS word_count FROM words GROUP BY word"); - // Saving the data to BigQuery - wordCountDF.write().format("bigquery").option("table", "wordcount_dataset.wordcount_output") - .save(); - } + // Saving the data to BigQuery + wordCountDF + .write() + .format("bigquery") + .option("table", "wordcount_dataset.wordcount_output") + .save(); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java index e1d20b0035..68dfaab217 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java @@ -33,176 +33,183 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -public class BigQueryDataSourceReader implements - DataSourceReader, +public class BigQueryDataSourceReader + implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters, SupportsReportStatistics { - private static Statistics UNKNOWN_STATISTICS = new Statistics() { + private static Statistics UNKNOWN_STATISTICS = + new Statistics() { @Override public OptionalLong sizeInBytes() { - return OptionalLong.empty(); + return OptionalLong.empty(); } @Override public OptionalLong numRows() { - return OptionalLong.empty(); + return OptionalLong.empty(); } - }; - - private final TableInfo table; - private final TableId tableId; - private final ReadSessionCreatorConfig readSessionCreatorConfig; - private final BigQueryClient bigQueryClient; - private final BigQueryReadClientFactory bigQueryReadClientFactory; - private final ReadSessionCreator readSessionCreator; - private final Optional globalFilter; - private Optional schema; - private Filter[] pushedFilters = new Filter[]{}; - - public BigQueryDataSourceReader( - TableInfo table, - BigQueryClient bigQueryClient, - BigQueryReadClientFactory bigQueryReadClientFactory, - ReadSessionCreatorConfig readSessionCreatorConfig, - Optional globalFilter, - Optional schema) { - this.table = table; - this.tableId = table.getTableId(); - this.readSessionCreatorConfig = readSessionCreatorConfig; - this.bigQueryClient = bigQueryClient; - this.bigQueryReadClientFactory = bigQueryReadClientFactory; - this.readSessionCreator = new ReadSessionCreator(readSessionCreatorConfig, bigQueryClient, bigQueryReadClientFactory); - this.globalFilter = globalFilter; - this.schema = schema; + }; + + private final TableInfo table; + private final TableId tableId; + private final ReadSessionCreatorConfig readSessionCreatorConfig; + private final BigQueryClient bigQueryClient; + private final BigQueryReadClientFactory bigQueryReadClientFactory; + private final ReadSessionCreator readSessionCreator; + private final Optional globalFilter; + private Optional schema; + private Filter[] pushedFilters = new Filter[] {}; + + public BigQueryDataSourceReader( + TableInfo table, + BigQueryClient bigQueryClient, + BigQueryReadClientFactory bigQueryReadClientFactory, + ReadSessionCreatorConfig readSessionCreatorConfig, + Optional globalFilter, + Optional schema) { + this.table = table; + this.tableId = table.getTableId(); + this.readSessionCreatorConfig = readSessionCreatorConfig; + this.bigQueryClient = bigQueryClient; + this.bigQueryReadClientFactory = bigQueryReadClientFactory; + this.readSessionCreator = + new ReadSessionCreator(readSessionCreatorConfig, bigQueryClient, bigQueryReadClientFactory); + this.globalFilter = globalFilter; + this.schema = schema; + } + + @Override + public StructType readSchema() { + // TODO: rely on Java code + return schema.orElse(SchemaConverters.toSpark(table.getDefinition().getSchema())); + } + + @Override + public List> planInputPartitions() { + if (schema.map(StructType::isEmpty).orElse(false)) { + // create empty projection + return createEmptyProjectionPartitions(); } - @Override - public StructType readSchema() { - // TODO: rely on Java code - return schema.orElse(SchemaConverters.toSpark(table.getDefinition().getSchema())); - } - - @Override - public List> planInputPartitions() { - if (schema.map(StructType::isEmpty).orElse(false)) { - // create empty projection - return createEmptyProjectionPartitions(); - } - - ImmutableList selectedFields = schema - .map(requiredSchema -> ImmutableList.copyOf(requiredSchema.fieldNames())) - .orElse(ImmutableList.of()); - Optional filter = emptyIfNeeded(SparkFilterUtils.getCompiledFilter( + ImmutableList selectedFields = + schema + .map(requiredSchema -> ImmutableList.copyOf(requiredSchema.fieldNames())) + .orElse(ImmutableList.of()); + Optional filter = + emptyIfNeeded( + SparkFilterUtils.getCompiledFilter( readSessionCreatorConfig.getReadDataFormat(), globalFilter, pushedFilters)); - ReadSessionResponse readSessionResponse = readSessionCreator.create( - tableId, selectedFields, filter, readSessionCreatorConfig.getMaxParallelism()); - ReadSession readSession = readSessionResponse.getReadSession(); - return readSession.getStreamsList().stream() - .map(stream -> new BigQueryInputPartition( - bigQueryReadClientFactory, - stream.getName(), - readSessionCreatorConfig.getMaxReadRowsRetries(), - createConverter(selectedFields, readSessionResponse))) - .collect(Collectors.toList()); - } - - private ReadRowsResponseToInternalRowIteratorConverter createConverter( - ImmutableList selectedFields, ReadSessionResponse readSessionResponse) { - ReadRowsResponseToInternalRowIteratorConverter converter; - if (readSessionCreatorConfig.getReadDataFormat() == DataFormat.AVRO) { - Schema schema = readSessionResponse.getReadTableInfo().getDefinition().getSchema(); - if (selectedFields.isEmpty()) { - // means select * - selectedFields = schema.getFields().stream() - .map(Field::getName) - .collect(ImmutableList.toImmutableList()); - } else { - Set requiredColumnSet = ImmutableSet.copyOf(selectedFields); - schema = Schema.of(schema.getFields().stream() - .filter(field -> requiredColumnSet.contains(field.getName())) - .collect(Collectors.toList())); - } - return ReadRowsResponseToInternalRowIteratorConverter.avro( - schema, - selectedFields, - readSessionResponse.getReadSession().getAvroSchema().getSchema()); - } else { - return ReadRowsResponseToInternalRowIteratorConverter.arrow( - selectedFields, - readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema()); - } - } - - List> createEmptyProjectionPartitions() { - long rowCount = bigQueryClient.calculateTableSize(tableId, globalFilter); - int partitionsCount = readSessionCreatorConfig.getDefaultParallelism(); - int partitionSize = (int) (rowCount / partitionsCount); - InputPartition[] partitions = IntStream - .range(0, partitionsCount) - .mapToObj(ignored -> new BigQueryEmptyProjectionInputPartition(partitionSize)) - .toArray(BigQueryEmptyProjectionInputPartition[]::new); - int firstPartitionSize = partitionSize + (int) (rowCount % partitionsCount); - partitions[0] = new BigQueryEmptyProjectionInputPartition(firstPartitionSize); - return ImmutableList.copyOf(partitions); + ReadSessionResponse readSessionResponse = + readSessionCreator.create( + tableId, selectedFields, filter, readSessionCreatorConfig.getMaxParallelism()); + ReadSession readSession = readSessionResponse.getReadSession(); + return readSession.getStreamsList().stream() + .map( + stream -> + new BigQueryInputPartition( + bigQueryReadClientFactory, + stream.getName(), + readSessionCreatorConfig.getMaxReadRowsRetries(), + createConverter(selectedFields, readSessionResponse))) + .collect(Collectors.toList()); + } + + private ReadRowsResponseToInternalRowIteratorConverter createConverter( + ImmutableList selectedFields, ReadSessionResponse readSessionResponse) { + ReadRowsResponseToInternalRowIteratorConverter converter; + if (readSessionCreatorConfig.getReadDataFormat() == DataFormat.AVRO) { + Schema schema = readSessionResponse.getReadTableInfo().getDefinition().getSchema(); + if (selectedFields.isEmpty()) { + // means select * + selectedFields = + schema.getFields().stream() + .map(Field::getName) + .collect(ImmutableList.toImmutableList()); + } else { + Set requiredColumnSet = ImmutableSet.copyOf(selectedFields); + schema = + Schema.of( + schema.getFields().stream() + .filter(field -> requiredColumnSet.contains(field.getName())) + .collect(Collectors.toList())); + } + return ReadRowsResponseToInternalRowIteratorConverter.avro( + schema, selectedFields, readSessionResponse.getReadSession().getAvroSchema().getSchema()); + } else { + return ReadRowsResponseToInternalRowIteratorConverter.arrow( + selectedFields, + readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema()); } - - @Override - public Filter[] pushFilters(Filter[] filters) { - List handledFilters = new ArrayList<>(); - List unhandledFilters = new ArrayList<>(); - for (Filter filter : filters) { - if (SparkFilterUtils.isHandled(filter, readSessionCreatorConfig.getReadDataFormat())) { - handledFilters.add(filter); - } else { - unhandledFilters.add(filter); - } - } - pushedFilters = handledFilters.stream().toArray(Filter[]::new); - return unhandledFilters.stream().toArray(Filter[]::new); - } - - @Override - public Filter[] pushedFilters() { - return pushedFilters; - } - - @Override - public void pruneColumns(StructType requiredSchema) { - this.schema = Optional.ofNullable(requiredSchema); - } - - Optional emptyIfNeeded(String value) { - return (value == null || value.length() == 0) ? - Optional.empty() : Optional.of(value); - } - - @Override - public Statistics estimateStatistics() { - return table.getDefinition().getType() == TableDefinition.Type.TABLE ? - new StandardTableStatistics(table.getDefinition()) : - UNKNOWN_STATISTICS; - + } + + List> createEmptyProjectionPartitions() { + long rowCount = bigQueryClient.calculateTableSize(tableId, globalFilter); + int partitionsCount = readSessionCreatorConfig.getDefaultParallelism(); + int partitionSize = (int) (rowCount / partitionsCount); + InputPartition[] partitions = + IntStream.range(0, partitionsCount) + .mapToObj(ignored -> new BigQueryEmptyProjectionInputPartition(partitionSize)) + .toArray(BigQueryEmptyProjectionInputPartition[]::new); + int firstPartitionSize = partitionSize + (int) (rowCount % partitionsCount); + partitions[0] = new BigQueryEmptyProjectionInputPartition(firstPartitionSize); + return ImmutableList.copyOf(partitions); + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + List handledFilters = new ArrayList<>(); + List unhandledFilters = new ArrayList<>(); + for (Filter filter : filters) { + if (SparkFilterUtils.isHandled(filter, readSessionCreatorConfig.getReadDataFormat())) { + handledFilters.add(filter); + } else { + unhandledFilters.add(filter); + } } + pushedFilters = handledFilters.stream().toArray(Filter[]::new); + return unhandledFilters.stream().toArray(Filter[]::new); + } + + @Override + public Filter[] pushedFilters() { + return pushedFilters; + } + + @Override + public void pruneColumns(StructType requiredSchema) { + this.schema = Optional.ofNullable(requiredSchema); + } + + Optional emptyIfNeeded(String value) { + return (value == null || value.length() == 0) ? Optional.empty() : Optional.of(value); + } + + @Override + public Statistics estimateStatistics() { + return table.getDefinition().getType() == TableDefinition.Type.TABLE + ? new StandardTableStatistics(table.getDefinition()) + : UNKNOWN_STATISTICS; + } } class StandardTableStatistics implements Statistics { - private StandardTableDefinition tableDefinition; + private StandardTableDefinition tableDefinition; - public StandardTableStatistics(StandardTableDefinition tableDefinition) { - this.tableDefinition = tableDefinition; - } + public StandardTableStatistics(StandardTableDefinition tableDefinition) { + this.tableDefinition = tableDefinition; + } - @Override - public OptionalLong sizeInBytes() { - return OptionalLong.of(tableDefinition.getNumBytes()); - } + @Override + public OptionalLong sizeInBytes() { + return OptionalLong.of(tableDefinition.getNumBytes()); + } - @Override - public OptionalLong numRows() { - return OptionalLong.of(tableDefinition.getNumRows()); - } + @Override + public OptionalLong numRows() { + return OptionalLong.of(tableDefinition.getNumRows()); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceV2.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceV2.java index 602a3f31ac..130bfb3b63 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceV2.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceV2.java @@ -29,29 +29,29 @@ public class BigQueryDataSourceV2 implements DataSourceV2, ReadSupport { - @Override - public DataSourceReader createReader(StructType schema, DataSourceOptions options) { - SparkSession spark = getDefaultSparkSessionOrCreate(); - - Injector injector = Guice.createInjector( - new BigQueryClientModule(), - new SparkBigQueryConnectorModule(spark, options, Optional.ofNullable(schema))); - - BigQueryDataSourceReader reader = injector.getInstance(BigQueryDataSourceReader.class); - return reader; - } - - private SparkSession getDefaultSparkSessionOrCreate() { - scala.Option defaultSpareSession = SparkSession.getDefaultSession(); - if (defaultSpareSession.isDefined()) { - return defaultSpareSession.get(); - } - return SparkSession.builder().appName("spark-bigquery-connector").getOrCreate(); + @Override + public DataSourceReader createReader(StructType schema, DataSourceOptions options) { + SparkSession spark = getDefaultSparkSessionOrCreate(); + + Injector injector = + Guice.createInjector( + new BigQueryClientModule(), + new SparkBigQueryConnectorModule(spark, options, Optional.ofNullable(schema))); + + BigQueryDataSourceReader reader = injector.getInstance(BigQueryDataSourceReader.class); + return reader; + } + + private SparkSession getDefaultSparkSessionOrCreate() { + scala.Option defaultSpareSession = SparkSession.getDefaultSession(); + if (defaultSpareSession.isDefined()) { + return defaultSpareSession.get(); } + return SparkSession.builder().appName("spark-bigquery-connector").getOrCreate(); + } - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return createReader(null, options); - } + @Override + public DataSourceReader createReader(DataSourceOptions options) { + return createReader(null, options); + } } - diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryEmptyProjectionInputPartition.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryEmptyProjectionInputPartition.java index 383c049f17..72220cf36a 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryEmptyProjectionInputPartition.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryEmptyProjectionInputPartition.java @@ -21,14 +21,14 @@ public class BigQueryEmptyProjectionInputPartition implements InputPartition { - final int partitionSize; + final int partitionSize; - public BigQueryEmptyProjectionInputPartition(int partitionSize) { - this.partitionSize = partitionSize; - } + public BigQueryEmptyProjectionInputPartition(int partitionSize) { + this.partitionSize = partitionSize; + } - @Override - public InputPartitionReader createPartitionReader() { - return new BigQueryEmptyProjectionInputPartitionReader(partitionSize); - } + @Override + public InputPartitionReader createPartitionReader() { + return new BigQueryEmptyProjectionInputPartitionReader(partitionSize); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryEmptyProjectionInputPartitionReader.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryEmptyProjectionInputPartitionReader.java index 7f7808c6f5..ce5d5bb9c1 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryEmptyProjectionInputPartitionReader.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryEmptyProjectionInputPartitionReader.java @@ -22,27 +22,27 @@ class BigQueryEmptyProjectionInputPartitionReader implements InputPartitionReader { - final int partitionSize; - int currentIndex; - - BigQueryEmptyProjectionInputPartitionReader(int partitionSize) { - this.partitionSize = partitionSize; - this.currentIndex = 0; - } - - @Override - public boolean next() throws IOException { - return currentIndex < partitionSize; - } - - @Override - public InternalRow get() { - currentIndex++; - return InternalRow.empty(); - } - - @Override - public void close() throws IOException { - // empty - } + final int partitionSize; + int currentIndex; + + BigQueryEmptyProjectionInputPartitionReader(int partitionSize) { + this.partitionSize = partitionSize; + this.currentIndex = 0; + } + + @Override + public boolean next() throws IOException { + return currentIndex < partitionSize; + } + + @Override + public InternalRow get() { + currentIndex++; + return InternalRow.empty(); + } + + @Override + public void close() throws IOException { + // empty + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartition.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartition.java index 9a9f832ef1..79831c601f 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartition.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartition.java @@ -28,27 +28,29 @@ public class BigQueryInputPartition implements InputPartition { - private final BigQueryReadClientFactory bigQueryReadClientFactory; - private final String streamName; - private final int maxReadRowsRetries; - private final ReadRowsResponseToInternalRowIteratorConverter converter; + private final BigQueryReadClientFactory bigQueryReadClientFactory; + private final String streamName; + private final int maxReadRowsRetries; + private final ReadRowsResponseToInternalRowIteratorConverter converter; - public BigQueryInputPartition( - BigQueryReadClientFactory bigQueryReadClientFactory, - String streamName, - int maxReadRowsRetries, - ReadRowsResponseToInternalRowIteratorConverter converter) { - this.bigQueryReadClientFactory = bigQueryReadClientFactory; - this.streamName = streamName; - this.maxReadRowsRetries = maxReadRowsRetries; - this.converter = converter; - } + public BigQueryInputPartition( + BigQueryReadClientFactory bigQueryReadClientFactory, + String streamName, + int maxReadRowsRetries, + ReadRowsResponseToInternalRowIteratorConverter converter) { + this.bigQueryReadClientFactory = bigQueryReadClientFactory; + this.streamName = streamName; + this.maxReadRowsRetries = maxReadRowsRetries; + this.converter = converter; + } - @Override - public InputPartitionReader createPartitionReader() { - ReadRowsRequest.Builder readRowsRequest = ReadRowsRequest.newBuilder().setReadStream(streamName); - ReadRowsHelper readRowsHelper = new ReadRowsHelper(bigQueryReadClientFactory, readRowsRequest, maxReadRowsRetries); - Iterator readRowsResponses = readRowsHelper.readRows(); - return new BigQueryInputPartitionReader(readRowsResponses, converter, readRowsHelper); - } + @Override + public InputPartitionReader createPartitionReader() { + ReadRowsRequest.Builder readRowsRequest = + ReadRowsRequest.newBuilder().setReadStream(streamName); + ReadRowsHelper readRowsHelper = + new ReadRowsHelper(bigQueryReadClientFactory, readRowsRequest, maxReadRowsRetries); + Iterator readRowsResponses = readRowsHelper.readRows(); + return new BigQueryInputPartitionReader(readRowsResponses, converter, readRowsHelper); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReader.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReader.java index c09651f149..d915d317c6 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReader.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryInputPartitionReader.java @@ -27,41 +27,41 @@ class BigQueryInputPartitionReader implements InputPartitionReader { - private Iterator readRowsResponses; - private ReadRowsResponseToInternalRowIteratorConverter converter; - private ReadRowsHelper readRowsHelper; - private Iterator rows = ImmutableList.of().iterator(); - private InternalRow currentRow; + private Iterator readRowsResponses; + private ReadRowsResponseToInternalRowIteratorConverter converter; + private ReadRowsHelper readRowsHelper; + private Iterator rows = ImmutableList.of().iterator(); + private InternalRow currentRow; - BigQueryInputPartitionReader( - Iterator readRowsResponses, - ReadRowsResponseToInternalRowIteratorConverter converter, - ReadRowsHelper readRowsHelper) { - this.readRowsResponses = readRowsResponses; - this.converter = converter; - this.readRowsHelper = readRowsHelper; - } + BigQueryInputPartitionReader( + Iterator readRowsResponses, + ReadRowsResponseToInternalRowIteratorConverter converter, + ReadRowsHelper readRowsHelper) { + this.readRowsResponses = readRowsResponses; + this.converter = converter; + this.readRowsHelper = readRowsHelper; + } - @Override - public boolean next() throws IOException { - while (!rows.hasNext()) { - if (!readRowsResponses.hasNext()) { - return false; - } - ReadRowsResponse readRowsResponse = readRowsResponses.next(); - rows = converter.convert(readRowsResponse); - } - currentRow = rows.next(); - return true; + @Override + public boolean next() throws IOException { + while (!rows.hasNext()) { + if (!readRowsResponses.hasNext()) { + return false; + } + ReadRowsResponse readRowsResponse = readRowsResponses.next(); + rows = converter.convert(readRowsResponse); } + currentRow = rows.next(); + return true; + } - @Override - public InternalRow get() { - return currentRow; - } + @Override + public InternalRow get() { + return currentRow; + } - @Override - public void close() throws IOException { - readRowsHelper.close(); - } + @Override + public void close() throws IOException { + readRowsHelper.close(); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/SparkBigQueryConnectorModule.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/SparkBigQueryConnectorModule.java index e53797a997..758fdafd7b 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/SparkBigQueryConnectorModule.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/SparkBigQueryConnectorModule.java @@ -37,53 +37,53 @@ public class SparkBigQueryConnectorModule implements Module { - private final SparkSession spark; - private final DataSourceOptions options; - private final Optional schema; + private final SparkSession spark; + private final DataSourceOptions options; + private final Optional schema; - public SparkBigQueryConnectorModule( - SparkSession spark, - DataSourceOptions options, - Optional schema) { - this.spark = spark; - this.options = options; - this.schema = schema; - } + public SparkBigQueryConnectorModule( + SparkSession spark, DataSourceOptions options, Optional schema) { + this.spark = spark; + this.options = options; + this.schema = schema; + } - @Override - public void configure(Binder binder) { - binder.bind(BigQueryConfig.class).toProvider(this::provideSparkBigQueryConfig); - } + @Override + public void configure(Binder binder) { + binder.bind(BigQueryConfig.class).toProvider(this::provideSparkBigQueryConfig); + } - @Singleton - @Provides - public SparkBigQueryConfig provideSparkBigQueryConfig() { - return SparkBigQueryConfig.from(options, - ImmutableMap.copyOf(mapAsJavaMap(spark.conf().getAll())), - spark.sparkContext().hadoopConfiguration(), - spark.sparkContext().defaultParallelism()); - } + @Singleton + @Provides + public SparkBigQueryConfig provideSparkBigQueryConfig() { + return SparkBigQueryConfig.from( + options, + ImmutableMap.copyOf(mapAsJavaMap(spark.conf().getAll())), + spark.sparkContext().hadoopConfiguration(), + spark.sparkContext().defaultParallelism()); + } - @Singleton - @Provides - public BigQueryDataSourceReader provideDataSourceReader( - BigQueryClient bigQueryClient, - BigQueryReadClientFactory bigQueryReadClientFactory, - SparkBigQueryConfig config) { - TableInfo tableInfo = bigQueryClient.getSupportedTable(config.getTableId(), config.isViewsEnabled(), - SparkBigQueryConfig.VIEWS_ENABLED_OPTION); - return new BigQueryDataSourceReader(tableInfo, - bigQueryClient, - bigQueryReadClientFactory, - config.toReadSessionCreatorConfig(), - config.getFilter(), - schema); - } - - @Singleton - @Provides - public UserAgentProvider provideUserAgentProvider() { - return new SparkBigQueryConnectorUserAgentProvider("v2"); - } + @Singleton + @Provides + public BigQueryDataSourceReader provideDataSourceReader( + BigQueryClient bigQueryClient, + BigQueryReadClientFactory bigQueryReadClientFactory, + SparkBigQueryConfig config) { + TableInfo tableInfo = + bigQueryClient.getSupportedTable( + config.getTableId(), config.isViewsEnabled(), SparkBigQueryConfig.VIEWS_ENABLED_OPTION); + return new BigQueryDataSourceReader( + tableInfo, + bigQueryClient, + bigQueryReadClientFactory, + config.toReadSessionCreatorConfig(), + config.getFilter(), + schema); + } + @Singleton + @Provides + public UserAgentProvider provideUserAgentProvider() { + return new SparkBigQueryConnectorUserAgentProvider("v2"); + } } diff --git a/project/plugins.sbt b/project/plugins.sbt index f58cd60a49..fcc0c0d9ff 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -8,4 +8,6 @@ addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.7.0") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "2.3") -addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.1") \ No newline at end of file +addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.1") + +addSbtPlugin("com.lightbend.sbt" % "sbt-java-formatter" % "0.4.4") From 40fedf882b70339848dc88c1c4bc5d92dce6e872 Mon Sep 17 00:00:00 2001 From: Yuval Medina Date: Fri, 19 Jun 2020 21:58:00 +0000 Subject: [PATCH 4/9] Created Spark-BigQuery schema converter and created BigQuery schema - ProtoSchema converter. Now awaiting comprehensive tests before merging with master. Fixing SparkBigQueryConnectorUserAgentProvider initialization bug (#186) prepare release 0.16.1 prepare for next development iteration Sectioned the schema converter file for easier readability. Added a Table creation method. Wrote comprehensive tests to check YuvalSchemaConverters. Now needs to improve equality testing: assertEquals does not check for more than superficial equality, so if further testing is to be done without the help of logs, it would be useful to write an equality function for schemas. Spark->BQ Schema working correctly. Blocked out Map functionality, as it is not supported. Made SchemaConverters, Schema-unit-tests more readable. Improved use of BigQuery library functions/iteration in SchemaConverters Renamed SchemaConverters file, about to merge into David's SchemaConverters. Improved unit tests to check the toBigQueryColumn method, instead of the more abstract toBigQuerySchema (in order to check each data type is working correctly. Tackling toProtoRows converter. BigQuery->ProtoSchema converter is passing all unit tests. Merged my (YuvalMedina) schema converters with David's (davidrabinowitz) SchemaConverters under spark.bigquery. Renamed my schema converters to SchemaConvertersDevelopment, in which I will continue working on a ProtoRows converter. SchemaConvertersDevelopment is passing all tests on Spark -> Protobuf Descriptor conversion, even on nested structs. Unit tests need to be written to tests actual row conversion (Spark values -> Protobuf values). Minor fixes to SchemaConverters.java: code needs to be smoothed out. ProtoRows converter is passing 10 unit tests, sparkRowToProtoRow test must be revised to confirm that ProtoRows conversion is fully working. All functions doing Spark InternalRow -> ProtoRow and BigQuery Schema -> ProtoSchema conversions were migrated from SchemaConverters.java to ProtoBufUtils.java. SchemaConverters.java now contains both Spark -> BigQuery as well as the original BigQuery -> Spark conversions. ProtoBufUtilsTests.java was created to test for functions in ProtoBufUtils separately. All conversion suites for Spark -> BigQuery, BigQuery -> ProtoSchema, and Spark rows -> ProtoRows are working correctly, and comprehensive tests were written. SchemaConvertersSuite.scala, which tests for BigQuery -> Spark conversions was translated into .java, and merged with SchemaConvertersTests.java. Cleaned up the SchemaConverter tests that were translated from Scala. Added a nesting-depth limit to Records created by the Spark->BigQuery converter. Deleted unnecessary comments Deleted a leftover TODO comment in SchemaConvertersTests Deleted some unnecessary tests. Last commit before write-support implementation Made minor edits according to davidrab@'s comments. Added license heading to all files that were created. Need to test if binary types are converted correctly to protobuf format. --- .../cloud/spark/bigquery/ProtobufUtils.java | 389 +++++++++++++++++ .../spark/bigquery/SchemaConverters.java | 114 ++++- .../spark/bigquery/ProtobufUtilsTest.java | 403 ++++++++++++++++++ .../spark/bigquery/SchemaConverterTest.java | 279 ++++++++++++ 4 files changed, 1178 insertions(+), 7 deletions(-) create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java create mode 100644 connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java create mode 100644 connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java new file mode 100644 index 0000000000..669bab27da --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java @@ -0,0 +1,389 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery; + +import avro.shaded.com.google.common.base.Preconditions; +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.FieldList; +import com.google.cloud.bigquery.LegacySQLTypeName; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.storage.v1alpha2.ProtoBufProto; +import com.google.cloud.bigquery.storage.v1alpha2.ProtoSchemaConverter; +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.*; + +import java.util.Arrays; +import java.util.stream.Collectors; + +public class ProtobufUtils { + + // The maximum nesting depth of a BigQuery RECORD: + private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; + // For every message, a nested type is name "STRUCT"+i, where i is the + // number of the corresponding field that is of this type in the containing message. + private static final String RESERVED_NESTED_TYPE_NAME = "STRUCT"; + private static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported."; + + /** + * BigQuery Schema ==> ProtoSchema converter utils: + */ + public static ProtoBufProto.ProtoSchema toProtoSchema (Schema schema) throws Exception { + try{ + Descriptors.Descriptor descriptor = toDescriptor(schema); + ProtoBufProto.ProtoSchema protoSchema = ProtoSchemaConverter.convert(descriptor); + return protoSchema; + } catch (Descriptors.DescriptorValidationException e){ + throw new Exception("Could not build Proto-Schema from Spark schema.", e); // TODO: right exception to throw? + } + } + + private static Descriptors.Descriptor toDescriptor (Schema schema) throws Descriptors.DescriptorValidationException { + DescriptorProtos.DescriptorProto.Builder descriptorBuilder = DescriptorProtos.DescriptorProto.newBuilder() + .setName("Schema"); + + FieldList fields = schema.getFields(); + + DescriptorProtos.DescriptorProto descriptorProto = buildDescriptorProtoWithFields(descriptorBuilder, fields, 0); + + return createDescriptorFromProto(descriptorProto); + } + + private static Descriptors.Descriptor createDescriptorFromProto(DescriptorProtos.DescriptorProto descriptorProto) + throws Descriptors.DescriptorValidationException { + DescriptorProtos.FileDescriptorProto fileDescriptorProto = DescriptorProtos.FileDescriptorProto + .newBuilder() + .addMessageType(descriptorProto) + .build(); + + Descriptors.Descriptor descriptor = Descriptors.FileDescriptor + .buildFrom(fileDescriptorProto, new Descriptors.FileDescriptor[]{}) + .getMessageTypes() + .get(0); + + return descriptor; + } + + @VisibleForTesting + protected static DescriptorProtos.DescriptorProto buildDescriptorProtoWithFields( + DescriptorProtos.DescriptorProto.Builder descriptorBuilder, FieldList fields, int depth){ + Preconditions.checkArgument(depth < MAX_BIGQUERY_NESTED_DEPTH, + "Tried to convert a BigQuery schema that exceeded BigQuery maximum nesting depth"); + int messageNumber = 1; + for (Field field : fields) { + String fieldName = field.getName(); + DescriptorProtos.FieldDescriptorProto.Label fieldLabel = toProtoFieldLabel(field.getMode()); + FieldList subFields = field.getSubFields(); + + if (field.getType() == LegacySQLTypeName.RECORD){ + String recordTypeName = "RECORD"+messageNumber; // TODO: Change or assert this to be a reserved name. No column can have this name. + DescriptorProtos.DescriptorProto.Builder nestedFieldTypeBuilder = descriptorBuilder.addNestedTypeBuilder(); + nestedFieldTypeBuilder.setName(recordTypeName); + DescriptorProtos.DescriptorProto nestedFieldType = buildDescriptorProtoWithFields( + nestedFieldTypeBuilder, subFields, depth+1); + + descriptorBuilder.addField(createProtoFieldBuilder(fieldName, fieldLabel, messageNumber) + .setTypeName(recordTypeName)); + } + else { + DescriptorProtos.FieldDescriptorProto.Type fieldType = toProtoFieldType(field.getType()); + descriptorBuilder.addField(createProtoFieldBuilder(fieldName, fieldLabel, messageNumber, fieldType)); + } + messageNumber++; + } + return descriptorBuilder.build(); + } + + private static DescriptorProtos.FieldDescriptorProto.Builder createProtoFieldBuilder( + String fieldName, DescriptorProtos.FieldDescriptorProto.Label fieldLabel, int messageNumber) { + return DescriptorProtos.FieldDescriptorProto + .newBuilder() + .setName(fieldName) + .setLabel(fieldLabel) + .setNumber(messageNumber); + } + + @VisibleForTesting + protected static DescriptorProtos.FieldDescriptorProto.Builder createProtoFieldBuilder( + String fieldName, DescriptorProtos.FieldDescriptorProto.Label fieldLabel, int messageNumber, + DescriptorProtos.FieldDescriptorProto.Type fieldType) { + return DescriptorProtos.FieldDescriptorProto + .newBuilder() + .setName(fieldName) + .setLabel(fieldLabel) + .setNumber(messageNumber) + .setType(fieldType); + } + + private static DescriptorProtos.FieldDescriptorProto.Label toProtoFieldLabel(Field.Mode mode) { + switch (mode) { + case NULLABLE: + return DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL; + case REPEATED: + return DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED; + case REQUIRED: + return DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; + default: + throw new IllegalArgumentException("A BigQuery Field Mode was invalid: "+mode.name()); + } + } + + // NOTE: annotations for DATETIME and TIMESTAMP objects are currently unsupported for external users, + // but if they become available, it would be advisable to append an annotation to the protoFieldBuilder + // for these and other types. + private static DescriptorProtos.FieldDescriptorProto.Type toProtoFieldType(LegacySQLTypeName bqType) { + DescriptorProtos.FieldDescriptorProto.Type protoFieldType; + if (LegacySQLTypeName.INTEGER.equals(bqType) || + LegacySQLTypeName.DATE.equals(bqType) || + LegacySQLTypeName.DATETIME.equals(bqType) || + LegacySQLTypeName.TIMESTAMP.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64; + } + if (LegacySQLTypeName.BOOLEAN.equals(bqType)){ + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL; + } + if (LegacySQLTypeName.STRING.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; + } + if (LegacySQLTypeName.GEOGRAPHY.equals(bqType) || + LegacySQLTypeName.BYTES.equals(bqType) || + LegacySQLTypeName.NUMERIC.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES; + } + if (LegacySQLTypeName.FLOAT.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE; + } + else { + if (LegacySQLTypeName.RECORD.equals(bqType)) { + throw new IllegalStateException("Program attempted to return an atomic data-type for a RECORD"); + } + throw new IllegalArgumentException("Unexpected type: " + bqType.name()); + } + } + + + /** + * Spark Row --> ProtoRows converter utils: + * To be used by the DataWriters facing the BigQuery Storage Write API + */ + public static ProtoBufProto.ProtoRows toProtoRows(StructType sparkSchema, InternalRow[] rows) { + try { + Descriptors.Descriptor schemaDescriptor = toDescriptor(sparkSchema); + ProtoBufProto.ProtoRows.Builder protoRows = ProtoBufProto.ProtoRows.newBuilder(); + for (InternalRow row : rows) { + DynamicMessage rowMessage = createSingleRowMessage(sparkSchema, + schemaDescriptor, row); + protoRows.addSerializedRows(rowMessage.toByteString()); + } + return protoRows.build(); + } catch (Exception e) { + throw new RuntimeException("Could not convert Internal Rows to Proto Rows.", e); + } + } + + public static DynamicMessage createSingleRowMessage(StructType schema, + Descriptors.Descriptor schemaDescriptor, + InternalRow row) { + + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(schemaDescriptor); + + for(int i = 1; i <= schemaDescriptor.getFields().size(); i++) { + StructField sparkField = schema.fields()[i-1]; + DataType sparkType = sparkField.dataType(); + if (sparkType instanceof StructType) { + messageBuilder.setField(schemaDescriptor.findFieldByNumber(i), + createSingleRowMessage((StructType)sparkType, + schemaDescriptor.findNestedTypeByName(RESERVED_NESTED_TYPE_NAME +i), + (InternalRow)row.get(i-1, sparkType))); + } + else { + messageBuilder.setField(schemaDescriptor.findFieldByNumber(i), + convert(sparkField, + row.get(i-1, sparkType))); + } + } + + return messageBuilder.build(); + } + + public static Descriptors.Descriptor toDescriptor (StructType schema) + throws Descriptors.DescriptorValidationException { + DescriptorProtos.DescriptorProto.Builder descriptorBuilder = DescriptorProtos.DescriptorProto.newBuilder() + .setName("Schema"); + + StructField[] fields = schema.fields(); + + DescriptorProtos.DescriptorProto descriptorProto = buildDescriptorProtoWithFields(descriptorBuilder, fields, 0); + + return createDescriptorFromProto(descriptorProto); + } + + @VisibleForTesting + protected static Object convert (StructField sparkField, Object sparkValue) { + if (sparkValue == null) { + if (!sparkField.nullable()) { + throw new IllegalArgumentException("Non-nullable field was null."); + } + else { + return null; + } + } + + DataType fieldType = sparkField.dataType(); + + if (fieldType instanceof ArrayType) { + ArrayType arrayType = (ArrayType)fieldType; + boolean containsNull = arrayType.containsNull(); // elements can be null. + DataType elementType = arrayType.elementType(); + + ArrayData arrayData = (ArrayData)sparkValue; + Object[] sparkValues = arrayData.toObjectArray(elementType); + + return Arrays.stream(sparkValues).map(value -> { + Preconditions.checkArgument(containsNull || value != null, + "Encountered a null value inside a non-null-containing array."); + return toAtomicProtoRowValue(elementType, value); + } ).collect(Collectors.toList()); + } else if (fieldType instanceof StructType) { + throw new IllegalStateException("Method did not expect to convert a StructType instance."); + } else if (fieldType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); + } + else { + return toAtomicProtoRowValue(fieldType, sparkValue); + } + } + + /* + Takes a value in Spark format and converts it into ProtoRows format (to eventually be given to BigQuery). + */ + private static Object toAtomicProtoRowValue(DataType sparkType, Object value) { + if (sparkType instanceof ByteType || + sparkType instanceof ShortType || + sparkType instanceof IntegerType || + sparkType instanceof LongType || + sparkType instanceof TimestampType || + sparkType instanceof DateType) { + return ((Number)value).longValue(); + } + + if (sparkType instanceof FloatType || + sparkType instanceof DoubleType || + sparkType instanceof DecimalType) { + return ((Number)value).doubleValue(); // TODO: should decimal be converted to double? Or a Bytes type containing extra width? + } + + if (sparkType instanceof BooleanType) { + return ((Boolean)value).booleanValue(); // TODO: can be unboxed? + } + + if (sparkType instanceof StringType || + sparkType instanceof BinaryType) { + return value; // TODO: verify correct method for extracting this value. + } + + throw new IllegalStateException("Unexpected type: " + sparkType); + } + + private static DescriptorProtos.DescriptorProto buildDescriptorProtoWithFields( + DescriptorProtos.DescriptorProto.Builder descriptorBuilder, StructField[] fields, int depth) { + Preconditions.checkArgument(depth < MAX_BIGQUERY_NESTED_DEPTH, + "Spark Schema exceeds BigQuery maximum nesting depth."); + int messageNumber = 1; + for (StructField field : fields) { + String fieldName = field.name(); + DescriptorProtos.FieldDescriptorProto.Label fieldLabel = field.nullable() ? + DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL : + DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; + DescriptorProtos.FieldDescriptorProto.Type fieldType; + + DataType sparkType = field.dataType(); + if (sparkType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); + } + if (sparkType instanceof StructType) { + StructType structType = (StructType)sparkType; + String nestedName = RESERVED_NESTED_TYPE_NAME +messageNumber; // TODO: this should be a reserved name. No column can have this name. + StructField[] subFields = structType.fields(); + + DescriptorProtos.DescriptorProto.Builder nestedFieldTypeBuilder = descriptorBuilder.addNestedTypeBuilder() + .setName(nestedName); + buildDescriptorProtoWithFields(nestedFieldTypeBuilder, subFields, depth+1); + + descriptorBuilder.addField(createProtoFieldBuilder(fieldName, fieldLabel, messageNumber) + .setTypeName(nestedName)); + messageNumber++; + continue; + } + + if (sparkType instanceof ArrayType) { + ArrayType arrayType = (ArrayType)sparkType; + /* DescriptorProtos.FieldDescriptorProto.Label elementLabel = arrayType.containsNull() ? + DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL : + DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; TODO: how to support null instances inside an array (repeated field) in BigQuery?*/ + fieldType = sparkAtomicTypeToProtoFieldType(arrayType.elementType()); + fieldLabel = DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED; + + } else { + fieldType = sparkAtomicTypeToProtoFieldType(sparkType); + } + descriptorBuilder.addField( + createProtoFieldBuilder(fieldName, fieldLabel, messageNumber, fieldType)); + messageNumber++; + } + return descriptorBuilder.build(); + } + + // NOTE: annotations for DATETIME and TIMESTAMP objects are currently unsupported for external users, + // but if they become available, it would be advisable to append an annotation to the protoFieldBuilder + // for these and other types. + // This function only converts atomic Spark DataTypes (MapType, ArrayType, and StructType will throw an error). + private static DescriptorProtos.FieldDescriptorProto.Type sparkAtomicTypeToProtoFieldType(DataType sparkType) { + if (sparkType instanceof ByteType || + sparkType instanceof ShortType || + sparkType instanceof IntegerType || + sparkType instanceof LongType || + sparkType instanceof TimestampType || + sparkType instanceof DateType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64; + } + + if (sparkType instanceof FloatType || + sparkType instanceof DoubleType || + sparkType instanceof DecimalType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE; + // TODO: an annotation to distinguish between decimals that are doubles, and decimals that are NUMERIC (Bytes types) + } + + if (sparkType instanceof BooleanType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL; + } + + if (sparkType instanceof BinaryType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES; + } + + if (sparkType instanceof StringType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; + } + + throw new IllegalStateException("Unexpected type: " + sparkType); + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java index 4aa646aa5d..b0598e3845 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java @@ -15,10 +15,9 @@ */ package com.google.cloud.spark.bigquery; -import com.google.cloud.bigquery.Field; -import com.google.cloud.bigquery.FieldList; -import com.google.cloud.bigquery.LegacySQLTypeName; -import com.google.cloud.bigquery.Schema; +import avro.shaded.com.google.common.base.Preconditions; +import com.google.cloud.bigquery.*; +import com.google.common.annotations.VisibleForTesting; import org.apache.avro.generic.GenericRecord; import org.apache.avro.util.Utf8; import org.apache.spark.sql.catalyst.InternalRow; @@ -30,9 +29,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; public class SchemaConverters { @@ -42,6 +39,9 @@ public class SchemaConverters { private static final int BQ_NUMERIC_SCALE = 9; private static final DecimalType NUMERIC_SPARK_TYPE = DataTypes.createDecimalType(BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); + // The maximum nesting depth of a BigQuery RECORD: + private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; + private static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported."; /** Convert a BigQuery schema to a Spark schema */ public static StructType toSpark(Schema schema) { @@ -204,4 +204,104 @@ private static DataType getDataType(Field field) { throw new IllegalStateException("Unexpected type: " + field.getType()); } } + + + /** + * Spark ==> BigQuery Schema Converter utils: + */ + public static Schema toBigQuerySchema (StructType sparkSchema) { + FieldList bigQueryFields = sparkToBigQueryFields(sparkSchema, 0); + return Schema.of(bigQueryFields); + } + + /** + * Returns a FieldList of all the Spark StructField objects, converted to BigQuery Field objects + */ + private static FieldList sparkToBigQueryFields (StructType sparkStruct, int depth){ + Preconditions.checkArgument(depth < MAX_BIGQUERY_NESTED_DEPTH, + "Spark Schema exceeds BigQuery maximum nesting depth."); + List bqFields = new ArrayList<>(); + for (StructField field : sparkStruct.fields()){ + bqFields.add(makeBigQueryColumn(field, depth)); + } + return FieldList.of(bqFields); + } + + /** + * Converts a single StructField to a BigQuery Field (column). + */ + @VisibleForTesting + protected static Field makeBigQueryColumn (StructField sparkField, int depth) { + DataType sparkType = sparkField.dataType(); + String fieldName = sparkField.name(); + Field.Mode fieldMode = (sparkField.nullable()) ? Field.Mode.NULLABLE : Field.Mode.REQUIRED; + String description; + FieldList subFields = null; + LegacySQLTypeName fieldType; + + if (sparkType instanceof ArrayType) { + ArrayType arrayType = (ArrayType)sparkType; + LegacySQLTypeName elementType = toBigQueryType(arrayType.elementType()); + fieldType = elementType; + fieldMode = Field.Mode.REPEATED; + } + else if (sparkType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); + } + else if (sparkType instanceof StructType) { + subFields = sparkToBigQueryFields((StructType)sparkType, depth+1); + fieldType = LegacySQLTypeName.RECORD; + } + else { + fieldType = toBigQueryType(sparkType); + } + + try { + description = sparkField.metadata().getString("description"); + } + catch (NoSuchElementException e) { + return createBigQueryFieldBuilder(fieldName, fieldType, fieldMode, subFields).build(); + } + + return createBigQueryFieldBuilder(fieldName, fieldType, fieldMode, subFields) + .setDescription(description).build(); + } + + @VisibleForTesting + protected static LegacySQLTypeName toBigQueryType (DataType elementType) { + if (elementType instanceof BinaryType) { + return LegacySQLTypeName.BYTES; + } if (elementType instanceof ByteType || + elementType instanceof ShortType || + elementType instanceof IntegerType || + elementType instanceof LongType) { + return LegacySQLTypeName.INTEGER; + } if (elementType instanceof BooleanType) { + return LegacySQLTypeName.BOOLEAN; + } if (elementType instanceof FloatType || + elementType instanceof DoubleType) { + return LegacySQLTypeName.FLOAT; + } if (elementType instanceof DecimalType) { + DecimalType decimalType = (DecimalType)elementType; + if (decimalType.precision() <= BQ_NUMERIC_PRECISION && + decimalType.scale() <= BQ_NUMERIC_SCALE) { + return LegacySQLTypeName.NUMERIC; + } else { + throw new IllegalArgumentException("Decimal type is too wide to fit in BigQuery Numeric format"); // TODO + } + } if (elementType instanceof StringType) { + return LegacySQLTypeName.STRING; + } if (elementType instanceof TimestampType) { + return LegacySQLTypeName.TIMESTAMP; + } if (elementType instanceof DateType) { // TODO: TIME & DATETIME in BigQuery + return LegacySQLTypeName.DATE; + } else { + throw new IllegalArgumentException("Data type not expected in toBQType: "+elementType.simpleString()); + } + } + + private static Field.Builder createBigQueryFieldBuilder (String name, LegacySQLTypeName type, Field.Mode mode, FieldList subFields){ + return Field.newBuilder(name, type, subFields) + .setMode(mode); + } } diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java new file mode 100644 index 0000000000..b9e81b2dca --- /dev/null +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java @@ -0,0 +1,403 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery; + +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.FieldList; +import com.google.cloud.bigquery.LegacySQLTypeName; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.storage.v1alpha2.ProtoBufProto; +import com.google.cloud.bigquery.storage.v1alpha2.ProtoSchemaConverter; +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import org.apache.log4j.Level; +import org.apache.log4j.LogManager; +import org.apache.log4j.Logger; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.AssumptionViolatedException; +import org.junit.Test; + +import static com.google.cloud.spark.bigquery.ProtobufUtils.*; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +public class ProtobufUtilsTest { + + private final Logger logger = LogManager.getLogger("com.google.cloud.spark"); + + @Test + public void testBigQueryRecordToDescriptor() throws Exception { + logger.setLevel(Level.DEBUG); + + DescriptorProtos.DescriptorProto expected = NESTED_STRUCT_DESCRIPTOR.setName("Struct").build(); + DescriptorProtos.DescriptorProto converted = buildDescriptorProtoWithFields(DescriptorProtos.DescriptorProto.newBuilder() + .setName("Struct"), BIGQUERY_NESTED_STRUCT_FIELD.getSubFields(), 0); + + assertThat(converted).isEqualTo(expected); + } + + @Test + public void testBigQueryToProtoSchema() throws Exception { + logger.setLevel(Level.DEBUG); + + ProtoBufProto.ProtoSchema converted = toProtoSchema(BIG_BIGQUERY_SCHEMA); + ProtoBufProto.ProtoSchema expected = ProtoSchemaConverter.convert( + Descriptors.FileDescriptor.buildFrom( + DescriptorProtos.FileDescriptorProto.newBuilder() + .addMessageType( + DescriptorProtos.DescriptorProto.newBuilder() + .addField(PROTO_INTEGER_FIELD.clone().setNumber(1)) + .addField(PROTO_STRING_FIELD.clone().setNumber(2)) + .addField(PROTO_ARRAY_FIELD.clone().setNumber(3)) + .addNestedType(NESTED_STRUCT_DESCRIPTOR.clone()) + .addField(PROTO_STRUCT_FIELD.clone().setNumber(4)) + .addField(PROTO_BYTES_FIELD.clone().setName("Geography").setNumber(5)) + .addField(PROTO_DOUBLE_FIELD.clone().setName("Float").setNumber(6)) + .addField(PROTO_BOOLEAN_FIELD.clone().setNumber(7)) + .setName("Schema").build() + ).build(), new Descriptors.FileDescriptor[]{} + ).getMessageTypes().get(0) + ); + + logger.debug("Expected schema: "+expected.getProtoDescriptor()); + logger.debug("Actual schema: "+converted.getProtoDescriptor()); + + for(int i = 0; i < 7; i++){ + assertThat(converted.getProtoDescriptor().getField(i)).isEqualTo(expected.getProtoDescriptor().getField(i)); + } + } + + @Test + public void testSparkIntegerSchemaToDescriptor() throws Exception { + logger.setLevel(Level.DEBUG); + + StructType schema = new StructType().add(SPARK_INTEGER_FIELD); + DescriptorProtos.DescriptorProto converted = toDescriptor(schema).toProto(); + + DescriptorProtos.DescriptorProto expected = DESCRIPTOR_PROTO_INTEGER; + + assertThat(converted).isEqualTo(expected); + } + + @Test + public void testSparkStringSchemaToDescriptor() throws Exception { + logger.setLevel(Level.DEBUG); + + StructType schema = new StructType().add(SPARK_STRING_FIELD); + DescriptorProtos.DescriptorProto converted = toDescriptor(schema).toProto(); + + DescriptorProtos.DescriptorProto expected = DESCRIPTOR_PROTO_STRING; + + assertThat(converted).isEqualTo(expected); + } + + @Test + public void testSparkArraySchemaToDescriptor() throws Exception { + logger.setLevel(Level.DEBUG); + + StructType schema = new StructType().add(SPARK_ARRAY_FIELD); + DescriptorProtos.DescriptorProto converted = toDescriptor(schema).toProto(); + + DescriptorProtos.DescriptorProto expected = DESCRIPTOR_PROTO_ARRAY; + + assertThat(converted).isEqualTo(expected); + } + + @Test + public void testSparkNestedStructSchemaToDescriptor() throws Exception { + logger.setLevel(Level.DEBUG); + + StructType schema = new StructType().add(SPARK_NESTED_STRUCT_FIELD); + DescriptorProtos.DescriptorProto converted = toDescriptor(schema).toProto(); + + DescriptorProtos.DescriptorProto expected = DESCRIPTOR_PROTO_STRUCT; + + assertThat(converted).isEqualTo(expected); + } + + @Test + public void testSparkArrayRowToDynamicMessage() throws Exception { + logger.setLevel(Level.DEBUG); + + StructType schema = new StructType().add(SPARK_ARRAY_FIELD); + DynamicMessage converted = createSingleRowMessage(schema, toDescriptor(schema), + ARRAY_INTERNAL_ROW); + DynamicMessage expected = ARRAY_ROW_MESSAGE; + + assertThat(converted.toString()).isEqualTo(expected.toString()); + } + + @Test + public void testSparkStructRowToDynamicMessage() throws Exception { + logger.setLevel(Level.DEBUG); + + StructType schema = new StructType().add(SPARK_NESTED_STRUCT_FIELD); + DynamicMessage converted = createSingleRowMessage(schema, toDescriptor(schema), + STRUCT_INTERNAL_ROW); + DynamicMessage expected = StructRowMessage; + + assertThat(converted.toString()).isEqualTo(expected.toString()); + } + + @Test + public void testSparkRowToProtoRow() throws Exception { + logger.setLevel(Level.DEBUG); + + ProtoBufProto.ProtoRows converted = toProtoRows(BIG_SPARK_SCHEMA, + new InternalRow[]{ + new GenericInternalRow(new Object[]{ + 1, + "A", + ArrayData.toArrayData(new int[]{0,1,2}), + INTERNAL_STRUCT_DATA, + 3.14, + true + })} + ); + + ProtoBufProto.ProtoRows expected = MyProtoRows; + + assertThat(converted.getSerializedRows(0).toByteArray()).isEqualTo(expected.getSerializedRows(0).toByteArray()); + } + + @Test + public void testSettingARequiredFieldAsNull() throws Exception { + logger.setLevel(Level.DEBUG); + + try { + convert(SPARK_STRING_FIELD, null); + fail("Convert did not assert field's /'Required/' status"); + } catch (IllegalArgumentException e){} + try { + convert(new StructField("String", DataTypes.StringType, true, Metadata.empty()), + null); + } catch (Exception e) { + fail("A nullable field could not be set to null."); + } + } + + + + private final StructType MY_STRUCT = DataTypes.createStructType( + new StructField[]{new StructField("Number", DataTypes.IntegerType, + true, Metadata.empty()), + new StructField("String", DataTypes.StringType, + true, Metadata.empty())}); + + private final StructField SPARK_INTEGER_FIELD = new StructField("Number", DataTypes.IntegerType, + true, Metadata.empty()); + private final StructField SPARK_STRING_FIELD = new StructField("String", DataTypes.StringType, + false, Metadata.empty()); + private final StructField SPARK_NESTED_STRUCT_FIELD = new StructField("Struct", MY_STRUCT, + true, Metadata.empty()); + private final StructField SPARK_ARRAY_FIELD = new StructField("Array", + DataTypes.createArrayType(DataTypes.IntegerType), + true, Metadata.empty()); + private final StructField SPARK_DOUBLE_FIELD = new StructField("Double", DataTypes.DoubleType, + true, Metadata.empty()); + private final StructField SPARK_BOOLEAN_FIELD = new StructField("Boolean", DataTypes.BooleanType, + true, Metadata.empty()); + + private final StructType BIG_SPARK_SCHEMA = new StructType() + .add(SPARK_INTEGER_FIELD) + .add(SPARK_STRING_FIELD) + .add(SPARK_ARRAY_FIELD) + .add(SPARK_NESTED_STRUCT_FIELD) + .add(SPARK_DOUBLE_FIELD) + .add(SPARK_BOOLEAN_FIELD); + + + private final Field BIGQUERY_INTEGER_FIELD = Field.newBuilder("Number", LegacySQLTypeName.INTEGER, + (FieldList)null).setMode(Field.Mode.NULLABLE).build(); + private final Field BIGQUERY_STRING_FIELD = Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) + .setMode(Field.Mode.REQUIRED).build(); + private final Field BIGQUERY_NESTED_STRUCT_FIELD = Field.newBuilder("Struct", LegacySQLTypeName.RECORD, + Field.newBuilder("Number", LegacySQLTypeName.INTEGER, (FieldList) null) + .setMode(Field.Mode.NULLABLE).build(), + Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) + .setMode(Field.Mode.NULLABLE).build()) + .setMode(Field.Mode.NULLABLE).build(); + private final Field BIGQUERY_ARRAY_FIELD = Field.newBuilder("Array", LegacySQLTypeName.INTEGER, (FieldList) null) + .setMode(Field.Mode.REPEATED).build(); + private final Field BIGQUERY_GEOGRAPHY_FIELD = Field.newBuilder("Geography", LegacySQLTypeName.GEOGRAPHY, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + private final Field BIGQUERY_FLOAT_FIELD = Field.newBuilder("Float", LegacySQLTypeName.FLOAT, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + private final Field BIGQUERY_BOOLEAN_FIELD = Field.newBuilder("Boolean", LegacySQLTypeName.BOOLEAN, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + + private final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, BIGQUERY_NESTED_STRUCT_FIELD, + BIGQUERY_GEOGRAPHY_FIELD, BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD); + + + private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_INTEGER_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("Number") + .setNumber(1) + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64) + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); + private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_STRING_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("String") + .setNumber(1) + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING) + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED); + private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_ARRAY_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("Array") + .setNumber(1) + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64) + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED); + private final DescriptorProtos.DescriptorProto.Builder NESTED_STRUCT_DESCRIPTOR = DescriptorProtos.DescriptorProto.newBuilder() + .setName("STRUCT1") + .addField(PROTO_INTEGER_FIELD.clone()) + .addField(PROTO_STRING_FIELD.clone().setNumber(2) + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL)); + private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_STRUCT_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("Struct") + .setNumber(1) + .setTypeName("STRUCT1") + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); + private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_BYTES_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("Bytes") + .setNumber(1) + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES) + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); + private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_DOUBLE_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("Double") + .setNumber(1) + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE) + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); + private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_BOOLEAN_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("Boolean") + .setNumber(1) + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL) + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); + + + private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_INTEGER = DescriptorProtos.DescriptorProto.newBuilder() + .addField(PROTO_INTEGER_FIELD).setName("Schema").build(); + private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_STRING = DescriptorProtos.DescriptorProto.newBuilder() + .addField(PROTO_STRING_FIELD).setName("Schema").build(); + private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_ARRAY = DescriptorProtos.DescriptorProto.newBuilder() + .addField(PROTO_ARRAY_FIELD).setName("Schema").build(); + private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_STRUCT = DescriptorProtos.DescriptorProto.newBuilder() + .addNestedType(NESTED_STRUCT_DESCRIPTOR).addField(PROTO_STRUCT_FIELD).setName("Schema").build(); + + private final InternalRow INTEGER_INTERNAL_ROW = new GenericInternalRow(new Object[]{1}); + private final InternalRow STRING_INTERNAL_ROW = new GenericInternalRow(new Object[]{"A"}); + private final InternalRow ARRAY_INTERNAL_ROW = new GenericInternalRow(new Object[]{ArrayData.toArrayData( + new int[]{0,1,2})}); + private final InternalRow INTERNAL_STRUCT_DATA = new GenericInternalRow(new Object[]{1, "A"}); + private final InternalRow STRUCT_INTERNAL_ROW = new GenericInternalRow(new Object[]{INTERNAL_STRUCT_DATA}); + + + private Descriptors.Descriptor INTEGER_SCHEMA_DESCRIPTOR = createIntegerSchemaDescriptor(); + private Descriptors.Descriptor createIntegerSchemaDescriptor() { + try { + return toDescriptor( + new StructType().add(SPARK_INTEGER_FIELD) + ); + } catch (Descriptors.DescriptorValidationException e) { + throw new AssumptionViolatedException("Could not create INTEGER_SCHEMA_DESCRIPTOR", e); + } + } + private Descriptors.Descriptor STRING_SCHEMA_DESCRIPTOR = createStringSchemaDescriptor(); + private Descriptors.Descriptor createStringSchemaDescriptor() { + try { + return toDescriptor( + new StructType().add(SPARK_STRING_FIELD) + ); + } catch (Descriptors.DescriptorValidationException e) { + throw new AssumptionViolatedException("Could not create STRING_SCHEMA_DESCRIPTOR", e); + } + } + private Descriptors.Descriptor ARRAY_SCHEMA_DESCRIPTOR = createArraySchemaDescriptor(); + private Descriptors.Descriptor createArraySchemaDescriptor() { + try { + return toDescriptor( + new StructType().add(SPARK_ARRAY_FIELD) + ); + } catch (Descriptors.DescriptorValidationException e) { + throw new AssumptionViolatedException("Could not create ARRAY_SCHEMA_DESCRIPTOR", e); + } + } + private Descriptors.Descriptor STRUCT_SCHEMA_DESCRIPTOR = createStructSchemaDescriptor(); + private Descriptors.Descriptor createStructSchemaDescriptor() { + try { + return toDescriptor( + new StructType().add(SPARK_NESTED_STRUCT_FIELD) + ); + } catch (Descriptors.DescriptorValidationException e) { + throw new AssumptionViolatedException("Could not create STRUCT_SCHEMA_DESCRIPTOR", e); + } + } + + + private final DynamicMessage INTEGER_ROW_MESSAGE = DynamicMessage.newBuilder(INTEGER_SCHEMA_DESCRIPTOR) + .setField(INTEGER_SCHEMA_DESCRIPTOR.findFieldByNumber(1), 1L).build(); + private final DynamicMessage STRING_ROW_MESSAGE = DynamicMessage.newBuilder(STRING_SCHEMA_DESCRIPTOR) + .setField(STRING_SCHEMA_DESCRIPTOR.findFieldByNumber(1), "A").build(); + private final DynamicMessage ARRAY_ROW_MESSAGE = DynamicMessage.newBuilder(ARRAY_SCHEMA_DESCRIPTOR) + .addRepeatedField(ARRAY_SCHEMA_DESCRIPTOR.findFieldByNumber(1), 0L) + .addRepeatedField(ARRAY_SCHEMA_DESCRIPTOR.findFieldByNumber(1), 1L) + .addRepeatedField(ARRAY_SCHEMA_DESCRIPTOR.findFieldByNumber(1), 2L).build(); + private DynamicMessage StructRowMessage = createStructRowMessage(); + private DynamicMessage createStructRowMessage() { + try{ + return DynamicMessage.newBuilder(STRUCT_SCHEMA_DESCRIPTOR) + .setField(STRUCT_SCHEMA_DESCRIPTOR.findFieldByNumber(1), createSingleRowMessage( + MY_STRUCT, toDescriptor(MY_STRUCT), INTERNAL_STRUCT_DATA + )).build(); + } catch (Descriptors.DescriptorValidationException e) { + throw new AssumptionViolatedException("Could not create STRUCT_ROW_MESSAGE", e); + } + } + + + private Descriptors.Descriptor BIG_SCHEMA_ROW_DESCRIPTOR = createBigSchemaRowDescriptor(); + private Descriptors.Descriptor createBigSchemaRowDescriptor() { + try { + return toDescriptor(BIG_SPARK_SCHEMA); + } catch (Descriptors.DescriptorValidationException e) { + throw new AssumptionViolatedException("Could not create BIG_SCHEMA_ROW_DESCRIPTOR", e); + } + } + private ProtoBufProto.ProtoRows MyProtoRows = createMyProtoRows(); + private ProtoBufProto.ProtoRows createMyProtoRows() { + try { + return ProtoBufProto.ProtoRows.newBuilder().addSerializedRows( + DynamicMessage.newBuilder(BIG_SCHEMA_ROW_DESCRIPTOR) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(1), 1L) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(2), "A") + .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 0L) + .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 1L) + .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 2L) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(4), + createSingleRowMessage( + MY_STRUCT, toDescriptor(MY_STRUCT), INTERNAL_STRUCT_DATA)) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(5), 3.14) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(6), true) + .build().toByteString()).build(); + } catch (Descriptors.DescriptorValidationException e) { + throw new AssumptionViolatedException("Could not create MY_PROTO_ROWS", e); + } + } +} diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java new file mode 100644 index 0000000000..e721403f80 --- /dev/null +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java @@ -0,0 +1,279 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery; + +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.FieldList; +import com.google.cloud.bigquery.LegacySQLTypeName; +import com.google.cloud.bigquery.Schema; +import org.apache.log4j.Level; +import org.apache.log4j.LogManager; +import org.apache.log4j.Logger; +import org.apache.spark.sql.types.*; +import org.junit.Test; + +import static com.google.cloud.spark.bigquery.SchemaConverters.*; +import static com.google.common.truth.Truth.*; +import static org.junit.Assert.fail; + +public class SchemaConverterTest { + + // Numeric is a fixed precision Decimal Type with 38 digits of precision and 9 digits of scale. + // See https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#numeric-type + private final static int BQ_NUMERIC_PRECISION = 38; + private final static int BQ_NUMERIC_SCALE = 9; + private final static DecimalType NUMERIC_SPARK_TYPE = DataTypes.createDecimalType( + BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); + // The maximum nesting depth of a BigQuery RECORD: + private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; + + private final Logger logger = LogManager.getLogger("com.google.cloud.spark"); + + /* + BigQuery -> Spark tests, translated from SchemaConvertersSuite.scala + */ + @Test + public void testEmptySchemaBigQueryToSparkConversion() throws Exception { + Schema bqSchema = Schema.of(); + StructType expected = new StructType(); + StructType result = SchemaConverters.toSpark(bqSchema); + assertThat(result).isEqualTo(expected); + } + + @Test + public void testSingleFieldSchemaBigQueryToSparkConversion() throws Exception { + Schema bqSchema = Schema.of(Field.of("foo", LegacySQLTypeName.STRING)); + StructType expected = new StructType() + .add(new StructField("foo", DataTypes.StringType, true, Metadata.empty())); + StructType result = SchemaConverters.toSpark(bqSchema); + assertThat(result).isEqualTo(expected); + } + + @Test + public void testFullFieldSchemaBigQueryToSparkConversion() throws Exception { + Schema bqSchema = BIG_BIGQUERY_SCHEMA2; + + StructType expected = BIG_SPARK_SCHEMA2; + + StructType result = SchemaConverters.toSpark(bqSchema); + assertThat(result).isEqualTo(expected); + } + + @Test + public void testFieldHasDescriptionBigQueryToSpark() throws Exception { + Schema bqSchema = Schema.of( + Field.newBuilder("name", LegacySQLTypeName.STRING) + .setDescription("foo") + .setMode(Field.Mode.NULLABLE) + .build()); + StructType expected = new StructType() + .add(new StructField("name", DataTypes.StringType, true, + (new MetadataBuilder()).putString("description", "foo").build())); + + StructType result = SchemaConverters.toSpark(bqSchema); + assertThat(result).isEqualTo(expected); + } + + @Test + public void testSparkStructFieldToBigQuery() throws Exception { + logger.setLevel(Level.DEBUG); + + Field expected = BIGQUERY_NESTED_STRUCT_FIELD; + Field converted = makeBigQueryColumn(SPARK_NESTED_STRUCT_FIELD, 0); + + assertThat(converted).isEqualTo(expected); + } + + @Test + public void testSparkToBQSchema() throws Exception { + logger.setLevel(Level.DEBUG); + + StructType schema = BIG_SPARK_SCHEMA; + Schema expected = BIG_BIGQUERY_SCHEMA; + + Schema converted = toBigQuerySchema(schema); + + for(int i = 0; i < converted.getFields().size(); i++){ + assertThat(converted.getFields().get(i)).isEqualTo(expected.getFields().get(i)); + } + } + + @Test + public void testSparkMapException() throws Exception { + logger.setLevel(Level.DEBUG); + + try { + makeBigQueryColumn(SPARK_MAP_FIELD, 0); + fail("Did not throw an error for an unsupported map-type"); + } catch (IllegalArgumentException e) {} + } + + @Test + public void testDecimalTypeConversion() throws Exception { + logger.setLevel(Level.DEBUG); + + assertThat(toBigQueryType(NUMERIC_SPARK_TYPE)).isEqualTo(LegacySQLTypeName.NUMERIC); + + try { + DecimalType wayTooBig = DataTypes.createDecimalType(38,38); + toBigQueryType(wayTooBig); + fail("Did not throw an error for a decimal that's too wide for big-query"); + } catch (IllegalArgumentException e) {} + } + + @Test + public void testTimeTypesConversions() throws Exception { + logger.setLevel(Level.DEBUG); + + assertThat(toBigQueryType(DataTypes.TimestampType)).isEqualTo(LegacySQLTypeName.TIMESTAMP); + assertThat(toBigQueryType(DataTypes.DateType)).isEqualTo(LegacySQLTypeName.DATE); + } + + @Test + public void testDescriptionConversion() throws Exception { + logger.setLevel(Level.DEBUG); + + String description = "I love bananas"; + Field result = makeBigQueryColumn(new StructField("Field", DataTypes.IntegerType, + true, new MetadataBuilder().putString("description", description).build()), 0); + + assertThat(result.getDescription().equals(description)); + } + + @Test + public void testMaximumNestingDepthError() throws Exception { + logger.setLevel(Level.DEBUG); + + StructType inner = new StructType(); + StructType superRecursiveSchema = inner; + for(int i = 0; i < MAX_BIGQUERY_NESTED_DEPTH+1; i++){ + StructType outer = new StructType() + .add(new StructField("struct"+i, superRecursiveSchema, true, Metadata.empty())); + superRecursiveSchema = outer; + } + + try { + makeBigQueryColumn(superRecursiveSchema.fields()[0], 0); + fail("Did not detect super-recursive schema of depth = 16."); + } + catch (IllegalArgumentException e) {} + } + + + + + public final StructType MY_STRUCT = DataTypes.createStructType( + new StructField[]{new StructField("Number", DataTypes.IntegerType, + true, Metadata.empty()), + new StructField("String", DataTypes.StringType, + true, Metadata.empty())}); + + public final StructField SPARK_INTEGER_FIELD = new StructField("Number", DataTypes.IntegerType, + true, Metadata.empty()); + public final StructField SPARK_STRING_FIELD = new StructField("String", DataTypes.StringType, + false, Metadata.empty()); + public final StructField SPARK_NESTED_STRUCT_FIELD = new StructField("Struct", MY_STRUCT, + true, Metadata.empty()); + public final StructField SPARK_ARRAY_FIELD = new StructField("Array", + DataTypes.createArrayType(DataTypes.IntegerType), + true, Metadata.empty()); + public final StructField SPARK_MAP_FIELD = new StructField("Map", + DataTypes.createMapType(DataTypes.IntegerType, DataTypes.StringType), + true, Metadata.empty()); + public final StructField SPARK_DOUBLE_FIELD = new StructField("Float", DataTypes.DoubleType, + true, Metadata.empty()); + public final StructField SPARK_BOOLEAN_FIELD = new StructField("Boolean", DataTypes.BooleanType, + true, Metadata.empty()); + public final StructField SPARK_NUMERIC_FIELD = new StructField("Numeric", NUMERIC_SPARK_TYPE, + true, Metadata.empty()); + + public final StructType BIG_SPARK_SCHEMA = new StructType() + .add(SPARK_INTEGER_FIELD) + .add(SPARK_STRING_FIELD) + .add(SPARK_ARRAY_FIELD) + .add(SPARK_NESTED_STRUCT_FIELD) + .add(SPARK_DOUBLE_FIELD) + .add(SPARK_BOOLEAN_FIELD) + .add(SPARK_NUMERIC_FIELD); + + public final StructType BIG_SPARK_SCHEMA2 = new StructType() + .add(new StructField("foo", DataTypes.StringType,true, Metadata.empty())) + .add(new StructField("bar", DataTypes.LongType, true, Metadata.empty())) + .add(new StructField("required", DataTypes.BooleanType, false, Metadata.empty())) + .add(new StructField("binary_arr", DataTypes.createArrayType(DataTypes.BinaryType, true), + true, Metadata.empty())) + .add(new StructField("float", DataTypes.DoubleType, true, Metadata.empty())) + .add(new StructField("numeric", DataTypes.createDecimalType(38, 9), + true, Metadata.empty())) + .add(new StructField("date", DataTypes.DateType, true, Metadata.empty())) + .add(new StructField("times", new StructType() + .add(new StructField("time", DataTypes.LongType, true, Metadata.empty())) + .add(new StructField("timestamp", DataTypes.TimestampType, true, Metadata.empty())) + .add(new StructField("datetime", DataTypes.StringType, true, Metadata.empty())), + true, Metadata.empty())); + + public final Schema BIG_BIGQUERY_SCHEMA2 = Schema.of( + Field.of("foo", LegacySQLTypeName.STRING), + Field.of("bar", LegacySQLTypeName.INTEGER), + Field.newBuilder("required", LegacySQLTypeName.BOOLEAN).setMode(Field.Mode.REQUIRED).build(), + Field.newBuilder("binary_arr", LegacySQLTypeName.BYTES).setMode(Field.Mode.REPEATED).build(), + Field.of("float", LegacySQLTypeName.FLOAT), + Field.of("numeric", LegacySQLTypeName.NUMERIC), + Field.of("date", LegacySQLTypeName.DATE), + Field.of("times", LegacySQLTypeName.RECORD, + Field.of("time", LegacySQLTypeName.TIME), + Field.of("timestamp", LegacySQLTypeName.TIMESTAMP), + Field.of("datetime", LegacySQLTypeName.DATETIME))); + + + public final Field BIGQUERY_INTEGER_FIELD = Field.newBuilder("Number", LegacySQLTypeName.INTEGER, + (FieldList)null).setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_STRING_FIELD = Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) + .setMode(Field.Mode.REQUIRED).build(); + public final Field BIGQUERY_NESTED_STRUCT_FIELD = Field.newBuilder("Struct", LegacySQLTypeName.RECORD, + Field.newBuilder("Number", LegacySQLTypeName.INTEGER, (FieldList) null) + .setMode(Field.Mode.NULLABLE).build(), + Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) + .setMode(Field.Mode.NULLABLE).build()) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_ARRAY_FIELD = Field.newBuilder("Array", LegacySQLTypeName.INTEGER, (FieldList) null) + .setMode(Field.Mode.REPEATED).build(); + public final Field BIGQUERY_FLOAT_FIELD = Field.newBuilder("Float", LegacySQLTypeName.FLOAT, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_BOOLEAN_FIELD = Field.newBuilder("Boolean", LegacySQLTypeName.BOOLEAN, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_NUMERIC_FIELD = Field.newBuilder("Numeric", LegacySQLTypeName.NUMERIC, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + + public final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, BIGQUERY_NESTED_STRUCT_FIELD, + BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD, BIGQUERY_NUMERIC_FIELD); + + /* TODO: create SchemaConverters.convert() from BigQuery -> Spark test. Translate specific test from SchemaIteratorSuite.scala + private final List BIG_SCHEMA_NAMES_INORDER = Arrays.asList( + new String[]{"Number", "String", "Array", "Struct", "Float", "Boolean", "Numeric"}); + + private final org.apache.avro.Schema AVRO_SCHEMA = createAvroSchema(); + private final org.apache.avro.Schema createAvroSchema() throws AssumptionViolatedException { + try { + org.apache.avro.Schema avroSchema = new org.apache.avro.Schema.Parser(). + parse(this.getClass().getResourceAsStream("/alltypes.avroschema.json")); + return avroSchema; + } catch (IOException e) { + throw new AssumptionViolatedException("Could not create AVRO_SCHEMA", e); + } + } + */ +} From 22b41d32fdbd49a99b627c45496490ca5ffa755b Mon Sep 17 00:00:00 2001 From: emkornfield Date: Tue, 7 Jul 2020 07:38:29 -0700 Subject: [PATCH 5/9] Adds implementation for supporting columnar batch reads from Spark. (#198) This bypasses most of the existing translation code for the following reasons: 1. I think there might be a memory leak because the existing code doesn't close the allocator. 2. This avoids continuously recopying the schema. I didn't delete the old code because it appears the BigQueryRDD still relies on it partially. I also couldn't find instructions on formatting/testing (I couldn't find explicit unit tests for existing arrow code, I'll update accordingly if pointers can be provided). --- .../connector/common/ReadRowsHelper.java | 16 +- .../v2/ArrowColumnBatchPartitionReader.java | 143 ++++++++++++++++++ .../bigquery/v2/ArrowInputPartition.java | 62 ++++++++ .../bigquery/v2/BigQueryDataSourceReader.java | 53 ++++++- 4 files changed, 265 insertions(+), 9 deletions(-) create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowColumnBatchPartitionReader.java create mode 100644 connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowInputPartition.java diff --git a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadRowsHelper.java b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadRowsHelper.java index 1c99a5c889..6ad3fbd233 100644 --- a/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadRowsHelper.java +++ b/connector/src/main/java/com/google/cloud/bigquery/connector/common/ReadRowsHelper.java @@ -15,6 +15,7 @@ */ package com.google.cloud.bigquery.connector.common; +import com.google.api.gax.rpc.ServerStream; import com.google.cloud.bigquery.storage.v1.BigQueryReadClient; import com.google.cloud.bigquery.storage.v1.ReadRowsRequest; import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; @@ -29,6 +30,7 @@ public class ReadRowsHelper { private ReadRowsRequest.Builder request; private int maxReadRowsRetries; private BigQueryReadClient client; + private ServerStream incomingStream; public ReadRowsHelper( BigQueryReadClientFactory bigQueryReadClientFactory, @@ -51,7 +53,13 @@ public Iterator readRows() { // In order to enable testing protected Iterator fetchResponses(ReadRowsRequest.Builder readRowsRequest) { - return client.readRowsCallable().call(readRowsRequest.build()).iterator(); + incomingStream = client.readRowsCallable().call(readRowsRequest.build()); + return incomingStream.iterator(); + } + + @Override + public String toString() { + return request.toString(); } // Ported from https://github.com/GoogleCloudDataproc/spark-bigquery-connector/pull/150 @@ -89,7 +97,7 @@ public ReadRowsResponse next() { serverResponses = helper.fetchResponses(helper.request.setOffset(readRowsCount)); retries++; } else { - helper.client.close(); + helper.close(); throw e; } } @@ -100,6 +108,10 @@ public ReadRowsResponse next() { } public void close() { + if (incomingStream != null) { + incomingStream.cancel(); + incomingStream = null; + } if (!client.isShutdown()) { client.close(); } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowColumnBatchPartitionReader.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowColumnBatchPartitionReader.java new file mode 100644 index 0000000000..b66b68acef --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowColumnBatchPartitionReader.java @@ -0,0 +1,143 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import com.google.cloud.bigquery.connector.common.ReadRowsHelper; +import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; +import com.google.cloud.spark.bigquery.ArrowSchemaConverter; +import com.google.protobuf.ByteString; +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; + +class ArrowColumnBatchPartitionColumnBatchReader implements InputPartitionReader { + private static final long maxAllocation = 500 * 1024 * 1024; + + private final ReadRowsHelper readRowsHelper; + private final ArrowStreamReader reader; + private final BufferAllocator allocator; + private final List namesInOrder; + private ColumnarBatch currentBatch; + private boolean closed = false; + + static class ReadRowsResponseInputStreamEnumeration + implements java.util.Enumeration { + private Iterator responses; + private ReadRowsResponse currentResponse; + + ReadRowsResponseInputStreamEnumeration(Iterator responses) { + this.responses = responses; + loadNextResponse(); + } + + public boolean hasMoreElements() { + return currentResponse != null; + } + + public InputStream nextElement() { + if (!hasMoreElements()) { + throw new NoSuchElementException("No more responses"); + } + ReadRowsResponse ret = currentResponse; + loadNextResponse(); + return ret.getArrowRecordBatch().getSerializedRecordBatch().newInput(); + } + + void loadNextResponse() { + if (responses.hasNext()) { + currentResponse = responses.next(); + } else { + currentResponse = null; + } + } + } + + ArrowColumnBatchPartitionColumnBatchReader( + Iterator readRowsResponses, + ByteString schema, + ReadRowsHelper readRowsHelper, + List namesInOrder) { + this.allocator = + (new RootAllocator(maxAllocation)) + .newChildAllocator("ArrowBinaryIterator", 0, maxAllocation); + this.readRowsHelper = readRowsHelper; + this.namesInOrder = namesInOrder; + + InputStream batchStream = + new SequenceInputStream(new ReadRowsResponseInputStreamEnumeration(readRowsResponses)); + InputStream fullStream = new SequenceInputStream(schema.newInput(), batchStream); + + reader = new ArrowStreamReader(fullStream, allocator); + } + + @Override + public boolean next() throws IOException { + if (closed) { + return false; + } + closed = !reader.loadNextBatch(); + if (closed) { + return false; + } + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + if (currentBatch == null) { + // trying to verify from dev@spark but this object + // should only need to get created once. The underlying + // vectors should stay the same. + ColumnVector[] columns = + namesInOrder.stream() + .map(root::getVector) + .map(ArrowSchemaConverter::new) + .toArray(ColumnVector[]::new); + + currentBatch = new ColumnarBatch(columns); + } + currentBatch.setNumRows(root.getRowCount()); + return true; + } + + @Override + public ColumnarBatch get() { + return currentBatch; + } + + @Override + public void close() throws IOException { + closed = true; + try { + readRowsHelper.close(); + } catch (Exception e) { + throw new IOException("Failure closing stream: " + readRowsHelper, e); + } finally { + try { + AutoCloseables.close(reader, allocator); + } catch (Exception e) { + throw new IOException("Failure closing arrow components. stream: " + readRowsHelper, e); + } + } + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowInputPartition.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowInputPartition.java new file mode 100644 index 0000000000..e1565e5e17 --- /dev/null +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/ArrowInputPartition.java @@ -0,0 +1,62 @@ +/* + * Copyright 2018 Google Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.cloud.spark.bigquery.v2; + +import com.google.cloud.bigquery.connector.common.BigQueryReadClientFactory; +import com.google.cloud.bigquery.connector.common.ReadRowsHelper; +import com.google.cloud.bigquery.connector.common.ReadSessionResponse; +import com.google.cloud.bigquery.storage.v1.ReadRowsRequest; +import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import java.util.Iterator; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +public class ArrowInputPartition implements InputPartition { + + private final BigQueryReadClientFactory bigQueryReadClientFactory; + private final String streamName; + private final int maxReadRowsRetries; + private final ImmutableList selectedFields; + private final ByteString serializedArrowSchema; + + public ArrowInputPartition( + BigQueryReadClientFactory bigQueryReadClientFactory, + String name, + int maxReadRowsRetries, + ImmutableList selectedFields, + ReadSessionResponse readSessionResponse) { + this.bigQueryReadClientFactory = bigQueryReadClientFactory; + this.streamName = name; + this.maxReadRowsRetries = maxReadRowsRetries; + this.selectedFields = selectedFields; + this.serializedArrowSchema = + readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema(); + } + + @Override + public InputPartitionReader createPartitionReader() { + ReadRowsRequest.Builder readRowsRequest = + ReadRowsRequest.newBuilder().setReadStream(streamName); + ReadRowsHelper readRowsHelper = + new ReadRowsHelper(bigQueryReadClientFactory, readRowsRequest, maxReadRowsRetries); + Iterator readRowsResponses = readRowsHelper.readRows(); + return new ArrowColumnBatchPartitionColumnBatchReader( + readRowsResponses, serializedArrowSchema, readRowsHelper, selectedFields); + } +} diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java index 68dfaab217..19332899d9 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/v2/BigQueryDataSourceReader.java @@ -32,12 +32,14 @@ import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.spark.sql.vectorized.ColumnarBatch; public class BigQueryDataSourceReader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters, - SupportsReportStatistics { + SupportsReportStatistics, + SupportsScanColumnarBatch { private static Statistics UNKNOWN_STATISTICS = new Statistics() { @@ -87,9 +89,14 @@ public StructType readSchema() { return schema.orElse(SchemaConverters.toSpark(table.getDefinition().getSchema())); } + @Override + public boolean enableBatchRead() { + return readSessionCreatorConfig.getReadDataFormat() == DataFormat.ARROW && !isEmptySchema(); + } + @Override public List> planInputPartitions() { - if (schema.map(StructType::isEmpty).orElse(false)) { + if (isEmptySchema()) { // create empty projection return createEmptyProjectionPartitions(); } @@ -117,10 +124,44 @@ public List> planInputPartitions() { .collect(Collectors.toList()); } + @Override + public List> planBatchInputPartitions() { + if (!enableBatchRead()) { + throw new IllegalStateException("Batch reads should not be enabled"); + } + ImmutableList selectedFields = + schema + .map(requiredSchema -> ImmutableList.copyOf(requiredSchema.fieldNames())) + .orElse(ImmutableList.of()); + Optional filter = + emptyIfNeeded( + SparkFilterUtils.getCompiledFilter( + readSessionCreatorConfig.getReadDataFormat(), globalFilter, pushedFilters)); + ReadSessionResponse readSessionResponse = + readSessionCreator.create( + tableId, selectedFields, filter, readSessionCreatorConfig.getMaxParallelism()); + ReadSession readSession = readSessionResponse.getReadSession(); + return readSession.getStreamsList().stream() + .map( + stream -> + new ArrowInputPartition( + bigQueryReadClientFactory, + stream.getName(), + readSessionCreatorConfig.getMaxReadRowsRetries(), + selectedFields, + readSessionResponse)) + .collect(Collectors.toList()); + } + + private boolean isEmptySchema() { + return schema.map(StructType::isEmpty).orElse(false); + } + private ReadRowsResponseToInternalRowIteratorConverter createConverter( ImmutableList selectedFields, ReadSessionResponse readSessionResponse) { ReadRowsResponseToInternalRowIteratorConverter converter; - if (readSessionCreatorConfig.getReadDataFormat() == DataFormat.AVRO) { + DataFormat format = readSessionCreatorConfig.getReadDataFormat(); + if (format == DataFormat.AVRO) { Schema schema = readSessionResponse.getReadTableInfo().getDefinition().getSchema(); if (selectedFields.isEmpty()) { // means select * @@ -138,11 +179,9 @@ private ReadRowsResponseToInternalRowIteratorConverter createConverter( } return ReadRowsResponseToInternalRowIteratorConverter.avro( schema, selectedFields, readSessionResponse.getReadSession().getAvroSchema().getSchema()); - } else { - return ReadRowsResponseToInternalRowIteratorConverter.arrow( - selectedFields, - readSessionResponse.getReadSession().getArrowSchema().getSerializedSchema()); } + throw new IllegalArgumentException( + "No known converted for " + readSessionCreatorConfig.getReadDataFormat()); } List> createEmptyProjectionPartitions() { From 601685b085193416c66248d4a99b44cc910293cf Mon Sep 17 00:00:00 2001 From: Yuval Medina Date: Tue, 7 Jul 2020 19:07:57 +0000 Subject: [PATCH 6/9] Integrated all of DavidRab's suggestions --- .../cloud/spark/bigquery/ProtobufUtils.java | 36 +++++--- .../spark/bigquery/SchemaConverters.java | 12 +-- .../spark/bigquery/ProtobufUtilsTest.java | 91 ++++++++++++------- .../spark/bigquery/SchemaConverterTest.java | 25 +++-- 4 files changed, 103 insertions(+), 61 deletions(-) diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java index 669bab27da..ee5300e1df 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; +import java.sql.Date; +import java.sql.Timestamp; import java.util.Arrays; import java.util.stream.Collectors; @@ -261,10 +263,6 @@ protected static Object convert (StructField sparkField, Object sparkValue) { "Encountered a null value inside a non-null-containing array."); return toAtomicProtoRowValue(elementType, value); } ).collect(Collectors.toList()); - } else if (fieldType instanceof StructType) { - throw new IllegalStateException("Method did not expect to convert a StructType instance."); - } else if (fieldType instanceof MapType) { - throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); } else { return toAtomicProtoRowValue(fieldType, sparkValue); @@ -278,9 +276,7 @@ private static Object toAtomicProtoRowValue(DataType sparkType, Object value) { if (sparkType instanceof ByteType || sparkType instanceof ShortType || sparkType instanceof IntegerType || - sparkType instanceof LongType || - sparkType instanceof TimestampType || - sparkType instanceof DateType) { + sparkType instanceof LongType) { return ((Number)value).longValue(); } @@ -290,13 +286,22 @@ private static Object toAtomicProtoRowValue(DataType sparkType, Object value) { return ((Number)value).doubleValue(); // TODO: should decimal be converted to double? Or a Bytes type containing extra width? } - if (sparkType instanceof BooleanType) { - return ((Boolean)value).booleanValue(); // TODO: can be unboxed? + if (sparkType instanceof TimestampType) { + return Timestamp.valueOf((String)value).getTime(); // + } + + if (sparkType instanceof DateType) { + return Date.valueOf((String)value).getTime(); } - if (sparkType instanceof StringType || + if (sparkType instanceof BooleanType || + sparkType instanceof StringType || sparkType instanceof BinaryType) { - return value; // TODO: verify correct method for extracting this value. + return value; + } + + if (sparkType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); } throw new IllegalStateException("Unexpected type: " + sparkType); @@ -315,9 +320,6 @@ private static DescriptorProtos.DescriptorProto buildDescriptorProtoWithFields( DescriptorProtos.FieldDescriptorProto.Type fieldType; DataType sparkType = field.dataType(); - if (sparkType instanceof MapType) { - throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); - } if (sparkType instanceof StructType) { StructType structType = (StructType)sparkType; String nestedName = RESERVED_NESTED_TYPE_NAME +messageNumber; // TODO: this should be a reserved name. No column can have this name. @@ -354,7 +356,7 @@ private static DescriptorProtos.DescriptorProto buildDescriptorProtoWithFields( // NOTE: annotations for DATETIME and TIMESTAMP objects are currently unsupported for external users, // but if they become available, it would be advisable to append an annotation to the protoFieldBuilder // for these and other types. - // This function only converts atomic Spark DataTypes (MapType, ArrayType, and StructType will throw an error). + // This function only converts atomic Spark DataTypes private static DescriptorProtos.FieldDescriptorProto.Type sparkAtomicTypeToProtoFieldType(DataType sparkType) { if (sparkType instanceof ByteType || sparkType instanceof ShortType || @@ -384,6 +386,10 @@ private static DescriptorProtos.FieldDescriptorProto.Type sparkAtomicTypeToProto return DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; } + if (sparkType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); + } + throw new IllegalStateException("Unexpected type: " + sparkType); } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java index b0598e3845..66889e1532 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java @@ -245,9 +245,6 @@ protected static Field makeBigQueryColumn (StructField sparkField, int depth) { fieldType = elementType; fieldMode = Field.Mode.REPEATED; } - else if (sparkType instanceof MapType) { - throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); - } else if (sparkType instanceof StructType) { subFields = sparkToBigQueryFields((StructType)sparkType, depth+1); fieldType = LegacySQLTypeName.RECORD; @@ -287,15 +284,18 @@ protected static LegacySQLTypeName toBigQueryType (DataType elementType) { decimalType.scale() <= BQ_NUMERIC_SCALE) { return LegacySQLTypeName.NUMERIC; } else { - throw new IllegalArgumentException("Decimal type is too wide to fit in BigQuery Numeric format"); // TODO + throw new IllegalArgumentException("Decimal type is too wide to fit in BigQuery Numeric format"); } } if (elementType instanceof StringType) { return LegacySQLTypeName.STRING; } if (elementType instanceof TimestampType) { return LegacySQLTypeName.TIMESTAMP; - } if (elementType instanceof DateType) { // TODO: TIME & DATETIME in BigQuery + } if (elementType instanceof DateType) { return LegacySQLTypeName.DATE; - } else { + } if (elementType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); + } + else { throw new IllegalArgumentException("Data type not expected in toBQType: "+elementType.simpleString()); } } diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java index b9e81b2dca..3143376097 100644 --- a/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java @@ -30,19 +30,28 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.*; import org.junit.AssumptionViolatedException; import org.junit.Test; +import java.sql.Date; +import java.sql.Timestamp; + import static com.google.cloud.spark.bigquery.ProtobufUtils.*; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; public class ProtobufUtilsTest { + // Numeric is a fixed precision Decimal Type with 38 digits of precision and 9 digits of scale. + // See https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#numeric-type + private final static int BQ_NUMERIC_PRECISION = 38; + private final static int BQ_NUMERIC_SCALE = 9; + private final static DecimalType NUMERIC_SPARK_TYPE = DataTypes.createDecimalType( + BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); + // The maximum nesting depth of a BigQuery RECORD: + private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; + private final Logger logger = LogManager.getLogger("com.google.cloud.spark"); @Test @@ -71,9 +80,10 @@ public void testBigQueryToProtoSchema() throws Exception { .addField(PROTO_ARRAY_FIELD.clone().setNumber(3)) .addNestedType(NESTED_STRUCT_DESCRIPTOR.clone()) .addField(PROTO_STRUCT_FIELD.clone().setNumber(4)) - .addField(PROTO_BYTES_FIELD.clone().setName("Geography").setNumber(5)) - .addField(PROTO_DOUBLE_FIELD.clone().setName("Float").setNumber(6)) - .addField(PROTO_BOOLEAN_FIELD.clone().setNumber(7)) + .addField(PROTO_DOUBLE_FIELD.clone().setName("Float").setNumber(5)) + .addField(PROTO_BOOLEAN_FIELD.clone().setNumber(6)) + .addField(PROTO_BYTES_FIELD.clone().setNumber(7)) + .addField(PROTO_INTEGER_FIELD.clone().setName("Date").setNumber(8)) .setName("Schema").build() ).build(), new Descriptors.FileDescriptor[]{} ).getMessageTypes().get(0) @@ -82,7 +92,7 @@ public void testBigQueryToProtoSchema() throws Exception { logger.debug("Expected schema: "+expected.getProtoDescriptor()); logger.debug("Actual schema: "+converted.getProtoDescriptor()); - for(int i = 0; i < 7; i++){ + for(int i = 0; i < expected.getProtoDescriptor().getFieldList().size(); i++){ assertThat(converted.getProtoDescriptor().getField(i)).isEqualTo(expected.getProtoDescriptor().getField(i)); } } @@ -171,11 +181,13 @@ public void testSparkRowToProtoRow() throws Exception { ArrayData.toArrayData(new int[]{0,1,2}), INTERNAL_STRUCT_DATA, 3.14, - true + true, + new byte[]{11, 0x7F}, + "2020-07-07" })} ); - ProtoBufProto.ProtoRows expected = MyProtoRows; + ProtoBufProto.ProtoRows expected = MY_PROTO_ROWS; assertThat(converted.getSerializedRows(0).toByteArray()).isEqualTo(expected.getSerializedRows(0).toByteArray()); } @@ -204,50 +216,61 @@ public void testSettingARequiredFieldAsNull() throws Exception { new StructField("String", DataTypes.StringType, true, Metadata.empty())}); - private final StructField SPARK_INTEGER_FIELD = new StructField("Number", DataTypes.IntegerType, + public final StructField SPARK_INTEGER_FIELD = new StructField("Number", DataTypes.IntegerType, true, Metadata.empty()); - private final StructField SPARK_STRING_FIELD = new StructField("String", DataTypes.StringType, + public final StructField SPARK_STRING_FIELD = new StructField("String", DataTypes.StringType, false, Metadata.empty()); - private final StructField SPARK_NESTED_STRUCT_FIELD = new StructField("Struct", MY_STRUCT, + public final StructField SPARK_NESTED_STRUCT_FIELD = new StructField("Struct", MY_STRUCT, true, Metadata.empty()); - private final StructField SPARK_ARRAY_FIELD = new StructField("Array", + public final StructField SPARK_ARRAY_FIELD = new StructField("Array", DataTypes.createArrayType(DataTypes.IntegerType), true, Metadata.empty()); - private final StructField SPARK_DOUBLE_FIELD = new StructField("Double", DataTypes.DoubleType, + public final StructField SPARK_DOUBLE_FIELD = new StructField("Float", DataTypes.DoubleType, + true, Metadata.empty()); + public final StructField SPARK_BOOLEAN_FIELD = new StructField("Boolean", DataTypes.BooleanType, + true, Metadata.empty()); + public final StructField SPARK_BINARY_FIELD = new StructField("Binary", DataTypes.BinaryType, true, Metadata.empty()); - private final StructField SPARK_BOOLEAN_FIELD = new StructField("Boolean", DataTypes.BooleanType, + public final StructField SPARK_DATE_FIELD = new StructField("Date", DataTypes.DateType, + true, Metadata.empty()); + public final StructField SPARK_MAP_FIELD = new StructField("Map", + DataTypes.createMapType(DataTypes.IntegerType, DataTypes.StringType), true, Metadata.empty()); - private final StructType BIG_SPARK_SCHEMA = new StructType() + public final StructType BIG_SPARK_SCHEMA = new StructType() .add(SPARK_INTEGER_FIELD) .add(SPARK_STRING_FIELD) .add(SPARK_ARRAY_FIELD) .add(SPARK_NESTED_STRUCT_FIELD) .add(SPARK_DOUBLE_FIELD) - .add(SPARK_BOOLEAN_FIELD); + .add(SPARK_BOOLEAN_FIELD) + .add(SPARK_BINARY_FIELD) + .add(SPARK_DATE_FIELD); - private final Field BIGQUERY_INTEGER_FIELD = Field.newBuilder("Number", LegacySQLTypeName.INTEGER, + public final Field BIGQUERY_INTEGER_FIELD = Field.newBuilder("Number", LegacySQLTypeName.INTEGER, (FieldList)null).setMode(Field.Mode.NULLABLE).build(); - private final Field BIGQUERY_STRING_FIELD = Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) + public final Field BIGQUERY_STRING_FIELD = Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) .setMode(Field.Mode.REQUIRED).build(); - private final Field BIGQUERY_NESTED_STRUCT_FIELD = Field.newBuilder("Struct", LegacySQLTypeName.RECORD, + public final Field BIGQUERY_NESTED_STRUCT_FIELD = Field.newBuilder("Struct", LegacySQLTypeName.RECORD, Field.newBuilder("Number", LegacySQLTypeName.INTEGER, (FieldList) null) .setMode(Field.Mode.NULLABLE).build(), Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) .setMode(Field.Mode.NULLABLE).build()) .setMode(Field.Mode.NULLABLE).build(); - private final Field BIGQUERY_ARRAY_FIELD = Field.newBuilder("Array", LegacySQLTypeName.INTEGER, (FieldList) null) + public final Field BIGQUERY_ARRAY_FIELD = Field.newBuilder("Array", LegacySQLTypeName.INTEGER, (FieldList) null) .setMode(Field.Mode.REPEATED).build(); - private final Field BIGQUERY_GEOGRAPHY_FIELD = Field.newBuilder("Geography", LegacySQLTypeName.GEOGRAPHY, (FieldList)null) + public final Field BIGQUERY_FLOAT_FIELD = Field.newBuilder("Float", LegacySQLTypeName.FLOAT, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_BOOLEAN_FIELD = Field.newBuilder("Boolean", LegacySQLTypeName.BOOLEAN, (FieldList)null) .setMode(Field.Mode.NULLABLE).build(); - private final Field BIGQUERY_FLOAT_FIELD = Field.newBuilder("Float", LegacySQLTypeName.FLOAT, (FieldList)null) + public final Field BIGQUERY_BYTES_FIELD = Field.newBuilder("Binary", LegacySQLTypeName.BYTES, (FieldList)null) .setMode(Field.Mode.NULLABLE).build(); - private final Field BIGQUERY_BOOLEAN_FIELD = Field.newBuilder("Boolean", LegacySQLTypeName.BOOLEAN, (FieldList)null) + public final Field BIGQUERY_DATE_FIELD = Field.newBuilder("Date", LegacySQLTypeName.DATE, (FieldList)null) .setMode(Field.Mode.NULLABLE).build(); - private final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, BIGQUERY_NESTED_STRUCT_FIELD, - BIGQUERY_GEOGRAPHY_FIELD, BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD); + public final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, + BIGQUERY_NESTED_STRUCT_FIELD, BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD, BIGQUERY_BYTES_FIELD, BIGQUERY_DATE_FIELD); private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_INTEGER_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() @@ -275,11 +298,6 @@ public void testSettingARequiredFieldAsNull() throws Exception { .setNumber(1) .setTypeName("STRUCT1") .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); - private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_BYTES_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() - .setName("Bytes") - .setNumber(1) - .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES) - .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_DOUBLE_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() .setName("Double") .setNumber(1) @@ -290,6 +308,11 @@ public void testSettingARequiredFieldAsNull() throws Exception { .setNumber(1) .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL) .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); + private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_BYTES_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName("Binary") + .setNumber(1) + .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES) + .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_INTEGER = DescriptorProtos.DescriptorProto.newBuilder() @@ -380,7 +403,7 @@ private Descriptors.Descriptor createBigSchemaRowDescriptor() { throw new AssumptionViolatedException("Could not create BIG_SCHEMA_ROW_DESCRIPTOR", e); } } - private ProtoBufProto.ProtoRows MyProtoRows = createMyProtoRows(); + private ProtoBufProto.ProtoRows MY_PROTO_ROWS = createMyProtoRows(); private ProtoBufProto.ProtoRows createMyProtoRows() { try { return ProtoBufProto.ProtoRows.newBuilder().addSerializedRows( @@ -395,6 +418,8 @@ private ProtoBufProto.ProtoRows createMyProtoRows() { MY_STRUCT, toDescriptor(MY_STRUCT), INTERNAL_STRUCT_DATA)) .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(5), 3.14) .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(6), true) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(7), new byte[]{11, 0x7F}) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(8), 1594080000000L) .build().toByteString()).build(); } catch (Descriptors.DescriptorValidationException e) { throw new AssumptionViolatedException("Could not create MY_PROTO_ROWS", e); diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java index e721403f80..d6bbf659b8 100644 --- a/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java @@ -106,7 +106,7 @@ public void testSparkToBQSchema() throws Exception { Schema converted = toBigQuerySchema(schema); - for(int i = 0; i < converted.getFields().size(); i++){ + for(int i = 0; i < expected.getFields().size(); i++){ assertThat(converted.getFields().get(i)).isEqualTo(expected.getFields().get(i)); } } @@ -190,15 +190,19 @@ public void testMaximumNestingDepthError() throws Exception { public final StructField SPARK_ARRAY_FIELD = new StructField("Array", DataTypes.createArrayType(DataTypes.IntegerType), true, Metadata.empty()); - public final StructField SPARK_MAP_FIELD = new StructField("Map", - DataTypes.createMapType(DataTypes.IntegerType, DataTypes.StringType), - true, Metadata.empty()); public final StructField SPARK_DOUBLE_FIELD = new StructField("Float", DataTypes.DoubleType, true, Metadata.empty()); public final StructField SPARK_BOOLEAN_FIELD = new StructField("Boolean", DataTypes.BooleanType, true, Metadata.empty()); public final StructField SPARK_NUMERIC_FIELD = new StructField("Numeric", NUMERIC_SPARK_TYPE, true, Metadata.empty()); + public final StructField SPARK_BINARY_FIELD = new StructField("Binary", DataTypes.BinaryType, + true, Metadata.empty()); + public final StructField SPARK_DATE_FIELD = new StructField("Date", DataTypes.DateType, + true, Metadata.empty()); + public final StructField SPARK_MAP_FIELD = new StructField("Map", + DataTypes.createMapType(DataTypes.IntegerType, DataTypes.StringType), + true, Metadata.empty()); public final StructType BIG_SPARK_SCHEMA = new StructType() .add(SPARK_INTEGER_FIELD) @@ -207,7 +211,9 @@ public void testMaximumNestingDepthError() throws Exception { .add(SPARK_NESTED_STRUCT_FIELD) .add(SPARK_DOUBLE_FIELD) .add(SPARK_BOOLEAN_FIELD) - .add(SPARK_NUMERIC_FIELD); + .add(SPARK_NUMERIC_FIELD) + .add(SPARK_BINARY_FIELD) + .add(SPARK_DATE_FIELD); public final StructType BIG_SPARK_SCHEMA2 = new StructType() .add(new StructField("foo", DataTypes.StringType,true, Metadata.empty())) @@ -257,9 +263,14 @@ public void testMaximumNestingDepthError() throws Exception { .setMode(Field.Mode.NULLABLE).build(); public final Field BIGQUERY_NUMERIC_FIELD = Field.newBuilder("Numeric", LegacySQLTypeName.NUMERIC, (FieldList)null) .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_BYTES_FIELD = Field.newBuilder("Binary", LegacySQLTypeName.BYTES, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_DATE_FIELD = Field.newBuilder("Date", LegacySQLTypeName.DATE, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); - public final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, BIGQUERY_NESTED_STRUCT_FIELD, - BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD, BIGQUERY_NUMERIC_FIELD); + public final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, + BIGQUERY_NESTED_STRUCT_FIELD, BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD, BIGQUERY_NUMERIC_FIELD, + BIGQUERY_BYTES_FIELD, BIGQUERY_DATE_FIELD); /* TODO: create SchemaConverters.convert() from BigQuery -> Spark test. Translate specific test from SchemaIteratorSuite.scala private final List BIG_SCHEMA_NAMES_INORDER = Arrays.asList( From 9d1afe37798d41bba4d130bf023aa8e67b5c785c Mon Sep 17 00:00:00 2001 From: Yuval Medina Date: Tue, 7 Jul 2020 22:42:22 +0000 Subject: [PATCH 7/9] Changed tests as well --- .../cloud/spark/bigquery/ProtobufUtils.java | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java index ee5300e1df..7fa5a40cf2 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java @@ -16,6 +16,7 @@ package com.google.cloud.spark.bigquery; import avro.shaded.com.google.common.base.Preconditions; +import com.google.cloud.ByteArray; import com.google.cloud.bigquery.Field; import com.google.cloud.bigquery.FieldList; import com.google.cloud.bigquery.LegacySQLTypeName; @@ -29,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; import java.sql.Date; import java.sql.Timestamp; @@ -216,6 +218,10 @@ public static DynamicMessage createSingleRowMessage(StructType schema, (InternalRow)row.get(i-1, sparkType))); } else { + Object converted = convert(sparkField, row.get(i-1, sparkType)); + if (converted == null) { + continue; + } messageBuilder.setField(schemaDescriptor.findFieldByNumber(i), convert(sparkField, row.get(i-1, sparkType))); @@ -292,14 +298,20 @@ private static Object toAtomicProtoRowValue(DataType sparkType, Object value) { if (sparkType instanceof DateType) { return Date.valueOf((String)value).getTime(); - } + } // TODO: CalendarInterval - if (sparkType instanceof BooleanType || - sparkType instanceof StringType || - sparkType instanceof BinaryType) { + if (sparkType instanceof BooleanType) { return value; } + if (sparkType instanceof StringType) { + return new String(((UTF8String)value).getBytes()); + } + + if (sparkType instanceof BinaryType) { + return ((ByteArray)value).toByteArray(); + } + if (sparkType instanceof MapType) { throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); } From 50d57a564df85fd0405bee22cee572da470ec6a9 Mon Sep 17 00:00:00 2001 From: Yuval Medina Date: Tue, 7 Jul 2020 22:43:12 +0000 Subject: [PATCH 8/9] Changed tests as well --- .../com/google/cloud/spark/bigquery/ProtobufUtils.java | 7 ++----- .../google/cloud/spark/bigquery/ProtobufUtilsTest.java | 8 +++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java index 7fa5a40cf2..c03769ca5d 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java @@ -300,7 +300,8 @@ private static Object toAtomicProtoRowValue(DataType sparkType, Object value) { return Date.valueOf((String)value).getTime(); } // TODO: CalendarInterval - if (sparkType instanceof BooleanType) { + if (sparkType instanceof BooleanType || + sparkType instanceof BinaryType) { return value; } @@ -308,10 +309,6 @@ private static Object toAtomicProtoRowValue(DataType sparkType, Object value) { return new String(((UTF8String)value).getBytes()); } - if (sparkType instanceof BinaryType) { - return ((ByteArray)value).toByteArray(); - } - if (sparkType instanceof MapType) { throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); } diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java index 3143376097..f1cfe37799 100644 --- a/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.ByteArray; +import org.apache.spark.unsafe.types.UTF8String; import org.junit.AssumptionViolatedException; import org.junit.Test; @@ -177,7 +179,7 @@ public void testSparkRowToProtoRow() throws Exception { new InternalRow[]{ new GenericInternalRow(new Object[]{ 1, - "A", + UTF8String.fromString("A"), ArrayData.toArrayData(new int[]{0,1,2}), INTERNAL_STRUCT_DATA, 3.14, @@ -325,10 +327,10 @@ public void testSettingARequiredFieldAsNull() throws Exception { .addNestedType(NESTED_STRUCT_DESCRIPTOR).addField(PROTO_STRUCT_FIELD).setName("Schema").build(); private final InternalRow INTEGER_INTERNAL_ROW = new GenericInternalRow(new Object[]{1}); - private final InternalRow STRING_INTERNAL_ROW = new GenericInternalRow(new Object[]{"A"}); + private final InternalRow STRING_INTERNAL_ROW = new GenericInternalRow(new Object[]{UTF8String.fromString("A")}); private final InternalRow ARRAY_INTERNAL_ROW = new GenericInternalRow(new Object[]{ArrayData.toArrayData( new int[]{0,1,2})}); - private final InternalRow INTERNAL_STRUCT_DATA = new GenericInternalRow(new Object[]{1, "A"}); + private final InternalRow INTERNAL_STRUCT_DATA = new GenericInternalRow(new Object[]{1, UTF8String.fromString("A")}); private final InternalRow STRUCT_INTERNAL_ROW = new GenericInternalRow(new Object[]{INTERNAL_STRUCT_DATA}); From 814a1bf9cfbb6f2046741f20aa07ba5e93c84268 Mon Sep 17 00:00:00 2001 From: Yuval Medina Date: Thu, 9 Jul 2020 19:59:26 +0000 Subject: [PATCH 9/9] Added functionality to support more complex Spark types (such as StructTypes within ArrayTypes) in SchemaConverters and ProtobufUtils. There are known issues with Timestamp conversion into BigQuery format when integrating with BigQuery Storage Write API. --- .../cloud/spark/bigquery/ProtobufUtils.java | 679 +++++++++--------- .../spark/bigquery/SchemaConverters.java | 186 ++--- .../spark/bigquery/ProtobufUtilsTest.java | 222 +++--- .../spark/bigquery/SchemaConverterTest.java | 80 +-- 4 files changed, 560 insertions(+), 607 deletions(-) diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java index c03769ca5d..ee77ef08b0 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/ProtobufUtils.java @@ -15,8 +15,7 @@ */ package com.google.cloud.spark.bigquery; -import avro.shaded.com.google.common.base.Preconditions; -import com.google.cloud.ByteArray; +import com.google.common.base.Preconditions; import com.google.cloud.bigquery.Field; import com.google.cloud.bigquery.FieldList; import com.google.cloud.bigquery.LegacySQLTypeName; @@ -27,378 +26,392 @@ import com.google.protobuf.DescriptorProtos; import com.google.protobuf.Descriptors; import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.Arrays; -import java.util.stream.Collectors; +import java.util.ArrayList; +import java.util.List; public class ProtobufUtils { - // The maximum nesting depth of a BigQuery RECORD: - private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; - // For every message, a nested type is name "STRUCT"+i, where i is the - // number of the corresponding field that is of this type in the containing message. - private static final String RESERVED_NESTED_TYPE_NAME = "STRUCT"; - private static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported."; - - /** - * BigQuery Schema ==> ProtoSchema converter utils: - */ - public static ProtoBufProto.ProtoSchema toProtoSchema (Schema schema) throws Exception { - try{ - Descriptors.Descriptor descriptor = toDescriptor(schema); - ProtoBufProto.ProtoSchema protoSchema = ProtoSchemaConverter.convert(descriptor); - return protoSchema; - } catch (Descriptors.DescriptorValidationException e){ - throw new Exception("Could not build Proto-Schema from Spark schema.", e); // TODO: right exception to throw? - } + static final Logger logger = LoggerFactory.getLogger(ProtobufUtils.class); + + // The maximum nesting depth of a BigQuery RECORD: + private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; + // For every message, a nested type is name "STRUCT"+i, where i is the + // number of the corresponding field that is of this type in the containing message. + private static final String RESERVED_NESTED_TYPE_NAME = "STRUCT"; + private static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported."; + + /** BigQuery Schema ==> ProtoSchema converter utils: */ + public static ProtoBufProto.ProtoSchema toProtoSchema(Schema schema) + throws IllegalArgumentException { + try { + Descriptors.Descriptor descriptor = toDescriptor(schema); + ProtoBufProto.ProtoSchema protoSchema = ProtoSchemaConverter.convert(descriptor); + return protoSchema; + } catch (Descriptors.DescriptorValidationException e) { + throw new IllegalArgumentException("Could not build Proto-Schema from Spark schema.", e); } - - private static Descriptors.Descriptor toDescriptor (Schema schema) throws Descriptors.DescriptorValidationException { - DescriptorProtos.DescriptorProto.Builder descriptorBuilder = DescriptorProtos.DescriptorProto.newBuilder() - .setName("Schema"); - - FieldList fields = schema.getFields(); - - DescriptorProtos.DescriptorProto descriptorProto = buildDescriptorProtoWithFields(descriptorBuilder, fields, 0); - - return createDescriptorFromProto(descriptorProto); + } + + private static Descriptors.Descriptor toDescriptor(Schema schema) + throws Descriptors.DescriptorValidationException { + DescriptorProtos.DescriptorProto.Builder descriptorBuilder = + DescriptorProtos.DescriptorProto.newBuilder().setName("Schema"); + + FieldList fields = schema.getFields(); + + DescriptorProtos.DescriptorProto descriptorProto = + buildDescriptorProtoWithFields(descriptorBuilder, fields, 0); + + return createDescriptorFromProto(descriptorProto); + } + + private static Descriptors.Descriptor createDescriptorFromProto( + DescriptorProtos.DescriptorProto descriptorProto) + throws Descriptors.DescriptorValidationException { + DescriptorProtos.FileDescriptorProto fileDescriptorProto = + DescriptorProtos.FileDescriptorProto.newBuilder().addMessageType(descriptorProto).build(); + + Descriptors.Descriptor descriptor = + Descriptors.FileDescriptor.buildFrom( + fileDescriptorProto, new Descriptors.FileDescriptor[] {}) + .getMessageTypes() + .get(0); + + return descriptor; + } + + @VisibleForTesting + protected static DescriptorProtos.DescriptorProto buildDescriptorProtoWithFields( + DescriptorProtos.DescriptorProto.Builder descriptorBuilder, FieldList fields, int depth) { + Preconditions.checkArgument( + depth < MAX_BIGQUERY_NESTED_DEPTH, + "Tried to convert a BigQuery schema that exceeded BigQuery maximum nesting depth"); + int messageNumber = 1; + for (Field field : fields) { + String fieldName = field.getName(); + DescriptorProtos.FieldDescriptorProto.Label fieldLabel = toProtoFieldLabel(field.getMode()); + FieldList subFields = field.getSubFields(); + + if (field.getType() == LegacySQLTypeName.RECORD) { + String recordTypeName = + RESERVED_NESTED_TYPE_NAME + + messageNumber; // TODO: Change or assert this to be a reserved name. No column can + // have this name. + DescriptorProtos.DescriptorProto.Builder nestedFieldTypeBuilder = + descriptorBuilder.addNestedTypeBuilder(); + nestedFieldTypeBuilder.setName(recordTypeName); + DescriptorProtos.DescriptorProto nestedFieldType = + buildDescriptorProtoWithFields(nestedFieldTypeBuilder, subFields, depth + 1); + + descriptorBuilder.addField( + createProtoFieldBuilder(fieldName, fieldLabel, messageNumber) + .setTypeName(recordTypeName)); + } else { + DescriptorProtos.FieldDescriptorProto.Type fieldType = toProtoFieldType(field.getType()); + descriptorBuilder.addField( + createProtoFieldBuilder(fieldName, fieldLabel, messageNumber, fieldType)); + } + messageNumber++; } - - private static Descriptors.Descriptor createDescriptorFromProto(DescriptorProtos.DescriptorProto descriptorProto) - throws Descriptors.DescriptorValidationException { - DescriptorProtos.FileDescriptorProto fileDescriptorProto = DescriptorProtos.FileDescriptorProto - .newBuilder() - .addMessageType(descriptorProto) - .build(); - - Descriptors.Descriptor descriptor = Descriptors.FileDescriptor - .buildFrom(fileDescriptorProto, new Descriptors.FileDescriptor[]{}) - .getMessageTypes() - .get(0); - - return descriptor; + return descriptorBuilder.build(); + } + + private static DescriptorProtos.FieldDescriptorProto.Builder createProtoFieldBuilder( + String fieldName, DescriptorProtos.FieldDescriptorProto.Label fieldLabel, int messageNumber) { + return DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName(fieldName) + .setLabel(fieldLabel) + .setNumber(messageNumber); + } + + @VisibleForTesting + protected static DescriptorProtos.FieldDescriptorProto.Builder createProtoFieldBuilder( + String fieldName, + DescriptorProtos.FieldDescriptorProto.Label fieldLabel, + int messageNumber, + DescriptorProtos.FieldDescriptorProto.Type fieldType) { + return DescriptorProtos.FieldDescriptorProto.newBuilder() + .setName(fieldName) + .setLabel(fieldLabel) + .setNumber(messageNumber) + .setType(fieldType); + } + + private static DescriptorProtos.FieldDescriptorProto.Label toProtoFieldLabel(Field.Mode mode) { + switch (mode) { + case NULLABLE: + return DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL; + case REPEATED: + return DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED; + case REQUIRED: + return DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; + default: + throw new IllegalArgumentException("A BigQuery Field Mode was invalid: " + mode.name()); } - - @VisibleForTesting - protected static DescriptorProtos.DescriptorProto buildDescriptorProtoWithFields( - DescriptorProtos.DescriptorProto.Builder descriptorBuilder, FieldList fields, int depth){ - Preconditions.checkArgument(depth < MAX_BIGQUERY_NESTED_DEPTH, - "Tried to convert a BigQuery schema that exceeded BigQuery maximum nesting depth"); - int messageNumber = 1; - for (Field field : fields) { - String fieldName = field.getName(); - DescriptorProtos.FieldDescriptorProto.Label fieldLabel = toProtoFieldLabel(field.getMode()); - FieldList subFields = field.getSubFields(); - - if (field.getType() == LegacySQLTypeName.RECORD){ - String recordTypeName = "RECORD"+messageNumber; // TODO: Change or assert this to be a reserved name. No column can have this name. - DescriptorProtos.DescriptorProto.Builder nestedFieldTypeBuilder = descriptorBuilder.addNestedTypeBuilder(); - nestedFieldTypeBuilder.setName(recordTypeName); - DescriptorProtos.DescriptorProto nestedFieldType = buildDescriptorProtoWithFields( - nestedFieldTypeBuilder, subFields, depth+1); - - descriptorBuilder.addField(createProtoFieldBuilder(fieldName, fieldLabel, messageNumber) - .setTypeName(recordTypeName)); - } - else { - DescriptorProtos.FieldDescriptorProto.Type fieldType = toProtoFieldType(field.getType()); - descriptorBuilder.addField(createProtoFieldBuilder(fieldName, fieldLabel, messageNumber, fieldType)); - } - messageNumber++; - } - return descriptorBuilder.build(); + } + + // NOTE: annotations for DATETIME and TIMESTAMP objects are currently unsupported for external + // users, + // but if they become available, it would be advisable to append an annotation to the + // protoFieldBuilder + // for these and other types. + private static DescriptorProtos.FieldDescriptorProto.Type toProtoFieldType( + LegacySQLTypeName bqType) { + DescriptorProtos.FieldDescriptorProto.Type protoFieldType; + if (LegacySQLTypeName.INTEGER.equals(bqType) + || LegacySQLTypeName.DATE.equals(bqType) + || LegacySQLTypeName.DATETIME.equals(bqType) + || LegacySQLTypeName.TIMESTAMP.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64; } - - private static DescriptorProtos.FieldDescriptorProto.Builder createProtoFieldBuilder( - String fieldName, DescriptorProtos.FieldDescriptorProto.Label fieldLabel, int messageNumber) { - return DescriptorProtos.FieldDescriptorProto - .newBuilder() - .setName(fieldName) - .setLabel(fieldLabel) - .setNumber(messageNumber); + if (LegacySQLTypeName.BOOLEAN.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL; } - - @VisibleForTesting - protected static DescriptorProtos.FieldDescriptorProto.Builder createProtoFieldBuilder( - String fieldName, DescriptorProtos.FieldDescriptorProto.Label fieldLabel, int messageNumber, - DescriptorProtos.FieldDescriptorProto.Type fieldType) { - return DescriptorProtos.FieldDescriptorProto - .newBuilder() - .setName(fieldName) - .setLabel(fieldLabel) - .setNumber(messageNumber) - .setType(fieldType); + if (LegacySQLTypeName.STRING.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; } - - private static DescriptorProtos.FieldDescriptorProto.Label toProtoFieldLabel(Field.Mode mode) { - switch (mode) { - case NULLABLE: - return DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL; - case REPEATED: - return DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED; - case REQUIRED: - return DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; - default: - throw new IllegalArgumentException("A BigQuery Field Mode was invalid: "+mode.name()); - } + if (LegacySQLTypeName.GEOGRAPHY.equals(bqType) + || LegacySQLTypeName.BYTES.equals(bqType) + || LegacySQLTypeName.NUMERIC.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES; } - - // NOTE: annotations for DATETIME and TIMESTAMP objects are currently unsupported for external users, - // but if they become available, it would be advisable to append an annotation to the protoFieldBuilder - // for these and other types. - private static DescriptorProtos.FieldDescriptorProto.Type toProtoFieldType(LegacySQLTypeName bqType) { - DescriptorProtos.FieldDescriptorProto.Type protoFieldType; - if (LegacySQLTypeName.INTEGER.equals(bqType) || - LegacySQLTypeName.DATE.equals(bqType) || - LegacySQLTypeName.DATETIME.equals(bqType) || - LegacySQLTypeName.TIMESTAMP.equals(bqType)) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64; - } - if (LegacySQLTypeName.BOOLEAN.equals(bqType)){ - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL; - } - if (LegacySQLTypeName.STRING.equals(bqType)) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; - } - if (LegacySQLTypeName.GEOGRAPHY.equals(bqType) || - LegacySQLTypeName.BYTES.equals(bqType) || - LegacySQLTypeName.NUMERIC.equals(bqType)) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES; - } - if (LegacySQLTypeName.FLOAT.equals(bqType)) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE; - } - else { - if (LegacySQLTypeName.RECORD.equals(bqType)) { - throw new IllegalStateException("Program attempted to return an atomic data-type for a RECORD"); - } - throw new IllegalArgumentException("Unexpected type: " + bqType.name()); - } + if (LegacySQLTypeName.FLOAT.equals(bqType)) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE; + } else { + if (LegacySQLTypeName.RECORD.equals(bqType)) { + throw new IllegalStateException( + "Program attempted to return an atomic data-type for a RECORD"); + } + throw new IllegalArgumentException("Unexpected type: " + bqType.name()); } - - - /** - * Spark Row --> ProtoRows converter utils: - * To be used by the DataWriters facing the BigQuery Storage Write API - */ - public static ProtoBufProto.ProtoRows toProtoRows(StructType sparkSchema, InternalRow[] rows) { - try { - Descriptors.Descriptor schemaDescriptor = toDescriptor(sparkSchema); - ProtoBufProto.ProtoRows.Builder protoRows = ProtoBufProto.ProtoRows.newBuilder(); - for (InternalRow row : rows) { - DynamicMessage rowMessage = createSingleRowMessage(sparkSchema, - schemaDescriptor, row); - protoRows.addSerializedRows(rowMessage.toByteString()); - } - return protoRows.build(); - } catch (Exception e) { - throw new RuntimeException("Could not convert Internal Rows to Proto Rows.", e); - } + } + + /** + * Spark Row --> ProtoRows converter utils: To be used by the DataWriters facing the BigQuery + * Storage Write API + */ + public static ProtoBufProto.ProtoRows toProtoRows(StructType sparkSchema, InternalRow[] rows) { + try { + Descriptors.Descriptor schemaDescriptor = toDescriptor(sparkSchema); + ProtoBufProto.ProtoRows.Builder protoRows = ProtoBufProto.ProtoRows.newBuilder(); + for (InternalRow row : rows) { + DynamicMessage rowMessage = buildSingleRowMessage(sparkSchema, schemaDescriptor, row); + protoRows.addSerializedRows(rowMessage.toByteString()); + } + return protoRows.build(); + } catch (Exception e) { + throw new RuntimeException("Could not convert Internal Rows to Proto Rows.", e); } + } - public static DynamicMessage createSingleRowMessage(StructType schema, - Descriptors.Descriptor schemaDescriptor, - InternalRow row) { - - DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(schemaDescriptor); - - for(int i = 1; i <= schemaDescriptor.getFields().size(); i++) { - StructField sparkField = schema.fields()[i-1]; - DataType sparkType = sparkField.dataType(); - if (sparkType instanceof StructType) { - messageBuilder.setField(schemaDescriptor.findFieldByNumber(i), - createSingleRowMessage((StructType)sparkType, - schemaDescriptor.findNestedTypeByName(RESERVED_NESTED_TYPE_NAME +i), - (InternalRow)row.get(i-1, sparkType))); - } - else { - Object converted = convert(sparkField, row.get(i-1, sparkType)); - if (converted == null) { - continue; - } - messageBuilder.setField(schemaDescriptor.findFieldByNumber(i), - convert(sparkField, - row.get(i-1, sparkType))); - } - } + public static DynamicMessage buildSingleRowMessage( + StructType schema, Descriptors.Descriptor schemaDescriptor, InternalRow row) { + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(schemaDescriptor); - return messageBuilder.build(); - } + for (int fieldIndex = 1; fieldIndex <= schemaDescriptor.getFields().size(); fieldIndex++) { + StructField sparkField = schema.fields()[fieldIndex - 1]; + DataType sparkType = sparkField.dataType(); - public static Descriptors.Descriptor toDescriptor (StructType schema) - throws Descriptors.DescriptorValidationException { - DescriptorProtos.DescriptorProto.Builder descriptorBuilder = DescriptorProtos.DescriptorProto.newBuilder() - .setName("Schema"); + Object sparkValue = row.get(fieldIndex - 1, sparkType); + boolean nullable = sparkField.nullable(); + Descriptors.Descriptor nestedTypeDescriptor = + schemaDescriptor.findNestedTypeByName(RESERVED_NESTED_TYPE_NAME + fieldIndex); + Object protoValue = + convertToProtoRowValue(sparkType, sparkValue, nullable, nestedTypeDescriptor); - StructField[] fields = schema.fields(); + logger.debug("Converted value {} to proto-value: {}", sparkValue, protoValue); - DescriptorProtos.DescriptorProto descriptorProto = buildDescriptorProtoWithFields(descriptorBuilder, fields, 0); + if (protoValue == null) { + continue; + } - return createDescriptorFromProto(descriptorProto); + messageBuilder.setField(schemaDescriptor.findFieldByNumber(fieldIndex), protoValue); } - @VisibleForTesting - protected static Object convert (StructField sparkField, Object sparkValue) { - if (sparkValue == null) { - if (!sparkField.nullable()) { - throw new IllegalArgumentException("Non-nullable field was null."); - } - else { - return null; - } - } - - DataType fieldType = sparkField.dataType(); - - if (fieldType instanceof ArrayType) { - ArrayType arrayType = (ArrayType)fieldType; - boolean containsNull = arrayType.containsNull(); // elements can be null. - DataType elementType = arrayType.elementType(); - - ArrayData arrayData = (ArrayData)sparkValue; - Object[] sparkValues = arrayData.toObjectArray(elementType); - - return Arrays.stream(sparkValues).map(value -> { - Preconditions.checkArgument(containsNull || value != null, - "Encountered a null value inside a non-null-containing array."); - return toAtomicProtoRowValue(elementType, value); - } ).collect(Collectors.toList()); - } - else { - return toAtomicProtoRowValue(fieldType, sparkValue); - } + return messageBuilder.build(); + } + + public static Descriptors.Descriptor toDescriptor(StructType schema) + throws Descriptors.DescriptorValidationException { + DescriptorProtos.DescriptorProto.Builder descriptorBuilder = + DescriptorProtos.DescriptorProto.newBuilder().setName("Schema"); + + DescriptorProtos.DescriptorProto descriptorProto = + buildDescriptorProtoWithFields(descriptorBuilder, schema.fields(), 0); + + return createDescriptorFromProto(descriptorProto); + } + + /* + Takes a value in Spark format and converts it into ProtoRows format (to eventually be given to BigQuery). + */ + private static Object convertToProtoRowValue( + DataType sparkType, + Object sparkValue, + boolean nullable, + Descriptors.Descriptor nestedTypeDescriptor) { + logger.debug("Converting type: {}", sparkType.json()); + if (sparkValue == null) { + if (!nullable) { + throw new IllegalArgumentException("Non-nullable field was null."); + } else { + return null; + } } - /* - Takes a value in Spark format and converts it into ProtoRows format (to eventually be given to BigQuery). - */ - private static Object toAtomicProtoRowValue(DataType sparkType, Object value) { - if (sparkType instanceof ByteType || - sparkType instanceof ShortType || - sparkType instanceof IntegerType || - sparkType instanceof LongType) { - return ((Number)value).longValue(); - } - - if (sparkType instanceof FloatType || - sparkType instanceof DoubleType || - sparkType instanceof DecimalType) { - return ((Number)value).doubleValue(); // TODO: should decimal be converted to double? Or a Bytes type containing extra width? - } - - if (sparkType instanceof TimestampType) { - return Timestamp.valueOf((String)value).getTime(); // + if (sparkType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) sparkType; + DataType elementType = arrayType.elementType(); + Object[] sparkArrayData = ((ArrayData) sparkValue).toObjectArray(elementType); + boolean containsNull = arrayType.containsNull(); + List protoValue = new ArrayList<>(); + for (Object sparkElement : sparkArrayData) { + Object converted = + convertToProtoRowValue(elementType, sparkElement, containsNull, nestedTypeDescriptor); + if (converted == null) { + continue; } + protoValue.add(converted); + } + return protoValue; + } - if (sparkType instanceof DateType) { - return Date.valueOf((String)value).getTime(); - } // TODO: CalendarInterval - - if (sparkType instanceof BooleanType || - sparkType instanceof BinaryType) { - return value; - } + if (sparkType instanceof StructType) { + return buildSingleRowMessage( + (StructType) sparkType, nestedTypeDescriptor, (InternalRow) sparkValue); + } - if (sparkType instanceof StringType) { - return new String(((UTF8String)value).getBytes()); - } + if (sparkType instanceof ByteType + || sparkType instanceof ShortType + || sparkType instanceof IntegerType + || sparkType instanceof LongType + || sparkType instanceof TimestampType + || sparkType instanceof DateType) { + return ((Number) sparkValue).longValue(); + } // TODO: CalendarInterval + + if (sparkType instanceof FloatType || sparkType instanceof DoubleType) { + return ((Number) sparkValue).doubleValue(); + } - if (sparkType instanceof MapType) { - throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); - } + if (sparkType instanceof DecimalType) { + return ((Decimal) sparkValue).toDouble(); + } - throw new IllegalStateException("Unexpected type: " + sparkType); + if (sparkType instanceof BooleanType || sparkType instanceof BinaryType) { + return sparkValue; } - private static DescriptorProtos.DescriptorProto buildDescriptorProtoWithFields( - DescriptorProtos.DescriptorProto.Builder descriptorBuilder, StructField[] fields, int depth) { - Preconditions.checkArgument(depth < MAX_BIGQUERY_NESTED_DEPTH, - "Spark Schema exceeds BigQuery maximum nesting depth."); - int messageNumber = 1; - for (StructField field : fields) { - String fieldName = field.name(); - DescriptorProtos.FieldDescriptorProto.Label fieldLabel = field.nullable() ? - DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL : - DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; - DescriptorProtos.FieldDescriptorProto.Type fieldType; - - DataType sparkType = field.dataType(); - if (sparkType instanceof StructType) { - StructType structType = (StructType)sparkType; - String nestedName = RESERVED_NESTED_TYPE_NAME +messageNumber; // TODO: this should be a reserved name. No column can have this name. - StructField[] subFields = structType.fields(); - - DescriptorProtos.DescriptorProto.Builder nestedFieldTypeBuilder = descriptorBuilder.addNestedTypeBuilder() - .setName(nestedName); - buildDescriptorProtoWithFields(nestedFieldTypeBuilder, subFields, depth+1); - - descriptorBuilder.addField(createProtoFieldBuilder(fieldName, fieldLabel, messageNumber) - .setTypeName(nestedName)); - messageNumber++; - continue; - } - - if (sparkType instanceof ArrayType) { - ArrayType arrayType = (ArrayType)sparkType; - /* DescriptorProtos.FieldDescriptorProto.Label elementLabel = arrayType.containsNull() ? - DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL : - DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; TODO: how to support null instances inside an array (repeated field) in BigQuery?*/ - fieldType = sparkAtomicTypeToProtoFieldType(arrayType.elementType()); - fieldLabel = DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED; - - } else { - fieldType = sparkAtomicTypeToProtoFieldType(sparkType); - } - descriptorBuilder.addField( - createProtoFieldBuilder(fieldName, fieldLabel, messageNumber, fieldType)); - messageNumber++; - } - return descriptorBuilder.build(); + if (sparkType instanceof StringType) { + return new String(((UTF8String) sparkValue).getBytes()); } - // NOTE: annotations for DATETIME and TIMESTAMP objects are currently unsupported for external users, - // but if they become available, it would be advisable to append an annotation to the protoFieldBuilder - // for these and other types. - // This function only converts atomic Spark DataTypes - private static DescriptorProtos.FieldDescriptorProto.Type sparkAtomicTypeToProtoFieldType(DataType sparkType) { - if (sparkType instanceof ByteType || - sparkType instanceof ShortType || - sparkType instanceof IntegerType || - sparkType instanceof LongType || - sparkType instanceof TimestampType || - sparkType instanceof DateType) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64; - } + if (sparkType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); + } - if (sparkType instanceof FloatType || - sparkType instanceof DoubleType || - sparkType instanceof DecimalType) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE; - // TODO: an annotation to distinguish between decimals that are doubles, and decimals that are NUMERIC (Bytes types) - } + throw new IllegalStateException("Unexpected type: " + sparkType); + } + + private static DescriptorProtos.DescriptorProto buildDescriptorProtoWithFields( + DescriptorProtos.DescriptorProto.Builder descriptorBuilder, StructField[] fields, int depth) { + Preconditions.checkArgument( + depth < MAX_BIGQUERY_NESTED_DEPTH, "Spark Schema exceeds BigQuery maximum nesting depth."); + int messageNumber = 1; + for (StructField field : fields) { + String fieldName = field.name(); + DescriptorProtos.FieldDescriptorProto.Label fieldLabel = + field.nullable() + ? DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL + : DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; + + DataType sparkType = field.dataType(); + + if (sparkType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) sparkType; + /* DescriptorProtos.FieldDescriptorProto.Label elementLabel = arrayType.containsNull() ? + DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL : + DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED; TODO: how to support null instances inside an array (repeated field) in BigQuery?*/ + sparkType = arrayType.elementType(); + fieldLabel = DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED; + } + + DescriptorProtos.FieldDescriptorProto.Builder protoFieldBuilder; + if (sparkType instanceof StructType) { + StructType structType = (StructType) sparkType; + String nestedName = + RESERVED_NESTED_TYPE_NAME + + messageNumber; // TODO: this should be a reserved name. No column can have this + // name. + StructField[] subFields = structType.fields(); + + DescriptorProtos.DescriptorProto.Builder nestedFieldTypeBuilder = + descriptorBuilder.addNestedTypeBuilder().setName(nestedName); + buildDescriptorProtoWithFields(nestedFieldTypeBuilder, subFields, depth + 1); + + protoFieldBuilder = + createProtoFieldBuilder(fieldName, fieldLabel, messageNumber).setTypeName(nestedName); + } else { + DescriptorProtos.FieldDescriptorProto.Type fieldType = + sparkAtomicTypeToProtoFieldType(sparkType); + protoFieldBuilder = + createProtoFieldBuilder(fieldName, fieldLabel, messageNumber, fieldType); + } + descriptorBuilder.addField(protoFieldBuilder); + messageNumber++; + } + return descriptorBuilder.build(); + } + + // NOTE: annotations for DATETIME and TIMESTAMP objects are currently unsupported for external + // users, + // but if they become available, it would be advisable to append an annotation to the + // protoFieldBuilder + // for these and other types. + // This function only converts atomic Spark DataTypes + private static DescriptorProtos.FieldDescriptorProto.Type sparkAtomicTypeToProtoFieldType( + DataType sparkType) { + if (sparkType instanceof ByteType + || sparkType instanceof ShortType + || sparkType instanceof IntegerType + || sparkType instanceof LongType + || sparkType instanceof TimestampType + || sparkType instanceof DateType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64; + } - if (sparkType instanceof BooleanType) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL; - } + if (sparkType instanceof FloatType + || sparkType instanceof DoubleType + || sparkType instanceof DecimalType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE; + /* TODO: an annotation to distinguish between decimals that are doubles, and decimals that are + NUMERIC (Bytes types) */ + } - if (sparkType instanceof BinaryType) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES; - } + if (sparkType instanceof BooleanType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL; + } - if (sparkType instanceof StringType) { - return DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; - } + if (sparkType instanceof BinaryType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES; + } - if (sparkType instanceof MapType) { - throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); - } + if (sparkType instanceof StringType) { + return DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; + } - throw new IllegalStateException("Unexpected type: " + sparkType); + if (sparkType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); } + + throw new IllegalStateException("Unexpected type: " + sparkType); + } } diff --git a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java index 66889e1532..86575f6eda 100644 --- a/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java +++ b/connector/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java @@ -39,9 +39,9 @@ public class SchemaConverters { private static final int BQ_NUMERIC_SCALE = 9; private static final DecimalType NUMERIC_SPARK_TYPE = DataTypes.createDecimalType(BQ_NUMERIC_PRECISION, BQ_NUMERIC_SCALE); - // The maximum nesting depth of a BigQuery RECORD: - private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; - private static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported."; + // The maximum nesting depth of a BigQuery RECORD: + private static final int MAX_BIGQUERY_NESTED_DEPTH = 15; + private static final String MAPTYPE_ERROR_MESSAGE = "MapType is unsupported."; /** Convert a BigQuery schema to a Spark schema */ public static StructType toSpark(Schema schema) { @@ -205,103 +205,105 @@ private static DataType getDataType(Field field) { } } + /** Spark ==> BigQuery Schema Converter utils: */ + public static Schema toBigQuerySchema(StructType sparkSchema) { + FieldList bigQueryFields = sparkToBigQueryFields(sparkSchema, 0); + return Schema.of(bigQueryFields); + } - /** - * Spark ==> BigQuery Schema Converter utils: - */ - public static Schema toBigQuerySchema (StructType sparkSchema) { - FieldList bigQueryFields = sparkToBigQueryFields(sparkSchema, 0); - return Schema.of(bigQueryFields); + /** + * Returns a FieldList of all the Spark StructField objects, converted to BigQuery Field objects + */ + private static FieldList sparkToBigQueryFields(StructType sparkStruct, int depth) { + Preconditions.checkArgument( + depth < MAX_BIGQUERY_NESTED_DEPTH, "Spark Schema exceeds BigQuery maximum nesting depth."); + List bqFields = new ArrayList<>(); + for (StructField field : sparkStruct.fields()) { + bqFields.add(makeBigQueryColumn(field, depth)); } + return FieldList.of(bqFields); + } - /** - * Returns a FieldList of all the Spark StructField objects, converted to BigQuery Field objects - */ - private static FieldList sparkToBigQueryFields (StructType sparkStruct, int depth){ - Preconditions.checkArgument(depth < MAX_BIGQUERY_NESTED_DEPTH, - "Spark Schema exceeds BigQuery maximum nesting depth."); - List bqFields = new ArrayList<>(); - for (StructField field : sparkStruct.fields()){ - bqFields.add(makeBigQueryColumn(field, depth)); - } - return FieldList.of(bqFields); + /** Converts a single StructField to a BigQuery Field (column). */ + @VisibleForTesting + protected static Field makeBigQueryColumn(StructField sparkField, int depth) { + DataType sparkType = sparkField.dataType(); + String fieldName = sparkField.name(); + Field.Mode fieldMode = (sparkField.nullable()) ? Field.Mode.NULLABLE : Field.Mode.REQUIRED; + String description; + FieldList subFields = null; + LegacySQLTypeName fieldType; + + if (sparkType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) sparkType; + + fieldMode = Field.Mode.REPEATED; + sparkType = arrayType.elementType(); } - /** - * Converts a single StructField to a BigQuery Field (column). - */ - @VisibleForTesting - protected static Field makeBigQueryColumn (StructField sparkField, int depth) { - DataType sparkType = sparkField.dataType(); - String fieldName = sparkField.name(); - Field.Mode fieldMode = (sparkField.nullable()) ? Field.Mode.NULLABLE : Field.Mode.REQUIRED; - String description; - FieldList subFields = null; - LegacySQLTypeName fieldType; - - if (sparkType instanceof ArrayType) { - ArrayType arrayType = (ArrayType)sparkType; - LegacySQLTypeName elementType = toBigQueryType(arrayType.elementType()); - fieldType = elementType; - fieldMode = Field.Mode.REPEATED; - } - else if (sparkType instanceof StructType) { - subFields = sparkToBigQueryFields((StructType)sparkType, depth+1); - fieldType = LegacySQLTypeName.RECORD; - } - else { - fieldType = toBigQueryType(sparkType); - } - - try { - description = sparkField.metadata().getString("description"); - } - catch (NoSuchElementException e) { - return createBigQueryFieldBuilder(fieldName, fieldType, fieldMode, subFields).build(); - } - - return createBigQueryFieldBuilder(fieldName, fieldType, fieldMode, subFields) - .setDescription(description).build(); + if (sparkType instanceof StructType) { + subFields = sparkToBigQueryFields((StructType) sparkType, depth + 1); + fieldType = LegacySQLTypeName.RECORD; + } else { + fieldType = toBigQueryType(sparkType); } - @VisibleForTesting - protected static LegacySQLTypeName toBigQueryType (DataType elementType) { - if (elementType instanceof BinaryType) { - return LegacySQLTypeName.BYTES; - } if (elementType instanceof ByteType || - elementType instanceof ShortType || - elementType instanceof IntegerType || - elementType instanceof LongType) { - return LegacySQLTypeName.INTEGER; - } if (elementType instanceof BooleanType) { - return LegacySQLTypeName.BOOLEAN; - } if (elementType instanceof FloatType || - elementType instanceof DoubleType) { - return LegacySQLTypeName.FLOAT; - } if (elementType instanceof DecimalType) { - DecimalType decimalType = (DecimalType)elementType; - if (decimalType.precision() <= BQ_NUMERIC_PRECISION && - decimalType.scale() <= BQ_NUMERIC_SCALE) { - return LegacySQLTypeName.NUMERIC; - } else { - throw new IllegalArgumentException("Decimal type is too wide to fit in BigQuery Numeric format"); - } - } if (elementType instanceof StringType) { - return LegacySQLTypeName.STRING; - } if (elementType instanceof TimestampType) { - return LegacySQLTypeName.TIMESTAMP; - } if (elementType instanceof DateType) { - return LegacySQLTypeName.DATE; - } if (elementType instanceof MapType) { - throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); - } - else { - throw new IllegalArgumentException("Data type not expected in toBQType: "+elementType.simpleString()); - } + try { + description = sparkField.metadata().getString("description"); + } catch (NoSuchElementException e) { + return createBigQueryFieldBuilder(fieldName, fieldType, fieldMode, subFields).build(); } - private static Field.Builder createBigQueryFieldBuilder (String name, LegacySQLTypeName type, Field.Mode mode, FieldList subFields){ - return Field.newBuilder(name, type, subFields) - .setMode(mode); + return createBigQueryFieldBuilder(fieldName, fieldType, fieldMode, subFields) + .setDescription(description) + .build(); + } + + @VisibleForTesting + protected static LegacySQLTypeName toBigQueryType(DataType elementType) { + if (elementType instanceof BinaryType) { + return LegacySQLTypeName.BYTES; + } + if (elementType instanceof ByteType + || elementType instanceof ShortType + || elementType instanceof IntegerType + || elementType instanceof LongType) { + return LegacySQLTypeName.INTEGER; + } + if (elementType instanceof BooleanType) { + return LegacySQLTypeName.BOOLEAN; + } + if (elementType instanceof FloatType || elementType instanceof DoubleType) { + return LegacySQLTypeName.FLOAT; } + if (elementType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) elementType; + if (decimalType.precision() <= BQ_NUMERIC_PRECISION + && decimalType.scale() <= BQ_NUMERIC_SCALE) { + return LegacySQLTypeName.NUMERIC; + } else { + throw new IllegalArgumentException( + "Decimal type is too wide to fit in BigQuery Numeric format"); + } + } + if (elementType instanceof StringType) { + return LegacySQLTypeName.STRING; + } + if (elementType instanceof TimestampType) { + return LegacySQLTypeName.TIMESTAMP; + } + if (elementType instanceof DateType) { + return LegacySQLTypeName.DATE; + } + if (elementType instanceof MapType) { + throw new IllegalArgumentException(MAPTYPE_ERROR_MESSAGE); + } else { + throw new IllegalArgumentException("Data type not expected: " + elementType.simpleString()); + } + } + + private static Field.Builder createBigQueryFieldBuilder( + String name, LegacySQLTypeName type, Field.Mode mode, FieldList subFields) { + return Field.newBuilder(name, type, subFields).setMode(mode); + } } diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java index f1cfe37799..9542498f0e 100644 --- a/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/ProtobufUtilsTest.java @@ -24,6 +24,7 @@ import com.google.protobuf.DescriptorProtos; import com.google.protobuf.Descriptors; import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; import org.apache.log4j.Level; import org.apache.log4j.LogManager; import org.apache.log4j.Logger; @@ -31,14 +32,10 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.ByteArray; import org.apache.spark.unsafe.types.UTF8String; import org.junit.AssumptionViolatedException; import org.junit.Test; -import java.sql.Date; -import java.sql.Timestamp; - import static com.google.cloud.spark.bigquery.ProtobufUtils.*; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; @@ -86,6 +83,7 @@ public void testBigQueryToProtoSchema() throws Exception { .addField(PROTO_BOOLEAN_FIELD.clone().setNumber(6)) .addField(PROTO_BYTES_FIELD.clone().setNumber(7)) .addField(PROTO_INTEGER_FIELD.clone().setName("Date").setNumber(8)) + .addField(PROTO_INTEGER_FIELD.clone().setName("TimeStamp").setNumber(9)) .setName("Schema").build() ).build(), new Descriptors.FileDescriptor[]{} ).getMessageTypes().get(0) @@ -99,73 +97,13 @@ public void testBigQueryToProtoSchema() throws Exception { } } - @Test - public void testSparkIntegerSchemaToDescriptor() throws Exception { - logger.setLevel(Level.DEBUG); - - StructType schema = new StructType().add(SPARK_INTEGER_FIELD); - DescriptorProtos.DescriptorProto converted = toDescriptor(schema).toProto(); - - DescriptorProtos.DescriptorProto expected = DESCRIPTOR_PROTO_INTEGER; - - assertThat(converted).isEqualTo(expected); - } - - @Test - public void testSparkStringSchemaToDescriptor() throws Exception { - logger.setLevel(Level.DEBUG); - - StructType schema = new StructType().add(SPARK_STRING_FIELD); - DescriptorProtos.DescriptorProto converted = toDescriptor(schema).toProto(); - - DescriptorProtos.DescriptorProto expected = DESCRIPTOR_PROTO_STRING; - - assertThat(converted).isEqualTo(expected); - } - - @Test - public void testSparkArraySchemaToDescriptor() throws Exception { - logger.setLevel(Level.DEBUG); - - StructType schema = new StructType().add(SPARK_ARRAY_FIELD); - DescriptorProtos.DescriptorProto converted = toDescriptor(schema).toProto(); - - DescriptorProtos.DescriptorProto expected = DESCRIPTOR_PROTO_ARRAY; - - assertThat(converted).isEqualTo(expected); - } - - @Test - public void testSparkNestedStructSchemaToDescriptor() throws Exception { - logger.setLevel(Level.DEBUG); - - StructType schema = new StructType().add(SPARK_NESTED_STRUCT_FIELD); - DescriptorProtos.DescriptorProto converted = toDescriptor(schema).toProto(); - - DescriptorProtos.DescriptorProto expected = DESCRIPTOR_PROTO_STRUCT; - - assertThat(converted).isEqualTo(expected); - } - - @Test - public void testSparkArrayRowToDynamicMessage() throws Exception { - logger.setLevel(Level.DEBUG); - - StructType schema = new StructType().add(SPARK_ARRAY_FIELD); - DynamicMessage converted = createSingleRowMessage(schema, toDescriptor(schema), - ARRAY_INTERNAL_ROW); - DynamicMessage expected = ARRAY_ROW_MESSAGE; - - assertThat(converted.toString()).isEqualTo(expected.toString()); - } - @Test public void testSparkStructRowToDynamicMessage() throws Exception { logger.setLevel(Level.DEBUG); StructType schema = new StructType().add(SPARK_NESTED_STRUCT_FIELD); - DynamicMessage converted = createSingleRowMessage(schema, toDescriptor(schema), - STRUCT_INTERNAL_ROW); + Descriptors.Descriptor schemaDescriptor = toDescriptor(schema); + Message converted = buildSingleRowMessage(schema, schemaDescriptor, STRUCT_INTERNAL_ROW); DynamicMessage expected = StructRowMessage; assertThat(converted.toString()).isEqualTo(expected.toString()); @@ -185,7 +123,8 @@ public void testSparkRowToProtoRow() throws Exception { 3.14, true, new byte[]{11, 0x7F}, - "2020-07-07" + 1594080000000L, + 1594080000000L })} ); @@ -199,12 +138,17 @@ public void testSettingARequiredFieldAsNull() throws Exception { logger.setLevel(Level.DEBUG); try { - convert(SPARK_STRING_FIELD, null); + ProtoBufProto.ProtoRows converted = toProtoRows(new StructType() + .add(new StructField("String", DataTypes.StringType, false, Metadata.empty())), + new InternalRow[]{ + new GenericInternalRow(new Object[]{null})}); fail("Convert did not assert field's /'Required/' status"); - } catch (IllegalArgumentException e){} + } catch (Exception e){} try { - convert(new StructField("String", DataTypes.StringType, true, Metadata.empty()), - null); + ProtoBufProto.ProtoRows converted = toProtoRows(new StructType() + .add(new StructField("String", DataTypes.StringType, true, Metadata.empty())), + new InternalRow[]{ + new GenericInternalRow(new Object[]{null})}); } catch (Exception e) { fail("A nullable field could not be set to null."); } @@ -212,7 +156,7 @@ public void testSettingARequiredFieldAsNull() throws Exception { - private final StructType MY_STRUCT = DataTypes.createStructType( + public final StructType MY_STRUCT = DataTypes.createStructType( new StructField[]{new StructField("Number", DataTypes.IntegerType, true, Metadata.empty()), new StructField("String", DataTypes.StringType, @@ -235,8 +179,7 @@ public void testSettingARequiredFieldAsNull() throws Exception { true, Metadata.empty()); public final StructField SPARK_DATE_FIELD = new StructField("Date", DataTypes.DateType, true, Metadata.empty()); - public final StructField SPARK_MAP_FIELD = new StructField("Map", - DataTypes.createMapType(DataTypes.IntegerType, DataTypes.StringType), + public final StructField SPARK_TIMESTAMP_FIELD = new StructField("TimeStamp", DataTypes.TimestampType, true, Metadata.empty()); public final StructType BIG_SPARK_SCHEMA = new StructType() @@ -247,7 +190,8 @@ public void testSettingARequiredFieldAsNull() throws Exception { .add(SPARK_DOUBLE_FIELD) .add(SPARK_BOOLEAN_FIELD) .add(SPARK_BINARY_FIELD) - .add(SPARK_DATE_FIELD); + .add(SPARK_DATE_FIELD) + .add(SPARK_TIMESTAMP_FIELD); public final Field BIGQUERY_INTEGER_FIELD = Field.newBuilder("Number", LegacySQLTypeName.INTEGER, @@ -270,72 +214,75 @@ public void testSettingARequiredFieldAsNull() throws Exception { .setMode(Field.Mode.NULLABLE).build(); public final Field BIGQUERY_DATE_FIELD = Field.newBuilder("Date", LegacySQLTypeName.DATE, (FieldList)null) .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_TIMESTAMP_FIELD = Field.newBuilder("TimeStamp", LegacySQLTypeName.TIMESTAMP, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); public final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, - BIGQUERY_NESTED_STRUCT_FIELD, BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD, BIGQUERY_BYTES_FIELD, BIGQUERY_DATE_FIELD); + BIGQUERY_NESTED_STRUCT_FIELD, BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD, BIGQUERY_BYTES_FIELD, BIGQUERY_DATE_FIELD, + BIGQUERY_TIMESTAMP_FIELD); - private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_INTEGER_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + public final DescriptorProtos.FieldDescriptorProto.Builder PROTO_INTEGER_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() .setName("Number") .setNumber(1) .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64) .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); - private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_STRING_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + public final DescriptorProtos.FieldDescriptorProto.Builder PROTO_STRING_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() .setName("String") .setNumber(1) .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING) .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED); - private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_ARRAY_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + public final DescriptorProtos.FieldDescriptorProto.Builder PROTO_ARRAY_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() .setName("Array") .setNumber(1) .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64) .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED); - private final DescriptorProtos.DescriptorProto.Builder NESTED_STRUCT_DESCRIPTOR = DescriptorProtos.DescriptorProto.newBuilder() + public final DescriptorProtos.DescriptorProto.Builder NESTED_STRUCT_DESCRIPTOR = DescriptorProtos.DescriptorProto.newBuilder() .setName("STRUCT1") .addField(PROTO_INTEGER_FIELD.clone()) .addField(PROTO_STRING_FIELD.clone().setNumber(2) .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL)); - private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_STRUCT_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + public final DescriptorProtos.FieldDescriptorProto.Builder PROTO_STRUCT_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() .setName("Struct") .setNumber(1) .setTypeName("STRUCT1") .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); - private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_DOUBLE_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + public final DescriptorProtos.FieldDescriptorProto.Builder PROTO_DOUBLE_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() .setName("Double") .setNumber(1) .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE) .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); - private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_BOOLEAN_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + public final DescriptorProtos.FieldDescriptorProto.Builder PROTO_BOOLEAN_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() .setName("Boolean") .setNumber(1) .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL) .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); - private final DescriptorProtos.FieldDescriptorProto.Builder PROTO_BYTES_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() + public final DescriptorProtos.FieldDescriptorProto.Builder PROTO_BYTES_FIELD = DescriptorProtos.FieldDescriptorProto.newBuilder() .setName("Binary") .setNumber(1) .setType(DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES) .setLabel(DescriptorProtos.FieldDescriptorProto.Label.LABEL_OPTIONAL); - private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_INTEGER = DescriptorProtos.DescriptorProto.newBuilder() + public final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_INTEGER = DescriptorProtos.DescriptorProto.newBuilder() .addField(PROTO_INTEGER_FIELD).setName("Schema").build(); - private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_STRING = DescriptorProtos.DescriptorProto.newBuilder() + public final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_STRING = DescriptorProtos.DescriptorProto.newBuilder() .addField(PROTO_STRING_FIELD).setName("Schema").build(); - private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_ARRAY = DescriptorProtos.DescriptorProto.newBuilder() + public final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_ARRAY = DescriptorProtos.DescriptorProto.newBuilder() .addField(PROTO_ARRAY_FIELD).setName("Schema").build(); - private final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_STRUCT = DescriptorProtos.DescriptorProto.newBuilder() + public final DescriptorProtos.DescriptorProto DESCRIPTOR_PROTO_STRUCT = DescriptorProtos.DescriptorProto.newBuilder() .addNestedType(NESTED_STRUCT_DESCRIPTOR).addField(PROTO_STRUCT_FIELD).setName("Schema").build(); - private final InternalRow INTEGER_INTERNAL_ROW = new GenericInternalRow(new Object[]{1}); - private final InternalRow STRING_INTERNAL_ROW = new GenericInternalRow(new Object[]{UTF8String.fromString("A")}); - private final InternalRow ARRAY_INTERNAL_ROW = new GenericInternalRow(new Object[]{ArrayData.toArrayData( + public final InternalRow INTEGER_INTERNAL_ROW = new GenericInternalRow(new Object[]{1}); + public final InternalRow STRING_INTERNAL_ROW = new GenericInternalRow(new Object[]{UTF8String.fromString("A")}); + public final InternalRow ARRAY_INTERNAL_ROW = new GenericInternalRow(new Object[]{ArrayData.toArrayData( new int[]{0,1,2})}); - private final InternalRow INTERNAL_STRUCT_DATA = new GenericInternalRow(new Object[]{1, UTF8String.fromString("A")}); - private final InternalRow STRUCT_INTERNAL_ROW = new GenericInternalRow(new Object[]{INTERNAL_STRUCT_DATA}); + public final InternalRow INTERNAL_STRUCT_DATA = new GenericInternalRow(new Object[]{1, UTF8String.fromString("A")}); + public final InternalRow STRUCT_INTERNAL_ROW = new GenericInternalRow(new Object[]{INTERNAL_STRUCT_DATA}); - private Descriptors.Descriptor INTEGER_SCHEMA_DESCRIPTOR = createIntegerSchemaDescriptor(); - private Descriptors.Descriptor createIntegerSchemaDescriptor() { + public Descriptors.Descriptor INTEGER_SCHEMA_DESCRIPTOR = createIntegerSchemaDescriptor(); + public Descriptors.Descriptor createIntegerSchemaDescriptor() { try { return toDescriptor( new StructType().add(SPARK_INTEGER_FIELD) @@ -344,8 +291,8 @@ private Descriptors.Descriptor createIntegerSchemaDescriptor() { throw new AssumptionViolatedException("Could not create INTEGER_SCHEMA_DESCRIPTOR", e); } } - private Descriptors.Descriptor STRING_SCHEMA_DESCRIPTOR = createStringSchemaDescriptor(); - private Descriptors.Descriptor createStringSchemaDescriptor() { + public Descriptors.Descriptor STRING_SCHEMA_DESCRIPTOR = createStringSchemaDescriptor(); + public Descriptors.Descriptor createStringSchemaDescriptor() { try { return toDescriptor( new StructType().add(SPARK_STRING_FIELD) @@ -354,8 +301,8 @@ private Descriptors.Descriptor createStringSchemaDescriptor() { throw new AssumptionViolatedException("Could not create STRING_SCHEMA_DESCRIPTOR", e); } } - private Descriptors.Descriptor ARRAY_SCHEMA_DESCRIPTOR = createArraySchemaDescriptor(); - private Descriptors.Descriptor createArraySchemaDescriptor() { + public Descriptors.Descriptor ARRAY_SCHEMA_DESCRIPTOR = createArraySchemaDescriptor(); + public Descriptors.Descriptor createArraySchemaDescriptor() { try { return toDescriptor( new StructType().add(SPARK_ARRAY_FIELD) @@ -364,8 +311,8 @@ private Descriptors.Descriptor createArraySchemaDescriptor() { throw new AssumptionViolatedException("Could not create ARRAY_SCHEMA_DESCRIPTOR", e); } } - private Descriptors.Descriptor STRUCT_SCHEMA_DESCRIPTOR = createStructSchemaDescriptor(); - private Descriptors.Descriptor createStructSchemaDescriptor() { + public Descriptors.Descriptor STRUCT_SCHEMA_DESCRIPTOR = createStructSchemaDescriptor(); + public Descriptors.Descriptor createStructSchemaDescriptor() { try { return toDescriptor( new StructType().add(SPARK_NESTED_STRUCT_FIELD) @@ -375,56 +322,53 @@ private Descriptors.Descriptor createStructSchemaDescriptor() { } } + Descriptors.Descriptor STRUCT_DESCRIPTOR = makeStructDescriptor(); + public Descriptors.Descriptor makeStructDescriptor() throws AssumptionViolatedException { + try { + return toDescriptor(MY_STRUCT); + } + catch(Descriptors.DescriptorValidationException e) { + throw new AssumptionViolatedException("Could not create STRUCT_DESCRIPTOR.", e); + } + } + - private final DynamicMessage INTEGER_ROW_MESSAGE = DynamicMessage.newBuilder(INTEGER_SCHEMA_DESCRIPTOR) + public final DynamicMessage INTEGER_ROW_MESSAGE = DynamicMessage.newBuilder(INTEGER_SCHEMA_DESCRIPTOR) .setField(INTEGER_SCHEMA_DESCRIPTOR.findFieldByNumber(1), 1L).build(); - private final DynamicMessage STRING_ROW_MESSAGE = DynamicMessage.newBuilder(STRING_SCHEMA_DESCRIPTOR) + public final DynamicMessage STRING_ROW_MESSAGE = DynamicMessage.newBuilder(STRING_SCHEMA_DESCRIPTOR) .setField(STRING_SCHEMA_DESCRIPTOR.findFieldByNumber(1), "A").build(); - private final DynamicMessage ARRAY_ROW_MESSAGE = DynamicMessage.newBuilder(ARRAY_SCHEMA_DESCRIPTOR) + public final DynamicMessage ARRAY_ROW_MESSAGE = DynamicMessage.newBuilder(ARRAY_SCHEMA_DESCRIPTOR) .addRepeatedField(ARRAY_SCHEMA_DESCRIPTOR.findFieldByNumber(1), 0L) .addRepeatedField(ARRAY_SCHEMA_DESCRIPTOR.findFieldByNumber(1), 1L) .addRepeatedField(ARRAY_SCHEMA_DESCRIPTOR.findFieldByNumber(1), 2L).build(); - private DynamicMessage StructRowMessage = createStructRowMessage(); - private DynamicMessage createStructRowMessage() { - try{ - return DynamicMessage.newBuilder(STRUCT_SCHEMA_DESCRIPTOR) - .setField(STRUCT_SCHEMA_DESCRIPTOR.findFieldByNumber(1), createSingleRowMessage( - MY_STRUCT, toDescriptor(MY_STRUCT), INTERNAL_STRUCT_DATA - )).build(); - } catch (Descriptors.DescriptorValidationException e) { - throw new AssumptionViolatedException("Could not create STRUCT_ROW_MESSAGE", e); - } - } + public DynamicMessage StructRowMessage = DynamicMessage.newBuilder(STRUCT_SCHEMA_DESCRIPTOR) + .setField(STRUCT_SCHEMA_DESCRIPTOR.findFieldByNumber(1), buildSingleRowMessage( + MY_STRUCT, STRUCT_DESCRIPTOR, INTERNAL_STRUCT_DATA + )).build(); - private Descriptors.Descriptor BIG_SCHEMA_ROW_DESCRIPTOR = createBigSchemaRowDescriptor(); - private Descriptors.Descriptor createBigSchemaRowDescriptor() { + public Descriptors.Descriptor BIG_SCHEMA_ROW_DESCRIPTOR = createBigSchemaRowDescriptor(); + public Descriptors.Descriptor createBigSchemaRowDescriptor() { try { return toDescriptor(BIG_SPARK_SCHEMA); } catch (Descriptors.DescriptorValidationException e) { throw new AssumptionViolatedException("Could not create BIG_SCHEMA_ROW_DESCRIPTOR", e); } } - private ProtoBufProto.ProtoRows MY_PROTO_ROWS = createMyProtoRows(); - private ProtoBufProto.ProtoRows createMyProtoRows() { - try { - return ProtoBufProto.ProtoRows.newBuilder().addSerializedRows( - DynamicMessage.newBuilder(BIG_SCHEMA_ROW_DESCRIPTOR) - .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(1), 1L) - .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(2), "A") - .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 0L) - .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 1L) - .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 2L) - .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(4), - createSingleRowMessage( - MY_STRUCT, toDescriptor(MY_STRUCT), INTERNAL_STRUCT_DATA)) - .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(5), 3.14) - .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(6), true) - .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(7), new byte[]{11, 0x7F}) - .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(8), 1594080000000L) - .build().toByteString()).build(); - } catch (Descriptors.DescriptorValidationException e) { - throw new AssumptionViolatedException("Could not create MY_PROTO_ROWS", e); - } - } + public ProtoBufProto.ProtoRows MY_PROTO_ROWS = ProtoBufProto.ProtoRows.newBuilder().addSerializedRows( + DynamicMessage.newBuilder(BIG_SCHEMA_ROW_DESCRIPTOR) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(1), 1L) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(2), "A") + .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 0L) + .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 1L) + .addRepeatedField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(3), 2L) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(4), + buildSingleRowMessage( + MY_STRUCT, STRUCT_DESCRIPTOR, INTERNAL_STRUCT_DATA)) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(5), 3.14) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(6), true) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(7), new byte[]{11, 0x7F}) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(8), 1594080000000L) + .setField(BIG_SCHEMA_ROW_DESCRIPTOR.findFieldByNumber(9), 1594080000000L) + .build().toByteString()).build(); } diff --git a/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java b/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java index d6bbf659b8..00ea1a98b7 100644 --- a/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java +++ b/connector/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java @@ -87,16 +87,9 @@ public void testFieldHasDescriptionBigQueryToSpark() throws Exception { assertThat(result).isEqualTo(expected); } - @Test - public void testSparkStructFieldToBigQuery() throws Exception { - logger.setLevel(Level.DEBUG); - - Field expected = BIGQUERY_NESTED_STRUCT_FIELD; - Field converted = makeBigQueryColumn(SPARK_NESTED_STRUCT_FIELD, 0); - - assertThat(converted).isEqualTo(expected); - } - + /* + Spark -> BigQuery conversion tests: + */ @Test public void testSparkToBQSchema() throws Exception { logger.setLevel(Level.DEBUG); @@ -194,12 +187,12 @@ public void testMaximumNestingDepthError() throws Exception { true, Metadata.empty()); public final StructField SPARK_BOOLEAN_FIELD = new StructField("Boolean", DataTypes.BooleanType, true, Metadata.empty()); - public final StructField SPARK_NUMERIC_FIELD = new StructField("Numeric", NUMERIC_SPARK_TYPE, - true, Metadata.empty()); public final StructField SPARK_BINARY_FIELD = new StructField("Binary", DataTypes.BinaryType, true, Metadata.empty()); public final StructField SPARK_DATE_FIELD = new StructField("Date", DataTypes.DateType, true, Metadata.empty()); + public final StructField SPARK_TIMESTAMP_FIELD = new StructField("TimeStamp", DataTypes.TimestampType, + true, Metadata.empty()); public final StructField SPARK_MAP_FIELD = new StructField("Map", DataTypes.createMapType(DataTypes.IntegerType, DataTypes.StringType), true, Metadata.empty()); @@ -211,9 +204,38 @@ public void testMaximumNestingDepthError() throws Exception { .add(SPARK_NESTED_STRUCT_FIELD) .add(SPARK_DOUBLE_FIELD) .add(SPARK_BOOLEAN_FIELD) - .add(SPARK_NUMERIC_FIELD) .add(SPARK_BINARY_FIELD) - .add(SPARK_DATE_FIELD); + .add(SPARK_DATE_FIELD) + .add(SPARK_TIMESTAMP_FIELD); + + + public final Field BIGQUERY_INTEGER_FIELD = Field.newBuilder("Number", LegacySQLTypeName.INTEGER, + (FieldList)null).setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_STRING_FIELD = Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) + .setMode(Field.Mode.REQUIRED).build(); + public final Field BIGQUERY_NESTED_STRUCT_FIELD = Field.newBuilder("Struct", LegacySQLTypeName.RECORD, + Field.newBuilder("Number", LegacySQLTypeName.INTEGER, (FieldList) null) + .setMode(Field.Mode.NULLABLE).build(), + Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) + .setMode(Field.Mode.NULLABLE).build()) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_ARRAY_FIELD = Field.newBuilder("Array", LegacySQLTypeName.INTEGER, (FieldList) null) + .setMode(Field.Mode.REPEATED).build(); + public final Field BIGQUERY_FLOAT_FIELD = Field.newBuilder("Float", LegacySQLTypeName.FLOAT, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_BOOLEAN_FIELD = Field.newBuilder("Boolean", LegacySQLTypeName.BOOLEAN, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_BYTES_FIELD = Field.newBuilder("Binary", LegacySQLTypeName.BYTES, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_DATE_FIELD = Field.newBuilder("Date", LegacySQLTypeName.DATE, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + public final Field BIGQUERY_TIMESTAMP_FIELD = Field.newBuilder("TimeStamp", LegacySQLTypeName.TIMESTAMP, (FieldList)null) + .setMode(Field.Mode.NULLABLE).build(); + + public final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, + BIGQUERY_NESTED_STRUCT_FIELD, BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD, BIGQUERY_BYTES_FIELD, BIGQUERY_DATE_FIELD, + BIGQUERY_TIMESTAMP_FIELD); + public final StructType BIG_SPARK_SCHEMA2 = new StructType() .add(new StructField("foo", DataTypes.StringType,true, Metadata.empty())) @@ -244,35 +266,7 @@ public void testMaximumNestingDepthError() throws Exception { Field.of("timestamp", LegacySQLTypeName.TIMESTAMP), Field.of("datetime", LegacySQLTypeName.DATETIME))); - - public final Field BIGQUERY_INTEGER_FIELD = Field.newBuilder("Number", LegacySQLTypeName.INTEGER, - (FieldList)null).setMode(Field.Mode.NULLABLE).build(); - public final Field BIGQUERY_STRING_FIELD = Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) - .setMode(Field.Mode.REQUIRED).build(); - public final Field BIGQUERY_NESTED_STRUCT_FIELD = Field.newBuilder("Struct", LegacySQLTypeName.RECORD, - Field.newBuilder("Number", LegacySQLTypeName.INTEGER, (FieldList) null) - .setMode(Field.Mode.NULLABLE).build(), - Field.newBuilder("String", LegacySQLTypeName.STRING, (FieldList) null) - .setMode(Field.Mode.NULLABLE).build()) - .setMode(Field.Mode.NULLABLE).build(); - public final Field BIGQUERY_ARRAY_FIELD = Field.newBuilder("Array", LegacySQLTypeName.INTEGER, (FieldList) null) - .setMode(Field.Mode.REPEATED).build(); - public final Field BIGQUERY_FLOAT_FIELD = Field.newBuilder("Float", LegacySQLTypeName.FLOAT, (FieldList)null) - .setMode(Field.Mode.NULLABLE).build(); - public final Field BIGQUERY_BOOLEAN_FIELD = Field.newBuilder("Boolean", LegacySQLTypeName.BOOLEAN, (FieldList)null) - .setMode(Field.Mode.NULLABLE).build(); - public final Field BIGQUERY_NUMERIC_FIELD = Field.newBuilder("Numeric", LegacySQLTypeName.NUMERIC, (FieldList)null) - .setMode(Field.Mode.NULLABLE).build(); - public final Field BIGQUERY_BYTES_FIELD = Field.newBuilder("Binary", LegacySQLTypeName.BYTES, (FieldList)null) - .setMode(Field.Mode.NULLABLE).build(); - public final Field BIGQUERY_DATE_FIELD = Field.newBuilder("Date", LegacySQLTypeName.DATE, (FieldList)null) - .setMode(Field.Mode.NULLABLE).build(); - - public final Schema BIG_BIGQUERY_SCHEMA = Schema.of(BIGQUERY_INTEGER_FIELD, BIGQUERY_STRING_FIELD, BIGQUERY_ARRAY_FIELD, - BIGQUERY_NESTED_STRUCT_FIELD, BIGQUERY_FLOAT_FIELD, BIGQUERY_BOOLEAN_FIELD, BIGQUERY_NUMERIC_FIELD, - BIGQUERY_BYTES_FIELD, BIGQUERY_DATE_FIELD); - - /* TODO: create SchemaConverters.convert() from BigQuery -> Spark test. Translate specific test from SchemaIteratorSuite.scala + /* TODO: translate BigQuery to Spark row conversion tests, from SchemaIteratorSuite.scala private final List BIG_SCHEMA_NAMES_INORDER = Arrays.asList( new String[]{"Number", "String", "Array", "Struct", "Float", "Boolean", "Numeric"});