Skip to content

Commit

Permalink
Migrate JDBC to page source API
Browse files Browse the repository at this point in the history
  • Loading branch information
raunaqmorarka committed Feb 26, 2025
1 parent 40f61e7 commit bfbff88
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 422 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
* Attaches dynamic filter to {@link JdbcSplit} after {@link JdbcDynamicFilteringSplitManager}
* has waited for the collection of dynamic filters.
* This allows JDBC based connectors to avoid waiting for dynamic filters again on the worker node
* in {@link JdbcRecordSetProvider}. The number of splits generated by JDBC based connectors are
* in {@link JdbcPageSourceProvider}. The number of splits generated by JDBC based connectors are
* typically small, therefore attaching dynamic filter here does not add significant overhead.
* Waiting for dynamic filters in {@link JdbcDynamicFilteringSplitManager} is preferred over waiting
* for them on the worker node in {@link JdbcRecordSetProvider} to allow connectors to take advantage of
* for them on the worker node in {@link JdbcPageSourceProvider} to allow connectors to take advantage of
* dynamic filters during the splits generation phase.
*/
public class DynamicFilteringJdbcSplitSource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@

import com.google.common.util.concurrent.MoreExecutors;
import com.google.inject.Binder;
import com.google.inject.Inject;
import com.google.inject.Key;
import com.google.inject.Provider;
import com.google.inject.Scopes;
import com.google.inject.Singleton;
import com.google.inject.multibindings.Multibinder;
import com.google.inject.multibindings.ProvidesIntoOptional;
import dev.failsafe.RetryPolicy;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.base.mapping.IdentifierMappingModule;
import io.trino.plugin.base.session.SessionPropertiesProvider;
Expand All @@ -33,7 +29,6 @@
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorPageSinkProvider;
import io.trino.spi.connector.ConnectorPageSourceProvider;
import io.trino.spi.connector.ConnectorRecordSetProvider;
import io.trino.spi.connector.ConnectorSplitManager;
import io.trino.spi.function.table.ConnectorTableFunction;
import io.trino.spi.procedure.Procedure;
Expand All @@ -42,7 +37,6 @@

