diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DynamicFilteringJdbcSplitSource.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DynamicFilteringJdbcSplitSource.java index 812979ed2bcc..47588ab3ccac 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DynamicFilteringJdbcSplitSource.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DynamicFilteringJdbcSplitSource.java @@ -27,10 +27,10 @@ * Attaches dynamic filter to {@link JdbcSplit} after {@link JdbcDynamicFilteringSplitManager} * has waited for the collection of dynamic filters. * This allows JDBC based connectors to avoid waiting for dynamic filters again on the worker node - * in {@link JdbcRecordSetProvider}. The number of splits generated by JDBC based connectors are + * in {@link JdbcPageSourceProvider}. The number of splits generated by JDBC based connectors are * typically small, therefore attaching dynamic filter here does not add significant overhead. * Waiting for dynamic filters in {@link JdbcDynamicFilteringSplitManager} is preferred over waiting - * for them on the worker node in {@link JdbcRecordSetProvider} to allow connectors to take advantage of + * for them on the worker node in {@link JdbcPageSourceProvider} to allow connectors to take advantage of * dynamic filters during the splits generation phase. */ public class DynamicFilteringJdbcSplitSource diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java index b16895f6e148..30ff42a0e8c6 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcModule.java @@ -15,14 +15,10 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.inject.Binder; -import com.google.inject.Inject; import com.google.inject.Key; import com.google.inject.Provider; import com.google.inject.Scopes; -import com.google.inject.Singleton; import com.google.inject.multibindings.Multibinder; -import com.google.inject.multibindings.ProvidesIntoOptional; -import dev.failsafe.RetryPolicy; import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.base.mapping.IdentifierMappingModule; import io.trino.plugin.base.session.SessionPropertiesProvider; @@ -33,7 +29,6 @@ import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorPageSinkProvider; import io.trino.spi.connector.ConnectorPageSourceProvider; -import io.trino.spi.connector.ConnectorRecordSetProvider; import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.function.table.ConnectorTableFunction; import io.trino.spi.procedure.Procedure; @@ -42,7 +37,6 @@ import static com.google.inject.multibindings.Multibinder.newSetBinder; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; -import static com.google.inject.multibindings.ProvidesIntoOptional.Type.DEFAULT; import static io.airlift.configuration.ConditionalModule.conditionalModule; import static io.airlift.configuration.ConfigBinder.configBinder; import static io.trino.plugin.base.ClosingBinder.closingBinder; @@ -149,12 +143,4 @@ public static void bindTablePropertiesProvider(Binder binder, Class policy) - { - return new JdbcRecordSetProvider(jdbcClient, executor, policy); - } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordCursor.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSource.java similarity index 60% rename from plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordCursor.java rename to plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSource.java index d4230af6726d..f93a3af7925e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordCursor.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSource.java @@ -14,11 +14,15 @@ package io.trino.plugin.jdbc; import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; import io.airlift.slice.Slice; +import io.trino.spi.Page; +import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.RecordCursor; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -28,25 +32,25 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.List; -import java.util.concurrent.ExecutionException; +import java.util.OptionalLong; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicLong; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; +import static java.lang.System.nanoTime; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.supplyAsync; -public class JdbcRecordCursor - implements RecordCursor +public class JdbcPageSource + implements ConnectorPageSource { - private static final Logger log = Logger.get(JdbcRecordCursor.class); + private static final Logger log = Logger.get(JdbcPageSource.class); - private final ExecutorService executor; - - private final JdbcColumnHandle[] columnHandles; + private final List columnHandles; private final ReadFunction[] readFunctions; private final BooleanReadFunction[] booleanReadFunctions; private final DoubleReadFunction[] doubleReadFunctions; @@ -58,16 +62,18 @@ public class JdbcRecordCursor private final Connection connection; private final PreparedStatement statement; private final AtomicLong readTimeNanos = new AtomicLong(0); + private final PageBuilder pageBuilder; + private final CompletableFuture resultSetFuture; @Nullable private ResultSet resultSet; + private boolean finished; private boolean closed; + private long completedPositions; - public JdbcRecordCursor(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, BaseJdbcConnectorTableHandle table, List columnHandles) + public JdbcPageSource(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, BaseJdbcConnectorTableHandle table, List columnHandles) { this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); - this.executor = requireNonNull(executor, "executor is null"); - - this.columnHandles = columnHandles.toArray(new JdbcColumnHandle[0]); + this.columnHandles = ImmutableList.copyOf(columnHandles); readFunctions = new ReadFunction[columnHandles.size()]; booleanReadFunctions = new BooleanReadFunction[columnHandles.size()]; @@ -84,7 +90,7 @@ public JdbcRecordCursor(JdbcClient jdbcClient, ExecutorService executor, Connect connection = jdbcClient.getConnection(session, split, (JdbcTableHandle) table); } - for (int i = 0; i < this.columnHandles.length; i++) { + for (int i = 0; i < this.columnHandles.size(); i++) { JdbcColumnHandle columnHandle = columnHandles.get(i); ColumnMapping columnMapping = jdbcClient.toColumnMapping(session, connection, columnHandle.getJdbcTypeHandle()) .orElseThrow(() -> new VerifyException("Column %s has unsupported type %s".formatted(columnHandle.getColumnName(), columnHandle.getJdbcTypeHandle()))); @@ -119,6 +125,22 @@ else if (javaType == Slice.class) { else { statement = jdbcClient.buildSql(session, connection, split, (JdbcTableHandle) table, columnHandles); } + pageBuilder = new PageBuilder(columnHandles.stream() + .map(JdbcColumnHandle::getColumnType) + .collect(toImmutableList())); + resultSetFuture = supplyAsync(() -> { + long start = nanoTime(); + try { + log.debug("Executing: %s", statement); + return statement.executeQuery(); + } + catch (SQLException e) { + throw handleSqlException(e); + } + finally { + readTimeNanos.addAndGet(nanoTime() - start); + } + }, executor); } catch (SQLException | RuntimeException e) { throw handleSqlException(e); @@ -132,140 +154,82 @@ public long getReadTimeNanos() } @Override - public long getCompletedBytes() + public boolean isFinished() { - return 0; + return finished; } @Override - public Type getType(int field) + public Page getNextPage() { - return columnHandles[field].getColumnType(); - } - - @Override - public boolean advanceNextPosition() - { - if (closed) { - return false; - } - + verify(pageBuilder.isEmpty(), "Expected pageBuilder to be empty"); try { if (resultSet == null) { - long start = System.nanoTime(); - Future resultSetFuture = executor.submit(() -> { - log.debug("Executing: %s", statement); - return statement.executeQuery(); - }); - try { - // statement.executeQuery() may block uninterruptedly, using async way so we are able to cancel remote query - // See javadoc of java.sql.Connection.setNetworkTimeout - resultSet = resultSetFuture.get(); - } - catch (ExecutionException e) { - if (e.getCause() instanceof SQLException cause) { - SQLException sqlException = new SQLException(cause.getMessage(), cause.getSQLState(), cause.getErrorCode(), e); - if (cause.getNextException() != null) { - sqlException.setNextException(cause.getNextException()); - } - throw sqlException; + resultSet = requireNonNull(getFutureValue(resultSetFuture), "resultSet is null"); + } + + while (!pageBuilder.isFull() && resultSet.next()) { + pageBuilder.declarePosition(); + completedPositions++; + for (int i = 0; i < columnHandles.size(); i++) { + BlockBuilder output = pageBuilder.getBlockBuilder(i); + Type type = columnHandles.get(i).getColumnType(); + if (readFunctions[i].isNull(resultSet, i + 1)) { + output.appendNull(); + } + else if (booleanReadFunctions[i] != null) { + type.writeBoolean(output, booleanReadFunctions[i].readBoolean(resultSet, i + 1)); + } + else if (doubleReadFunctions[i] != null) { + type.writeDouble(output, doubleReadFunctions[i].readDouble(resultSet, i + 1)); + } + else if (longReadFunctions[i] != null) { + type.writeLong(output, longReadFunctions[i].readLong(resultSet, i + 1)); + } + else if (sliceReadFunctions[i] != null) { + type.writeSlice(output, sliceReadFunctions[i].readSlice(resultSet, i + 1)); + } + else { + type.writeObject(output, objectReadFunctions[i].readObject(resultSet, i + 1)); } - throw new RuntimeException(e); - } - catch (InterruptedException e) { - Thread.currentThread().interrupt(); - resultSetFuture.cancel(true); - throw new RuntimeException(e); - } - finally { - readTimeNanos.addAndGet(System.nanoTime() - start); } } - return resultSet.next(); - } - catch (SQLException | RuntimeException e) { - throw handleSqlException(e); - } - } - @Override - public boolean getBoolean(int field) - { - checkState(!closed, "cursor is closed"); - requireNonNull(resultSet, "resultSet is null"); - try { - return booleanReadFunctions[field].readBoolean(resultSet, field + 1); + if (!pageBuilder.isFull()) { + finished = true; + } } catch (SQLException | RuntimeException e) { throw handleSqlException(e); } - } - @Override - public long getLong(int field) - { - checkState(!closed, "cursor is closed"); - requireNonNull(resultSet, "resultSet is null"); - try { - return longReadFunctions[field].readLong(resultSet, field + 1); - } - catch (SQLException | RuntimeException e) { - throw handleSqlException(e); - } + Page page = pageBuilder.build(); + pageBuilder.reset(); + return page; } @Override - public double getDouble(int field) + public long getMemoryUsage() { - checkState(!closed, "cursor is closed"); - requireNonNull(resultSet, "resultSet is null"); - try { - return doubleReadFunctions[field].readDouble(resultSet, field + 1); - } - catch (SQLException | RuntimeException e) { - throw handleSqlException(e); - } + return pageBuilder.getRetainedSizeInBytes(); } @Override - public Slice getSlice(int field) + public long getCompletedBytes() { - checkState(!closed, "cursor is closed"); - requireNonNull(resultSet, "resultSet is null"); - try { - return sliceReadFunctions[field].readSlice(resultSet, field + 1); - } - catch (SQLException | RuntimeException e) { - throw handleSqlException(e); - } + return 0; } @Override - public Object getObject(int field) + public OptionalLong getCompletedPositions() { - checkState(!closed, "cursor is closed"); - requireNonNull(resultSet, "resultSet is null"); - try { - return objectReadFunctions[field].readObject(resultSet, field + 1); - } - catch (SQLException | RuntimeException e) { - throw handleSqlException(e); - } + return OptionalLong.of(completedPositions); } @Override - public boolean isNull(int field) + public CompletableFuture isBlocked() { - checkState(!closed, "cursor is closed"); - checkArgument(field < columnHandles.length, "Invalid field index"); - requireNonNull(resultSet, "resultSet is null"); - - try { - return readFunctions[field].isNull(resultSet, field + 1); - } - catch (SQLException | RuntimeException e) { - throw handleSqlException(e); - } + return resultSetFuture; } @Override @@ -275,6 +239,7 @@ public void close() return; } closed = true; + finished = true; // use try with resources to close everything properly try (Connection connection = this.connection; @@ -291,11 +256,13 @@ public void close() } if (connection != null && resultSet != null) { jdbcClient.abortReadConnection(connection, resultSet); + resultSetFuture.cancel(true); } } catch (SQLException | RuntimeException e) { // ignore exception from close } + resultSet = null; } private RuntimeException handleSqlException(Exception e) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSourceProvider.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSourceProvider.java index 4264080ef017..52c0548637fc 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSourceProvider.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcPageSourceProvider.java @@ -15,40 +15,48 @@ import com.google.common.collect.ImmutableList; import com.google.inject.Inject; +import dev.failsafe.RetryPolicy; +import io.trino.plugin.base.MappedPageSource; import io.trino.plugin.jdbc.MergeJdbcPageSource.ColumnAdaptation; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorPageSourceProvider; -import io.trino.spi.connector.ConnectorRecordSetProvider; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplit; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.DynamicFilter; -import io.trino.spi.connector.RecordPageSource; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.MoreCollectors.toOptional; import static io.trino.plugin.jdbc.DefaultJdbcMetadata.MERGE_ROW_ID; import static io.trino.plugin.jdbc.MergeJdbcPageSource.MergedRowAdaptation; import static io.trino.plugin.jdbc.MergeJdbcPageSource.SourceColumn; +import static io.trino.plugin.jdbc.RetryingModule.retry; import static java.util.Objects.requireNonNull; +import static java.util.function.UnaryOperator.identity; public class JdbcPageSourceProvider implements ConnectorPageSourceProvider { private final JdbcClient jdbcClient; - private final ConnectorRecordSetProvider recordSetProvider; + private final ExecutorService executor; + private final RetryPolicy policy; @Inject - public JdbcPageSourceProvider(JdbcClient jdbcClient, ConnectorRecordSetProvider recordSetProvider) + public JdbcPageSourceProvider(JdbcClient jdbcClient, @ForRecordCursor ExecutorService executor, RetryPolicy policy) { this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); - this.recordSetProvider = requireNonNull(recordSetProvider, "recordSetProvider is null"); + this.executor = requireNonNull(executor, "executor is null"); + this.policy = requireNonNull(policy, "policy is null"); } @Override @@ -60,27 +68,45 @@ public ConnectorPageSource createPageSource( List columns, DynamicFilter dynamicFilter) { + JdbcSplit jdbcSplit = (JdbcSplit) split; + List jdbcColumns = columns.stream() + .map(JdbcColumnHandle.class::cast) + .collect(toImmutableList()); + if (table instanceof JdbcProcedureHandle procedureHandle) { - return new RecordPageSource(recordSetProvider.getRecordSet(transaction, session, split, procedureHandle, columns)); + List sourceColumns = procedureHandle.getColumns().orElseThrow(); + Map columnIndexMap = IntStream.range(0, sourceColumns.size()) + .boxed() + .collect(toImmutableMap(sourceColumns::get, identity())); + + return new MappedPageSource( + createPageSource(session, jdbcSplit, procedureHandle, sourceColumns), + jdbcColumns.stream() + .map(columnIndexMap::get) + .collect(toImmutableList())); } JdbcTableHandle tableHandle = (JdbcTableHandle) table; - Optional mergeRowId = columns.stream() - .map(JdbcColumnHandle.class::cast) + Optional mergeRowId = jdbcColumns.stream() .filter(column -> column.getColumnName().equalsIgnoreCase(MERGE_ROW_ID)) .collect(toOptional()); if (mergeRowId.isEmpty()) { - return new RecordPageSource(recordSetProvider.getRecordSet(transaction, session, split, tableHandle, columns)); + return new JdbcPageSource( + jdbcClient, + executor, + session, + jdbcSplit, + tableHandle.intersectedWithConstraint(jdbcSplit.getDynamicFilter().transformKeys(ColumnHandle.class::cast)), + jdbcColumns); } - return createMergePageSource(transaction, session, split, columns, tableHandle, mergeRowId); + return createMergePageSource(session, jdbcSplit, jdbcColumns, tableHandle, mergeRowId); } private MergeJdbcPageSource createMergePageSource( - ConnectorTransactionHandle transaction, ConnectorSession session, - ConnectorSplit split, - List columns, + JdbcSplit jdbcSplit, + List columns, JdbcTableHandle tableHandle, Optional mergeRowId) { @@ -88,7 +114,7 @@ private MergeJdbcPageSource createMergePageSource( List scanColumns = getScanColumns(session, jdbcClient, tableHandle, primaryKeys); ImmutableList.Builder columnAdaptationsBuilder = ImmutableList.builder(); - for (ColumnHandle columnHandle : columns) { + for (JdbcColumnHandle columnHandle : columns) { if (columnHandle.equals(mergeRowId.get())) { columnAdaptationsBuilder.add(buildMergeIdColumnAdaptation(scanColumns, primaryKeys)); } @@ -109,10 +135,19 @@ private MergeJdbcPageSource createMergePageSource( tableHandle.getAuthorization(), tableHandle.getUpdateAssignments()); return new MergeJdbcPageSource( - new RecordPageSource(recordSetProvider.getRecordSet(transaction, session, split, newTableHandle, scanColumns)), + createPageSource(session, jdbcSplit, newTableHandle, scanColumns), columnAdaptationsBuilder.build()); } + private JdbcPageSource createPageSource( + ConnectorSession session, + JdbcSplit jdbcSplit, + BaseJdbcConnectorTableHandle table, + List columnHandles) + { + return retry(policy, () -> new JdbcPageSource(jdbcClient, executor, session, jdbcSplit, table, columnHandles)); + } + private static List getScanColumns( ConnectorSession session, JdbcClient jdbcClient, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSet.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSet.java deleted file mode 100644 index 8b5b19fded03..000000000000 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSet.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.jdbc; - -import com.google.common.collect.ImmutableList; -import dev.failsafe.RetryPolicy; -import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.RecordCursor; -import io.trino.spi.connector.RecordSet; -import io.trino.spi.type.Type; - -import java.util.List; -import java.util.concurrent.ExecutorService; - -import static io.trino.plugin.jdbc.RetryingModule.retry; -import static java.util.Objects.requireNonNull; - -public class JdbcRecordSet - implements RecordSet -{ - private final JdbcClient jdbcClient; - private final ExecutorService executor; - private final BaseJdbcConnectorTableHandle table; - private final List columnHandles; - private final List columnTypes; - private final JdbcSplit split; - private final ConnectorSession session; - private final RetryPolicy policy; - - public JdbcRecordSet( - JdbcClient jdbcClient, - ExecutorService executor, - ConnectorSession session, - RetryPolicy policy, - JdbcSplit split, - BaseJdbcConnectorTableHandle table, - List columnHandles) - { - this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); - this.executor = requireNonNull(executor, "executor is null"); - this.split = requireNonNull(split, "split is null"); - - this.table = requireNonNull(table, "table is null"); - this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); - ImmutableList.Builder types = ImmutableList.builderWithExpectedSize(columnHandles.size()); - for (JdbcColumnHandle column : columnHandles) { - types.add(column.getColumnType()); - } - this.columnTypes = types.build(); - this.session = requireNonNull(session, "session is null"); - this.policy = requireNonNull(policy, "policy is null"); - } - - @Override - public List getColumnTypes() - { - return columnTypes; - } - - @Override - public RecordCursor cursor() - { - return retry(policy, () -> new JdbcRecordCursor(jdbcClient, executor, session, split, table, columnHandles)); - } -} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSetProvider.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSetProvider.java deleted file mode 100644 index 0e09450e885f..000000000000 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcRecordSetProvider.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.jdbc; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.inject.Inject; -import dev.failsafe.RetryPolicy; -import io.trino.plugin.base.MappedRecordSet; -import io.trino.spi.connector.ColumnHandle; -import io.trino.spi.connector.ConnectorRecordSetProvider; -import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.ConnectorSplit; -import io.trino.spi.connector.ConnectorTableHandle; -import io.trino.spi.connector.ConnectorTransactionHandle; -import io.trino.spi.connector.RecordSet; - -import java.util.List; -import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.stream.IntStream; - -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static java.util.Objects.requireNonNull; -import static java.util.function.UnaryOperator.identity; - -public class JdbcRecordSetProvider - implements ConnectorRecordSetProvider -{ - private final JdbcClient jdbcClient; - private final ExecutorService executor; - private final RetryPolicy policy; - - @Inject - public JdbcRecordSetProvider(JdbcClient jdbcClient, @ForRecordCursor ExecutorService executor, RetryPolicy policy) - { - this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); - this.executor = requireNonNull(executor, "executor is null"); - this.policy = requireNonNull(policy, "policy is null"); - } - - @Override - public RecordSet getRecordSet(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorSplit split, ConnectorTableHandle table, List columns) - { - JdbcSplit jdbcSplit = (JdbcSplit) split; - BaseJdbcConnectorTableHandle jdbcTable = (BaseJdbcConnectorTableHandle) table; - - // In the current API, the columns (and order) needed by the engine are provided via an argument to this method. Make sure we can - // satisfy the requirements using columns which were recorded in the table handle. - // If no columns are recorded, it means that applyProjection never got called (e.g., in the case all columns are being used) and all - // table columns should be returned. TODO: this is something that should be addressed once the getRecordSet API is revamped - jdbcTable.getColumns() - .ifPresent(tableColumns -> verify(ImmutableSet.copyOf(tableColumns).containsAll(columns))); - - if (jdbcTable instanceof JdbcTableHandle jdbcTableHandle) { - ImmutableList.Builder handles = ImmutableList.builderWithExpectedSize(columns.size()); - for (ColumnHandle handle : columns) { - handles.add((JdbcColumnHandle) handle); - } - - return new JdbcRecordSet( - jdbcClient, - executor, - session, - policy, - jdbcSplit, - jdbcTableHandle.intersectedWithConstraint(jdbcSplit.getDynamicFilter().transformKeys(ColumnHandle.class::cast)), - handles.build()); - } - JdbcProcedureHandle procedureHandle = (JdbcProcedureHandle) jdbcTable; - List sourceColumns = procedureHandle.getColumns().orElseThrow(); - - Map columnIndexMap = IntStream.range(0, sourceColumns.size()) - .boxed() - .collect(toImmutableMap(sourceColumns::get, identity())); - - return new MappedRecordSet( - new JdbcRecordSet( - jdbcClient, - executor, - session, - policy, - jdbcSplit, - procedureHandle, - sourceColumns), - columns.stream() - .map(JdbcColumnHandle.class::cast) - .map(columnIndexMap::get) - .collect(toImmutableList())); - } -} diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSet.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcPageSource.java similarity index 53% rename from plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSet.java rename to plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcPageSource.java index 87aede67b24b..2d3b7793bcaa 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSet.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcPageSource.java @@ -15,9 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import dev.failsafe.RetryPolicy; -import io.trino.spi.connector.RecordCursor; -import io.trino.spi.connector.RecordSet; +import io.trino.spi.Page; import io.trino.spi.connector.SchemaTableName; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -32,11 +30,8 @@ import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static io.airlift.testing.Closeables.closeAll; -import static io.trino.plugin.jdbc.TestingJdbcTypeHandle.JDBC_BIGINT; -import static io.trino.plugin.jdbc.TestingJdbcTypeHandle.JDBC_VARCHAR; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.testing.TestingConnectorSession.SESSION; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @@ -44,7 +39,7 @@ @TestInstance(PER_CLASS) @Execution(CONCURRENT) -public class TestJdbcRecordSet +public class TestJdbcPageSource { private TestingDatabase database; private JdbcClient jdbcClient; @@ -76,50 +71,32 @@ public void tearDown() executor = null; } - @Test - public void testGetColumnTypes() - { - RecordSet recordSet = createRecordSet(ImmutableList.of( - new JdbcColumnHandle("text", JDBC_VARCHAR, VARCHAR), - new JdbcColumnHandle("text_short", JDBC_VARCHAR, createVarcharType(32)), - new JdbcColumnHandle("value", JDBC_BIGINT, BIGINT))); - assertThat(recordSet.getColumnTypes()).containsExactly(VARCHAR, createVarcharType(32), BIGINT); - - recordSet = createRecordSet(ImmutableList.of( - new JdbcColumnHandle("value", JDBC_BIGINT, BIGINT), - new JdbcColumnHandle("text", JDBC_VARCHAR, VARCHAR))); - assertThat(recordSet.getColumnTypes()).containsExactly(BIGINT, VARCHAR); - - recordSet = createRecordSet(ImmutableList.of( - new JdbcColumnHandle("value", JDBC_BIGINT, BIGINT), - new JdbcColumnHandle("value", JDBC_BIGINT, BIGINT), - new JdbcColumnHandle("text", JDBC_VARCHAR, VARCHAR))); - assertThat(recordSet.getColumnTypes()).containsExactly(BIGINT, BIGINT, VARCHAR); - - recordSet = createRecordSet(ImmutableList.of()); - assertThat(recordSet.getColumnTypes()).isEmpty(); - } - @Test public void testCursorSimple() { - RecordSet recordSet = createRecordSet(ImmutableList.of( + try (JdbcPageSource pageSource = createPageSource(ImmutableList.of( columnHandles.get("text"), columnHandles.get("text_short"), - columnHandles.get("value"))); - - try (RecordCursor cursor = recordSet.cursor()) { - assertThat(cursor.getType(0)).isEqualTo(VARCHAR); - assertThat(cursor.getType(1)).isEqualTo(createVarcharType(32)); - assertThat(cursor.getType(2)).isEqualTo(BIGINT); - + columnHandles.get("value")))) { Map data = new LinkedHashMap<>(); - while (cursor.advanceNextPosition()) { - data.put(cursor.getSlice(0).toStringUtf8(), cursor.getLong(2)); - assertThat(cursor.getSlice(0)).isEqualTo(cursor.getSlice(1)); - assertThat(cursor.isNull(0)).isFalse(); - assertThat(cursor.isNull(1)).isFalse(); - assertThat(cursor.isNull(2)).isFalse(); + for (Page page = pageSource.getNextPage(); ; page = pageSource.getNextPage()) { + if (page == null) { + continue; + } + for (int position = 0; position < page.getPositionCount(); position++) { + assertThat(page.getBlock(0).isNull(position)).isFalse(); + assertThat(page.getBlock(1).isNull(position)).isFalse(); + assertThat(page.getBlock(2).isNull(position)).isFalse(); + assertThat(VARCHAR.getSlice(page.getBlock(0), position)) + .isEqualTo(VARCHAR.getSlice(page.getBlock(1), position)); + data.put( + VARCHAR.getSlice(page.getBlock(0), position).toStringUtf8(), + BIGINT.getLong(page.getBlock(2), position)); + } + + if (pageSource.isFinished()) { + break; + } } assertThat(data).isEqualTo(ImmutableMap.builder() @@ -131,27 +108,33 @@ public void testCursorSimple() .put("twelve", 12L) .buildOrThrow()); - assertThat(cursor.getReadTimeNanos()).isPositive(); + assertThat(pageSource.getReadTimeNanos()).isPositive(); } } @Test public void testCursorMixedOrder() { - RecordSet recordSet = createRecordSet(ImmutableList.of( + try (JdbcPageSource pageSource = createPageSource(ImmutableList.of( columnHandles.get("value"), columnHandles.get("value"), - columnHandles.get("text"))); - - try (RecordCursor cursor = recordSet.cursor()) { - assertThat(cursor.getType(0)).isEqualTo(BIGINT); - assertThat(cursor.getType(1)).isEqualTo(BIGINT); - assertThat(cursor.getType(2)).isEqualTo(VARCHAR); - + columnHandles.get("text")))) { Map data = new LinkedHashMap<>(); - while (cursor.advanceNextPosition()) { - assertThat(cursor.getLong(0)).isEqualTo(cursor.getLong(1)); - data.put(cursor.getSlice(2).toStringUtf8(), cursor.getLong(0)); + for (Page page = pageSource.getNextPage(); ; page = pageSource.getNextPage()) { + if (page == null) { + continue; + } + for (int position = 0; position < page.getPositionCount(); position++) { + assertThat(BIGINT.getLong(page.getBlock(0), position)) + .isEqualTo(BIGINT.getLong(page.getBlock(1), position)); + data.put( + VARCHAR.getSlice(page.getBlock(2), position).toStringUtf8(), + BIGINT.getLong(page.getBlock(0), position)); + } + + if (pageSource.isFinished()) { + break; + } } assertThat(data).isEqualTo(ImmutableMap.builder() @@ -163,25 +146,24 @@ public void testCursorMixedOrder() .put("twelve", 12L) .buildOrThrow()); - assertThat(cursor.getReadTimeNanos()).isPositive(); + assertThat(pageSource.getReadTimeNanos()).isPositive(); } } @Test public void testIdempotentClose() { - RecordSet recordSet = createRecordSet(ImmutableList.of( + JdbcPageSource pageSource = createPageSource(ImmutableList.of( columnHandles.get("value"), columnHandles.get("value"), columnHandles.get("text"))); - RecordCursor cursor = recordSet.cursor(); - cursor.close(); - cursor.close(); + pageSource.close(); + pageSource.close(); } - private JdbcRecordSet createRecordSet(List columnHandles) + private JdbcPageSource createPageSource(List columnHandles) { - return new JdbcRecordSet(jdbcClient, executor, SESSION, RetryPolicy.ofDefaults(), split, table, columnHandles); + return new JdbcPageSource(jdbcClient, executor, SESSION, split, table, columnHandles); } } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSetProvider.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcPageSourceProvider.java similarity index 84% rename from plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSetProvider.java rename to plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcPageSourceProvider.java index 5ff6896312e1..b6173769d07c 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcRecordSetProvider.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcPageSourceProvider.java @@ -16,12 +16,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import dev.failsafe.RetryPolicy; +import io.trino.spi.Page; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTransactionHandle; -import io.trino.spi.connector.RecordCursor; -import io.trino.spi.connector.RecordSet; +import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; @@ -55,7 +56,7 @@ @TestInstance(PER_CLASS) @Execution(CONCURRENT) -public class TestJdbcRecordSetProvider +public class TestJdbcPageSourceProvider { private static final ConnectorSession SESSION = TestingConnectorSession.builder() .setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig(), Optional.empty()).getSessionProperties()) @@ -101,20 +102,29 @@ public void tearDown() } @Test - public void testGetRecordSet() + public void testGetPageSource() { ConnectorTransactionHandle transaction = new JdbcTransactionHandle(); - JdbcRecordSetProvider recordSetProvider = new JdbcRecordSetProvider(jdbcClient, executor, RetryPolicy.ofDefaults()); - RecordSet recordSet = recordSetProvider.getRecordSet(transaction, SESSION, split, table, ImmutableList.of(textColumn, textShortColumn, valueColumn)); - assertThat(recordSet).withFailMessage("recordSet is null").isNotNull(); - - RecordCursor cursor = recordSet.cursor(); - assertThat(cursor).withFailMessage("cursor is null").isNotNull(); + JdbcPageSourceProvider pageSourceProvider = new JdbcPageSourceProvider(jdbcClient, executor, RetryPolicy.ofDefaults()); + ConnectorPageSource pageSource = pageSourceProvider.createPageSource(transaction, SESSION, split, table, ImmutableList.of(textColumn, textShortColumn, valueColumn), DynamicFilter.EMPTY); + assertThat(pageSource).withFailMessage("recordSet is null").isNotNull(); Map data = new LinkedHashMap<>(); - while (cursor.advanceNextPosition()) { - data.put(cursor.getSlice(0).toStringUtf8(), cursor.getLong(2)); - assertThat(cursor.getSlice(0)).isEqualTo(cursor.getSlice(1)); + for (Page page = pageSource.getNextPage(); ; page = pageSource.getNextPage()) { + if (page == null) { + continue; + } + for (int position = 0; position < page.getPositionCount(); position++) { + assertThat(VARCHAR.getSlice(page.getBlock(0), position)) + .isEqualTo(VARCHAR.getSlice(page.getBlock(1), position)); + data.put( + VARCHAR.getSlice(page.getBlock(0), position).toStringUtf8(), + BIGINT.getLong(page.getBlock(2), position)); + } + + if (pageSource.isFinished()) { + break; + } } assertThat(data).isEqualTo(ImmutableMap.builder() .put("one", 1L) @@ -197,7 +207,7 @@ public void testTupleDomain() true)))); } - private RecordCursor getCursor(JdbcTableHandle jdbcTableHandle, List columns, TupleDomain domain) + private ConnectorPageSource getCursor(JdbcTableHandle jdbcTableHandle, List columns, TupleDomain domain) { jdbcTableHandle = new JdbcTableHandle( jdbcTableHandle.getRelationHandle(), @@ -215,9 +225,7 @@ private RecordCursor getCursor(JdbcTableHandle jdbcTableHandle, List accessControl, Set procedures, @@ -89,7 +89,7 @@ public RedshiftUnloadConnector( .flatMap(tablePropertiesProvider -> tablePropertiesProvider.getTableProperties().stream()) .collect(toImmutableList()); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); - this.pageSourceProvider = new RedshiftPageSourceProvider(jdbcRecordSetProvider, fileSystemFactory, fileFormatDataSourceStats); + this.pageSourceProvider = new RedshiftPageSourceProvider(jdbcPageSourceProvider, fileSystemFactory, fileFormatDataSourceStats); } @Override