diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/BasicFlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/BasicFlightSqlProducer.java new file mode 100644 index 0000000000000..ea99191f28e13 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/BasicFlightSqlProducer.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import java.util.List; + +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.protobuf.Message; + +/** + * A {@link FlightSqlProducer} that implements getting FlightInfo for each metadata request. + */ +public abstract class BasicFlightSqlProducer extends NoOpFlightSqlProducer { + + @Override + public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTypeInfo(FlightSql.CommandGetXdbcTypeInfo request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_TYPE_INFO_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoCatalogs(FlightSql.CommandGetCatalogs request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_CATALOGS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoSchemas(FlightSql.CommandGetDbSchemas request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_SCHEMAS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTables(FlightSql.CommandGetTables request, CallContext context, + FlightDescriptor descriptor) { + if (request.getIncludeSchema()) { + return generateFlightInfo(request, descriptor, Schemas.GET_TABLES_SCHEMA); + } + return generateFlightInfo(request, descriptor, Schemas.GET_TABLES_SCHEMA_NO_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoTableTypes(FlightSql.CommandGetTableTypes request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_TABLE_TYPES_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys(FlightSql.CommandGetPrimaryKeys request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_PRIMARY_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoExportedKeys(FlightSql.CommandGetExportedKeys request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_EXPORTED_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoImportedKeys(FlightSql.CommandGetImportedKeys request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_IMPORTED_KEYS_SCHEMA); + } + + @Override + public FlightInfo getFlightInfoCrossReference(FlightSql.CommandGetCrossReference request, CallContext context, + FlightDescriptor descriptor) { + return generateFlightInfo(request, descriptor, Schemas.GET_CROSS_REFERENCE_SCHEMA); + } + + /** + * Return a list of FlightEndpoints for the given request and FlightDescriptor. This method should validate that + * the request is supported by this FlightSqlProducer. + */ + protected abstract + List determineEndpoints(T request, FlightDescriptor flightDescriptor, Schema schema); + + protected FlightInfo generateFlightInfo(T request, FlightDescriptor descriptor, Schema schema) { + final List endpoints = determineEndpoints(request, descriptor, schema); + return new FlightInfo(schema, descriptor, endpoints, -1, -1); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java new file mode 100644 index 0000000000000..a02cee64bd855 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.sql.impl.FlightSql; + +/** + * A {@link FlightSqlProducer} that throws on all FlightSql-specific operations. + */ +public class NoOpFlightSqlProducer implements FlightSqlProducer { + @Override + public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, + CallContext context, StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, + CallContext context, StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamStatement(FlightSql.TicketStatementQuery ticket, + CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public void getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery command, + CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public Runnable acceptPutStatement(FlightSql.CommandStatementUpdate command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public Runnable acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatementUpdate command, + CallContext context, FlightStream flightStream, + StreamListener ackStream) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public Runnable acceptPutPreparedStatementQuery(FlightSql.CommandPreparedStatementQuery command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request, CallContext context, + FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command, CallContext context, + ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfoTypeInfo(FlightSql.CommandGetXdbcTypeInfo request, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamTypeInfo(FlightSql.CommandGetXdbcTypeInfo request, + CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfoCatalogs(FlightSql.CommandGetCatalogs request, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamCatalogs(CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfoSchemas(FlightSql.CommandGetDbSchemas request, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamSchemas(FlightSql.CommandGetDbSchemas command, + CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfoTables(FlightSql.CommandGetTables request, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamTables(FlightSql.CommandGetTables command, CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfoTableTypes(FlightSql.CommandGetTableTypes request, CallContext context, + FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamTableTypes(CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys(FlightSql.CommandGetPrimaryKeys request, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamPrimaryKeys(FlightSql.CommandGetPrimaryKeys command, + CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfoExportedKeys(FlightSql.CommandGetExportedKeys request, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoImportedKeys(FlightSql.CommandGetImportedKeys request, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoCrossReference(FlightSql.CommandGetCrossReference request, + CallContext context, FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void getStreamExportedKeys(FlightSql.CommandGetExportedKeys command, + CallContext context, ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public void getStreamImportedKeys(FlightSql.CommandGetImportedKeys command, CallContext context, + ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public void getStreamCrossReference(FlightSql.CommandGetCrossReference command, CallContext context, + ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public void close() throws Exception { + + } + + @Override + public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java index 6da915a8ffb14..7635b80ecd0fd 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -20,7 +20,7 @@ import static java.util.Arrays.asList; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; -import static java.util.Objects.isNull; +import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults; import static org.apache.arrow.util.AutoCloseables.close; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.is; @@ -29,16 +29,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.nio.channels.Channels; import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.stream.IntStream; @@ -52,18 +48,9 @@ import org.apache.arrow.flight.sql.util.TableRef; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.UInt1Vector; -import org.apache.arrow.vector.UInt4Vector; -import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.complex.DenseUnionVector; -import org.apache.arrow.vector.complex.ListVector; -import org.apache.arrow.vector.ipc.ReadChannel; -import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; @@ -657,197 +644,202 @@ public void testGetSqlInfoResultsWithThreeArgs() throws Exception { } @Test - public void testGetCommandExportedKeys() { - final FlightStream stream = + public void testGetCommandExportedKeys() throws Exception { + try (final FlightStream stream = sqlClient.getStream( sqlClient.getExportedKeys(TableRef.of(null, null, "FOREIGNTABLE")) - .getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = asList( - nullValue(String.class), // pk_catalog_name - is("APP"), // pk_schema_name - is("FOREIGNTABLE"), // pk_table_name - is("ID"), // pk_column_name - nullValue(String.class), // fk_catalog_name - is("APP"), // fk_schema_name - is("INTTABLE"), // fk_table_name - is("FOREIGNID"), // fk_column_name - is("1"), // key_sequence - containsString("SQL"), // fk_key_name - containsString("SQL"), // pk_key_name - is("3"), // update_rule - is("3")); // delete_rule - - final List assertions = new ArrayList<>(); - Assertions.assertEquals(1, results.size()); - for (int i = 0; i < matchers.size(); i++) { - final String actual = results.get(0).get(i); - final Matcher expected = matchers.get(i); - assertions.add(() -> MatcherAssert.assertThat(actual, expected)); + .getEndpoints().get(0).getTicket())) { + + final List> results = getResults(stream); + + final List> matchers = asList( + nullValue(String.class), // pk_catalog_name + is("APP"), // pk_schema_name + is("FOREIGNTABLE"), // pk_table_name + is("ID"), // pk_column_name + nullValue(String.class), // fk_catalog_name + is("APP"), // fk_schema_name + is("INTTABLE"), // fk_table_name + is("FOREIGNID"), // fk_column_name + is("1"), // key_sequence + containsString("SQL"), // fk_key_name + containsString("SQL"), // pk_key_name + is("3"), // update_rule + is("3")); // delete_rule + + final List assertions = new ArrayList<>(); + Assertions.assertEquals(1, results.size()); + for (int i = 0; i < matchers.size(); i++) { + final String actual = results.get(0).get(i); + final Matcher expected = matchers.get(i); + assertions.add(() -> MatcherAssert.assertThat(actual, expected)); + } + Assertions.assertAll(assertions); } - Assertions.assertAll(assertions); } @Test - public void testGetCommandImportedKeys() { - final FlightStream stream = + public void testGetCommandImportedKeys() throws Exception { + try (final FlightStream stream = sqlClient.getStream( sqlClient.getImportedKeys(TableRef.of(null, null, "INTTABLE")) - .getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = asList( - nullValue(String.class), // pk_catalog_name - is("APP"), // pk_schema_name - is("FOREIGNTABLE"), // pk_table_name - is("ID"), // pk_column_name - nullValue(String.class), // fk_catalog_name - is("APP"), // fk_schema_name - is("INTTABLE"), // fk_table_name - is("FOREIGNID"), // fk_column_name - is("1"), // key_sequence - containsString("SQL"), // fk_key_name - containsString("SQL"), // pk_key_name - is("3"), // update_rule - is("3")); // delete_rule - - Assertions.assertEquals(1, results.size()); - final List assertions = new ArrayList<>(); - for (int i = 0; i < matchers.size(); i++) { - final String actual = results.get(0).get(i); - final Matcher expected = matchers.get(i); - assertions.add(() -> MatcherAssert.assertThat(actual, expected)); + .getEndpoints().get(0).getTicket())) { + + final List> results = getResults(stream); + + final List> matchers = asList( + nullValue(String.class), // pk_catalog_name + is("APP"), // pk_schema_name + is("FOREIGNTABLE"), // pk_table_name + is("ID"), // pk_column_name + nullValue(String.class), // fk_catalog_name + is("APP"), // fk_schema_name + is("INTTABLE"), // fk_table_name + is("FOREIGNID"), // fk_column_name + is("1"), // key_sequence + containsString("SQL"), // fk_key_name + containsString("SQL"), // pk_key_name + is("3"), // update_rule + is("3")); // delete_rule + + Assertions.assertEquals(1, results.size()); + final List assertions = new ArrayList<>(); + for (int i = 0; i < matchers.size(); i++) { + final String actual = results.get(0).get(i); + final Matcher expected = matchers.get(i); + assertions.add(() -> MatcherAssert.assertThat(actual, expected)); + } + Assertions.assertAll(assertions); } - Assertions.assertAll(assertions); } @Test - public void testGetTypeInfo() { + public void testGetTypeInfo() throws Exception { FlightInfo flightInfo = sqlClient.getXdbcTypeInfo(); - FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = ImmutableList.of( - asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", - "BIGINT", "0", "0", - null, null, "10", null), - asList("LONG VARCHAR FOR BIT DATA", "-4", "32700", "X'", "'", emptyList().toString(), "1", "false", "0", "true", - "false", "false", - "LONG VARCHAR FOR BIT DATA", null, null, null, null, null, null), - asList("VARCHAR () FOR BIT DATA", "-3", "32672", "X'", "'", singletonList("length").toString(), "1", "false", - "2", "true", "false", - "false", "VARCHAR () FOR BIT DATA", null, null, null, null, null, null), - asList("CHAR () FOR BIT DATA", "-2", "254", "X'", "'", singletonList("length").toString(), "1", "false", "2", - "true", "false", "false", - "CHAR () FOR BIT DATA", null, null, null, null, null, null), - asList("LONG VARCHAR", "-1", "32700", "'", "'", emptyList().toString(), "1", "true", "1", "true", "false", - "false", - "LONG VARCHAR", null, null, null, null, null, null), - asList("CHAR", "1", "254", "'", "'", singletonList("length").toString(), "1", "true", "3", "true", "false", - "false", "CHAR", null, null, - null, null, null, null), - asList("NUMERIC", "2", "31", null, null, Arrays.asList("precision", "scale").toString(), "1", "false", "2", - "false", "true", "false", - "NUMERIC", "0", "31", null, null, "10", null), - asList("DECIMAL", "3", "31", null, null, Arrays.asList("precision", "scale").toString(), "1", "false", "2", - "false", "true", "false", - "DECIMAL", "0", "31", null, null, "10", null), - asList("INTEGER", "4", "10", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", - "INTEGER", "0", "0", - null, null, "10", null), - asList("SMALLINT", "5", "5", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", - "SMALLINT", "0", - "0", null, null, "10", null), - asList("FLOAT", "6", "52", null, null, singletonList("precision").toString(), "1", "false", "2", "false", - "false", "false", "FLOAT", null, - null, null, null, "2", null), - asList("REAL", "7", "23", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "false", - "REAL", null, null, - null, null, "2", null), - asList("DOUBLE", "8", "52", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "false", - "DOUBLE", null, - null, null, null, "2", null), - asList("VARCHAR", "12", "32672", "'", "'", singletonList("length").toString(), "1", "true", "3", "true", - "false", "false", "VARCHAR", - null, null, null, null, null, null), - asList("BOOLEAN", "16", "1", null, null, emptyList().toString(), "1", "false", "2", "true", "false", "false", - "BOOLEAN", null, - null, null, null, null, null), - asList("DATE", "91", "10", "DATE'", "'", emptyList().toString(), "1", "false", "2", "true", "false", "false", - "DATE", "0", "0", - null, null, "10", null), - asList("TIME", "92", "8", "TIME'", "'", emptyList().toString(), "1", "false", "2", "true", "false", "false", - "TIME", "0", "0", - null, null, "10", null), - asList("TIMESTAMP", "93", "29", "TIMESTAMP'", "'", emptyList().toString(), "1", "false", "2", "true", "false", - "false", - "TIMESTAMP", "0", "9", null, null, "10", null), - asList("OBJECT", "2000", null, null, null, emptyList().toString(), "1", "false", "2", "true", "false", "false", - "OBJECT", null, - null, null, null, null, null), - asList("BLOB", "2004", "2147483647", null, null, singletonList("length").toString(), "1", "false", "0", null, - "false", null, "BLOB", null, - null, null, null, null, null), - asList("CLOB", "2005", "2147483647", "'", "'", singletonList("length").toString(), "1", "true", "1", null, - "false", null, "CLOB", null, - null, null, null, null, null), - asList("XML", "2009", null, null, null, emptyList().toString(), "1", "true", "0", "false", "false", "false", - "XML", null, null, - null, null, null, null)); - MatcherAssert.assertThat(results, is(matchers)); + try (FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) { + + final List> results = getResults(stream); + + final List> matchers = ImmutableList.of( + asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", + "BIGINT", "0", "0", + null, null, "10", null), + asList("LONG VARCHAR FOR BIT DATA", "-4", "32700", "X'", "'", emptyList().toString(), "1", "false", "0", + "true", "false", "false", + "LONG VARCHAR FOR BIT DATA", null, null, null, null, null, null), + asList("VARCHAR () FOR BIT DATA", "-3", "32672", "X'", "'", singletonList("length").toString(), "1", "false", + "2", "true", "false", + "false", "VARCHAR () FOR BIT DATA", null, null, null, null, null, null), + asList("CHAR () FOR BIT DATA", "-2", "254", "X'", "'", singletonList("length").toString(), "1", "false", "2", + "true", "false", "false", + "CHAR () FOR BIT DATA", null, null, null, null, null, null), + asList("LONG VARCHAR", "-1", "32700", "'", "'", emptyList().toString(), "1", "true", "1", "true", "false", + "false", + "LONG VARCHAR", null, null, null, null, null, null), + asList("CHAR", "1", "254", "'", "'", singletonList("length").toString(), "1", "true", "3", "true", "false", + "false", "CHAR", null, null, + null, null, null, null), + asList("NUMERIC", "2", "31", null, null, Arrays.asList("precision", "scale").toString(), "1", "false", "2", + "false", "true", "false", + "NUMERIC", "0", "31", null, null, "10", null), + asList("DECIMAL", "3", "31", null, null, Arrays.asList("precision", "scale").toString(), "1", "false", "2", + "false", "true", "false", + "DECIMAL", "0", "31", null, null, "10", null), + asList("INTEGER", "4", "10", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", + "INTEGER", "0", "0", + null, null, "10", null), + asList("SMALLINT", "5", "5", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", + "SMALLINT", "0", + "0", null, null, "10", null), + asList("FLOAT", "6", "52", null, null, singletonList("precision").toString(), "1", "false", "2", "false", + "false", "false", "FLOAT", null, + null, null, null, "2", null), + asList("REAL", "7", "23", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "false", + "REAL", null, null, + null, null, "2", null), + asList("DOUBLE", "8", "52", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "false", + "DOUBLE", null, + null, null, null, "2", null), + asList("VARCHAR", "12", "32672", "'", "'", singletonList("length").toString(), "1", "true", "3", "true", + "false", "false", "VARCHAR", + null, null, null, null, null, null), + asList("BOOLEAN", "16", "1", null, null, emptyList().toString(), "1", "false", "2", "true", "false", "false", + "BOOLEAN", null, + null, null, null, null, null), + asList("DATE", "91", "10", "DATE'", "'", emptyList().toString(), "1", "false", "2", "true", "false", "false", + "DATE", "0", "0", + null, null, "10", null), + asList("TIME", "92", "8", "TIME'", "'", emptyList().toString(), "1", "false", "2", "true", "false", "false", + "TIME", "0", "0", + null, null, "10", null), + asList("TIMESTAMP", "93", "29", "TIMESTAMP'", "'", emptyList().toString(), "1", "false", "2", "true", "false", + "false", + "TIMESTAMP", "0", "9", null, null, "10", null), + asList("OBJECT", "2000", null, null, null, emptyList().toString(), "1", "false", "2", "true", "false", + "false", "OBJECT", null, + null, null, null, null, null), + asList("BLOB", "2004", "2147483647", null, null, singletonList("length").toString(), "1", "false", "0", null, + "false", null, "BLOB", null, + null, null, null, null, null), + asList("CLOB", "2005", "2147483647", "'", "'", singletonList("length").toString(), "1", "true", "1", null, + "false", null, "CLOB", null, + null, null, null, null, null), + asList("XML", "2009", null, null, null, emptyList().toString(), "1", "true", "0", "false", "false", "false", + "XML", null, null, + null, null, null, null)); + MatcherAssert.assertThat(results, is(matchers)); + } } @Test - public void testGetTypeInfoWithFiltering() { + public void testGetTypeInfoWithFiltering() throws Exception { FlightInfo flightInfo = sqlClient.getXdbcTypeInfo(-5); - FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); + try (FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) { - final List> results = getResults(stream); + final List> results = getResults(stream); - final List> matchers = ImmutableList.of( - asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", - "BIGINT", "0", "0", - null, null, "10", null)); - MatcherAssert.assertThat(results, is(matchers)); + final List> matchers = ImmutableList.of( + asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", + "BIGINT", "0", "0", + null, null, "10", null)); + MatcherAssert.assertThat(results, is(matchers)); + } } @Test - public void testGetCommandCrossReference() { + public void testGetCommandCrossReference() throws Exception { final FlightInfo flightInfo = sqlClient.getCrossReference(TableRef.of(null, null, "FOREIGNTABLE"), TableRef.of(null, null, "INTTABLE")); - final FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = asList( - nullValue(String.class), // pk_catalog_name - is("APP"), // pk_schema_name - is("FOREIGNTABLE"), // pk_table_name - is("ID"), // pk_column_name - nullValue(String.class), // fk_catalog_name - is("APP"), // fk_schema_name - is("INTTABLE"), // fk_table_name - is("FOREIGNID"), // fk_column_name - is("1"), // key_sequence - containsString("SQL"), // fk_key_name - containsString("SQL"), // pk_key_name - is("3"), // update_rule - is("3")); // delete_rule - - Assertions.assertEquals(1, results.size()); - final List assertions = new ArrayList<>(); - for (int i = 0; i < matchers.size(); i++) { - final String actual = results.get(0).get(i); - final Matcher expected = matchers.get(i); - assertions.add(() -> MatcherAssert.assertThat(actual, expected)); + try (final FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) { + + final List> results = getResults(stream); + + final List> matchers = asList( + nullValue(String.class), // pk_catalog_name + is("APP"), // pk_schema_name + is("FOREIGNTABLE"), // pk_table_name + is("ID"), // pk_column_name + nullValue(String.class), // fk_catalog_name + is("APP"), // fk_schema_name + is("INTTABLE"), // fk_table_name + is("FOREIGNID"), // fk_column_name + is("1"), // key_sequence + containsString("SQL"), // fk_key_name + containsString("SQL"), // pk_key_name + is("3"), // update_rule + is("3")); // delete_rule + + Assertions.assertEquals(1, results.size()); + final List assertions = new ArrayList<>(); + for (int i = 0; i < matchers.size(); i++) { + final String actual = results.get(0).get(i); + final Matcher expected = matchers.get(i); + assertions.add(() -> MatcherAssert.assertThat(actual, expected)); + } + Assertions.assertAll(assertions); } - Assertions.assertAll(assertions); } @Test @@ -878,90 +870,6 @@ public void testCreateStatementResults() throws Exception { } } - List> getResults(FlightStream stream) { - final List> results = new ArrayList<>(); - while (stream.next()) { - try (final VectorSchemaRoot root = stream.getRoot()) { - final long rowCount = root.getRowCount(); - for (int i = 0; i < rowCount; ++i) { - results.add(new ArrayList<>()); - } - - root.getSchema().getFields().forEach(field -> { - try (final FieldVector fieldVector = root.getVector(field.getName())) { - if (fieldVector instanceof VarCharVector) { - final VarCharVector varcharVector = (VarCharVector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final Text data = varcharVector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : data.toString()); - } - } else if (fieldVector instanceof IntVector) { - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - Object data = fieldVector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else if (fieldVector instanceof VarBinaryVector) { - final VarBinaryVector varbinaryVector = (VarBinaryVector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final byte[] data = varbinaryVector.getObject(rowIndex); - final String output; - try { - output = isNull(data) ? - null : - MessageSerializer.deserializeSchema( - new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data)))).toJson(); - } catch (final IOException e) { - throw new RuntimeException("Failed to deserialize schema", e); - } - results.get(rowIndex).add(output); - } - } else if (fieldVector instanceof DenseUnionVector) { - final DenseUnionVector denseUnionVector = (DenseUnionVector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final Object data = denseUnionVector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else if (fieldVector instanceof ListVector) { - for (int i = 0; i < fieldVector.getValueCount(); i++) { - if (!fieldVector.isNull(i)) { - List elements = (List) ((ListVector) fieldVector).getObject(i); - List values = new ArrayList<>(); - - for (Text element : elements) { - values.add(element.toString()); - } - results.get(i).add(values.toString()); - } - } - - } else if (fieldVector instanceof UInt4Vector) { - final UInt4Vector uInt4Vector = (UInt4Vector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final Object data = uInt4Vector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else if (fieldVector instanceof UInt1Vector) { - final UInt1Vector uInt1Vector = (UInt1Vector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final Object data = uInt1Vector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else if (fieldVector instanceof BitVector) { - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - Object data = fieldVector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else { - throw new UnsupportedOperationException("Not yet implemented"); - } - } - }); - } - } - - return results; - } - @Test public void testExecuteUpdate() { Assertions.assertAll( diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java new file mode 100644 index 0000000000000..4672e0a141832 --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSqlStreams.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults; +import static org.apache.arrow.util.AutoCloseables.close; +import static org.apache.arrow.vector.types.Types.MinorType.INT; +import static org.hamcrest.CoreMatchers.is; + +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.flight.sql.BasicFlightSqlProducer; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.hamcrest.MatcherAssert; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.Message; + +public class TestFlightSqlStreams { + + /** + * A limited {@link FlightSqlProducer} for testing GetTables, GetTableTypes, GetSqlInfo, and limited SQL commands. + */ + private static class FlightSqlTestProducer extends BasicFlightSqlProducer { + + // Note that for simplicity the getStream* implementations are blocking, but a proper FlightSqlProducer should + // have non-blocking implementations of getStream*. + + private static final String FIXED_QUERY = "SELECT 1 AS c1 FROM test_table"; + private static final Schema FIXED_SCHEMA = new Schema(asList( + Field.nullable("c1", Types.MinorType.INT.getType()))); + + private BufferAllocator allocator; + + FlightSqlTestProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + protected List determineEndpoints(T request, FlightDescriptor flightDescriptor, + Schema schema) { + if (request instanceof FlightSql.CommandGetTables || + request instanceof FlightSql.CommandGetTableTypes || + request instanceof FlightSql.CommandGetXdbcTypeInfo || + request instanceof FlightSql.CommandGetSqlInfo) { + return Collections.singletonList(new FlightEndpoint(new Ticket(Any.pack(request).toByteArray()))); + } else if (request instanceof FlightSql.CommandStatementQuery && + ((FlightSql.CommandStatementQuery) request).getQuery().equals(FIXED_QUERY)) { + + // Tickets from CommandStatementQuery requests should be built using TicketStatementQuery then packed() into + // a ticket. The content of the statement handle is specific to the FlightSqlProducer. It does not need to + // be the query. It can be a query ID for example. + FlightSql.TicketStatementQuery ticketStatementQuery = FlightSql.TicketStatementQuery.newBuilder() + .setStatementHandle(((FlightSql.CommandStatementQuery) request).getQueryBytes()) + .build(); + return Collections.singletonList(new FlightEndpoint(new Ticket(Any.pack(ticketStatementQuery).toByteArray()))); + } + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + return generateFlightInfo(command, descriptor, FIXED_SCHEMA); + } + + @Override + public void getStreamStatement(FlightSql.TicketStatementQuery ticket, + CallContext context, ServerStreamListener listener) { + final String query = ticket.getStatementHandle().toStringUtf8(); + if (!query.equals(FIXED_QUERY)) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + try (VectorSchemaRoot root = VectorSchemaRoot.create(FIXED_SCHEMA, allocator)) { + root.setRowCount(1); + ((IntVector) root.getVector("c1")).setSafe(0, 1); + listener.start(root); + listener.putNext(); + listener.completed(); + } + } + + @Override + public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command, CallContext context, + ServerStreamListener listener) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_SQL_INFO_SCHEMA, allocator)) { + root.setRowCount(0); + listener.start(root); + listener.putNext(); + listener.completed(); + } + } + + @Override + public void getStreamTypeInfo(FlightSql.CommandGetXdbcTypeInfo request, + CallContext context, ServerStreamListener listener) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_TYPE_INFO_SCHEMA, allocator)) { + root.setRowCount(1); + ((VarCharVector) root.getVector("type_name")).setSafe(0, new Text("Integer")); + ((IntVector) root.getVector("data_type")).setSafe(0, INT.ordinal()); + ((IntVector) root.getVector("column_size")).setSafe(0, 400); + root.getVector("literal_prefix").setNull(0); + root.getVector("literal_suffix").setNull(0); + root.getVector("create_params").setNull(0); + ((IntVector) root.getVector("nullable")).setSafe(0, FlightSql.Nullable.NULLABILITY_NULLABLE.getNumber()); + ((BitVector) root.getVector("case_sensitive")).setSafe(0, 1); + ((IntVector) root.getVector("nullable")).setSafe(0, FlightSql.Searchable.SEARCHABLE_FULL.getNumber()); + ((BitVector) root.getVector("unsigned_attribute")).setSafe(0, 1); + root.getVector("fixed_prec_scale").setNull(0); + ((BitVector) root.getVector("auto_increment")).setSafe(0, 1); + ((VarCharVector) root.getVector("local_type_name")).setSafe(0, new Text("Integer")); + root.getVector("minimum_scale").setNull(0); + root.getVector("maximum_scale").setNull(0); + ((IntVector) root.getVector("sql_data_type")).setSafe(0, INT.ordinal()); + root.getVector("datetime_subcode").setNull(0); + ((IntVector) root.getVector("num_prec_radix")).setSafe(0, 10); + root.getVector("interval_precision").setNull(0); + + listener.start(root); + listener.putNext(); + listener.completed(); + } + } + + @Override + public void getStreamTables(FlightSql.CommandGetTables command, CallContext context, + ServerStreamListener listener) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, allocator)) { + root.setRowCount(1); + root.getVector("catalog_name").setNull(0); + root.getVector("db_schema_name").setNull(0); + ((VarCharVector) root.getVector("table_name")).setSafe(0, new Text("test_table")); + ((VarCharVector) root.getVector("table_type")).setSafe(0, new Text("TABLE")); + + listener.start(root); + listener.putNext(); + listener.completed(); + } + } + + @Override + public void getStreamTableTypes(CallContext context, ServerStreamListener listener) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_TABLE_TYPES_SCHEMA, allocator)) { + root.setRowCount(1); + ((VarCharVector) root.getVector("table_type")).setSafe(0, new Text("TABLE")); + + listener.start(root); + listener.putNext(); + listener.completed(); + } + } + } + + private static BufferAllocator allocator; + + private static FlightServer server; + private static FlightSqlClient sqlClient; + + @BeforeAll + public static void setUp() throws Exception { + allocator = new RootAllocator(Integer.MAX_VALUE); + + final Location serverLocation = Location.forGrpcInsecure("localhost", 0); + server = FlightServer.builder(allocator, serverLocation, new FlightSqlTestProducer(allocator)) + .build() + .start(); + + final Location clientLocation = Location.forGrpcInsecure("localhost", server.getPort()); + sqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build()); + } + + @AfterAll + public static void tearDown() throws Exception { + close(sqlClient, server, allocator); + } + + @Test + public void testGetTablesResultNoSchema() throws Exception { + try (final FlightStream stream = + sqlClient.getStream( + sqlClient.getTables(null, null, null, null, false) + .getEndpoints().get(0).getTicket())) { + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)), + () -> { + final List> results = getResults(stream); + final List> expectedResults = ImmutableList.of( + // catalog_name | schema_name | table_name | table_type | table_schema + asList(null, null, "test_table", "TABLE")); + MatcherAssert.assertThat(results, is(expectedResults)); + } + ); + } + } + + @Test + public void testGetTableTypesResult() throws Exception { + try (final FlightStream stream = + sqlClient.getStream(sqlClient.getTableTypes().getEndpoints().get(0).getTicket())) { + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)), + () -> { + final List> tableTypes = getResults(stream); + final List> expectedTableTypes = ImmutableList.of( + // table_type + singletonList("TABLE") + ); + MatcherAssert.assertThat(tableTypes, is(expectedTableTypes)); + } + ); + } + } + + @Test + public void testGetSqlInfoResults() throws Exception { + final FlightInfo info = sqlClient.getSqlInfo(); + try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)), + () -> MatcherAssert.assertThat(getResults(stream), is(emptyList())) + ); + } + } + + @Test + public void testGetTypeInfo() throws Exception { + FlightInfo flightInfo = sqlClient.getXdbcTypeInfo(); + + try (FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) { + + final List> results = getResults(stream); + + final List> matchers = ImmutableList.of( + asList("Integer", "4", "400", null, null, "3", "true", null, "true", null, "true", + "Integer", null, null, "4", null, "10", null)); + + MatcherAssert.assertThat(results, is(matchers)); + } + } + + @Test + public void testExecuteQuery() throws Exception { + try (final FlightStream stream = sqlClient + .getStream(sqlClient.execute(FlightSqlTestProducer.FIXED_QUERY).getEndpoints().get(0).getTicket())) { + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlTestProducer.FIXED_SCHEMA)), + () -> MatcherAssert.assertThat(getResults(stream), is(singletonList(singletonList("1")))) + ); + } + } +} diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/FlightStreamUtils.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/FlightStreamUtils.java new file mode 100644 index 0000000000000..fbbe9ef01816e --- /dev/null +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/FlightStreamUtils.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql.util; + +import static java.util.Objects.isNull; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.util.Text; + +public class FlightStreamUtils { + + public static List> getResults(FlightStream stream) { + final List> results = new ArrayList<>(); + while (stream.next()) { + try (final VectorSchemaRoot root = stream.getRoot()) { + final long rowCount = root.getRowCount(); + for (int i = 0; i < rowCount; ++i) { + results.add(new ArrayList<>()); + } + + root.getSchema().getFields().forEach(field -> { + try (final FieldVector fieldVector = root.getVector(field.getName())) { + if (fieldVector instanceof VarCharVector) { + final VarCharVector varcharVector = (VarCharVector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final Text data = varcharVector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : data.toString()); + } + } else if (fieldVector instanceof IntVector) { + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + Object data = fieldVector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); + } + } else if (fieldVector instanceof VarBinaryVector) { + final VarBinaryVector varbinaryVector = (VarBinaryVector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final byte[] data = varbinaryVector.getObject(rowIndex); + final String output; + try { + output = isNull(data) ? + null : + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data)))).toJson(); + } catch (final IOException e) { + throw new RuntimeException("Failed to deserialize schema", e); + } + results.get(rowIndex).add(output); + } + } else if (fieldVector instanceof DenseUnionVector) { + final DenseUnionVector denseUnionVector = (DenseUnionVector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final Object data = denseUnionVector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); + } + } else if (fieldVector instanceof ListVector) { + for (int i = 0; i < fieldVector.getValueCount(); i++) { + if (!fieldVector.isNull(i)) { + List elements = (List) ((ListVector) fieldVector).getObject(i); + List values = new ArrayList<>(); + + for (Text element : elements) { + values.add(element.toString()); + } + results.get(i).add(values.toString()); + } + } + + } else if (fieldVector instanceof UInt4Vector) { + final UInt4Vector uInt4Vector = (UInt4Vector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final Object data = uInt4Vector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); + } + } else if (fieldVector instanceof UInt1Vector) { + final UInt1Vector uInt1Vector = (UInt1Vector) fieldVector; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + final Object data = uInt1Vector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); + } + } else if (fieldVector instanceof BitVector) { + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + Object data = fieldVector.getObject(rowIndex); + results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); + } + } else { + throw new UnsupportedOperationException("Not yet implemented"); + } + } + }); + } + } + + return results; + } +}