import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
import static com.google.inject.multibindings.ProvidesIntoOptional.Type.DEFAULT;
import static io.airlift.configuration.ConditionalModule.conditionalModule;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.trino.plugin.base.ClosingBinder.closingBinder;
Expand Down Expand Up @@ -149,12 +143,4 @@ public static void bindTablePropertiesProvider(Binder binder, Class<? extends Ta
{
tablePropertiesProviderBinder(binder).addBinding().to(type).in(Scopes.SINGLETON);
}

@ProvidesIntoOptional(DEFAULT)
@Inject
@Singleton
ConnectorRecordSetProvider recordSetProvider(JdbcClient jdbcClient, @ForRecordCursor ExecutorService executor, RetryPolicy<Object> policy)
{
return new JdbcRecordSetProvider(jdbcClient, executor, policy);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
package io.trino.plugin.jdbc;

import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.connector.ConnectorPageSource;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.RecordCursor;
import io.trino.spi.type.Type;
import jakarta.annotation.Nullable;

Expand All @@ -28,25 +32,25 @@
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.OptionalLong;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.concurrent.MoreFutures.getFutureValue;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static java.lang.System.nanoTime;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.supplyAsync;

public class JdbcRecordCursor
implements RecordCursor
public class JdbcPageSource
implements ConnectorPageSource
{
private static final Logger log = Logger.get(JdbcRecordCursor.class);
private static final Logger log = Logger.get(JdbcPageSource.class);

private final ExecutorService executor;

private final JdbcColumnHandle[] columnHandles;
private final List<JdbcColumnHandle> columnHandles;
private final ReadFunction[] readFunctions;
private final BooleanReadFunction[] booleanReadFunctions;
private final DoubleReadFunction[] doubleReadFunctions;
Expand All @@ -58,16 +62,18 @@ public class JdbcRecordCursor
private final Connection connection;
private final PreparedStatement statement;
private final AtomicLong readTimeNanos = new AtomicLong(0);
private final PageBuilder pageBuilder;
private final CompletableFuture<ResultSet> resultSetFuture;
@Nullable
private ResultSet resultSet;
private boolean finished;
private boolean closed;
private long completedPositions;

public JdbcRecordCursor(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, BaseJdbcConnectorTableHandle table, List<JdbcColumnHandle> columnHandles)
public JdbcPageSource(JdbcClient jdbcClient, ExecutorService executor, ConnectorSession session, JdbcSplit split, BaseJdbcConnectorTableHandle table, List<JdbcColumnHandle> columnHandles)
{
this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null");
this.executor = requireNonNull(executor, "executor is null");

this.columnHandles = columnHandles.toArray(new JdbcColumnHandle[0]);
this.columnHandles = ImmutableList.copyOf(columnHandles);

readFunctions = new ReadFunction[columnHandles.size()];
booleanReadFunctions = new BooleanReadFunction[columnHandles.size()];
Expand All @@ -84,7 +90,7 @@ public JdbcRecordCursor(JdbcClient jdbcClient, ExecutorService executor, Connect
connection = jdbcClient.getConnection(session, split, (JdbcTableHandle) table);
}

for (int i = 0; i < this.columnHandles.length; i++) {
for (int i = 0; i < this.columnHandles.size(); i++) {
JdbcColumnHandle columnHandle = columnHandles.get(i);
ColumnMapping columnMapping = jdbcClient.toColumnMapping(session, connection, columnHandle.getJdbcTypeHandle())
.orElseThrow(() -> new VerifyException("Column %s has unsupported type %s".formatted(columnHandle.getColumnName(), columnHandle.getJdbcTypeHandle())));
Expand Down Expand Up @@ -119,6 +125,22 @@ else if (javaType == Slice.class) {
else {
statement = jdbcClient.buildSql(session, connection, split, (JdbcTableHandle) table, columnHandles);
}
pageBuilder = new PageBuilder(columnHandles.stream()
.map(JdbcColumnHandle::getColumnType)
.collect(toImmutableList()));
resultSetFuture = supplyAsync(() -> {
long start = nanoTime();
try {
log.debug("Executing: %s", statement);
return statement.executeQuery();
}
catch (SQLException e) {
throw handleSqlException(e);
}
finally {
readTimeNanos.addAndGet(nanoTime() - start);
}
}, executor);
}
catch (SQLException | RuntimeException e) {
throw handleSqlException(e);
Expand All @@ -132,140 +154,82 @@ public long getReadTimeNanos()
}

@Override
public long getCompletedBytes()
public boolean isFinished()
{
return 0;
return finished;
}

@Override
public Type getType(int field)
public Page getNextPage()
{
return columnHandles[field].getColumnType();
}

@Override
public boolean advanceNextPosition()
{
if (closed) {
return false;
}

verify(pageBuilder.isEmpty(), "Expected pageBuilder to be empty");
try {
if (resultSet == null) {
long start = System.nanoTime();
Future<ResultSet> resultSetFuture = executor.submit(() -> {
log.debug("Executing: %s", statement);
return statement.executeQuery();
});
try {
// statement.executeQuery() may block uninterruptedly, using async way so we are able to cancel remote query
// See javadoc of java.sql.Connection.setNetworkTimeout
resultSet = resultSetFuture.get();
}
catch (ExecutionException e) {
if (e.getCause() instanceof SQLException cause) {
SQLException sqlException = new SQLException(cause.getMessage(), cause.getSQLState(), cause.getErrorCode(), e);
if (cause.getNextException() != null) {
sqlException.setNextException(cause.getNextException());
}
throw sqlException;
resultSet = requireNonNull(getFutureValue(resultSetFuture), "resultSet is null");
}

while (!pageBuilder.isFull() && resultSet.next()) {
pageBuilder.declarePosition();
completedPositions++;
for (int i = 0; i < columnHandles.size(); i++) {
BlockBuilder output = pageBuilder.getBlockBuilder(i);
Type type = columnHandles.get(i).getColumnType();
if (readFunctions[i].isNull(resultSet, i + 1)) {
output.appendNull();
}
else if (booleanReadFunctions[i] != null) {
type.writeBoolean(output, booleanReadFunctions[i].readBoolean(resultSet, i + 1));
}
else if (doubleReadFunctions[i] != null) {
type.writeDouble(output, doubleReadFunctions[i].readDouble(resultSet, i + 1));
}
else if (longReadFunctions[i] != null) {
type.writeLong(output, longReadFunctions[i].readLong(resultSet, i + 1));
}
else if (sliceReadFunctions[i] != null) {
type.writeSlice(output, sliceReadFunctions[i].readSlice(resultSet, i + 1));
}
else {
type.writeObject(output, objectReadFunctions[i].readObject(resultSet, i + 1));
}
throw new RuntimeException(e);
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
resultSetFuture.cancel(true);
throw new RuntimeException(e);
}
finally {
readTimeNanos.addAndGet(System.nanoTime() - start);
}
}
return resultSet.next();
}
catch (SQLException | RuntimeException e) {
throw handleSqlException(e);
}
}

@Override
public boolean getBoolean(int field)
{
checkState(!closed, "cursor is closed");
requireNonNull(resultSet, "resultSet is null");
try {
return booleanReadFunctions[field].readBoolean(resultSet, field + 1);
if (!pageBuilder.isFull()) {
finished = true;
}
}
catch (SQLException | RuntimeException e) {
throw handleSqlException(e);
}
}

@Override
public long getLong(int field)
{
checkState(!closed, "cursor is closed");
requireNonNull(resultSet, "resultSet is null");
try {
return longReadFunctions[field].readLong(resultSet, field + 1);
}
catch (SQLException | RuntimeException e) {
throw handleSqlException(e);
}
Page page = pageBuilder.build();
pageBuilder.reset();
return page;
}

@Override
public double getDouble(int field)
public long getMemoryUsage()
{
checkState(!closed, "cursor is closed");
requireNonNull(resultSet, "resultSet is null");
try {
return doubleReadFunctions[field].readDouble(resultSet, field + 1);
}
catch (SQLException | RuntimeException e) {
throw handleSqlException(e);
}
return pageBuilder.getRetainedSizeInBytes();
}

@Override
public Slice getSlice(int field)
public long getCompletedBytes()
{
checkState(!closed, "cursor is closed");
requireNonNull(resultSet, "resultSet is null");
try {
return sliceReadFunctions[field].readSlice(resultSet, field + 1);
}
catch (SQLException | RuntimeException e) {
throw handleSqlException(e);
}
return 0;
}

@Override
public Object getObject(int field)
public OptionalLong getCompletedPositions()
{
checkState(!closed, "cursor is closed");
requireNonNull(resultSet, "resultSet is null");
try {
return objectReadFunctions[field].readObject(resultSet, field + 1);
}
catch (SQLException | RuntimeException e) {
throw handleSqlException(e);
}
return OptionalLong.of(completedPositions);
}

@Override
public boolean isNull(int field)
public CompletableFuture<?> isBlocked()
{
checkState(!closed, "cursor is closed");
checkArgument(field < columnHandles.length, "Invalid field index");
requireNonNull(resultSet, "resultSet is null");

try {
return readFunctions[field].isNull(resultSet, field + 1);
}
catch (SQLException | RuntimeException e) {
throw handleSqlException(e);
}
return resultSetFuture;
}

@Override
Expand All @@ -275,6 +239,7 @@ public void close()
return;
}
closed = true;
finished = true;

// use try with resources to close everything properly
try (Connection connection = this.connection;
Expand All @@ -291,11 +256,13 @@ public void close()
}
if (connection != null && resultSet != null) {
jdbcClient.abortReadConnection(connection, resultSet);
resultSetFuture.cancel(true);
}
}
catch (SQLException | RuntimeException e) {
// ignore exception from close
}
resultSet = null;
}

private RuntimeException handleSqlException(Exception e)
Expand Down
Loading

0 comments on commit bfbff88

Please sign in to comment.