diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java b/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java index 4aa2eda71495..6e5004994c95 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java @@ -15,12 +15,20 @@ import okhttp3.OkHttpClient; +import java.util.Optional; +import java.util.Set; + public final class StatementClientFactory { private StatementClientFactory() {} public static StatementClient newStatementClient(OkHttpClient httpClient, ClientSession session, String query) { - return new StatementClientV1(httpClient, session, query); + return new StatementClientV1(httpClient, session, query, Optional.empty()); + } + + public static StatementClient newStatementClient(OkHttpClient httpClient, ClientSession session, String query, Optional> clientCapabilities) + { + return new StatementClientV1(httpClient, session, query, clientCapabilities); } } diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java index c2acbaa3c7ff..c231679c32fc 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java @@ -48,6 +48,7 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.net.HttpHeaders.ACCEPT_ENCODING; import static com.google.common.net.HttpHeaders.USER_AGENT; import static io.trino.client.JsonCodec.jsonCodec; @@ -56,6 +57,7 @@ import static java.net.HttpURLConnection.HTTP_OK; import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; import static java.net.HttpURLConnection.HTTP_UNAVAILABLE; +import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -93,7 +95,7 @@ class StatementClientV1 private final AtomicReference state = new AtomicReference<>(State.RUNNING); - public StatementClientV1(OkHttpClient httpClient, ClientSession session, String query) + public StatementClientV1(OkHttpClient httpClient, ClientSession session, String query, Optional> clientCapabilities) { requireNonNull(httpClient, "httpClient is null"); requireNonNull(session, "session is null"); @@ -107,7 +109,9 @@ public StatementClientV1(OkHttpClient httpClient, ClientSession session, String .filter(Optional::isPresent) .map(Optional::get) .findFirst(); - this.clientCapabilities = Joiner.on(",").join(ClientCapabilities.values()); + this.clientCapabilities = Joiner.on(",").join(clientCapabilities.orElseGet(() -> stream(ClientCapabilities.values()) + .map(Enum::name) + .collect(toImmutableSet()))); this.compressionDisabled = session.isCompressionDisabled(); Request request = buildQueryRequest(session, query); diff --git a/core/docker/default/etc/jvm.config b/core/docker/default/etc/jvm.config index 68b4f129d830..47e9e3176ac7 100644 --- a/core/docker/default/etc/jvm.config +++ b/core/docker/default/etc/jvm.config @@ -15,3 +15,5 @@ # Improve AES performance for S3, etc. on ARM64 (JDK-8271567) -XX:+UnlockDiagnosticVMOptions -XX:+UseAESCTRIntrinsics +# Disable Preventive GC for performance reasons (JDK-8293861) +-XX:-G1UsePreventiveGC diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index 2b48d5d224ea..b30ec04bdf93 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -607,6 +607,7 @@ private SessionBuilder(Session session) this.remoteUserAddress = session.remoteUserAddress.orElse(null); this.userAgent = session.userAgent.orElse(null); this.clientInfo = session.clientInfo.orElse(null); + this.clientCapabilities = ImmutableSet.copyOf(session.clientCapabilities); this.clientTags = ImmutableSet.copyOf(session.clientTags); this.start = session.start; this.systemProperties.putAll(session.systemProperties); diff --git a/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java b/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java index 21b75218f657..98228ac46f32 100644 --- a/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java +++ b/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -139,12 +138,6 @@ public Map> listTableColumns(ConnectorSess return builder.buildOrThrow(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle handle, Constraint constraint) { diff --git a/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java b/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java index 7514bc9be6b9..58235f1a4152 100644 --- a/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/CreateMaterializedViewTask.java @@ -115,6 +115,10 @@ public ListenableFuture execute( parameterLookup, true); + if (statement.getGracePeriod().isPresent()) { + // Should fail in analysis + throw new UnsupportedOperationException(); + } MaterializedViewDefinition definition = new MaterializedViewDefinition( sql, session.getCatalog(), diff --git a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java index f5e461cc3e5d..5eed6d8fe2d6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java @@ -13,6 +13,7 @@ */ package io.trino.operator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.http.client.HttpClient; @@ -36,9 +37,9 @@ import java.util.Deque; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.locks.Lock; @@ -68,7 +69,7 @@ public class DirectExchangeClient @GuardedBy("this") private boolean noMoreLocations; - private final ConcurrentMap allClients = new ConcurrentHashMap<>(); + private final Map allClients = new ConcurrentHashMap<>(); @GuardedBy("this") private final Deque queuedClients = new LinkedList<>(); @@ -260,31 +261,37 @@ public synchronized void close() } } - private synchronized void scheduleRequestIfNecessary() + @VisibleForTesting + synchronized int scheduleRequestIfNecessary() { if ((buffer.isFinished() || buffer.isFailed()) && completedClients.size() == allClients.size()) { - return; + return 0; } long neededBytes = buffer.getRemainingCapacityInBytes(); if (neededBytes <= 0) { - return; + return 0; } - int clientCount = (int) ((1.0 * neededBytes / averageBytesPerRequest) * concurrentRequestMultiplier); - clientCount = Math.max(clientCount, 1); - - int pendingClients = allClients.size() - queuedClients.size() - completedClients.size(); - clientCount -= pendingClients; + long reservedBytesForScheduledClients = allClients.values().stream() + .filter(client -> !queuedClients.contains(client) && !completedClients.contains(client)) + .mapToLong(HttpPageBufferClient::getAverageRequestSizeInBytes) + .sum(); + long projectedBytesToBeRequested = 0; + int clientCount = 0; + for (HttpPageBufferClient client : queuedClients) { + if (projectedBytesToBeRequested >= neededBytes * concurrentRequestMultiplier - reservedBytesForScheduledClients) { + break; + } + projectedBytesToBeRequested += client.getAverageRequestSizeInBytes(); + clientCount++; + } for (int i = 0; i < clientCount; i++) { HttpPageBufferClient client = queuedClients.poll(); - if (client == null) { - // no more clients available - return; - } client.scheduleRequest(); } + return clientCount; } public ListenableFuture isBlocked() @@ -292,6 +299,18 @@ public ListenableFuture isBlocked() return buffer.isBlocked(); } + @VisibleForTesting + Deque getQueuedClients() + { + return queuedClients; + } + + @VisibleForTesting + Map getAllClients() + { + return allClients; + } + private boolean addPages(HttpPageBufferClient client, List pages) { checkState(!completedClients.contains(client), "client is already marked as completed"); diff --git a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java index 388883f38dee..3e9b0dcfd72c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java @@ -13,6 +13,7 @@ */ package io.trino.operator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; import com.google.common.io.LittleEndianDataInputStream; @@ -145,6 +146,9 @@ public interface ClientCallback @GuardedBy("this") private String taskInstanceId; + // it is synchronized on `this` for update + private volatile long averageRequestSizeInBytes; + private final AtomicLong rowsReceived = new AtomicLong(); private final AtomicInteger pagesReceived = new AtomicInteger(); @@ -153,6 +157,7 @@ public interface ClientCallback private final AtomicInteger requestsScheduled = new AtomicInteger(); private final AtomicInteger requestsCompleted = new AtomicInteger(); + private final AtomicInteger requestsSucceeded = new AtomicInteger(); private final AtomicInteger requestsFailed = new AtomicInteger(); private final Executor pageBufferClientCallbackExecutor; @@ -251,6 +256,7 @@ else if (completed) { requestsScheduled.get(), requestsCompleted.get(), requestsFailed.get(), + requestsSucceeded.get(), httpRequestState); } @@ -259,6 +265,11 @@ public TaskId getRemoteTaskId() return remoteTaskId; } + public long getAverageRequestSizeInBytes() + { + return averageRequestSizeInBytes; + } + public synchronized boolean isRunning() { return future != null; @@ -434,6 +445,8 @@ public Void handle(Request request, Response response) } } requestsCompleted.incrementAndGet(); + long responseSize = pages.stream().mapToLong(Slice::length).sum(); + requestSucceeded(responseSize); synchronized (HttpPageBufferClient.this) { // client is complete, acknowledge it by sending it a delete in the next request @@ -485,6 +498,14 @@ public void onFailure(Throwable t) }, pageBufferClientCallbackExecutor); } + @VisibleForTesting + synchronized void requestSucceeded(long responseSize) + { + int successfulRequests = requestsSucceeded.incrementAndGet(); + // AVG_n = AVG_(n-1) * (n-1)/n + VALUE_n / n + averageRequestSizeInBytes = (long) ((1.0 * averageRequestSizeInBytes * (successfulRequests - 1)) + responseSize) / successfulRequests; + } + private synchronized void destroyTaskResults() { HttpResponseFuture resultFuture = httpClient.executeAsync(prepareDelete().setUri(location).build(), createStatusResponseHandler()); diff --git a/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java b/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java index a1a2dac3b707..f8584b97f210 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java +++ b/core/trino-main/src/main/java/io/trino/operator/PageBufferClientStatus.java @@ -37,6 +37,7 @@ public class PageBufferClientStatus private final int requestsScheduled; private final int requestsCompleted; private final int requestsFailed; + private final int requestsSucceeded; private final String httpRequestState; @JsonCreator @@ -50,6 +51,7 @@ public PageBufferClientStatus(@JsonProperty("uri") URI uri, @JsonProperty("requestsScheduled") int requestsScheduled, @JsonProperty("requestsCompleted") int requestsCompleted, @JsonProperty("requestsFailed") int requestsFailed, + @JsonProperty("requestsSucceeded") int requestsSucceeded, @JsonProperty("httpRequestState") String httpRequestState) { this.uri = uri; @@ -62,6 +64,7 @@ public PageBufferClientStatus(@JsonProperty("uri") URI uri, this.requestsScheduled = requestsScheduled; this.requestsCompleted = requestsCompleted; this.requestsFailed = requestsFailed; + this.requestsSucceeded = requestsSucceeded; this.httpRequestState = httpRequestState; } @@ -125,6 +128,12 @@ public int getRequestsFailed() return requestsFailed; } + @JsonProperty + public int getRequestsSucceeded() + { + return requestsSucceeded; + } + @JsonProperty public String getHttpRequestState() { diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java b/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java index 62cee860defe..808d875ea431 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java @@ -19,7 +19,6 @@ import io.trino.spi.Page; import it.unimi.dsi.fastutil.ints.IntArrayList; -import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import java.util.List; @@ -58,18 +57,7 @@ public PartitioningExchanger( @Override public void accept(Page page) { - Consumer wholePagePartition = partitionPageOrFindWholePagePartition(page, partitionedPagePreparer.apply(page)); - if (wholePagePartition != null) { - // whole input page will go to this partition, compact the input page avoid over-retaining memory and to - // match the behavior of sub-partitioned pages that copy positions out - page.compact(); - sendPageToPartition(wholePagePartition, page); - } - } - - @Nullable - private Consumer partitionPageOrFindWholePagePartition(Page page, Page partitionPage) - { + Page partitionPage = partitionedPagePreparer.apply(page); // assign each row to a partition. The assignments lists are all expected to cleared by the previous iterations for (int position = 0; position < partitionPage.getPositionCount(); position++) { int partition = partitionFunction.getPartition(partitionPage, position); @@ -89,22 +77,19 @@ private Consumer partitionPageOrFindWholePagePartition(Page page, Page par int[] positions = positionsList.elements(); positionsList.clear(); + Page pageSplit; if (partitionSize == page.getPositionCount()) { - // entire page will be sent to this partition, compact and send the page after releasing the lock - return buffers.get(partition); + // whole input page will go to this partition, compact the input page avoid over-retaining memory and to + // match the behavior of sub-partitioned pages that copy positions out + page.compact(); + pageSplit = page; } - Page pageSplit = page.copyPositions(positions, 0, partitionSize); - sendPageToPartition(buffers.get(partition), pageSplit); + else { + pageSplit = page.copyPositions(positions, 0, partitionSize); + } + memoryManager.updateMemoryUsage(pageSplit.getRetainedSizeInBytes()); + buffers.get(partition).accept(pageSplit); } - // No single partition receives the entire input page - return null; - } - - // This is safe to call without synchronizing because the partition buffers are themselves threadsafe - private void sendPageToPartition(Consumer buffer, Page pageSplit) - { - memoryManager.updateMemoryUsage(pageSplit.getRetainedSizeInBytes()); - buffer.accept(pageSplit); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java b/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java index a08faec243ae..d21267eec5cc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java @@ -342,13 +342,15 @@ private ProcessBatchResult processBatch(int batchSize) } else { if (pageProjectWork == null) { - Page inputPage = projection.getInputChannels().getInputChannels(page); - expressionProfiler.start(); - pageProjectWork = projection.project(session, yieldSignal, inputPage, positionsBatch); - long projectionTimeNanos = expressionProfiler.stop(positionsBatch.size()); - metrics.recordProjectionTime(projectionTimeNanos); + pageProjectWork = projection.project(session, yieldSignal, projection.getInputChannels().getInputChannels(page), positionsBatch); } - if (!pageProjectWork.process()) { + + expressionProfiler.start(); + boolean finished = pageProjectWork.process(); + long projectionTimeNanos = expressionProfiler.stop(positionsBatch.size()); + metrics.recordProjectionTime(projectionTimeNanos); + + if (!finished) { return ProcessBatchResult.processBatchYield(); } previouslyComputedResults[i] = pageProjectWork.getResult(); diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/NimbusOAuth2Client.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/NimbusOAuth2Client.java index 019ceb7affbd..9860703bdf59 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/NimbusOAuth2Client.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/NimbusOAuth2Client.java @@ -30,7 +30,6 @@ import com.nimbusds.oauth2.sdk.AccessTokenResponse; import com.nimbusds.oauth2.sdk.AuthorizationCode; import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; -import com.nimbusds.oauth2.sdk.AuthorizationGrant; import com.nimbusds.oauth2.sdk.AuthorizationRequest; import com.nimbusds.oauth2.sdk.ParseException; import com.nimbusds.oauth2.sdk.RefreshTokenGrant; @@ -347,19 +346,19 @@ private String hashNonce(String nonce) private T getTokenResponse(String code, URI callbackUri, NimbusAirliftHttpClient.Parser parser) throws ChallengeFailedException { - return getTokenResponse(new AuthorizationCodeGrant(new AuthorizationCode(code), callbackUri), parser); + return getTokenResponse(new TokenRequest(tokenUrl, clientAuth, new AuthorizationCodeGrant(new AuthorizationCode(code), callbackUri)), parser); } private T getTokenResponse(String refreshToken, NimbusAirliftHttpClient.Parser parser) throws ChallengeFailedException { - return getTokenResponse(new RefreshTokenGrant(new RefreshToken(refreshToken)), parser); + return getTokenResponse(new TokenRequest(tokenUrl, clientAuth, new RefreshTokenGrant(new RefreshToken(refreshToken)), scope), parser); } - private T getTokenResponse(AuthorizationGrant authorizationGrant, NimbusAirliftHttpClient.Parser parser) + private T getTokenResponse(TokenRequest tokenRequest, NimbusAirliftHttpClient.Parser parser) throws ChallengeFailedException { - T tokenResponse = httpClient.execute(new TokenRequest(tokenUrl, clientAuth, authorizationGrant, scope), parser); + T tokenResponse = httpClient.execute(tokenRequest, parser); if (!tokenResponse.indicatesSuccess()) { throw new ChallengeFailedException("Error while fetching access token: " + tokenResponse.toErrorResponse().toJSONObject()); } diff --git a/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java b/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java index 253b456be457..24b66c563f24 100644 --- a/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java +++ b/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java @@ -164,7 +164,7 @@ private void redirectForNewToken(ContainerRequestContext request, String refresh { OAuth2Client.Response response = client.refreshTokens(refreshToken); String serializedToken = tokenPairSerializer.serialize(TokenPair.fromOAuth2Response(response)); - request.abortWith(Response.seeOther(request.getUriInfo().getRequestUri()) + request.abortWith(Response.temporaryRedirect(request.getUriInfo().getRequestUri()) .cookie(OAuthWebUiCookie.create(serializedToken, tokenExpiration.map(expiration -> Instant.now().plus(expiration)).orElse(response.getExpiration()))) .build()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index d5f2b90e1a43..2b83247f7ad8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -2222,6 +2222,7 @@ public static class TableFunctionInvocationAnalysis private final String functionName; private final Map arguments; private final List tableArgumentAnalyses; + private final Map> requiredColumns; private final List> copartitioningLists; private final int properColumnsCount; private final ConnectorTableFunctionHandle connectorTableFunctionHandle; @@ -2232,6 +2233,7 @@ public TableFunctionInvocationAnalysis( String functionName, Map arguments, List tableArgumentAnalyses, + Map> requiredColumns, List> copartitioningLists, int properColumnsCount, ConnectorTableFunctionHandle connectorTableFunctionHandle, @@ -2241,6 +2243,8 @@ public TableFunctionInvocationAnalysis( this.functionName = requireNonNull(functionName, "functionName is null"); this.arguments = ImmutableMap.copyOf(arguments); this.tableArgumentAnalyses = ImmutableList.copyOf(tableArgumentAnalyses); + this.requiredColumns = requiredColumns.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> ImmutableList.copyOf(entry.getValue()))); this.copartitioningLists = ImmutableList.copyOf(copartitioningLists); this.properColumnsCount = properColumnsCount; this.connectorTableFunctionHandle = requireNonNull(connectorTableFunctionHandle, "connectorTableFunctionHandle is null"); @@ -2267,6 +2271,11 @@ public List getTableArgumentAnalyses() return tableArgumentAnalyses; } + public Map> getRequiredColumns() + { + return requiredColumns; + } + public List> getCopartitioningLists() { return copartitioningLists; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index d76da4c290bb..481fb58f064c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -292,6 +292,7 @@ import static io.trino.spi.StandardErrorCode.DUPLICATE_WINDOW_NAME; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_CONSTANT; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_IN_DISTINCT; +import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_WINDOW; import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; @@ -1311,6 +1312,9 @@ protected Scope visitCreateMaterializedView(CreateMaterializedView node, Optiona if (node.isReplace() && node.isNotExists()) { throw semanticException(NOT_SUPPORTED, node, "'CREATE OR REPLACE' and 'IF NOT EXISTS' clauses can not be used together"); } + if (node.getGracePeriod().isPresent()) { + throw new TrinoException(NOT_SUPPORTED, "GRACE PERIOD is not supported yet"); + } // analyze the query that creates the view StatementAnalyzer analyzer = statementAnalyzerFactory.createStatementAnalyzer(analysis, session, warningCollector, CorrelationSupport.ALLOWED); @@ -1570,6 +1574,35 @@ else if (returnTypeSpecification == GENERIC_TABLE) { properColumnsDescriptor = ((DescribedTable) returnTypeSpecification).getDescriptor(); } + // validate the required input columns + Map> requiredColumns = functionAnalysis.getRequiredColumns(); + Map tableArgumentsByName = argumentsAnalysis.getTableArgumentAnalyses().stream() + .collect(toImmutableMap(TableArgumentAnalysis::getArgumentName, Function.identity())); + Set allInputs = ImmutableSet.copyOf(tableArgumentsByName.keySet()); + requiredColumns.forEach((name, columns) -> { + if (!allInputs.contains(name)) { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, format("Table function %s specifies required columns from table argument %s which cannot be found", node.getName(), name)); + } + if (columns.isEmpty()) { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, format("Table function %s specifies empty list of required columns from table argument %s", node.getName(), name)); + } + // the scope is recorded, because table arguments are already analyzed + Scope inputScope = analysis.getScope(tableArgumentsByName.get(name).getRelation()); + columns.stream() + .filter(column -> column < 0 || column >= inputScope.getRelationType().getAllFieldCount()) // hidden columns can be required as well as visible columns + .findFirst() + .ifPresent(column -> { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, format("Invalid index: %s of required column from table argument %s", column, name)); + }); + }); + Set requiredInputs = ImmutableSet.copyOf(requiredColumns.keySet()); + allInputs.stream() + .filter(input -> !requiredInputs.contains(input)) + .findFirst() + .ifPresent(input -> { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, format("Table function %s does not specify required input columns from table argument %s", node.getName(), input)); + }); + // The result relation type of a table function consists of: // 1. columns created by the table function, called the proper columns. // 2. passed columns from input tables: @@ -1590,8 +1623,6 @@ else if (returnTypeSpecification == GENERIC_TABLE) { .filter(argumentSpecification -> argumentSpecification instanceof TableArgumentSpecification) .map(ArgumentSpecification::getName) .collect(toImmutableList()); - Map tableArgumentsByName = argumentsAnalysis.getTableArgumentAnalyses().stream() - .collect(toImmutableMap(TableArgumentAnalysis::getArgumentName, Function.identity())); // table arguments in order of argument declarations ImmutableList.Builder orderedTableArguments = ImmutableList.builder(); @@ -1616,6 +1647,7 @@ else if (argument.getPartitionBy().isPresent()) { function.getName(), argumentsAnalysis.getPassedArguments(), orderedTableArguments.build(), + functionAnalysis.getRequiredColumns(), copartitioningLists, properColumnsDescriptor == null ? 0 : properColumnsDescriptor.getFields().size(), functionAnalysis.getHandle(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index f66a8b3f96cf..2b0e420e27d8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import io.trino.Session; @@ -347,23 +346,14 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node outputSymbols.addAll(properOutputs); - // process sources in order of argument declarations + // process sources in order of argument declarations for (TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { RelationPlan sourcePlan = process(tableArgument.getRelation(), context); PlanBuilder sourcePlanBuilder = newPlanBuilder(sourcePlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); - // map column names to symbols - // note: hidden columns are included in the mapping. They are present both in sourceDescriptor.allFields, and in sourcePlan.fieldMappings - // note: for an aliased relation or a CTE, the field names in the relation type are in the same case as specified in the alias. - // quotes and canonicalization rules are not applied. - ImmutableMultimap.Builder columnMapping = ImmutableMultimap.builder(); - RelationType sourceDescriptor = sourcePlan.getDescriptor(); - for (int i = 0; i < sourceDescriptor.getAllFieldCount(); i++) { - Optional name = sourceDescriptor.getFieldByIndex(i).getName(); - if (name.isPresent()) { - columnMapping.put(name.get(), sourcePlan.getSymbol(i)); - } - } + List requiredColumns = functionAnalysis.getRequiredColumns().get(tableArgument.getArgumentName()).stream() + .map(sourcePlan::getSymbol) + .collect(toImmutableList()); Optional specification = Optional.empty(); @@ -394,10 +384,10 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node sources.add(sourcePlanBuilder.getRoot()); sourceProperties.add(new TableArgumentProperties( tableArgument.getArgumentName(), - columnMapping.build(), tableArgument.isRowSemantics(), tableArgument.isPruneWhenEmpty(), tableArgument.isPassThroughColumns(), + requiredColumns, specification)); // add output symbols passed from the table argument diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index ac302ca75719..d08a671b2f46 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; @@ -338,16 +337,13 @@ public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext SymbolMapper inputMapper = symbolMapper(new HashMap<>(newSource.getMappings())); TableArgumentProperties properties = node.getTableArgumentProperties().get(i); - ImmutableMultimap.Builder newColumnMapping = ImmutableMultimap.builder(); - properties.getColumnMapping().entries().stream() - .forEach(entry -> newColumnMapping.put(entry.getKey(), inputMapper.map(entry.getValue()))); Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); newTableArgumentProperties.add(new TableArgumentProperties( properties.getArgumentName(), - newColumnMapping.build(), properties.isRowSemantics(), properties.isPruneWhenEmpty(), properties.isPassThroughColumns(), + inputMapper.map(properties.getRequiredColumns()), newSpecification)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java index 924d88960693..3386a2280334 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -17,8 +17,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.Multimap; import io.trino.metadata.TableFunctionHandle; import io.trino.spi.ptf.Argument; import io.trino.sql.planner.Symbol; @@ -149,26 +147,26 @@ public PlanNode replaceChildren(List newSources) public static class TableArgumentProperties { private final String argumentName; - private final Multimap columnMapping; private final boolean rowSemantics; private final boolean pruneWhenEmpty; private final boolean passThroughColumns; + private final List requiredColumns; private final Optional specification; @JsonCreator public TableArgumentProperties( @JsonProperty("argumentName") String argumentName, - @JsonProperty("columnMapping") Multimap columnMapping, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { this.argumentName = requireNonNull(argumentName, "argumentName is null"); - this.columnMapping = ImmutableMultimap.copyOf(columnMapping); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; + this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } @@ -178,12 +176,6 @@ public String getArgumentName() return argumentName; } - @JsonProperty - public Multimap getColumnMapping() - { - return columnMapping; - } - @JsonProperty public boolean isRowSemantics() { @@ -202,6 +194,12 @@ public boolean isPassThroughColumns() return passThroughColumns; } + @JsonProperty + public List getRequiredColumns() + { + return requiredColumns; + } + @JsonProperty public Optional getSpecification() { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index e03e392f882a..5623de419042 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import io.airlift.json.JsonCodec; +import io.airlift.stats.TDigest; import io.airlift.units.Duration; import io.trino.Session; import io.trino.client.NodeVersion; @@ -153,6 +154,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.json.JsonCodec.mapJsonCodec; +import static io.airlift.units.DataSize.succinctBytes; import static io.airlift.units.Duration.succinctNanos; import static io.trino.execution.StageInfo.getAllStages; import static io.trino.metadata.ResolvedFunction.extractFunctionName; @@ -530,6 +532,23 @@ private static String formatFragment( formatDouble(outputBufferUtilization.get().getP99() * 100), formatDouble(outputBufferUtilization.get().getMax() * 100))); } + + TDigest taskOutputDistribution = new TDigest(); + stageInfo.get().getTasks().forEach(task -> taskOutputDistribution.add(task.getStats().getOutputDataSize().toBytes())); + TDigest taskInputDistribution = new TDigest(); + stageInfo.get().getTasks().forEach(task -> taskInputDistribution.add(task.getStats().getProcessedInputDataSize().toBytes())); + + if (verbose) { + builder.append(indentString(1)) + .append(format("Task output distribution: %s\n", formatSizeDistribution(taskOutputDistribution))); + builder.append(indentString(1)) + .append(format("Task input distribution: %s\n", formatSizeDistribution(taskInputDistribution))); + } + + if (taskInputDistribution.valueAt(0.99) > taskInputDistribution.valueAt(0.49) * 2) { + builder.append(indentString(1)) + .append("Amount of input data processed by the workers for this stage might be skewed\n"); + } } PartitioningScheme partitioningScheme = fragment.getPartitioningScheme(); @@ -581,6 +600,22 @@ private static String formatFragment( return builder.toString(); } + private static String formatSizeDistribution(TDigest digest) + { + return format("{count=%s, p01=%s, p05=%s, p10=%s, p25=%s, p50=%s, p75=%s, p90=%s, p95=%s, p99=%s, max=%s}", + formatDouble(digest.getCount()), + succinctBytes((long) digest.valueAt(0.01)), + succinctBytes((long) digest.valueAt(0.05)), + succinctBytes((long) digest.valueAt(0.10)), + succinctBytes((long) digest.valueAt(0.25)), + succinctBytes((long) digest.valueAt(0.50)), + succinctBytes((long) digest.valueAt(0.75)), + succinctBytes((long) digest.valueAt(0.90)), + succinctBytes((long) digest.valueAt(0.95)), + succinctBytes((long) digest.valueAt(0.99)), + succinctBytes((long) digest.getMax())); + } + private static TypeProvider getTypeProvider(List fragments) { return TypeProvider.copyOf(fragments.stream() @@ -1821,6 +1856,9 @@ private String formatArgument(String argumentName, Argument argument, Map boundSymbols) checkDependencies( inputs, - argumentProperties.getColumnMapping().values(), - "Invalid node. Input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getRequiredColumns(), + "Invalid node. Required input symbols from source %s (%s) not in source plan output (%s)", argumentProperties.getArgumentName(), - argumentProperties.getColumnMapping().values(), + argumentProperties.getRequiredColumns(), source.getOutputSymbols()); argumentProperties.getSpecification().ifPresent(specification -> { checkDependencies( diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java index 4799a848fd74..f999b8fddde0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowQueriesRewrite.java @@ -599,8 +599,15 @@ protected Node visitShowCreate(ShowCreate node, Void context) Collection> allMaterializedViewProperties = materializedViewPropertyManager.getAllProperties(catalogHandle); List propertyNodes = buildProperties(objectName, Optional.empty(), INVALID_MATERIALIZED_VIEW_PROPERTY, properties, allMaterializedViewProperties); - String sql = formatSql(new CreateMaterializedView(Optional.empty(), QualifiedName.of(ImmutableList.of(catalogName, schemaName, tableName)), - query, false, false, propertyNodes, viewDefinition.get().getComment())).trim(); + String sql = formatSql(new CreateMaterializedView( + Optional.empty(), + QualifiedName.of(ImmutableList.of(catalogName, schemaName, tableName)), + query, + false, + false, + Optional.empty(), // TODO support GRACE PERIOD + propertyNodes, + viewDefinition.get().getComment())).trim(); return singleValueQuery("Create Materialized View", sql); } diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java b/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java index 4a1d42115cec..587c6203bf2b 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.MaterializedViewFreshness; import io.trino.spi.connector.MaterializedViewNotFoundException; @@ -331,12 +330,6 @@ public void grantTablePrivileges(ConnectorSession session, SchemaTableName table @Override public void revokeTablePrivileges(ConnectorSession session, SchemaTableName tableName, Set privileges, TrinoPrincipal grantee, boolean grantOption) {} - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - public void clear() { views.clear(); diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingSession.java b/core/trino-main/src/main/java/io/trino/testing/TestingSession.java index b02d27bd5907..9d4f2637a140 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingSession.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingSession.java @@ -15,11 +15,15 @@ import io.trino.Session; import io.trino.Session.SessionBuilder; +import io.trino.client.ClientCapabilities; import io.trino.execution.QueryIdGenerator; import io.trino.metadata.SessionPropertyManager; import io.trino.spi.security.Identity; import io.trino.spi.type.TimeZoneKey; +import java.util.Arrays; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Locale.ENGLISH; public final class TestingSession @@ -54,6 +58,8 @@ public static SessionBuilder testSessionBuilder(SessionPropertyManager sessionPr .setSchema("schema") .setTimeZoneKey(DEFAULT_TIME_ZONE_KEY) .setLocale(ENGLISH) + .setClientCapabilities(Arrays.stream(ClientCapabilities.values()).map(Enum::name) + .collect(toImmutableSet())) .setRemoteUserAddress("address") .setUserAgent("agent"); } diff --git a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java index b5ba197c029c..4febbe36409a 100644 --- a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java +++ b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java @@ -54,6 +54,7 @@ public class TestingTableFunctions .build(); private static final TableFunctionAnalysis NO_DESCRIPTOR_ANALYSIS = TableFunctionAnalysis.builder() .handle(HANDLE) + .requiredColumns("INPUT", ImmutableList.of(0)) .build(); /** @@ -164,7 +165,11 @@ public TableArgumentFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); } } @@ -187,7 +192,11 @@ public TableArgumentRowSemanticsFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); } } @@ -235,7 +244,12 @@ public TwoTableArgumentsFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT1", ImmutableList.of(0)) + .requiredColumns("INPUT2", ImmutableList.of(0)) + .build(); } } @@ -278,7 +292,9 @@ public MonomorphicStaticReturnTypeFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return NO_DESCRIPTOR_ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .build(); } } @@ -364,7 +380,39 @@ public DifferentArgumentTypesFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .requiredColumns("INPUT_3", ImmutableList.of(0)) + .build(); + } + } + + public static class RequiredColumnsFunction + extends AbstractConnectorTableFunction + { + public RequiredColumnsFunction() + { + super( + SCHEMA_NAME, + "required_columns_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0, 1)) + .build(); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestCreateMaterializedViewTask.java b/core/trino-main/src/test/java/io/trino/execution/TestCreateMaterializedViewTask.java index ac4088e9ba81..d82132a5b4e5 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestCreateMaterializedViewTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestCreateMaterializedViewTask.java @@ -176,6 +176,7 @@ public void testCreateMaterializedViewIfNotExists() simpleQuery(selectList(new AllColumns()), table(QualifiedName.of(TEST_CATALOG_NAME, "schema", "mock_table"))), false, true, + Optional.empty(), ImmutableList.of(), Optional.empty()); @@ -193,6 +194,7 @@ public void testCreateMaterializedViewWithExistingView() simpleQuery(selectList(new AllColumns()), table(QualifiedName.of(TEST_CATALOG_NAME, "schema", "mock_table"))), false, false, + Optional.empty(), ImmutableList.of(), Optional.empty()); @@ -213,6 +215,7 @@ public void testCreateMaterializedViewWithInvalidProperty() simpleQuery(selectList(new AllColumns()), table(QualifiedName.of(TEST_CATALOG_NAME, "schema", "mock_table"))), false, true, + Optional.empty(), ImmutableList.of(new Property(new Identifier("baz"), new StringLiteral("abc"))), Optional.empty()); @@ -234,6 +237,7 @@ public void testCreateMaterializedViewWithDefaultProperties() simpleQuery(selectList(new AllColumns()), table(QualifiedName.of(TEST_CATALOG_NAME, "schema", "mock_table"))), false, true, + Optional.empty(), ImmutableList.of( new Property(new Identifier("foo")), // set foo to DEFAULT new Property(new Identifier("bar"))), // set bar to DEFAULT @@ -258,6 +262,7 @@ public void testCreateDenyPermission() simpleQuery(selectList(new AllColumns()), table(QualifiedName.of(TEST_CATALOG_NAME, "schema", "mock_table"))), false, true, + Optional.empty(), ImmutableList.of(), Optional.empty()); TestingAccessControlManager accessControl = new TestingAccessControlManager(transactionManager, emptyEventListenerManager()); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java b/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java index 8d22ffd8d491..1a13e6add44c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java @@ -36,7 +36,7 @@ public class TestDropTableTask @Test public void testDropExistingTable() { - QualifiedObjectName tableName = qualifiedObjectName("not_existing_table"); + QualifiedObjectName tableName = qualifiedObjectName("existing_table"); metadata.createTable(testSession, TEST_CATALOG_NAME, someTable(tableName), false); assertThat(metadata.getTableHandle(testSession, tableName)).isPresent(); diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java index 36cc9ae989b5..540528144f3b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDirectExchangeClient.java @@ -70,6 +70,7 @@ import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertLessThan; +import static io.trino.execution.TestSqlTaskExecution.TASK_ID; import static io.trino.execution.buffer.PagesSerdeUtil.getSerializedPagePositionCount; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; @@ -946,6 +947,128 @@ public void testStreamingClose() assertEquals(clientStatus.getHttpRequestState(), "not scheduled", "httpRequestState"); } + @Test + public void testScheduleWhenOneClientFilledBuffer() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient clientToBeUsed = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient clientToBeSkipped = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + clientToBeUsed.requestSucceeded(DataSize.of(33, Unit.MEGABYTE).toBytes()); + clientToBeSkipped.requestSucceeded(DataSize.of(1, Unit.MEGABYTE).toBytes()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, clientToBeUsed, locationTwo, clientToBeSkipped)); + exchangeClient.getQueuedClients().addAll(ImmutableList.of(clientToBeUsed, clientToBeSkipped)); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + // The first client filled the buffer. There is no place for the another one + assertEquals(clientCount, 1); + } + + @Test + public void testScheduleWhenAllClientsAreEmpty() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient firstClient = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient secondClient = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, firstClient, locationTwo, secondClient)); + exchangeClient.getQueuedClients().addAll(ImmutableList.of(firstClient, secondClient)); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + assertEquals(clientCount, 2); + } + + @Test + public void testScheduleWhenThereIsPendingClient() + { + DataSize maxResponseSize = DataSize.of(8, Unit.MEGABYTE); + + URI locationOne = URI.create("http://localhost:8080"); + URI locationTwo = URI.create("http://localhost:8081"); + + MockExchangeRequestProcessor processor = new MockExchangeRequestProcessor(maxResponseSize); + + HttpPageBufferClient pendingClient = createHttpPageBufferClient(processor, maxResponseSize, locationOne, new MockClientCallback()); + HttpPageBufferClient clientToBeSkipped = createHttpPageBufferClient(processor, maxResponseSize, locationTwo, new MockClientCallback()); + + pendingClient.requestSucceeded(DataSize.of(33, Unit.MEGABYTE).toBytes()); + + @SuppressWarnings("resource") + DirectExchangeClient exchangeClient = new DirectExchangeClient( + "localhost", + DataIntegrityVerification.ABORT, + new StreamingDirectExchangeBuffer(scheduler, DataSize.of(32, Unit.MEGABYTE)), + maxResponseSize, + 1, + new Duration(1, TimeUnit.MINUTES), + true, + new TestingHttpClient(processor, scheduler), + scheduler, + new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), + pageBufferClientCallbackExecutor, + (taskId, failure) -> {}); + exchangeClient.getAllClients().putAll(Map.of(locationOne, pendingClient, locationTwo, clientToBeSkipped)); + exchangeClient.getQueuedClients().add(clientToBeSkipped); + + int clientCount = exchangeClient.scheduleRequestIfNecessary(); + // The first client is pending and it reserved the space in the buffer. There is no place for the another one + assertEquals(clientCount, 0); + } + + private HttpPageBufferClient createHttpPageBufferClient(TestingHttpClient.Processor processor, DataSize expectedMaxSize, URI location, HttpPageBufferClient.ClientCallback callback) + { + return new HttpPageBufferClient( + "localhost", + new TestingHttpClient(processor, scheduler), + DataIntegrityVerification.ABORT, + expectedMaxSize, + new Duration(1, TimeUnit.MINUTES), + true, + TASK_ID, + location, + callback, + scheduler, + pageBufferClientCallbackExecutor); + } + private static Page createPage(int size) { return new Page(BlockAssertions.createLongSequenceBlock(0, size)); @@ -985,4 +1108,29 @@ private static void assertStatus( assertEquals(clientStatus.getRequestsCompleted(), requestsCompleted, "requestsCompleted"); assertEquals(clientStatus.getHttpRequestState(), httpRequestState, "httpRequestState"); } + + private static class MockClientCallback + implements HttpPageBufferClient.ClientCallback + { + @Override + public boolean addPages(HttpPageBufferClient client, List pages) + { + return false; + } + + @Override + public void requestComplete(HttpPageBufferClient client) + { + } + + @Override + public void clientFinished(HttpPageBufferClient client) + { + } + + @Override + public void clientFailed(HttpPageBufferClient client, Throwable cause) + { + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java b/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java index 77c6293e47c2..47ab79c4a4c0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHttpPageBufferClient.java @@ -438,6 +438,33 @@ public void testErrorCodes() assertEquals(new PageTransportTimeoutException(HostAddress.fromParts("127.0.0.1", 8080), "", null).getErrorCode(), PAGE_TRANSPORT_TIMEOUT.toErrorCode()); } + @Test + public void testAverageSizeOfRequest() + { + HttpPageBufferClient client = new HttpPageBufferClient( + "localhost", + new TestingHttpClient(new MockExchangeRequestProcessor(DataSize.of(10, MEGABYTE)), scheduler), + DataIntegrityVerification.ABORT, + DataSize.of(10, MEGABYTE), + new Duration(30, TimeUnit.SECONDS), + true, + TASK_ID, + URI.create("http://localhost:8080"), + new TestingClientCallback(new CyclicBarrier(1)), + scheduler, + new TestingTicker(), + pageBufferClientCallbackExecutor); + + assertEquals(client.getAverageRequestSizeInBytes(), 0); + + client.requestSucceeded(0); + assertEquals(client.getAverageRequestSizeInBytes(), 0); + + client.requestSucceeded(1000); + client.requestSucceeded(800); + assertEquals(client.getAverageRequestSizeInBytes(), 600); + } + @Test public void testMemoryExceededInAddPages() throws Exception diff --git a/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java b/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java index cf56a8fbeaa6..d164a6c64245 100644 --- a/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java +++ b/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java @@ -35,6 +35,8 @@ import io.trino.server.security.ResourceSecurity; import io.trino.server.security.oauth2.ChallengeFailedException; import io.trino.server.security.oauth2.OAuth2Client; +import io.trino.server.security.oauth2.TokenPairSerializer; +import io.trino.server.security.oauth2.TokenPairSerializer.TokenPair; import io.trino.server.testing.TestingTrinoServer; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.BasicPrincipal; @@ -104,12 +106,14 @@ import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOGIN; import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOGOUT; import static io.trino.testing.assertions.Assert.assertEquals; +import static io.trino.testing.assertions.Assert.assertEventually; import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.Objects.requireNonNull; import static java.util.function.Predicate.not; import static javax.servlet.http.HttpServletResponse.SC_NOT_FOUND; import static javax.servlet.http.HttpServletResponse.SC_OK; import static javax.servlet.http.HttpServletResponse.SC_SEE_OTHER; +import static javax.servlet.http.HttpServletResponse.SC_TEMPORARY_REDIRECT; import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; import static javax.ws.rs.core.Response.Status.UNAUTHORIZED; import static org.assertj.core.api.Assertions.assertThat; @@ -148,6 +152,8 @@ public class TestWebUi private static final String TEST_PASSWORD2 = "test-password2"; private static final String HMAC_KEY = Resources.getResource("hmac_key.txt").getPath(); private static final PrivateKey JWK_PRIVATE_KEY; + private static final String REFRESH_TOKEN = "REFRESH_TOKEN"; + private static final Duration REFRESH_TOKEN_TIMEOUT = Duration.ofMinutes(1); static { try { @@ -652,8 +658,7 @@ public void testOAuth2Authenticator() .setBinding() .toInstance(oauthClient)) .build()) { - HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); } finally { jwkServer.stop(); @@ -664,7 +669,7 @@ public void testOAuth2Authenticator() public void testOAuth2AuthenticatorWithoutOpenIdScope() throws Exception { - OAuth2ClientStub oauthClient = new OAuth2ClientStub(false); + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); TestingHttpServer jwkServer = createTestingJwkServer(); jwkServer.start(); try (TestingTrinoServer server = TestingTrinoServer.builder() @@ -677,8 +682,116 @@ public void testOAuth2AuthenticatorWithoutOpenIdScope() .setBinding() .toInstance(oauthClient)) .build()) { + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorWithRefreshToken() + throws Exception + { + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", REFRESH_TOKEN_TIMEOUT.getSeconds() + "s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + assertAuth2Authentication(server, oauthClient.getAccessToken(), true); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorRedirectAfterAuthTokenRefresh() + throws Exception + { + // the first issued authorization token will be expired + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ZERO); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", REFRESH_TOKEN_TIMEOUT.getSeconds() + "s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + CookieManager cookieManager = new CookieManager(); + OkHttpClient client = this.client.newBuilder() + .cookieJar(new JavaNetCookieJar(cookieManager)) + .build(); + + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); + URI baseUri = httpServerInfo.getHttpsUri(); + + loginWithCallbackEndpoint(client, baseUri); + HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); + assertCookieWithRefreshToken(server, cookie, oauthClient.getAccessToken()); + + assertResponseCode(client, getValidApiLocation(baseUri), SC_TEMPORARY_REDIRECT); + assertOk(client, getValidApiLocation(baseUri)); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorRefreshTokenExpiration() + throws Exception + { + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", "10s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + CookieManager cookieManager = new CookieManager(); + OkHttpClient client = this.client.newBuilder() + .cookieJar(new JavaNetCookieJar(cookieManager)) + .build(); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + URI baseUri = httpServerInfo.getHttpsUri(); + + loginWithCallbackEndpoint(client, baseUri); + HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); + assertOk(client, getValidApiLocation(baseUri)); + + // wait for the cookie to expire + assertEventually(() -> assertThat(cookieManager.getCookieStore().getCookies()).isEmpty()); + assertResponseCode(client, getValidApiLocation(baseUri), UNAUTHORIZED.getStatusCode()); + + // create fake cookie with previous cookie value to check token validity + HttpCookie biscuit = new HttpCookie(cookie.getName(), cookie.getValue()); + biscuit.setPath(cookie.getPath()); + cookieManager.getCookieStore().add(baseUri, biscuit); + assertResponseCode(client, getValidApiLocation(baseUri), UNAUTHORIZED.getStatusCode()); } finally { jwkServer.stop(); @@ -694,6 +807,7 @@ public void testCustomPrincipalField() .put(SUBJECT, "unknown") .put("preferred_username", "test-user@email.com") .buildOrThrow(), + Duration.ofSeconds(5), true); TestingHttpServer jwkServer = createTestingJwkServer(); jwkServer.start(); @@ -711,8 +825,7 @@ public void testCustomPrincipalField() jaxrsBinder(binder).bind(AuthenticatedIdentityCapturingFilter.class); }) .build()) { - HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); Identity identity = server.getInstance(Key.get(AuthenticatedIdentityCapturingFilter.class)).getAuthenticatedIdentity(); assertThat(identity.getUser()).isEqualTo("test-user"); assertThat(identity.getPrincipal()).isEqualTo(Optional.of(new BasicPrincipal("test-user@email.com"))); @@ -722,20 +835,15 @@ public void testCustomPrincipalField() } } - private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String accessToken) + private void assertAuth2Authentication(TestingTrinoServer server, String accessToken, boolean refreshTokensEnabled) throws Exception { - String state = newJwtBuilder() - .signWith(hmacShaKeyFor(Hashing.sha256().hashString(STATE_KEY, UTF_8).asBytes())) - .setAudience("trino_oauth_ui") - .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(10).toInstant())) - .compact(); - CookieManager cookieManager = new CookieManager(); OkHttpClient client = this.client.newBuilder() .cookieJar(new JavaNetCookieJar(cookieManager)) .build(); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); // HTTP is not allowed for OAuth testDisabled(httpServerInfo.getHttpUri()); @@ -747,21 +855,17 @@ private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String acc assertRedirect(client, getLocation(baseUri, "/ui/unknown"), "http://example.com/authorize", false); assertResponseCode(client, getLocation(baseUri, "/ui/api/unknown"), UNAUTHORIZED.getStatusCode()); - // login with the callback endpoint - assertRedirect( - client, - uriBuilderFrom(baseUri) - .replacePath(CALLBACK_ENDPOINT) - .addParameter("code", "TEST_CODE") - .addParameter("state", state) - .toString(), - getUiLocation(baseUri), - false); + loginWithCallbackEndpoint(client, baseUri); HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); - assertEquals(cookie.getValue(), accessToken); + if (refreshTokensEnabled) { + assertCookieWithRefreshToken(server, cookie, accessToken); + } + else { + assertEquals(cookie.getValue(), accessToken); + assertThat(cookie.getMaxAge()).isGreaterThan(0).isLessThan(30); + } assertEquals(cookie.getPath(), "/ui/"); assertEquals(cookie.getDomain(), baseUri.getHost()); - assertTrue(cookie.getMaxAge() > 0 && cookie.getMaxAge() < MINUTES.toSeconds(5)); assertTrue(cookie.isHttpOnly()); // authentication cookie is now set, so UI should work @@ -778,6 +882,34 @@ private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String acc assertRedirect(client, getUiLocation(baseUri), "http://example.com/authorize", false); } + private static void loginWithCallbackEndpoint(OkHttpClient client, URI baseUri) + throws IOException + { + String state = newJwtBuilder() + .signWith(hmacShaKeyFor(Hashing.sha256().hashString(STATE_KEY, UTF_8).asBytes())) + .setAudience("trino_oauth_ui") + .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(10).toInstant())) + .compact(); + assertRedirect( + client, + uriBuilderFrom(baseUri) + .replacePath(CALLBACK_ENDPOINT) + .addParameter("code", "TEST_CODE") + .addParameter("state", state) + .toString(), + getUiLocation(baseUri), + false); + } + + private static void assertCookieWithRefreshToken(TestingTrinoServer server, HttpCookie authCookie, String accessToken) + { + TokenPairSerializer tokenPairSerializer = server.getInstance(Key.get(TokenPairSerializer.class)); + TokenPair deserialize = tokenPairSerializer.deserialize(authCookie.getValue()); + assertEquals(deserialize.getAccessToken(), accessToken); + assertEquals(deserialize.getRefreshToken(), Optional.of(REFRESH_TOKEN)); + assertThat(authCookie.getMaxAge()).isGreaterThan(0).isLessThan(REFRESH_TOKEN_TIMEOUT.getSeconds()); + } + private static void testAlwaysAuthorized(URI baseUri, OkHttpClient authorizedClient, String nodeId) throws IOException { @@ -1078,23 +1210,25 @@ private static class OAuth2ClientStub private static final SecureRandom secureRandom = new SecureRandom(); private final Claims claims; private final String accessToken; + private final Duration accessTokenValidity; private final Optional nonce; private final Optional idToken; public OAuth2ClientStub() { - this(true); + this(true, Duration.ofSeconds(5)); } - public OAuth2ClientStub(boolean issueIdToken) + public OAuth2ClientStub(boolean issueIdToken, Duration accessTokenValidity) { - this(ImmutableMap.of(), issueIdToken); + this(ImmutableMap.of(), accessTokenValidity, issueIdToken); } - public OAuth2ClientStub(Map additionalClaims, boolean issueIdToken) + public OAuth2ClientStub(Map additionalClaims, Duration accessTokenValidity, boolean issueIdToken) { claims = new DefaultClaims(createClaims()); - claims.putAll(additionalClaims); + claims.putAll(requireNonNull(additionalClaims, "additionalClaims is null")); + this.accessTokenValidity = requireNonNull(accessTokenValidity, "accessTokenValidity is null"); accessToken = issueToken(claims); if (issueIdToken) { nonce = Optional.of(randomNonce()); @@ -1127,7 +1261,7 @@ public Response getOAuth2Response(String code, URI callbackUri, Optional if (!"TEST_CODE".equals(code)) { throw new IllegalArgumentException("Expected TEST_CODE"); } - return new Response(accessToken, Instant.now().plusSeconds(5), idToken, Optional.empty()); + return new Response(accessToken, Instant.now().plus(accessTokenValidity), idToken, Optional.of(REFRESH_TOKEN)); } @Override @@ -1140,7 +1274,10 @@ public Optional> getClaims(String accessToken) public Response refreshTokens(String refreshToken) throws ChallengeFailedException { - throw new UnsupportedOperationException("Refresh tokens are not supported"); + if (refreshToken.equals(REFRESH_TOKEN)) { + return new Response(issueToken(claims), Instant.now().plusSeconds(30), idToken, Optional.of(REFRESH_TOKEN)); + } + throw new ChallengeFailedException("invalid refresh token"); } public String getAccessToken() diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java index 40e0b1725512..86f95fc1a775 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java @@ -30,6 +30,7 @@ import io.trino.connector.TestingTableFunctions.OnlyPassThroughFunction; import io.trino.connector.TestingTableFunctions.PassThroughFunction; import io.trino.connector.TestingTableFunctions.PolymorphicStaticReturnTypeFunction; +import io.trino.connector.TestingTableFunctions.RequiredColumnsFunction; import io.trino.connector.TestingTableFunctions.TableArgumentFunction; import io.trino.connector.TestingTableFunctions.TableArgumentRowSemanticsFunction; import io.trino.connector.TestingTableFunctions.TwoScalarArgumentsFunction; @@ -120,6 +121,7 @@ import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_CONSTANT; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_IN_DISTINCT; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_SCALAR; +import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_AGGREGATE; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; @@ -6622,6 +6624,34 @@ public void testTableFunctionAliasing() .hasMessage("line 1:23: Column 'table_alias.a' cannot be resolved"); } + @Test + public void testTableFunctionRequiredColumns() + { + // the function required_column_function specifies columns 0 and 1 from table argument "INPUT" as required. + analyze(""" + SELECT * FROM TABLE(system.required_columns_function( + input => TABLE(t1))) + """); + + analyze(""" + SELECT * FROM TABLE(system.required_columns_function( + input => TABLE(SELECT 1, 2, 3))) + """); + + assertFails(""" + SELECT * FROM TABLE(system.required_columns_function( + input => TABLE(SELECT 1))) + """) + .hasErrorCode(FUNCTION_IMPLEMENTATION_ERROR) + .hasMessage("Invalid index: 1 of required column from table argument INPUT"); + + // table s1.t5 has two columns. The second column is hidden. Table function can require a hidden column. + analyze(""" + SELECT * FROM TABLE(system.required_columns_function( + input => TABLE(s1.t5))) + """); + } + @BeforeClass public void setup() { @@ -7011,7 +7041,8 @@ public ConnectorTransactionHandle getConnectorTransaction(TransactionId transact new OnlyPassThroughFunction(), new MonomorphicStaticReturnTypeFunction(), new PolymorphicStaticReturnTypeFunction(), - new PassThroughFunction()))), + new PassThroughFunction(), + new RequiredColumnsFunction()))), new SessionPropertyManager(), tablePropertyManager, analyzePropertyManager, diff --git a/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 b/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 index c2d2006859ed..6e1699354b58 100644 --- a/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 +++ b/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 @@ -81,6 +81,7 @@ statement | ANALYZE qualifiedName (WITH properties)? #analyze | CREATE (OR REPLACE)? MATERIALIZED VIEW (IF NOT EXISTS)? qualifiedName + (GRACE PERIOD interval)? (COMMENT string)? (WITH properties)? AS query #createMaterializedView | CREATE (OR REPLACE)? VIEW qualifiedName @@ -834,7 +835,7 @@ nonReserved | DATA | DATE | DAY | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR | DISTRIBUTED | DOUBLE | EMPTY | ENCODING | ERROR | EXCLUDING | EXPLAIN | FETCH | FILTER | FINAL | FIRST | FOLLOWING | FORMAT | FUNCTIONS - | GRANT | GRANTED | GRANTS | GRAPHVIZ | GROUPS + | GRACE | GRANT | GRANTED | GRANTS | GRAPHVIZ | GROUPS | HOUR | IF | IGNORE | INCLUDING | INITIAL | INPUT | INTERVAL | INVOKER | IO | ISOLATION | JSON @@ -843,7 +844,7 @@ nonReserved | MAP | MATCH | MATCHED | MATCHES | MATCH_RECOGNIZE | MATERIALIZED | MEASURES | MERGE | MINUTE | MONTH | NEXT | NFC | NFD | NFKC | NFKD | NO | NONE | NULLIF | NULLS | OBJECT | OF | OFFSET | OMIT | ONE | ONLY | OPTION | ORDINALITY | OUTPUT | OVER | OVERFLOW - | PARTITION | PARTITIONS | PASSING | PAST | PATH | PATTERN | PER | PERMUTE | POSITION | PRECEDING | PRECISION | PRIVILEGES | PROPERTIES | PRUNE + | PARTITION | PARTITIONS | PASSING | PAST | PATH | PATTERN | PER | PERIOD | PERMUTE | POSITION | PRECEDING | PRECISION | PRIVILEGES | PROPERTIES | PRUNE | QUOTES | RANGE | READ | REFRESH | RENAME | REPEATABLE | REPLACE | RESET | RESPECT | RESTRICT | RETURNING | REVOKE | ROLE | ROLES | ROLLBACK | ROW | ROWS | RUNNING | SCALAR | SCHEMA | SCHEMAS | SECOND | SECURITY | SEEK | SERIALIZABLE | SESSION | SET | SETS @@ -939,6 +940,7 @@ FORMAT: 'FORMAT'; FROM: 'FROM'; FULL: 'FULL'; FUNCTIONS: 'FUNCTIONS'; +GRACE: 'GRACE'; GRANT: 'GRANT'; GRANTED: 'GRANTED'; GRANTS: 'GRANTS'; @@ -1030,6 +1032,7 @@ PAST: 'PAST'; PATH: 'PATH'; PATTERN: 'PATTERN'; PER: 'PER'; +PERIOD: 'PERIOD'; PERMUTE: 'PERMUTE'; POSITION: 'POSITION'; PRECEDING: 'PRECEDING'; diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index 18670f0e8064..e5501172139e 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -1009,6 +1009,8 @@ protected Void visitCreateMaterializedView(CreateMaterializedView node, Integer } builder.append(formatName(node.getName())); + node.getGracePeriod().ifPresent(interval -> + builder.append("\nGRACE PERIOD ").append(formatExpression(interval))); node.getComment().ifPresent(comment -> builder .append("\nCOMMENT ") .append(formatStringLiteral(comment))); diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 0723b8fc4af7..41d8bcff7823 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -461,6 +461,11 @@ public Node visitCreateTable(SqlBaseParser.CreateTableContext context) @Override public Node visitCreateMaterializedView(SqlBaseParser.CreateMaterializedViewContext context) { + Optional gracePeriod = Optional.empty(); + if (context.GRACE() != null) { + gracePeriod = Optional.of((IntervalLiteral) visit(context.interval())); + } + Optional comment = Optional.empty(); if (context.COMMENT() != null) { comment = Optional.of(((StringLiteral) visit(context.string())).getValue()); @@ -477,6 +482,7 @@ public Node visitCreateMaterializedView(SqlBaseParser.CreateMaterializedViewCont (Query) visit(context.query()), context.REPLACE() != null, context.EXISTS() != null, + gracePeriod, properties, comment); } @@ -3745,18 +3751,4 @@ private static QueryPeriod.RangeType getRangeType(Token token) } throw new IllegalArgumentException("Unsupported query period range type: " + token.getText()); } - - private static Trim.Specification toTrimSpecification(String functionName) - { - requireNonNull(functionName, "functionName is null"); - switch (functionName) { - case "trim": - return Trim.Specification.BOTH; - case "ltrim": - return Trim.Specification.LEADING; - case "rtrim": - return Trim.Specification.TRAILING; - } - throw new IllegalArgumentException("Unsupported trim specification: " + functionName); - } } diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/CreateMaterializedView.java b/core/trino-parser/src/main/java/io/trino/sql/tree/CreateMaterializedView.java index e4a2ca034359..f2766eacebf1 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/CreateMaterializedView.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/CreateMaterializedView.java @@ -29,14 +29,17 @@ public class CreateMaterializedView private final Query query; private final boolean replace; private final boolean notExists; + private final Optional gracePeriod; private final List properties; private final Optional comment; - public CreateMaterializedView(Optional location, + public CreateMaterializedView( + Optional location, QualifiedName name, Query query, boolean replace, boolean notExists, + Optional gracePeriod, List properties, Optional comment) { @@ -45,6 +48,7 @@ public CreateMaterializedView(Optional location, this.query = requireNonNull(query, "query is null"); this.replace = replace; this.notExists = notExists; + this.gracePeriod = requireNonNull(gracePeriod, "gracePeriod is null"); this.properties = ImmutableList.copyOf(requireNonNull(properties, "properties is null")); this.comment = requireNonNull(comment, "comment is null"); } @@ -69,6 +73,11 @@ public boolean isNotExists() return notExists; } + public Optional getGracePeriod() + { + return gracePeriod; + } + public List getProperties() { return properties; @@ -88,13 +97,16 @@ public R accept(AstVisitor visitor, C context) @Override public List getChildren() { - return ImmutableList.of(query); + ImmutableList.Builder children = ImmutableList.builder(); + children.add(query); + gracePeriod.ifPresent(children::add); + return children.build(); } @Override public int hashCode() { - return Objects.hash(name, query, replace, notExists, properties, comment); + return Objects.hash(name, query, replace, notExists, gracePeriod, properties, comment); } @Override @@ -111,6 +123,7 @@ public boolean equals(Object obj) && Objects.equals(query, o.query) && Objects.equals(replace, o.replace) && Objects.equals(notExists, o.notExists) + && Objects.equals(gracePeriod, o.gracePeriod) && Objects.equals(properties, o.properties) && Objects.equals(comment, o.comment); } @@ -123,6 +136,7 @@ public String toString() .add("query", query) .add("replace", replace) .add("notExists", notExists) + .add("gracePeriod", gracePeriod) .add("properties", properties) .add("comment", comment) .toString(); diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index b0fae9aefcee..8b4989a93a32 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -3425,6 +3425,7 @@ public void testCreateMaterializedView() Optional.empty()), false, false, + Optional.empty(), ImmutableList.of(), Optional.empty())); @@ -3464,9 +3465,43 @@ public void testCreateMaterializedView() Optional.empty()), true, false, + Optional.empty(), ImmutableList.of(), Optional.of("A simple materialized view"))); + // GRACE PERIOD + assertThat(statement("CREATE MATERIALIZED VIEW a GRACE PERIOD INTERVAL '2' DAY AS SELECT * FROM t")) + .isEqualTo(new CreateMaterializedView( + Optional.of(new NodeLocation(1, 1)), + QualifiedName.of(ImmutableList.of(new Identifier(new NodeLocation(1, 26), "a", false))), + new Query( + new NodeLocation(1, 61), + Optional.empty(), + new QuerySpecification( + new NodeLocation(1, 61), + new Select( + new NodeLocation(1, 61), + false, + ImmutableList.of(new AllColumns(new NodeLocation(1, 68), Optional.empty(), ImmutableList.of()))), + Optional.of(new Table( + new NodeLocation(1, 75), + QualifiedName.of(ImmutableList.of(new Identifier(new NodeLocation(1, 75), "t", false))))), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + Optional.empty(), + Optional.empty(), + Optional.empty()), + false, + false, + Optional.of(new IntervalLiteral(new NodeLocation(1, 41), "2", Sign.POSITIVE, IntervalField.DAY, Optional.empty())), + ImmutableList.of(), + Optional.empty())); + // OR REPLACE, COMMENT, WITH properties assertThat(statement("CREATE OR REPLACE MATERIALIZED VIEW catalog.schema.matview COMMENT 'A simple materialized view'" + "WITH (partitioned_by = ARRAY ['dateint'])" + @@ -3504,6 +3539,7 @@ public void testCreateMaterializedView() Optional.empty()), true, false, + Optional.empty(), ImmutableList.of(new Property( new NodeLocation(1, 102), new Identifier(new NodeLocation(1, 102), "partitioned_by", false), @@ -3590,6 +3626,7 @@ public void testCreateMaterializedView() Optional.empty()), true, false, + Optional.empty(), ImmutableList.of(new Property( new NodeLocation(1, 108), new Identifier(new NodeLocation(1, 108), "partitioned_by", false), diff --git a/core/trino-server-rpm/src/main/resources/dist/config/jvm.config b/core/trino-server-rpm/src/main/resources/dist/config/jvm.config index e609c0abc206..bd4958ddf7c8 100644 --- a/core/trino-server-rpm/src/main/resources/dist/config/jvm.config +++ b/core/trino-server-rpm/src/main/resources/dist/config/jvm.config @@ -14,3 +14,5 @@ # Improve AES performance for S3, etc. on ARM64 (JDK-8271567) -XX:+UnlockDiagnosticVMOptions -XX:+UseAESCTRIntrinsics +# Disable Preventive GC for performance reasons (JDK-8293861) +-XX:-G1UsePreventiveGC diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java index dddd662f86ef..a5a0d5af1946 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java @@ -34,7 +34,7 @@ public interface ConnectorTableFunction /** * This method is called by the Analyzer. Its main purposes are to: * 1. Determine the resulting relation type of the Table Function in case when the declared return type is GENERIC_TABLE. - * 2. Declare the dependencies between the input descriptors and the input tables. + * 2. Declare the required columns from the input tables. * 3. Perform function-specific validation and pre-processing of the input arguments. * As part of function-specific validation, the Table Function's author might want to: * - check if the descriptors which reference input tables contain a correct number of column references diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableFunctionAnalysis.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableFunctionAnalysis.java index a415c54d61f6..7c6709d70b08 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableFunctionAnalysis.java +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableFunctionAnalysis.java @@ -15,10 +15,14 @@ import io.trino.spi.Experimental; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Optional; import static io.trino.spi.ptf.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toMap; /** * An object of this class is produced by the `analyze()` method of a `ConnectorTableFunction` @@ -28,6 +32,9 @@ * Function, that is, the columns produced by the function, as opposed to the columns passed from the * input tables. The `returnedType` should only be set if the declared returned type is GENERIC_TABLE. *

+ * The `requiredColumns` field is used to inform the Analyzer of the columns from the table arguments + * that are necessary to execute the table function. + *

* The `handle` field can be used to carry all information necessary to execute the table function, * gathered at analysis time. Typically, these are the values of the constant arguments, and results * of pre-processing arguments. @@ -36,12 +43,17 @@ public final class TableFunctionAnalysis { private final Optional returnedType; + + // a map from table argument name to list of column indexes for all columns required from the table argument + private final Map> requiredColumns; private final ConnectorTableFunctionHandle handle; - private TableFunctionAnalysis(Optional returnedType, ConnectorTableFunctionHandle handle) + private TableFunctionAnalysis(Optional returnedType, Map> requiredColumns, ConnectorTableFunctionHandle handle) { this.returnedType = requireNonNull(returnedType, "returnedType is null"); returnedType.ifPresent(descriptor -> checkArgument(descriptor.isTyped(), "field types not specified")); + this.requiredColumns = Map.copyOf(requiredColumns.entrySet().stream() + .collect(toMap(Map.Entry::getKey, entry -> List.copyOf(entry.getValue())))); this.handle = requireNonNull(handle, "handle is null"); } @@ -50,6 +62,11 @@ public Optional getReturnedType() return returnedType; } + public Map> getRequiredColumns() + { + return requiredColumns; + } + public ConnectorTableFunctionHandle getHandle() { return handle; @@ -63,6 +80,7 @@ public static Builder builder() public static final class Builder { private Descriptor returnedType; + private final Map> requiredColumns = new HashMap<>(); private ConnectorTableFunctionHandle handle = new ConnectorTableFunctionHandle() {}; private Builder() {} @@ -73,6 +91,12 @@ public Builder returnedType(Descriptor returnedType) return this; } + public Builder requiredColumns(String tableArgument, List columns) + { + this.requiredColumns.put(tableArgument, columns); + return this; + } + public Builder handle(ConnectorTableFunctionHandle handle) { this.handle = handle; @@ -81,7 +105,7 @@ public Builder handle(ConnectorTableFunctionHandle handle) public TableFunctionAnalysis build() { - return new TableFunctionAnalysis(Optional.ofNullable(returnedType), handle); + return new TableFunctionAnalysis(Optional.ofNullable(returnedType), requiredColumns, handle); } } } diff --git a/docs/src/main/sphinx/admin/resource-groups.rst b/docs/src/main/sphinx/admin/resource-groups.rst index d58836749e2b..aa5ee4f14857 100644 --- a/docs/src/main/sphinx/admin/resource-groups.rst +++ b/docs/src/main/sphinx/admin/resource-groups.rst @@ -72,9 +72,11 @@ Property name Description ``resource-groups.config-db-password`` Password for database user to connect with. ``none`` -``resource-groups.max-refresh-interval`` Time period for which the cluster will continue to accept ``1h`` - queries after refresh failures cause configuration to - become stale. +``resource-groups.max-refresh-interval`` The maximum time period for which the cluster will ``1h`` + continue to accept queries after refresh failures, + causing configuration to become stale. + +``resource-groups.refresh-interval`` How often the cluster reloads from the database ``1s`` ``resource-groups.exact-match-selector-enabled`` Setting this flag enables usage of an additional ``false`` ``exact_match_source_selectors`` table to configure diff --git a/docs/src/main/sphinx/connector/delta-lake.rst b/docs/src/main/sphinx/connector/delta-lake.rst index 6fbb946ce60a..76aec22883eb 100644 --- a/docs/src/main/sphinx/connector/delta-lake.rst +++ b/docs/src/main/sphinx/connector/delta-lake.rst @@ -418,6 +418,7 @@ Delta Lake. In addition to the :ref:`globally available statements, the connector supports the following features: * :ref:`sql-data-management`, see also :ref:`delta-lake-write-support` +* :ref:`sql-view-management` * :doc:`/sql/create-schema`, see also :ref:`delta-lake-create-schema` * :doc:`/sql/create-table`, see also :ref:`delta-lake-create-table` * :doc:`/sql/create-table-as` diff --git a/docs/src/main/sphinx/connector/sqlserver.rst b/docs/src/main/sphinx/connector/sqlserver.rst index ca5f51f9fc9e..28e53c28278c 100644 --- a/docs/src/main/sphinx/connector/sqlserver.rst +++ b/docs/src/main/sphinx/connector/sqlserver.rst @@ -357,7 +357,7 @@ For example, select the top 10 percent of nations by population:: TABLE( sqlserver.system.query( query => 'SELECT - TOP(10) PERCENT + TOP(10) PERCENT * FROM tpch.nation ORDER BY diff --git a/docs/src/main/sphinx/develop/table-functions.rst b/docs/src/main/sphinx/develop/table-functions.rst index 1fe1eaf9b13a..817e0aa3aaeb 100644 --- a/docs/src/main/sphinx/develop/table-functions.rst +++ b/docs/src/main/sphinx/develop/table-functions.rst @@ -1,3 +1,4 @@ + =============== Table functions =============== @@ -136,8 +137,8 @@ execute the table function invocation: - The returned row type, specified as an optional ``Descriptor``. It should be passed if and only if the table function is declared with the ``GENERIC_TABLE`` returned type. -- Dependencies between descriptor arguments and table arguments. It defaults to - ``EMPTY_MAPPING``. +- Required columns from the table arguments, specified as a map of table + argument names to lists of column indexes. - Any information gathered during analysis that is useful during planning or execution, in the form of a ``ConnectorTableFunctionHandle``. ``ConnectorTableFunctionHandle`` is a marker interface intended to carry diff --git a/docs/src/main/sphinx/functions/datetime.rst b/docs/src/main/sphinx/functions/datetime.rst index 27eb678a83fc..7f16a2465141 100644 --- a/docs/src/main/sphinx/functions/datetime.rst +++ b/docs/src/main/sphinx/functions/datetime.rst @@ -228,7 +228,16 @@ The above examples use the timestamp ``2001-08-22 03:04:05.321`` as the input. .. function:: date_trunc(unit, x) -> [same as input] - Returns ``x`` truncated to ``unit``. + Returns ``x`` truncated to ``unit``:: + + SELECT date_trunc('day' , TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-10-20 00:00:00.000 + + SELECT date_trunc('month' , TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-10-01 00:00:00.000 + + SELECT date_trunc('year', TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-01-01 00:00:00.000 .. _datetime-interval-functions: @@ -383,11 +392,17 @@ Specifier Description .. function:: date_format(timestamp, format) -> varchar - Formats ``timestamp`` as a string using ``format``. + Formats ``timestamp`` as a string using ``format``:: + + SELECT date_format(TIMESTAMP '2022-10-20 05:10:00', '%m-%d-%Y %H'); + -- 10-20-2022 05 .. function:: date_parse(string, format) -> timestamp(3) - Parses ``string`` into a timestamp using ``format``. + Parses ``string`` into a timestamp using ``format``:: + + SELECT date_parse('2022/10/20/05', '%Y/%m/%d/%H'); + -- 2022-10-20 05:00:00.000 Java date functions ------------------- @@ -437,7 +452,10 @@ field to be extracted. Most fields support all date and time types. .. function:: extract(field FROM x) -> bigint - Returns ``field`` from ``x``. + Returns ``field`` from ``x``:: + + SELECT extract(YEAR FROM TIMESTAMP '2022-10-20 05:10:00'); + -- 2022 .. note:: This SQL-standard function uses special syntax for specifying the arguments. diff --git a/docs/src/main/sphinx/installation/deployment.rst b/docs/src/main/sphinx/installation/deployment.rst index eb3861256395..5cce16570520 100644 --- a/docs/src/main/sphinx/installation/deployment.rst +++ b/docs/src/main/sphinx/installation/deployment.rst @@ -143,6 +143,8 @@ The following provides a good starting point for creating ``etc/jvm.config``: -Djdk.nio.maxCachedBufferSize=2000000 -XX:+UnlockDiagnosticVMOptions -XX:+UseAESCTRIntrinsics + # Disable Preventive GC for performance reasons (JDK-8293861) + -XX:-G1UsePreventiveGC Because an ``OutOfMemoryError`` typically leaves the JVM in an inconsistent state, we write a heap dump, for debugging, and forcibly @@ -156,6 +158,7 @@ temporary directory by adding ``-Djava.io.tmpdir=/path/to/other/tmpdir`` to the list of JVM options. We enable ``-XX:+UnlockDiagnosticVMOptions`` and ``-XX:+UseAESCTRIntrinsics`` to improve AES performance for S3, etc. on ARM64 (`JDK-8271567 `_) +We disable Preventive GC (``-XX:-G1UsePreventiveGC``) for performance reasons (see `JDK-8293861 `_) .. _config_properties: diff --git a/docs/src/main/sphinx/optimizer/pushdown.rst b/docs/src/main/sphinx/optimizer/pushdown.rst index 447281295a3d..4d40b7dca267 100644 --- a/docs/src/main/sphinx/optimizer/pushdown.rst +++ b/docs/src/main/sphinx/optimizer/pushdown.rst @@ -271,3 +271,90 @@ FETCH FIRST N ROWS``. Implementation and support is connector-specific since different data sources support different SQL syntax and processing. + +For example, you can find two queries to learn how to identify Top-N pushdown behavior in the following section. + +First, a concrete example of a Top-N pushdown query on top of a PostgreSQL database:: + + SELECT id, name + FROM postgresql.public.company + ORDER BY id + LIMIT 5; + +You can get the explain plan by prepending the above query with ``EXPLAIN``:: + + EXPLAIN SELECT id, name + FROM postgresql.public.company + ORDER BY id + LIMIT 5; + +.. code-block:: text + + Fragment 0 [SINGLE] + Output layout: [id, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + Output[id, name] + │ Layout: [id:integer, name:varchar] + │ Estimates: {rows: ? (?), cpu: ?, memory: 0B, network: ?} + └─ RemoteSource[1] + Layout: [id:integer, name:varchar] + + Fragment 1 [SOURCE] + Output layout: [id, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + TableScan[postgresql:public.company public.company sortOrder=[id:integer:int4 ASC NULLS LAST] limit=5, grouped = false] + Layout: [id:integer, name:varchar] + Estimates: {rows: ? (?), cpu: ?, memory: 0B, network: 0B} + name := name:varchar:text + id := id:integer:int4 + +Second, an example of a Top-N query on the ``tpch`` connector which does not support +Top-N pushdown functionality:: + + SELECT custkey, name + FROM tpch.sf1.customer + ORDER BY custkey + LIMIT 5; + +The related query plan: + +.. code-block:: text + + Fragment 0 [SINGLE] + Output layout: [custkey, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + Output[custkey, name] + │ Layout: [custkey:bigint, name:varchar(25)] + │ Estimates: {rows: ? (?), cpu: ?, memory: ?, network: ?} + └─ TopN[5 by (custkey ASC NULLS LAST)] + │ Layout: [custkey:bigint, name:varchar(25)] + └─ LocalExchange[SINGLE] () + │ Layout: [custkey:bigint, name:varchar(25)] + │ Estimates: {rows: ? (?), cpu: ?, memory: ?, network: ?} + └─ RemoteSource[1] + Layout: [custkey:bigint, name:varchar(25)] + + Fragment 1 [SOURCE] + Output layout: [custkey, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + TopNPartial[5 by (custkey ASC NULLS LAST)] + │ Layout: [custkey:bigint, name:varchar(25)] + └─ TableScan[tpch:customer:sf1.0, grouped = false] + Layout: [custkey:bigint, name:varchar(25)] + Estimates: {rows: 150000 (4.58MB), cpu: 4.58M, memory: 0B, network: 0B} + custkey := tpch:custkey + name := tpch:name + +In the preceding query plan, the Top-N operation ``TopN[5 by (custkey ASC NULLS LAST)]`` +is being applied in the ``Fragment 0`` by Trino and not by the source database. + +Note that, compared to the query executed on top of the ``tpch`` connector, +the explain plan of the query applied on top of the ``postgresql`` connector +is missing the reference to the operation ``TopN[5 by (id ASC NULLS LAST)]`` +in the ``Fragment 0``. +The absence of the ``TopN`` Trino operator in the ``Fragment 0`` from the query plan +demonstrates that the query benefits of the Top-N pushdown optimization. diff --git a/lib/trino-geospatial-toolkit/src/main/java/io/trino/geospatial/serde/GeometrySerde.java b/lib/trino-geospatial-toolkit/src/main/java/io/trino/geospatial/serde/GeometrySerde.java index 8a1ca430dd3a..329344c301d8 100644 --- a/lib/trino-geospatial-toolkit/src/main/java/io/trino/geospatial/serde/GeometrySerde.java +++ b/lib/trino-geospatial-toolkit/src/main/java/io/trino/geospatial/serde/GeometrySerde.java @@ -82,31 +82,18 @@ private static void writeGeometry(DynamicSliceOutput output, OGCGeometry geometr { GeometryType type = GeometryType.getForEsriGeometryType(geometry.geometryType()); switch (type) { - case POINT: - writePoint(output, geometry); - return; - case MULTI_POINT: - writeSimpleGeometry(output, GeometrySerializationType.MULTI_POINT, geometry); - return; - case LINE_STRING: - writeSimpleGeometry(output, GeometrySerializationType.LINE_STRING, geometry); - return; - case MULTI_LINE_STRING: - writeSimpleGeometry(output, GeometrySerializationType.MULTI_LINE_STRING, geometry); - return; - case POLYGON: - writeSimpleGeometry(output, GeometrySerializationType.POLYGON, geometry); - return; - case MULTI_POLYGON: - writeSimpleGeometry(output, GeometrySerializationType.MULTI_POLYGON, geometry); - return; - case GEOMETRY_COLLECTION: { + case POINT -> writePoint(output, geometry); + case MULTI_POINT -> writeSimpleGeometry(output, GeometrySerializationType.MULTI_POINT, geometry); + case LINE_STRING -> writeSimpleGeometry(output, GeometrySerializationType.LINE_STRING, geometry); + case MULTI_LINE_STRING -> writeSimpleGeometry(output, GeometrySerializationType.MULTI_LINE_STRING, geometry); + case POLYGON -> writeSimpleGeometry(output, GeometrySerializationType.POLYGON, geometry); + case MULTI_POLYGON -> writeSimpleGeometry(output, GeometrySerializationType.MULTI_POLYGON, geometry); + case GEOMETRY_COLLECTION -> { verify(geometry instanceof OGCConcreteGeometryCollection); writeGeometryCollection(output, (OGCConcreteGeometryCollection) geometry); - return; } + default -> throw new IllegalArgumentException("Unexpected type: " + type); } - throw new IllegalArgumentException("Unexpected type: " + type); } private static void writeGeometryCollection(DynamicSliceOutput output, OGCGeometryCollection collection) @@ -175,21 +162,12 @@ public static OGCGeometry deserialize(Slice shape) private static OGCGeometry readGeometry(BasicSliceInput input, Slice inputSlice, GeometrySerializationType type, int length) { - switch (type) { - case POINT: - return readPoint(input); - case MULTI_POINT: - case LINE_STRING: - case MULTI_LINE_STRING: - case POLYGON: - case MULTI_POLYGON: - return readSimpleGeometry(input, inputSlice, type, length); - case GEOMETRY_COLLECTION: - return readGeometryCollection(input, inputSlice); - case ENVELOPE: - return createFromEsriGeometry(readEnvelope(input), false); - } - throw new IllegalArgumentException("Unexpected type: " + type); + return switch (type) { + case POINT -> readPoint(input); + case MULTI_POINT, LINE_STRING, MULTI_LINE_STRING, POLYGON, MULTI_POLYGON -> readSimpleGeometry(input, inputSlice, type, length); + case GEOMETRY_COLLECTION -> readGeometryCollection(input, inputSlice); + case ENVELOPE -> createFromEsriGeometry(readEnvelope(input), false); + }; } private static OGCConcreteGeometryCollection readGeometryCollection(BasicSliceInput input, Slice inputSlice) @@ -286,21 +264,12 @@ public static Envelope deserializeEnvelope(Slice shape) private static Envelope getEnvelope(BasicSliceInput input, GeometrySerializationType type, int length) { - switch (type) { - case POINT: - return getPointEnvelope(input); - case MULTI_POINT: - case LINE_STRING: - case MULTI_LINE_STRING: - case POLYGON: - case MULTI_POLYGON: - return getSimpleGeometryEnvelope(input, length); - case GEOMETRY_COLLECTION: - return getGeometryCollectionOverallEnvelope(input); - case ENVELOPE: - return readEnvelope(input); - } - throw new IllegalArgumentException("Unexpected type: " + type); + return switch (type) { + case POINT -> getPointEnvelope(input); + case MULTI_POINT, LINE_STRING, MULTI_LINE_STRING, POLYGON, MULTI_POLYGON -> getSimpleGeometryEnvelope(input, length); + case GEOMETRY_COLLECTION -> getGeometryCollectionOverallEnvelope(input); + case ENVELOPE -> readEnvelope(input); + }; } private static Envelope getGeometryCollectionOverallEnvelope(BasicSliceInput input) diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java index 15e91744961d..003f8a0c755c 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java @@ -34,7 +34,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; @@ -389,12 +388,6 @@ public Optional> applyFilter(C return Optional.of(new ConstraintApplicationResult<>(handle, constraint.getSummary(), false)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle handle) - { - return new ConnectorTableProperties(); - } - private void checkNoRollback() { checkState(rollbackAction.get() == null, "Cannot begin a new write while in an existing one"); diff --git a/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java b/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java index 9ea2908f7b05..1a91bcc162c1 100644 --- a/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java +++ b/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -149,12 +148,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable throw new ColumnNotFoundException(tableName, columnName); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-atop/src/test/java/io/trino/plugin/atop/TestAtopPlugin.java b/plugin/trino-atop/src/test/java/io/trino/plugin/atop/TestAtopPlugin.java index c8845d034e37..67a3f52011d5 100644 --- a/plugin/trino-atop/src/test/java/io/trino/plugin/atop/TestAtopPlugin.java +++ b/plugin/trino-atop/src/test/java/io/trino/plugin/atop/TestAtopPlugin.java @@ -13,19 +13,30 @@ */ package io.trino.plugin.atop; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.Plugin; +import io.trino.spi.connector.ConnectorFactory; +import io.trino.testing.TestingConnectorContext; import org.testng.annotations.Test; +import java.nio.file.Files; +import java.nio.file.Path; + import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.testing.Assertions.assertInstanceOf; public class TestAtopPlugin { @Test - public void testGetConnectorFactory() + public void testCreateConnector() + throws Exception { - AtopPlugin plugin = new AtopPlugin(); + Plugin plugin = new AtopPlugin(); assertInstanceOf(getOnlyElement(plugin.getConnectorFactories()), AtopConnectorFactory.class); - } - // TODO test factory + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + + Path atopExecutable = Files.createTempFile(null, null); + factory.create("test", ImmutableMap.of("atop.executable-path", atopExecutable.toString()), new TestingConnectorContext()).shutdown(); + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 677edecb4cea..7f7cd97ccc3e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -33,7 +33,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorTableSchema; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; @@ -625,12 +624,6 @@ public Optional applyTableScanRedirect(Conne return jdbcClient.getTableScanRedirection(session, tableHandle); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorTableHandle table) { diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java index ef1dfef9faf3..d009b88f3505 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java @@ -40,6 +40,7 @@ import io.airlift.units.Duration; import io.trino.collect.cache.EvictableCacheBuilder; import io.trino.spi.TrinoException; +import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableNotFoundException; import java.util.Collections; @@ -348,19 +349,21 @@ private static String fullTableName(TableId remoteTableId) public List getColumns(BigQueryTableHandle tableHandle) { if (tableHandle.getProjectedColumns().isPresent()) { - return tableHandle.getProjectedColumns().get().stream() - .map(column -> (BigQueryColumnHandle) column) - .collect(toImmutableList()); + return tableHandle.getProjectedColumns().get(); } checkArgument(tableHandle.isNamedRelation(), "Cannot get columns for %s", tableHandle); TableInfo tableInfo = getTable(tableHandle.asPlainTable().getRemoteTableName().toTableId()) .orElseThrow(() -> new TableNotFoundException(tableHandle.asPlainTable().getSchemaTableName())); + return buildColumnHandles(tableInfo); + } + + public static List buildColumnHandles(TableInfo tableInfo) + { Schema schema = tableInfo.getDefinition().getSchema(); if (schema == null) { - throw new TableNotFoundException( - tableHandle.asPlainTable().getSchemaTableName(), - format("Table '%s' has no schema", tableHandle.asPlainTable().getSchemaTableName())); + SchemaTableName schemaTableName = new SchemaTableName(tableInfo.getTableId().getDataset(), tableInfo.getTableId().getTable()); + throw new TableNotFoundException(schemaTableName, format("Table '%s' has no schema", schemaTableName)); } return schema.getFields() .stream() diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 5708ad5e3632..6d9251c49751 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -33,6 +33,7 @@ import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.trino.plugin.bigquery.BigQueryClient.RemoteDatabaseObject; +import io.trino.plugin.bigquery.BigQueryTableHandle.BigQueryPartitionType; import io.trino.plugin.bigquery.ptf.Query.QueryHandle; import io.trino.spi.TrinoException; import io.trino.spi.connector.Assignment; @@ -47,7 +48,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorTableSchema; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.Constraint; @@ -87,6 +87,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.plugin.bigquery.BigQueryClient.buildColumnHandles; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_LISTING_DATASET_ERROR; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_UNSUPPORTED_OPERATION; import static io.trino.plugin.bigquery.BigQueryPseudoColumn.PARTITION_DATE; @@ -230,12 +231,20 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable return null; } + ImmutableList.Builder columns = ImmutableList.builder(); + columns.addAll(buildColumnHandles(tableInfo.get())); + Optional partitionType = getPartitionType(tableInfo.get().getDefinition()); + if (partitionType.isPresent() && partitionType.get() == INGESTION) { + columns.add(PARTITION_DATE.getColumnHandle()); + columns.add(PARTITION_TIME.getColumnHandle()); + } return new BigQueryTableHandle(new BigQueryNamedRelationHandle( schemaTableName, new RemoteTableName(tableInfo.get().getTableId()), tableInfo.get().getDefinition().getType().toString(), - getPartitionType(tableInfo.get().getDefinition()), - Optional.ofNullable(tableInfo.get().getDescription()))); + partitionType, + Optional.ofNullable(tableInfo.get().getDescription()))) + .withProjectedColumns(columns.build()); } private ConnectorTableHandle getTableHandleIgnoringConflicts(ConnectorSession session, SchemaTableName schemaTableName) @@ -269,17 +278,10 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect log.debug("getTableMetadata(session=%s, tableHandle=%s)", session, tableHandle); BigQueryTableHandle handle = ((BigQueryTableHandle) tableHandle); - ImmutableList.Builder columnMetadata = ImmutableList.builder(); - for (BigQueryColumnHandle column : client.getColumns(handle)) { - columnMetadata.add(column.getColumnMetadata()); - } - if (handle.isNamedRelation()) { - if (handle.asPlainTable().getPartitionType().isPresent() && handle.asPlainTable().getPartitionType().get() == INGESTION) { - columnMetadata.add(PARTITION_DATE.getColumnMetadata()); - columnMetadata.add(PARTITION_TIME.getColumnMetadata()); - } - } - return new ConnectorTableMetadata(getSchemaTableName(handle), columnMetadata.build(), ImmutableMap.of(), getTableComment(handle)); + List columns = client.getColumns(handle).stream() + .map(BigQueryColumnHandle::getColumnMetadata) + .collect(toImmutableList()); + return new ConnectorTableMetadata(getSchemaTableName(handle), columns, ImmutableMap.of(), getTableComment(handle)); } @Override @@ -326,7 +328,7 @@ public Map getColumnHandles(ConnectorSession session, Conn BigQueryTableHandle table = (BigQueryTableHandle) tableHandle; if (table.getProjectedColumns().isPresent()) { return table.getProjectedColumns().get().stream() - .collect(toImmutableMap(columnHandle -> ((BigQueryColumnHandle) columnHandle).getName(), identity())); + .collect(toImmutableMap(BigQueryColumnHandle::getName, identity())); } checkArgument(table.isNamedRelation(), "Cannot get columns for %s", tableHandle); @@ -371,13 +373,6 @@ public Map> listTableColumns(ConnectorSess return columns.buildOrThrow(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - log.debug("getTableProperties(session=%s, prefix=%s)", session, table); - return new ConnectorTableProperties(); - } - @Override public void createSchema(ConnectorSession session, String schemaName, Map properties, TrinoPrincipal owner) { @@ -567,11 +562,12 @@ public Optional> applyProjecti return Optional.empty(); } - ImmutableList.Builder projectedColumns = ImmutableList.builder(); + ImmutableList.Builder projectedColumns = ImmutableList.builder(); ImmutableList.Builder assignmentList = ImmutableList.builder(); assignments.forEach((name, column) -> { - projectedColumns.add(column); - assignmentList.add(new Assignment(name, column, ((BigQueryColumnHandle) column).getTrinoType())); + BigQueryColumnHandle columnHandle = (BigQueryColumnHandle) column; + projectedColumns.add(columnHandle); + assignmentList.add(new Assignment(name, column, columnHandle.getTrinoType())); }); bigQueryTableHandle = bigQueryTableHandle.withProjectedColumns(projectedColumns.build()); diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java index 6ba9daddae17..7b53e6971ce2 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java @@ -17,7 +17,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.trino.spi.HostAddress; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSplit; import org.openjdk.jol.info.ClassLayout; @@ -43,7 +42,7 @@ public class BigQuerySplit private final Mode mode; private final String streamName; private final String avroSchema; - private final List columns; + private final List columns; private final long emptyRowsToGenerate; private final Optional filter; private final OptionalInt dataSize; @@ -54,7 +53,7 @@ public BigQuerySplit( @JsonProperty("mode") Mode mode, @JsonProperty("streamName") String streamName, @JsonProperty("avroSchema") String avroSchema, - @JsonProperty("columns") List columns, + @JsonProperty("columns") List columns, @JsonProperty("emptyRowsToGenerate") long emptyRowsToGenerate, @JsonProperty("filter") Optional filter, @JsonProperty("dataSize") OptionalInt dataSize) @@ -68,12 +67,12 @@ public BigQuerySplit( this.dataSize = requireNonNull(dataSize, "dataSize is null"); } - static BigQuerySplit forStream(String streamName, String avroSchema, List columns, OptionalInt dataSize) + static BigQuerySplit forStream(String streamName, String avroSchema, List columns, OptionalInt dataSize) { return new BigQuerySplit(STORAGE, streamName, avroSchema, columns, NO_ROWS_TO_GENERATE, Optional.empty(), dataSize); } - static BigQuerySplit forViewStream(List columns, Optional filter) + static BigQuerySplit forViewStream(List columns, Optional filter) { return new BigQuerySplit(QUERY, "", "", columns, NO_ROWS_TO_GENERATE, filter, OptionalInt.empty()); } @@ -102,7 +101,7 @@ public String getAvroSchema() } @JsonProperty - public List getColumns() + public List getColumns() { return columns; } @@ -149,7 +148,7 @@ public long getRetainedSizeInBytes() return INSTANCE_SIZE + estimatedSizeOf(streamName) + estimatedSizeOf(avroSchema) - + estimatedSizeOf(columns, column -> ((BigQueryColumnHandle) column).getRetainedSizeInBytes()); + + estimatedSizeOf(columns, BigQueryColumnHandle::getRetainedSizeInBytes); } @Override diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java index dbd10de79003..9b78d0727cf1 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java @@ -47,6 +47,7 @@ import static com.google.cloud.bigquery.TableDefinition.Type.MATERIALIZED_VIEW; import static com.google.cloud.bigquery.TableDefinition.Type.TABLE; import static com.google.cloud.bigquery.TableDefinition.Type.VIEW; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_FAILED_TO_EXECUTE_QUERY; import static io.trino.plugin.bigquery.BigQuerySessionProperties.createDisposition; @@ -103,7 +104,7 @@ public ConnectorSplitSource getSplits( Optional filter = BigQueryFilterQueryBuilder.buildFilter(tableConstraint); if (!bigQueryTableHandle.isNamedRelation()) { - List columns = bigQueryTableHandle.getProjectedColumns().orElse(ImmutableList.of()); + List columns = bigQueryTableHandle.getProjectedColumns().orElse(ImmutableList.of()); return new FixedSplitSource(ImmutableList.of(BigQuerySplit.forViewStream(columns, filter))); } @@ -114,17 +115,19 @@ public ConnectorSplitSource getSplits( return new FixedSplitSource(splits); } - private static boolean emptyProjectionIsRequired(Optional> projectedColumns) + private static boolean emptyProjectionIsRequired(Optional> projectedColumns) { return projectedColumns.isPresent() && projectedColumns.get().isEmpty(); } - private List readFromBigQuery(ConnectorSession session, TableDefinition.Type type, TableId remoteTableId, Optional> projectedColumns, int actualParallelism, Optional filter) + private List readFromBigQuery(ConnectorSession session, TableDefinition.Type type, TableId remoteTableId, Optional> projectedColumns, int actualParallelism, Optional filter) { + checkArgument(projectedColumns.isPresent() && projectedColumns.get().size() > 0, "Projected column is empty"); + log.debug("readFromBigQuery(tableId=%s, projectedColumns=%s, actualParallelism=%s, filter=[%s])", remoteTableId, projectedColumns, actualParallelism, filter); - List columns = projectedColumns.orElse(ImmutableList.of()); + List columns = projectedColumns.get(); List projectedColumnsNames = columns.stream() - .map(column -> ((BigQueryColumnHandle) column).getName()) + .map(BigQueryColumnHandle::getName) .collect(toImmutableList()); if (isWildcardTable(type, remoteTableId.getTable())) { diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java index cdd3f1988780..6618f6b4d14b 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java @@ -39,13 +39,13 @@ public class BigQueryTableHandle { private final BigQueryRelationHandle relationHandle; private final TupleDomain constraint; - private final Optional> projectedColumns; + private final Optional> projectedColumns; @JsonCreator public BigQueryTableHandle( @JsonProperty("relationHandle") BigQueryRelationHandle relationHandle, @JsonProperty("constraint") TupleDomain constraint, - @JsonProperty("projectedColumns") Optional> projectedColumns) + @JsonProperty("projectedColumns") Optional> projectedColumns) { this.relationHandle = requireNonNull(relationHandle, "relationHandle is null"); this.constraint = requireNonNull(constraint, "constraint is null"); @@ -79,7 +79,7 @@ public TupleDomain getConstraint() } @JsonProperty - public Optional> getProjectedColumns() + public Optional> getProjectedColumns() { return projectedColumns; } @@ -145,7 +145,7 @@ BigQueryTableHandle withConstraint(TupleDomain newConstraint) return new BigQueryTableHandle(relationHandle, newConstraint, projectedColumns); } - public BigQueryTableHandle withProjectedColumns(List newProjectedColumns) + public BigQueryTableHandle withProjectedColumns(List newProjectedColumns) { return new BigQueryTableHandle(relationHandle, constraint, Optional.of(newProjectedColumns)); } diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java index 6b9f8c03aaad..bfd8ef34e5aa 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java @@ -24,7 +24,6 @@ import io.trino.plugin.bigquery.BigQueryColumnHandle; import io.trino.plugin.bigquery.BigQueryQueryRelationHandle; import io.trino.plugin.bigquery.BigQueryTableHandle; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -110,13 +109,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact } columnsBuilder.add(toColumnHandle(field)); } - List columns = columnsBuilder.build(); Descriptor returnedType = new Descriptor(columnsBuilder.build().stream() .map(column -> new Field(column.getName(), Optional.of(column.getTrinoType()))) .collect(toList())); - QueryHandle handle = new QueryHandle(tableHandle.withProjectedColumns(columns.stream().map(column -> (ColumnHandle) column).collect(toList()))); + QueryHandle handle = new QueryHandle(tableHandle.withProjectedColumns(columnsBuilder.build())); return TableFunctionAnalysis.builder() .returnedType(returnedType) diff --git a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java index 2c63744aab9a..6d611dac602a 100644 --- a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java +++ b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.RetryMode; import io.trino.spi.connector.RowChangeParadigm; @@ -387,12 +386,6 @@ public Optional getView(ConnectorSession session, Schem return Optional.ofNullable(views.get(viewName)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - private void checkSchemaExists(String schemaName) { if (!schemas.contains(schemaName)) { diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java index 9b40fd8fbced..913040043100 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java @@ -30,7 +30,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.NotFoundException; @@ -251,12 +250,6 @@ public Optional> applyFilter(C false)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, boolean ignoreExisting) { diff --git a/plugin/trino-delta-lake/pom.xml b/plugin/trino-delta-lake/pom.xml index 84646b248e97..674f1f2b60eb 100644 --- a/plugin/trino-delta-lake/pom.xml +++ b/plugin/trino-delta-lake/pom.xml @@ -335,7 +335,7 @@ com.azure azure-core - 1.25.0 + 1.34.0 test diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index 081651d19be2..f322356b383a 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -1255,7 +1255,7 @@ private static void appendAddFileEntries(TransactionLogWriter transactionLogWrit transactionLogWriter.appendAddFileEntry( new AddFileEntry( - toUriFormat(info.getPath()), // Databricks and OSS Delta Lake both expect path to be url-encoded, even though the procotol specification doesn't mention that + toUriFormat(info.getPath()), // Paths are RFC 2396 URI encoded https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file partitionValues, info.getSize(), info.getCreationTime(), @@ -1683,12 +1683,9 @@ public Optional getTableHandleForExecute( throw new IllegalArgumentException("Unknown procedure '" + procedureName + "'"); } - switch (procedureId) { - case OPTIMIZE: - return getTableHandleForOptimize(tableHandle, executeProperties, retryMode); - } - - throw new IllegalArgumentException("Unknown procedure: " + procedureId); + return switch (procedureId) { + case OPTIMIZE -> getTableHandleForOptimize(tableHandle, executeProperties, retryMode); + }; } private Optional getTableHandleForOptimize(DeltaLakeTableHandle tableHandle, Map executeProperties, RetryMode retryMode) diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java index 7ceac4d7909f..ea296a05a8f2 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java @@ -36,7 +36,7 @@ import javax.inject.Inject; -import java.net.URLDecoder; +import java.net.URI; import java.time.Instant; import java.util.List; import java.util.Map; @@ -58,7 +58,6 @@ import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.getMaxSplitSize; import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.extractSchema; import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.deserializePartitionValue; -import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; @@ -290,8 +289,9 @@ private List splitsForFile( private static String buildSplitPath(String tableLocation, AddFileEntry addAction) { - // paths are relative to the table location and URL encoded - String path = URLDecoder.decode(addAction.getPath(), UTF_8); + // paths are relative to the table location and are RFC 2396 URIs + // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file + String path = URI.create(addAction.getPath()).getPath(); if (tableLocation.endsWith("/")) { return tableLocation + path; } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java index 94290ae288cd..a7175dce7cf7 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java @@ -282,6 +282,19 @@ public void testCreatePartitionedTable() assertUpdate("DROP TABLE " + tableName); } + @Test + public void testPathUriDecoding() + { + String tableName = "test_uri_table_" + randomNameSuffix(); + registerTableFromResources(tableName, "databricks/uri", getQueryRunner()); + + assertQuery("SELECT * FROM " + tableName, "VALUES ('a=equal', 1), ('a:colon', 2), ('a+plus', 3)"); + String firstFilePath = (String) computeScalar("SELECT \"$path\" FROM " + tableName + " WHERE y = 1"); + assertQuery("SELECT * FROM " + tableName + " WHERE \"$path\" = '" + firstFilePath + "'", "VALUES ('a=equal', 1)"); + + assertUpdate("DROP TABLE " + tableName); + } + @Test public void testCreateTablePartitionValidation() { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java index 2fe5b4aef91e..82358243fe27 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java @@ -50,6 +50,7 @@ import static java.util.Objects.requireNonNull; import static java.util.regex.Matcher.quoteReplacement; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestDeltaLakeAdlsConnectorSmokeTest extends BaseDeltaLakeConnectorSmokeTest @@ -113,6 +114,14 @@ public void removeTestData() assertThat(azureContainerClient.listBlobsByHierarchy(bucketName + "/").stream()).hasSize(0); } + @Override + public void testPathUriDecoding() + { + // TODO https://github.com/trinodb/trino/issues/15376 AzureBlobFileSystem doesn't expect URI as the path argument + assertThatThrownBy(super::testPathUriDecoding) + .hasStackTraceContaining("The specified path does not exist"); + } + @Override protected void registerTableFromResources(String table, String resourcePath, QueryRunner queryRunner) { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSink.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSink.java index 3f8a5d84d917..bc7f22229b23 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSink.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSink.java @@ -138,23 +138,12 @@ public void testPageSinkStats() private void writeToBlock(BlockBuilder blockBuilder, LineItemColumn column, LineItem lineItem) { switch (column.getType().getBase()) { - case IDENTIFIER: - BIGINT.writeLong(blockBuilder, column.getIdentifier(lineItem)); - break; - case INTEGER: - INTEGER.writeLong(blockBuilder, column.getInteger(lineItem)); - break; - case DATE: - DATE.writeLong(blockBuilder, column.getDate(lineItem)); - break; - case DOUBLE: - DOUBLE.writeDouble(blockBuilder, column.getDouble(lineItem)); - break; - case VARCHAR: - createUnboundedVarcharType().writeSlice(blockBuilder, Slices.utf8Slice(column.getString(lineItem))); - break; - default: - throw new IllegalArgumentException("Unsupported type " + column.getType()); + case IDENTIFIER -> BIGINT.writeLong(blockBuilder, column.getIdentifier(lineItem)); + case INTEGER -> INTEGER.writeLong(blockBuilder, column.getInteger(lineItem)); + case DATE -> DATE.writeLong(blockBuilder, column.getDate(lineItem)); + case DOUBLE -> DOUBLE.writeDouble(blockBuilder, column.getDouble(lineItem)); + case VARCHAR -> createUnboundedVarcharType().writeSlice(blockBuilder, Slices.utf8Slice(column.getString(lineItem))); + default -> throw new IllegalArgumentException("Unsupported type " + column.getType()); } } @@ -203,19 +192,12 @@ private static List getColumnHandles() private static Type getTrinoType(TpchColumnType type) { - switch (type.getBase()) { - case IDENTIFIER: - return BIGINT; - case INTEGER: - return INTEGER; - case DATE: - return DATE; - case DOUBLE: - return DOUBLE; - case VARCHAR: - return createUnboundedVarcharType(); - default: - throw new UnsupportedOperationException(); - } + return switch (type.getBase()) { + case IDENTIFIER -> BIGINT; + case INTEGER -> INTEGER; + case DATE -> DATE; + case DOUBLE -> DOUBLE; + case VARCHAR -> createUnboundedVarcharType(); + }; } } diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md b/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md new file mode 100644 index 000000000000..52a0fe3e758f --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md @@ -0,0 +1,12 @@ +Data generated using OSS DELTA 2.0.0: + +```sql +CREATE TABLE default.uri_test (part string, y long) +USING delta +PARTITIONED BY (part) +LOCATION '/home/username/trino/plugin/trino-delta-lake/src/test/resources/databricks/uri'; + +INSERT INTO default.uri_test VALUES ('a=equal', 1); +INSERT INTO default.uri_test VALUES ('a:colon', 2); +INSERT INTO default.uri_test VALUES ('a+plus', 3); +``` diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json new file mode 100644 index 000000000000..6ce3c0ab0934 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json @@ -0,0 +1,3 @@ +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"metaData":{"id":"ae596d2a-868c-4480-9dba-ecb1366eef15","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"part\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"y\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["part"],"configuration":{},"createdTime":1670672197310}} +{"commitInfo":{"timestamp":1670672197479,"operation":"CREATE TABLE","operationParameters":{"isManaged":"false","description":null,"partitionBy":"[\"part\"]","properties":"{}"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"ec6fd978-2b74-4482-a908-02528f96cff5"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json new file mode 100644 index 000000000000..2a4c5a4b763f --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a%253Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet","partitionValues":{"part":"a=equal"},"size":475,"modificationTime":1670672201990,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":1},\"maxValues\":{\"y\":1},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672202024,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":0,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"bd99e504-0c28-4661-af9c-854016b09a7e"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json new file mode 100644 index 000000000000..f939248ad264 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a%253Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet","partitionValues":{"part":"a:colon"},"size":475,"modificationTime":1670672202850,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":2},\"maxValues\":{\"y\":2},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672202858,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":1,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"fd9e716f-3072-4ecf-ac2b-e52d3a66e3b8"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json new file mode 100644 index 000000000000..e0b9d17bd36e --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet","partitionValues":{"part":"a+plus"},"size":475,"modificationTime":1670672203670,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":3},\"maxValues\":{\"y\":3},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672203680,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":2,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"3adbdcc8-1a1d-4457-b168-0292d9d34dac"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet new file mode 100644 index 000000000000..88a74a3a51a2 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet new file mode 100644 index 000000000000..3c3dab40dd86 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet differ diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet new file mode 100644 index 000000000000..2291dfae5a98 Binary files /dev/null and b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet differ diff --git a/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java b/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java index f1d705957d94..b964390904d2 100644 --- a/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java +++ b/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -155,10 +154,4 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable { return ((ExampleColumnHandle) columnHandle).getColumnMetadata(); } - - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } } diff --git a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java index efac03397d77..763070d3a053 100644 --- a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java +++ b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -146,10 +145,4 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable { return ((SheetsColumnHandle) columnHandle).getColumnMetadata(); } - - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java index 9230d9e26013..45afb7bb299a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java @@ -32,10 +32,14 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.lenientFormat; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.hive.HiveStorageFormat.AVRO; import static io.trino.plugin.hive.HiveStorageFormat.ORC; import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_FIELD_PREFIX; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_NAME; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_TYPE; import static io.trino.plugin.hive.util.HiveTypeTranslator.fromPrimitiveType; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeInfo; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeSignature; @@ -219,13 +223,32 @@ public Optional getHiveTypeForDereferences(List dereferences) { TypeInfo typeInfo = getTypeInfo(); for (int fieldIndex : dereferences) { - checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo); - StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - try { - typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + if (typeInfo instanceof StructTypeInfo structTypeInfo) { + try { + typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + } + catch (RuntimeException e) { + // return empty when failed to dereference, this could happen when partition and table schema mismatch + return Optional.empty(); + } } - catch (RuntimeException e) { - return Optional.empty(); + else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { + try { + if (fieldIndex == 0) { + // union's tag field, defined in {@link io.trino.plugin.hive.util.HiveTypeTranslator#toTypeSignature} + return Optional.of(HiveType.toHiveType(UNION_FIELD_TAG_TYPE)); + } + else { + typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1); + } + } + catch (RuntimeException e) { + // return empty when failed to dereference, this could happen when partition and table schema mismatch + return Optional.empty(); + } + } + else { + throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); } } return Optional.of(toHiveType(typeInfo)); @@ -235,16 +258,35 @@ public List getHiveDereferenceNames(List dereferences) { ImmutableList.Builder dereferenceNames = ImmutableList.builder(); TypeInfo typeInfo = getTypeInfo(); - for (int fieldIndex : dereferences) { - checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo); - StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - + for (int i = 0; i < dereferences.size(); i++) { + int fieldIndex = dereferences.get(i); checkArgument(fieldIndex >= 0, "fieldIndex cannot be negative"); - checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(), - "fieldIndex should be less than the number of fields in the struct"); - String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); - dereferenceNames.add(fieldName); - typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + + if (typeInfo instanceof StructTypeInfo structTypeInfo) { + checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(), + "fieldIndex should be less than the number of fields in the struct"); + + String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); + dereferenceNames.add(fieldName); + typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + } + else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { + checkArgument((fieldIndex - 1) < unionTypeInfo.getAllUnionObjectTypeInfos().size(), + "fieldIndex should be less than the number of fields in the union plus tag field"); + + if (fieldIndex == 0) { + checkArgument(i == (dereferences.size() - 1), "Union's tag field should not have more subfields"); + dereferenceNames.add(UNION_FIELD_TAG_NAME); + break; + } + else { + typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1); + dereferenceNames.add(UNION_FIELD_FIELD_PREFIX + (fieldIndex - 1)); + } + } + else { + throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); + } } return dereferenceNames.build(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java index 22d1e5f4ce2d..4d165511130a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java @@ -91,6 +91,10 @@ public final class HiveTypeTranslator { private HiveTypeTranslator() {} + public static final String UNION_FIELD_TAG_NAME = "tag"; + public static final String UNION_FIELD_FIELD_PREFIX = "field"; + public static final Type UNION_FIELD_TAG_TYPE = TINYINT; + public static TypeInfo toTypeInfo(Type type) { requireNonNull(type, "type is null"); @@ -213,10 +217,10 @@ public static TypeSignature toTypeSignature(TypeInfo typeInfo, HiveTimestampPrec UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; List unionObjectTypes = unionTypeInfo.getAllUnionObjectTypeInfos(); ImmutableList.Builder typeSignatures = ImmutableList.builder(); - typeSignatures.add(namedField("tag", TINYINT.getTypeSignature())); + typeSignatures.add(namedField(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE.getTypeSignature())); for (int i = 0; i < unionObjectTypes.size(); i++) { TypeInfo unionObjectType = unionObjectTypes.get(i); - typeSignatures.add(namedField("field" + i, toTypeSignature(unionObjectType, timestampPrecision))); + typeSignatures.add(namedField(UNION_FIELD_FIELD_PREFIX + i, toTypeSignature(unionObjectType, timestampPrecision))); } return rowType(typeSignatures.build()); } diff --git a/plugin/trino-iceberg/pom.xml b/plugin/trino-iceberg/pom.xml index 825187bc5c81..48db13121363 100644 --- a/plugin/trino-iceberg/pom.xml +++ b/plugin/trino-iceberg/pom.xml @@ -406,6 +406,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java index 3add0c0ceef6..32f721035158 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java @@ -91,6 +91,8 @@ public IcebergAvroPageSource( .collect(toImmutableMap(Types.NestedField::name, Types.NestedField::type)); pageBuilder = new PageBuilder(columnTypes); recordIterator = avroReader.iterator(); + // TODO: Remove when NPE check has been released: https://github.com/trinodb/trino/issues/15372 + isFinished(); } private boolean isIndexColumn(int column) diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index e12f937af007..8d131369a9d2 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -392,30 +392,33 @@ public IcebergTableHandle getTableHandle( private static long getSnapshotIdFromVersion(Table table, ConnectorTableVersion version) { io.trino.spi.type.Type versionType = version.getVersionType(); - switch (version.getPointerType()) { - case TEMPORAL: - long epochMillis; - if (versionType instanceof TimestampWithTimeZoneType) { - epochMillis = ((TimestampWithTimeZoneType) versionType).isShort() - ? unpackMillisUtc((long) version.getVersion()) - : ((LongTimestampWithTimeZone) version.getVersion()).getEpochMillis(); - } - else { - throw new TrinoException(NOT_SUPPORTED, "Unsupported type for temporal table version: " + versionType.getDisplayName()); - } - return getSnapshotIdAsOfTime(table, epochMillis); + return switch (version.getPointerType()) { + case TEMPORAL -> getTemporalSnapshotIdFromVersion(table, version, versionType); + case TARGET_ID -> getTargetSnapshotIdFromVersion(table, version, versionType); + }; + } - case TARGET_ID: - if (versionType != BIGINT) { - throw new TrinoException(NOT_SUPPORTED, "Unsupported type for table version: " + versionType.getDisplayName()); - } - long snapshotId = (long) version.getVersion(); - if (table.snapshot(snapshotId) == null) { - throw new TrinoException(INVALID_ARGUMENTS, "Iceberg snapshot ID does not exists: " + snapshotId); - } - return snapshotId; + private static long getTargetSnapshotIdFromVersion(Table table, ConnectorTableVersion version, io.trino.spi.type.Type versionType) + { + if (versionType != BIGINT) { + throw new TrinoException(NOT_SUPPORTED, "Unsupported type for table version: " + versionType.getDisplayName()); + } + long snapshotId = (long) version.getVersion(); + if (table.snapshot(snapshotId) == null) { + throw new TrinoException(INVALID_ARGUMENTS, "Iceberg snapshot ID does not exists: " + snapshotId); + } + return snapshotId; + } + + private static long getTemporalSnapshotIdFromVersion(Table table, ConnectorTableVersion version, io.trino.spi.type.Type versionType) + { + if (versionType instanceof TimestampWithTimeZoneType timeZonedVersionType) { + long epochMillis = timeZonedVersionType.isShort() + ? unpackMillisUtc((long) version.getVersion()) + : ((LongTimestampWithTimeZone) version.getVersion()).getEpochMillis(); + return getSnapshotIdAsOfTime(table, epochMillis); } - throw new TrinoException(NOT_SUPPORTED, "Version pointer type is not supported: " + version.getPointerType()); + throw new TrinoException(NOT_SUPPORTED, "Unsupported type for temporal table version: " + versionType.getDisplayName()); } @Override @@ -947,18 +950,12 @@ public Optional getTableHandleForExecute( throw new IllegalArgumentException("Unknown procedure '" + procedureName + "'"); } - switch (procedureId) { - case OPTIMIZE: - return getTableHandleForOptimize(tableHandle, executeProperties, retryMode); - case DROP_EXTENDED_STATS: - return getTableHandleForDropExtendedStats(session, tableHandle); - case EXPIRE_SNAPSHOTS: - return getTableHandleForExpireSnapshots(session, tableHandle, executeProperties); - case REMOVE_ORPHAN_FILES: - return getTableHandleForRemoveOrphanFiles(session, tableHandle, executeProperties); - } - - throw new IllegalArgumentException("Unknown procedure: " + procedureId); + return switch (procedureId) { + case OPTIMIZE -> getTableHandleForOptimize(tableHandle, executeProperties, retryMode); + case DROP_EXTENDED_STATS -> getTableHandleForDropExtendedStats(session, tableHandle); + case EXPIRE_SNAPSHOTS -> getTableHandleForExpireSnapshots(session, tableHandle, executeProperties); + case REMOVE_ORPHAN_FILES -> getTableHandleForRemoveOrphanFiles(session, tableHandle, executeProperties); + }; } private Optional getTableHandleForOptimize(IcebergTableHandle tableHandle, Map executeProperties, RetryMode retryMode) diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index 28d48977547c..754e6df7fded 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -1354,7 +1354,7 @@ private static TupleDomain getParquetTupleDomain(Map { String baseType = columnHandle.getType().getTypeSignature().getBase(); // skip looking up predicates for complex types as Parquet only stores stats for primitives - if (!baseType.equals(StandardTypes.MAP) && !baseType.equals(StandardTypes.ARRAY) && !baseType.equals(StandardTypes.ROW)) { + if (columnHandle.isBaseColumn() && (!baseType.equals(StandardTypes.MAP) && !baseType.equals(StandardTypes.ARRAY) && !baseType.equals(StandardTypes.ROW))) { ColumnDescriptor descriptor = descriptorsByPath.get(ImmutableList.of(columnHandle.getName())); if (descriptor != null) { predicate.put(descriptor, domain); diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java index 028278d43152..5302bf325e67 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java @@ -194,6 +194,8 @@ public CompletableFuture getNextBatch(int maxSize) closer.register(fileScanTaskIterable); this.fileScanTaskIterator = fileScanTaskIterable.iterator(); closer.register(fileScanTaskIterator); + // TODO: Remove when NPE check has been released: https://github.com/trinodb/trino/issues/15372 + isFinished(); } TupleDomain dynamicFilterPredicate = dynamicFilter.getCurrentPredicate() diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java index baa84cde6423..ed5d84fe5cdd 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java @@ -133,8 +133,8 @@ public void testSelect() assertUpdate("CREATE TABLE test_select AS SELECT 1 col_name", 1); assertFileSystemAccesses("SELECT * FROM test_select", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -160,26 +160,26 @@ public void testSelectFromVersionedTable() assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v2SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v3SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); } @@ -203,26 +203,26 @@ public void testSelectFromVersionedTableWithSchemaEvolution() assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v2SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v3SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); } @@ -232,8 +232,8 @@ public void testSelectWithFilter() assertUpdate("CREATE TABLE test_select_with_filter AS SELECT 1 col_name", 1); assertFileSystemAccesses("SELECT * FROM test_select_with_filter WHERE col_name = 1", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -248,8 +248,8 @@ public void testJoin() assertFileSystemAccesses("SELECT name, age FROM test_join_t1 JOIN test_join_t2 ON test_join_t2.id = test_join_t1.id", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 8) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 8) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 2) @@ -266,8 +266,8 @@ public void testJoinWithPartitionedTable() assertFileSystemAccesses("SELECT count(*) FROM test_join_partitioned_t1 t1 join test_join_partitioned_t2 t2 on t1.a = t2.foo", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 8) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 8) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 2) @@ -281,8 +281,8 @@ public void testExplainSelect() assertFileSystemAccesses("EXPLAIN SELECT * FROM test_explain", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -296,8 +296,8 @@ public void testShowStatsForTable() assertFileSystemAccesses("SHOW STATS FOR test_show_stats", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -313,8 +313,8 @@ public void testShowStatsForPartitionedTable() assertFileSystemAccesses("SHOW STATS FOR test_show_stats_partitioned", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -328,8 +328,8 @@ public void testShowStatsForTableWithFilter() assertFileSystemAccesses("SHOW STATS FOR (SELECT * FROM test_show_stats_with_filter WHERE age >= 2)", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -345,8 +345,8 @@ public void testPredicateWithVarcharCastToDate() assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -355,8 +355,8 @@ public void testPredicateWithVarcharCastToDate() // CAST to date and comparison assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE CAST(a AS date) >= DATE '2005-01-01'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -365,8 +365,8 @@ public void testPredicateWithVarcharCastToDate() // CAST to date and BETWEEN assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE CAST(a AS date) BETWEEN DATE '2005-01-01' AND DATE '2005-12-31'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -375,8 +375,8 @@ public void testPredicateWithVarcharCastToDate() // conversion to date as a date function assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE date(a) >= DATE '2005-01-01'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -404,8 +404,8 @@ public void testRemoveOrphanFiles() .add(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM)) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 4) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 4) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 6) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 6) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 5) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 5) .build()); assertUpdate("DROP TABLE " + tableName); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java index 64f74e0053d0..074ccd439107 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java @@ -51,6 +51,7 @@ import org.apache.iceberg.data.Record; import org.apache.iceberg.data.parquet.GenericParquetWriter; import org.apache.iceberg.deletes.EqualityDeleteWriter; +import org.apache.iceberg.deletes.PositionDelete; import org.apache.iceberg.deletes.PositionDeleteWriter; import org.apache.iceberg.hadoop.HadoopOutputFile; import org.apache.iceberg.parquet.Parquet; @@ -171,8 +172,10 @@ public void testV2TableWithPositionDelete() .withSpec(PartitionSpec.unpartitioned()) .buildPositionWriter(); + PositionDelete positionDelete = PositionDelete.create(); + PositionDelete record = positionDelete.set(dataFilePath, 0, GenericRecord.create(icebergTable.schema())); try (Closeable ignored = writer) { - writer.delete(dataFilePath, 0, GenericRecord.create(icebergTable.schema())); + writer.write(record); } icebergTable.newRowDelta().addDeletes(writer.toDeleteFile()).commit(); @@ -521,7 +524,7 @@ private void writeEqualityDeleteToNationTable(Table icebergTable, Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java b/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java index f4bdef956de0..6f29ad35d03d 100644 --- a/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java +++ b/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -85,12 +84,6 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession connectorSession return getTableMetadata(((KinesisTableHandle) tableHandle).toSchemaTableName()); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public List listTables(ConnectorSession session, Optional schemaName) { diff --git a/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java b/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java index 7d757e3eb2c2..032fbeae7621 100644 --- a/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java +++ b/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java @@ -21,7 +21,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -133,12 +132,6 @@ private List listTables(ConnectorSession session, SchemaTablePr return ImmutableList.of(prefix.toSchemaTableName()); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle tableHandle) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java index ab19f25f75f5..0e167cac5494 100644 --- a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java +++ b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.RetryMode; @@ -421,12 +420,6 @@ private void updateRowsOnHosts(long tableId, Collection fragments) tables.put(tableId, new TableInfo(tableId, info.getSchemaName(), info.getTableName(), info.getColumns(), dataFragments, info.getComment())); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - public List getDataFragments(long tableId) { return ImmutableList.copyOf(tables.get(tableId).getDataFragments().values()); diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java index 0e4a574aeb87..6710fbfca8e1 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java @@ -43,7 +43,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.LimitApplicationResult; @@ -234,12 +233,6 @@ public Optional getInfo(ConnectorTableHandle table) return Optional.empty(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyLimit(ConnectorSession session, ConnectorTableHandle table, long limit) { diff --git a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java index 46cafc5c5d8c..4e5e79f4703e 100644 --- a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java +++ b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -155,12 +154,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable return ((PrometheusColumnHandle) columnHandle).getColumnMetadata(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle handle, Constraint constraint) { diff --git a/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java b/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java index cff52758efa4..e2505e399d5e 100644 --- a/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java +++ b/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java @@ -25,7 +25,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -276,12 +275,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable return ((RedisColumnHandle) columnHandle).getColumnMetadata(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle tableHandle) - { - return new ConnectorTableProperties(); - } - @VisibleForTesting Map getDefinedTables() { diff --git a/plugin/trino-redshift/README.md b/plugin/trino-redshift/README.md new file mode 100644 index 000000000000..16229b145da1 --- /dev/null +++ b/plugin/trino-redshift/README.md @@ -0,0 +1,20 @@ +# Redshift Connector + +To run the Redshift tests you will need to provision a Redshift cluster. The +tests are designed to run on the smallest possible Redshift cluster containing +is a single dc2.large instance. Additionally, you will need a S3 bucket +containing TPCH tiny data in Parquet format. The files should be named: + +``` +s3:///tpch/tiny/.parquet +``` + +To run the tests set the following system properties: + +``` +test.redshift.jdbc.endpoint=..redshift.amazonaws.com:5439/ +test.redshift.jdbc.user= +test.redshift.jdbc.password= +test.redshift.s3.tpch.tables.root= +test.redshift.iam.role= +``` diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index a6270049b7c2..367ae423d4a4 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -23,12 +23,32 @@ trino-base-jdbc + + io.trino + trino-matching + + + + io.trino + trino-plugin-toolkit + + + + io.airlift + configuration + + com.amazon.redshift redshift-jdbc42 2.1.0.9 + + com.google.guava + guava + + com.google.inject guice @@ -39,10 +59,27 @@ javax.inject + + org.jdbi + jdbi3-core + + - com.google.guava - guava + io.airlift + log + runtime + + + + io.airlift + log-manager + runtime + + + + net.jodah + failsafe runtime @@ -72,16 +109,91 @@ + + io.trino + trino-base-jdbc + test-jar + test + + io.trino trino-main test + + io.trino + trino-main + test-jar + test + + + + io.trino + trino-testing + test + + + + io.trino + trino-testing-services + test + + + + io.trino + trino-tpch + test + + + + io.trino.tpch + tpch + test + + + + io.airlift + testing + test + + + + org.assertj + assertj-core + test + + org.testng testng test + + + + default + + true + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/TestRedshiftAutomaticJoinPushdown.java + **/TestRedshiftConnectorTest.java + **/TestRedshiftTableStatisticsReader.java + **/TestRedshiftTypeMapping.java + + + + + + + diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java new file mode 100644 index 000000000000..f9c546105546 --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint; + +public class ImplementRedshiftAvgBigint + extends BaseImplementAvgBigint +{ + @Override + protected String getRewriteFormatExpression() + { + return "avg(CAST(%s AS double precision))"; + } +} diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java new file mode 100644 index 000000000000..103258db12b7 --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DecimalType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; +import static io.trino.plugin.redshift.RedshiftClient.REDSHIFT_MAX_DECIMAL_PRECISION; +import static java.lang.String.format; + +public class ImplementRedshiftAvgDecimal + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().equalTo("avg")) + .with(singleArgument().matching( + variable() + .with(type().matching(DecimalType.class::isInstance)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + DecimalType type = (DecimalType) columnHandle.getColumnType(); + verify(aggregateFunction.getOutputType().equals(type)); + + // When decimal type has maximum precision we can get result that is not matching Presto avg semantics. + if (type.getPrecision() == REDSHIFT_MAX_DECIMAL_PRECISION) { + return Optional.of(new JdbcExpression( + format("avg(CAST(%s AS decimal(%s, %s)))", context.rewriteExpression(input).orElseThrow(), type.getPrecision(), type.getScale()), + columnHandle.getJdbcTypeHandle())); + } + + // Redshift avg function rounds down resulting decimal. + // To match Presto avg semantics, we extend scale by 1 and round result to target scale. + return Optional.of(new JdbcExpression( + format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", context.rewriteExpression(input).orElseThrow(), type.getPrecision() + 1, type.getScale() + 1, type.getScale()), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index af0078309e9f..21ee2eabcf02 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -13,39 +13,105 @@ */ package io.trino.plugin.redshift; +import com.amazon.redshift.jdbc.RedshiftPreparedStatement; +import com.amazon.redshift.util.RedshiftObject; +import com.google.common.base.CharMatcher; +import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; +import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; +import io.trino.plugin.jdbc.JdbcSortItem; +import io.trino.plugin.jdbc.JdbcSplit; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.LongWriteFunction; +import io.trino.plugin.jdbc.ObjectReadFunction; +import io.trino.plugin.jdbc.ObjectWriteFunction; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.QueryBuilder; +import io.trino.plugin.jdbc.SliceWriteFunction; +import io.trino.plugin.jdbc.StandardColumnMappings; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint; +import io.trino.plugin.jdbc.aggregation.ImplementCount; +import io.trino.plugin.jdbc.aggregation.ImplementCountAll; +import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct; +import io.trino.plugin.jdbc.aggregation.ImplementMinMax; +import io.trino.plugin.jdbc.aggregation.ImplementStddevPop; +import io.trino.plugin.jdbc.aggregation.ImplementStddevSamp; +import io.trino.plugin.jdbc.aggregation.ImplementSum; +import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; +import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; +import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; +import io.trino.spi.type.Chars; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; +import io.trino.spi.type.Int128; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import javax.inject.Inject; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.OptionalLong; import java.util.function.BiFunction; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.base.Verify.verify; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; +import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.charReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.charWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.dateColumnMappingUsingSqlDate; import static io.trino.plugin.jdbc.StandardColumnMappings.dateWriteFunctionUsingSqlDate; @@ -56,6 +122,7 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.integerColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.integerWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.realColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.realWriteFunction; @@ -67,33 +134,168 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryColumnMapping; -import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryReadFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varcharColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.LongTimestampWithTimeZone.fromEpochSecondsAndFraction; import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.TIME_MICROS; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_DAY; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MILLISECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; +import static io.trino.spi.type.Timestamps.roundDiv; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; import static java.lang.Math.max; +import static java.lang.Math.min; import static java.lang.String.format; +import static java.math.RoundingMode.UNNECESSARY; +import static java.time.temporal.ChronoField.NANO_OF_SECOND; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; public class RedshiftClient extends BaseJdbcClient { + /** + * Redshift does not handle values larger than 64 bits for + * {@code DECIMAL(19, s)}. It supports the full range of values for all + * other precisions. + * + * @see + * Redshift documentation + */ + private static final int REDSHIFT_DECIMAL_CUTOFF_PRECISION = 19; + + static final int REDSHIFT_MAX_DECIMAL_PRECISION = 38; + + /** + * Maximum size of a {@link BigInteger} storing a Redshift {@code DECIMAL} + * with precision {@link #REDSHIFT_DECIMAL_CUTOFF_PRECISION}. + */ + // actual value is 63 + private static final int REDSHIFT_DECIMAL_CUTOFF_BITS = BigInteger.valueOf(Long.MAX_VALUE).bitLength(); + + /** + * Maximum size of a Redshift CHAR column. + * + * @see + * Redshift documentation + */ + private static final int REDSHIFT_MAX_CHAR = 4096; + + /** + * Maximum size of a Redshift VARCHAR column. + * + * @see + * Redshift documentation + */ + static final int REDSHIFT_MAX_VARCHAR = 65535; + + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormatter.ofPattern("yyy-MM-dd[ G]"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder() + .appendPattern("yyy-MM-dd HH:mm:ss") + .optionalStart() + .appendFraction(NANO_OF_SECOND, 0, 6, true) + .optionalEnd() + .appendPattern("[ G]") + .toFormatter(); + private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + + private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final boolean statisticsEnabled; + private final RedshiftTableStatisticsReader statisticsReader; + @Inject - public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) + public RedshiftClient( + BaseJdbcConfig config, + ConnectionFactory connectionFactory, + JdbcStatisticsConfig statisticsConfig, + QueryBuilder queryBuilder, + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); + ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + .addStandardRules(this::quoted) + .build(); + + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + + aggregateFunctionRewriter = new AggregateFunctionRewriter<>( + connectorExpressionRewriter, + ImmutableSet.>builder() + .add(new ImplementCountAll(bigintTypeHandle)) + .add(new ImplementCount(bigintTypeHandle)) + .add(new ImplementCountDistinct(bigintTypeHandle, true)) + .add(new ImplementMinMax(true)) + .add(new ImplementSum(RedshiftClient::toTypeHandle)) + .add(new ImplementAvgFloatingPoint()) + .add(new ImplementRedshiftAvgDecimal()) + .add(new ImplementRedshiftAvgBigint()) + .add(new ImplementStddevSamp()) + .add(new ImplementStddevPop()) + .add(new ImplementVarianceSamp()) + .add(new ImplementVariancePop()) + .build()); + + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); + this.statisticsReader = new RedshiftTableStatisticsReader(connectionFactory); + } + + private static Optional toTypeHandle(DecimalType decimalType) + { + return Optional.of( + new JdbcTypeHandle( + Types.NUMERIC, + Optional.of("decimal"), + Optional.of(decimalType.getPrecision()), + Optional.of(decimalType.getScale()), + Optional.empty(), + Optional.empty())); + } + + @Override + public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) + throws SQLException + { + Connection connection = super.getConnection(session, split, tableHandle); + try { + // super.getConnection sets read-only, since the connection is going to be used only for reads. + // However, for a complex query, Redshift may decide to create some temporary tables behind + // the scenes, and this requires the connection not to be read-only, otherwise Redshift + // may fail with "ERROR: transaction is read-only". + connection.setReadOnly(false); + } + catch (SQLException e) { + connection.close(); + throw e; + } + return connection; } @Override @@ -103,6 +305,87 @@ public Optional getTableComment(ResultSet resultSet) return Optional.empty(); } + @Override + public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments) + { + return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); + } + + @Override + public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) + { + if (!statisticsEnabled) { + return TableStatistics.empty(); + } + if (!handle.isNamedRelation()) { + return TableStatistics.empty(); + } + try { + return statisticsReader.readTableStatistics(session, handle, () -> this.getColumns(session, handle)); + } + catch (SQLException | RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e); + } + } + + @Override + public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) + { + return true; + } + + @Override + protected Optional topNFunction() + { + return Optional.of((query, sortItems, limit) -> { + String orderBy = sortItems.stream() + .map(sortItem -> { + String ordering = sortItem.getSortOrder().isAscending() ? "ASC" : "DESC"; + String nullsHandling = sortItem.getSortOrder().isNullsFirst() ? "NULLS FIRST" : "NULLS LAST"; + return format("%s %s %s", quoted(sortItem.getColumn().getColumnName()), ordering, nullsHandling); + }) + .collect(joining(", ")); + + return format("%s ORDER BY %s LIMIT %d", query, orderBy, limit); + }); + } + + @Override + public boolean isTopNGuaranteed(ConnectorSession session) + { + return true; + } + + @Override + protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) + { + return joinCondition.getOperator() != JoinCondition.Operator.IS_DISTINCT_FROM; + } + + @Override + public Optional implementJoin(ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // FULL JOIN is only supported with merge-joinable or hash-joinable join conditions + return Optional.empty(); + } + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException @@ -131,7 +414,147 @@ public PreparedStatement getPreparedStatement(Connection connection, String sql) } @Override - public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) + public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle) + { + checkArgument(handle.isNamedRelation(), "Unable to delete from synthetic table: %s", handle); + checkArgument(handle.getLimit().isEmpty(), "Unable to delete when limit is set: %s", handle); + checkArgument(handle.getSortOrder().isEmpty(), "Unable to delete when sort order is set: %s", handle); + try (Connection connection = connectionFactory.openConnection(session)) { + verify(connection.getAutoCommit()); + PreparedQuery preparedQuery = queryBuilder.prepareDeleteQuery(this, session, connection, handle.getRequiredNamedRelation(), handle.getConstraint(), Optional.empty()); + try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery)) { + int affectedRowsCount = preparedStatement.executeUpdate(); + // connection.getAutoCommit() == true is not enough to make DELETE effective and explicit commit is required + connection.commit(); + return OptionalLong.of(affectedRowsCount); + } + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, e); + } + } + + @Override + protected void verifySchemaName(DatabaseMetaData databaseMetadata, String schemaName) + throws SQLException + { + // Redshift truncates schema name to 127 chars silently + if (schemaName.length() > databaseMetadata.getMaxSchemaNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Schema name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxSchemaNameLength(), schemaName.length())); + } + } + + @Override + protected void verifyTableName(DatabaseMetaData databaseMetadata, String tableName) + throws SQLException + { + // Redshift truncates table name to 127 chars silently + if (tableName.length() > databaseMetadata.getMaxTableNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Table name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxTableNameLength(), tableName.length())); + } + } + + @Override + protected void verifyColumnName(DatabaseMetaData databaseMetadata, String columnName) + throws SQLException + { + // Redshift truncates table name to 127 chars silently + if (columnName.length() > databaseMetadata.getMaxColumnNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Column name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxColumnNameLength(), columnName.length())); + } + } + + @Override + public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle type) + { + Optional mapping = getForcedMappingToVarchar(type); + if (mapping.isPresent()) { + return mapping; + } + + if ("time".equals(type.getJdbcTypeName().orElse(""))) { + return Optional.of(ColumnMapping.longMapping( + TIME_MICROS, + RedshiftClient::readTime, + RedshiftClient::writeTime)); + } + + switch (type.getJdbcType()) { + case Types.BIT: // Redshift uses this for booleans + return Optional.of(booleanColumnMapping()); + + // case Types.TINYINT: -- Redshift doesn't support tinyint + case Types.SMALLINT: + return Optional.of(smallintColumnMapping()); + case Types.INTEGER: + return Optional.of(integerColumnMapping()); + case Types.BIGINT: + return Optional.of(bigintColumnMapping()); + + case Types.REAL: + return Optional.of(realColumnMapping()); + case Types.DOUBLE: + return Optional.of(doubleColumnMapping()); + + case Types.NUMERIC: { + int precision = type.getRequiredColumnSize(); + int scale = type.getRequiredDecimalDigits(); + DecimalType decimalType = createDecimalType(precision, scale); + if (precision == REDSHIFT_DECIMAL_CUTOFF_PRECISION) { + return Optional.of(ColumnMapping.objectMapping( + decimalType, + longDecimalReadFunction(decimalType), + writeDecimalAtRedshiftCutoff(scale))); + } + return Optional.of(decimalColumnMapping(decimalType, UNNECESSARY)); + } + + case Types.CHAR: + CharType charType = createCharType(type.getRequiredColumnSize()); + return Optional.of(ColumnMapping.sliceMapping( + charType, + charReadFunction(charType), + RedshiftClient::writeChar)); + + case Types.VARCHAR: { + int length = type.getRequiredColumnSize(); + return Optional.of(varcharColumnMapping( + length < VarcharType.MAX_LENGTH + ? createVarcharType(length) + : createUnboundedVarcharType(), + true)); + } + + case Types.LONGVARBINARY: + return Optional.of(ColumnMapping.sliceMapping( + VARBINARY, + varbinaryReadFunction(), + varbinaryWriteFunction())); + + case Types.DATE: + return Optional.of(ColumnMapping.longMapping( + DATE, + RedshiftClient::readDate, + RedshiftClient::writeDate)); + + case Types.TIMESTAMP: + return Optional.of(ColumnMapping.longMapping( + TIMESTAMP_MICROS, + RedshiftClient::readTimestamp, + RedshiftClient::writeShortTimestamp)); + + case Types.TIMESTAMP_WITH_TIMEZONE: + return Optional.of(ColumnMapping.objectMapping( + TIMESTAMP_TZ_MICROS, + longTimestampWithTimeZoneReadFunction(), + longTimestampWithTimeZoneWriteFunction())); + } + + // Fall back to default behavior + return legacyToColumnMapping(session, type); + } + + private Optional legacyToColumnMapping(ConnectorSession session, JdbcTypeHandle typeHandle) { Optional mapping = getForcedMappingToVarchar(typeHandle); if (mapping.isPresent()) { @@ -150,6 +573,99 @@ public Optional toColumnMapping(ConnectorSession session, Connect @Override public WriteMapping toWriteMapping(ConnectorSession session, Type type) { + if (BOOLEAN.equals(type)) { + return WriteMapping.booleanMapping("boolean", booleanWriteFunction()); + } + if (TINYINT.equals(type)) { + // Redshift doesn't have tinyint + return WriteMapping.longMapping("smallint", tinyintWriteFunction()); + } + if (SMALLINT.equals(type)) { + return WriteMapping.longMapping("smallint", smallintWriteFunction()); + } + if (INTEGER.equals(type)) { + return WriteMapping.longMapping("integer", integerWriteFunction()); + } + if (BIGINT.equals(type)) { + return WriteMapping.longMapping("bigint", bigintWriteFunction()); + } + if (REAL.equals(type)) { + return WriteMapping.longMapping("real", realWriteFunction()); + } + if (DOUBLE.equals(type)) { + return WriteMapping.doubleMapping("double precision", doubleWriteFunction()); + } + + if (type instanceof DecimalType decimal) { + if (decimal.getPrecision() == REDSHIFT_DECIMAL_CUTOFF_PRECISION) { + // See doc for REDSHIFT_DECIMAL_CUTOFF_PRECISION + return WriteMapping.objectMapping( + format("decimal(%s, %s)", decimal.getPrecision(), decimal.getScale()), + writeDecimalAtRedshiftCutoff(decimal.getScale())); + } + String name = format("decimal(%s, %s)", decimal.getPrecision(), decimal.getScale()); + return decimal.isShort() + ? WriteMapping.longMapping(name, shortDecimalWriteFunction(decimal)) + : WriteMapping.objectMapping(name, longDecimalWriteFunction(decimal)); + } + + if (type instanceof CharType) { + // Redshift has no unbounded text/binary types, so if a CHAR is too + // large for Redshift, we write as VARCHAR. If too large for that, + // we use the largest VARCHAR Redshift supports. + int size = ((CharType) type).getLength(); + if (size <= REDSHIFT_MAX_CHAR) { + return WriteMapping.sliceMapping( + format("char(%d)", size), + RedshiftClient::writeChar); + } + int redshiftVarcharWidth = min(size, REDSHIFT_MAX_VARCHAR); + return WriteMapping.sliceMapping( + format("varchar(%d)", redshiftVarcharWidth), + (statement, index, value) -> writeCharAsVarchar(statement, index, value, redshiftVarcharWidth)); + } + + if (type instanceof VarcharType) { + // Redshift has no unbounded text/binary types, so if a VARCHAR is + // larger than Redshift's limit, we make it that big instead. + int size = ((VarcharType) type).getLength() + .filter(n -> n <= REDSHIFT_MAX_VARCHAR) + .orElse(REDSHIFT_MAX_VARCHAR); + return WriteMapping.sliceMapping(format("varchar(%d)", size), varcharWriteFunction()); + } + + if (VARBINARY.equals(type)) { + return WriteMapping.sliceMapping("varbyte", varbinaryWriteFunction()); + } + + if (DATE.equals(type)) { + return WriteMapping.longMapping("date", RedshiftClient::writeDate); + } + + if (type instanceof TimeType) { + return WriteMapping.longMapping("time", RedshiftClient::writeTime); + } + + if (type instanceof TimestampType) { + if (((TimestampType) type).isShort()) { + return WriteMapping.longMapping( + "timestamp", + RedshiftClient::writeShortTimestamp); + } + return WriteMapping.objectMapping( + "timestamp", + LongTimestamp.class, + RedshiftClient::writeLongTimestamp); + } + + if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType) { + if (timestampWithTimeZoneType.getPrecision() <= TimestampWithTimeZoneType.MAX_SHORT_PRECISION) { + return WriteMapping.longMapping("timestamptz", shortTimestampWithTimeZoneWriteFunction()); + } + return WriteMapping.objectMapping("timestamptz", longTimestampWithTimeZoneWriteFunction()); + } + + // Fall back to legacy behavior return legacyToWriteMapping(type); } @@ -183,9 +699,168 @@ private static String redshiftVarcharLiteral(String value) return "'" + value.replace("'", "''").replace("\\", "\\\\") + "'"; } + private static ObjectReadFunction longTimestampWithTimeZoneReadFunction() + { + return ObjectReadFunction.of( + LongTimestampWithTimeZone.class, + (resultSet, columnIndex) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + OffsetDateTime offsetDateTime = resultSet.getObject(columnIndex, OffsetDateTime.class); + return fromEpochSecondsAndFraction( + offsetDateTime.toEpochSecond(), + (long) offsetDateTime.getNano() * PICOSECONDS_PER_NANOSECOND, + UTC_KEY); + }); + } + + private static LongWriteFunction shortTimestampWithTimeZoneWriteFunction() + { + return (statement, index, value) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + long millisUtc = unpackMillisUtc(value); + long epochSeconds = floorDiv(millisUtc, MILLISECONDS_PER_SECOND); + int nanosOfSecond = floorMod(millisUtc, MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND; + OffsetDateTime offsetDateTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); + verifySupportedTimestampWithTimeZone(offsetDateTime); + statement.setObject(index, offsetDateTime); + }; + } + + private static ObjectWriteFunction longTimestampWithTimeZoneWriteFunction() + { + return ObjectWriteFunction.of( + LongTimestampWithTimeZone.class, + (statement, index, value) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + long epochSeconds = floorDiv(value.getEpochMillis(), MILLISECONDS_PER_SECOND); + long nanosOfSecond = ((long) floorMod(value.getEpochMillis(), MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND) + + (value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND); + OffsetDateTime offsetDateTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); + verifySupportedTimestampWithTimeZone(offsetDateTime); + statement.setObject(index, offsetDateTime); + }); + } + + private static void verifySupportedTimestampWithTimeZone(OffsetDateTime value) + { + if (value.isBefore(REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ)) { + DateTimeFormatter format = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss.SSSSSS"); + throw new TrinoException( + INVALID_ARGUMENTS, + format("Minimum timestamp with time zone in Redshift is %s: %s", REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ.format(format), value.format(format))); + } + } + + /** + * Decimal write function for precision {@link #REDSHIFT_DECIMAL_CUTOFF_PRECISION}. + * Ensures that values fit in 8 bytes. + */ + private static ObjectWriteFunction writeDecimalAtRedshiftCutoff(int scale) + { + return ObjectWriteFunction.of( + Int128.class, + (statement, index, decimal) -> { + BigInteger unscaled = decimal.toBigInteger(); + if (unscaled.bitLength() > REDSHIFT_DECIMAL_CUTOFF_BITS) { + throw new TrinoException(JDBC_NON_TRANSIENT_ERROR, format( + "Value out of range for Redshift DECIMAL(%d, %d)", + REDSHIFT_DECIMAL_CUTOFF_PRECISION, + scale)); + } + MathContext precision = new MathContext(REDSHIFT_DECIMAL_CUTOFF_PRECISION); + statement.setBigDecimal(index, new BigDecimal(unscaled, scale, precision)); + }); + } + + /** + * Like {@link StandardColumnMappings#charWriteFunction}, but restrict to + * ASCII because Redshift only allows ASCII in {@code CHAR} values. + */ + private static void writeChar(PreparedStatement statement, int index, Slice slice) + throws SQLException + { + String value = slice.toStringUtf8(); + if (!CharMatcher.ascii().matchesAllOf(value)) { + throw new TrinoException( + JDBC_NON_TRANSIENT_ERROR, + format("Value for Redshift CHAR must be ASCII, but found '%s'", value)); + } + statement.setString(index, slice.toStringAscii()); + } + + /** + * Like {@link StandardColumnMappings#charWriteFunction}, but pads + * the value with spaces to simulate {@code CHAR} semantics. + */ + private static void writeCharAsVarchar(PreparedStatement statement, int index, Slice slice, int columnLength) + throws SQLException + { + // Redshift counts varchar size limits in UTF-8 bytes, so this may make the string longer than + // the limit, but Redshift also truncates extra trailing spaces, so that doesn't cause any problems. + statement.setString(index, Chars.padSpaces(slice, columnLength).toStringUtf8()); + } + + private static void writeDate(PreparedStatement statement, int index, long day) + throws SQLException + { + statement.setObject(index, new RedshiftObject("date", DATE_FORMATTER.format(LocalDate.ofEpochDay(day)))); + } + + private static long readDate(ResultSet results, int index) + throws SQLException + { + // Reading date as string to workaround issues around julian->gregorian calendar switch + return LocalDate.parse(results.getString(index), DATE_FORMATTER).toEpochDay(); + } + + /** + * Write time with microsecond precision + */ + private static void writeTime(PreparedStatement statement, int index, long picos) + throws SQLException + { + statement.setObject(index, LocalTime.ofNanoOfDay((roundDiv(picos, PICOSECONDS_PER_MICROSECOND) % MICROSECONDS_PER_DAY) * NANOSECONDS_PER_MICROSECOND)); + } + + /** + * Read a time value with microsecond precision + */ + private static long readTime(ResultSet results, int index) + throws SQLException + { + return results.getObject(index, LocalTime.class).toNanoOfDay() * PICOSECONDS_PER_NANOSECOND; + } + + private static void writeShortTimestamp(PreparedStatement statement, int index, long epochMicros) + throws SQLException + { + statement.setObject(index, new RedshiftObject("timestamp", DATE_TIME_FORMATTER.format(StandardColumnMappings.fromTrinoTimestamp(epochMicros)))); + } + + private static void writeLongTimestamp(PreparedStatement statement, int index, Object value) + throws SQLException + { + LongTimestamp timestamp = (LongTimestamp) value; + long epochMicros = timestamp.getEpochMicros(); + if (timestamp.getPicosOfMicro() >= PICOSECONDS_PER_MICROSECOND / 2) { + epochMicros += 1; // Add one micro if picos round up + } + statement.setObject(index, new RedshiftObject("timestamp", DATE_TIME_FORMATTER.format(StandardColumnMappings.fromTrinoTimestamp(epochMicros)))); + } + + private static long readTimestamp(ResultSet results, int index) + throws SQLException + { + return StandardColumnMappings.toTrinoTimestamp(TIMESTAMP_MICROS, results.getObject(index, LocalDateTime.class)); + } + + private static SliceWriteFunction varbinaryWriteFunction() + { + return (statement, index, value) -> statement.unwrap(RedshiftPreparedStatement.class).setVarbyte(index, value.getBytes()); + } + private static Optional legacyDefaultColumnMapping(JdbcTypeHandle typeHandle) { - // TODO (https://github.com/trinodb/trino/issues/497) Implement proper type mapping and add test // This method is copied from deprecated StandardColumnMappings.legacyDefaultColumnMapping() switch (typeHandle.getJdbcType()) { case Types.BIT: @@ -251,7 +926,6 @@ private static Optional legacyDefaultColumnMapping(JdbcTypeHandle private static WriteMapping legacyToWriteMapping(Type type) { - // TODO (https://github.com/trinodb/trino/issues/497) Implement proper type mapping and add test // This method is copied from deprecated BaseJdbcClient.legacyToWriteMapping() if (type == BOOLEAN) { return WriteMapping.booleanMapping("boolean", booleanWriteFunction()); diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index 53e1aee6ac29..aeffaac16ff7 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -15,29 +15,39 @@ import com.amazon.redshift.Driver; import com.google.inject.Binder; -import com.google.inject.Module; import com.google.inject.Provides; -import com.google.inject.Scopes; import com.google.inject.Singleton; +import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.DecimalModule; import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.credential.CredentialProvider; import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.ptf.ConnectorTableFunction; +import java.util.Properties; + +import static com.google.inject.Scopes.SINGLETON; import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static io.airlift.configuration.ConfigBinder.configBinder; public class RedshiftClientModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + public void setup(Binder binder) { - binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(Scopes.SINGLETON); - newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(SINGLETON); + newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(SINGLETON); + configBinder(binder).bindConfig(JdbcStatisticsConfig.class); + + install(new DecimalModule()); + install(new JdbcJoinPushdownSupportModule()); } @Singleton @@ -45,6 +55,14 @@ public void configure(Binder binder) @ForBaseJdbc public static ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialProvider credentialProvider) { - return new DriverConnectionFactory(new Driver(), config, credentialProvider); + return new DriverConnectionFactory(new Driver(), config.getConnectionUrl(), getDriverProperties(), credentialProvider); + } + + private static Properties getDriverProperties() + { + Properties properties = new Properties(); + properties.put("reWriteBatchedInserts", "true"); + properties.put("reWriteBatchedInsertsSize", "512"); + return properties; } } diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java new file mode 100644 index 000000000000..c576abdd109d --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public class RedshiftTableStatisticsReader +{ + private final ConnectionFactory connectionFactory; + + public RedshiftTableStatisticsReader(ConnectionFactory connectionFactory) + { + this.connectionFactory = requireNonNull(connectionFactory, "connectionFactory is null"); + } + + public TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table, Supplier> columnSupplier) + throws SQLException + { + checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); + + try (Connection connection = connectionFactory.openConnection(session); + Handle handle = Jdbi.open(connection)) { + StatisticsDao statisticsDao = new StatisticsDao(handle); + + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + Optional optionalRowCount = readRowCountTableStat(statisticsDao, table); + if (optionalRowCount.isEmpty()) { + // Table not found + return TableStatistics.empty(); + } + long rowCount = optionalRowCount.get(); + + TableStatistics.Builder tableStatistics = TableStatistics.builder() + .setRowCount(Estimate.of(rowCount)); + + if (rowCount == 0) { + return tableStatistics.build(); + } + + Map columnStatistics = statisticsDao.getColumnStatistics(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()).stream() + .collect(toImmutableMap(ColumnStatisticsResult::columnName, identity())); + + for (JdbcColumnHandle column : columnSupplier.get()) { + ColumnStatisticsResult result = columnStatistics.get(column.getColumnName()); + if (result == null) { + continue; + } + + ColumnStatistics statistics = ColumnStatistics.builder() + .setNullsFraction(result.nullsFraction() + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDistinctValuesCount(result.distinctValuesIndicator() + .map(distinctValuesIndicator -> { + // If the distinct value count is an estimate Redshift uses "the negative of the number of distinct values divided by the number of rows + // For example, -1 indicates a unique column in which the number of distinct values is the same as the number of rows." + // https://www.postgresql.org/docs/9.3/view-pg-stats.html + if (distinctValuesIndicator < 0.0) { + return Math.min(-distinctValuesIndicator * rowCount, rowCount); + } + return distinctValuesIndicator; + }) + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDataSize(result.averageColumnLength() + .flatMap(averageColumnLength -> + result.nullsFraction() + .map(nullsFraction -> 1.0 * averageColumnLength * rowCount * (1 - nullsFraction)) + .map(Estimate::of)) + .orElseGet(Estimate::unknown)) + .build(); + + tableStatistics.setColumnStatistics(column, statistics); + } + + return tableStatistics.build(); + } + } + + private static Optional readRowCountTableStat(StatisticsDao statisticsDao, JdbcTableHandle table) + { + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + Optional rowCount = statisticsDao.getRowCountFromPgClass(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + if (rowCount.isEmpty()) { + // Table not found + return Optional.empty(); + } + + if (rowCount.get() == 0) { + // `pg_class.reltuples = 0` may mean an empty table or a recently populated table (CTAS, LOAD or INSERT) + // The `pg_stat_all_tables` view can be way off, so we use it only as a fallback + rowCount = statisticsDao.getRowCountFromPgStat(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + } + + return rowCount; + } + + private static class StatisticsDao + { + private final Handle handle; + + public StatisticsDao(Handle handle) + { + this.handle = requireNonNull(handle, "handle is null"); + } + + Optional getRowCountFromPgClass(String schema, String tableName) + { + return handle.createQuery("SELECT reltuples FROM pg_class WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + Optional getRowCountFromPgStat(String schema, String tableName) + { + // Redshift does not have the Postgres `n_live_tup`, so estimate from `inserts - deletes` + return handle.createQuery("SELECT n_tup_ins - n_tup_del FROM pg_stat_all_tables WHERE schemaname = :schema AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + List getColumnStatistics(String schema, String tableName) + { + return handle.createQuery("SELECT attname, null_frac, n_distinct, avg_width FROM pg_stats WHERE schemaname = :schema AND tablename = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .map((rs, ctx) -> + new ColumnStatisticsResult( + requireNonNull(rs.getString("attname"), "attname is null"), + Optional.of(rs.getFloat("null_frac")), + Optional.of(rs.getFloat("n_distinct")), + Optional.of(rs.getInt("avg_width")))) + .list(); + } + } + + // TODO remove when error prone is updated for Java 17 records + @SuppressWarnings("unused") + private record ColumnStatisticsResult(String columnName, Optional nullsFraction, Optional distinctValuesIndicator, Optional averageColumnLength) {} +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java new file mode 100644 index 000000000000..3e96738e7ba1 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java @@ -0,0 +1,271 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; +import io.airlift.log.Logger; +import io.airlift.log.Logging; +import io.trino.Session; +import io.trino.metadata.QualifiedObjectName; +import io.trino.plugin.tpch.TpchPlugin; +import io.trino.spi.security.Identity; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import io.trino.testing.QueryRunner; +import io.trino.tpch.TpchTable; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.RetryPolicy; +import org.jdbi.v3.core.HandleCallback; +import org.jdbi.v3.core.HandleConsumer; +import org.jdbi.v3.core.Jdbi; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static io.airlift.testing.Closeables.closeAllSuppress; +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.QueryAssertions.copyTable; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.assertions.Assert.assertEquals; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toUnmodifiableSet; + +public final class RedshiftQueryRunner +{ + private static final Logger log = Logger.get(RedshiftQueryRunner.class); + private static final String JDBC_ENDPOINT = requireSystemProperty("test.redshift.jdbc.endpoint"); + static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user"); + static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password"); + private static final String S3_TPCH_TABLES_ROOT = requireSystemProperty("test.redshift.s3.tpch.tables.root"); + private static final String IAM_ROLE = requireSystemProperty("test.redshift.iam.role"); + + private static final String TEST_DATABASE = "testdb"; + private static final String TEST_CATALOG = "redshift"; + static final String TEST_SCHEMA = "test_schema"; + + static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + TEST_DATABASE; + + private static final String CONNECTOR_NAME = "redshift"; + private static final String TPCH_CATALOG = "tpch"; + + private static final String GRANTED_USER = "alice"; + private static final String NON_GRANTED_USER = "bob"; + + private RedshiftQueryRunner() {} + + public static DistributedQueryRunner createRedshiftQueryRunner( + Map extraProperties, + Map connectorProperties, + Iterable> tables) + throws Exception + { + return createRedshiftQueryRunner( + createSession(), + extraProperties, + connectorProperties, + tables); + } + + public static DistributedQueryRunner createRedshiftQueryRunner( + Session session, + Map extraProperties, + Map connectorProperties, + Iterable> tables) + throws Exception + { + DistributedQueryRunner.Builder builder = DistributedQueryRunner.builder(session); + extraProperties.forEach(builder::addExtraProperty); + DistributedQueryRunner runner = builder.build(); + try { + runner.installPlugin(new TpchPlugin()); + runner.createCatalog(TPCH_CATALOG, "tpch", Map.of()); + + Map properties = new HashMap<>(connectorProperties); + properties.putIfAbsent("connection-url", JDBC_URL); + properties.putIfAbsent("connection-user", JDBC_USER); + properties.putIfAbsent("connection-password", JDBC_PASSWORD); + + runner.installPlugin(new RedshiftPlugin()); + runner.createCatalog(TEST_CATALOG, CONNECTOR_NAME, properties); + + executeInRedshift("CREATE SCHEMA IF NOT EXISTS " + TEST_SCHEMA); + createUserIfNotExists(NON_GRANTED_USER, JDBC_PASSWORD); + createUserIfNotExists(GRANTED_USER, JDBC_PASSWORD); + + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON DATABASE %s TO %s", TEST_DATABASE, GRANTED_USER)); + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER)); + + provisionTables(session, runner, tables); + + // This step is necessary for product tests + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER)); + } + catch (Throwable e) { + closeAllSuppress(e, runner); + throw e; + } + return runner; + } + + private static Session createSession() + { + return createSession(GRANTED_USER); + } + + private static Session createSession(String user) + { + return testSessionBuilder() + .setCatalog(TEST_CATALOG) + .setSchema(TEST_SCHEMA) + .setIdentity(Identity.ofUser(user)) + .build(); + } + + private static void createUserIfNotExists(String user, String password) + { + try { + executeInRedshift("CREATE USER " + user + " PASSWORD " + "'" + password + "'"); + } + catch (Exception e) { + // if user already exists, swallow the exception + if (!e.getMessage().matches(".*user \"" + user + "\" already exists.*")) { + throw e; + } + } + } + + private static void executeInRedshiftWithRetry(String sql) + { + Failsafe.with(new RetryPolicy<>() + .handleIf(e -> e.getMessage().matches(".* concurrent transaction .*")) + .withDelay(Duration.ofSeconds(10)) + .withMaxRetries(3)) + .run(() -> executeInRedshift(sql)); + } + + public static void executeInRedshift(String sql, Object... parameters) + { + executeInRedshift(handle -> handle.execute(sql, parameters)); + } + + public static void executeInRedshift(HandleConsumer consumer) + throws E + { + executeWithRedshift(consumer.asCallback()); + } + + public static T executeWithRedshift(HandleCallback callback) + throws E + { + return Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(callback); + } + + private static synchronized void provisionTables(Session session, QueryRunner queryRunner, Iterable> tables) + { + Set existingTables = queryRunner.listTables(session, session.getCatalog().orElseThrow(), session.getSchema().orElseThrow()) + .stream() + .map(QualifiedObjectName::getObjectName) + .collect(toUnmodifiableSet()); + + Streams.stream(tables) + .map(table -> table.getTableName().toLowerCase(ENGLISH)) + .filter(name -> !existingTables.contains(name)) + .forEach(name -> copyFromS3(queryRunner, session, name)); + + for (TpchTable tpchTable : tables) { + verifyLoadedDataHasSameSchema(session, queryRunner, tpchTable); + } + } + + private static void copyFromS3(QueryRunner queryRunner, Session session, String name) + { + String s3Path = format("%s/%s/%s.parquet", S3_TPCH_TABLES_ROOT, TPCH_CATALOG, name); + log.info("Creating table %s in Redshift copying from %s", name, s3Path); + + // Create table in ephemeral Redshift cluster with no data + String createTableSql = format("CREATE TABLE %s.%s.%s AS ", session.getCatalog().orElseThrow(), session.getSchema().orElseThrow(), name) + + format("SELECT * FROM %s.%s.%s WITH NO DATA", TPCH_CATALOG, TINY_SCHEMA_NAME, name); + queryRunner.execute(session, createTableSql); + + // Copy data from S3 bucket to ephemeral Redshift + String copySql = "COPY " + TEST_SCHEMA + "." + name + + " FROM '" + s3Path + "'" + + " IAM_ROLE '" + IAM_ROLE + "'" + + " FORMAT PARQUET"; + executeInRedshiftWithRetry(copySql); + } + + private static void copyFromTpchCatalog(QueryRunner queryRunner, Session session, String name) + { + // This function exists in case we need to copy data from the TPCH catalog rather than S3, + // such as moving to a new AWS account or if the schema changes. We can swap this method out for + // copyFromS3 in provisionTables and then export the data again to S3. + copyTable(queryRunner, TPCH_CATALOG, TINY_SCHEMA_NAME, name, session); + } + + private static void verifyLoadedDataHasSameSchema(Session session, QueryRunner queryRunner, TpchTable tpchTable) + { + // We want to verify that the loaded data has the same schema as if we created a fresh table from the TPC-H catalog + // If this assertion fails, we may need to recreate the Redshift tables from the TPC-H catalog and unload the data to S3 + try { + long expectedCount = (long) queryRunner.execute("SELECT count(*) FROM " + format("%s.%s.%s", TPCH_CATALOG, TINY_SCHEMA_NAME, tpchTable.getTableName())).getOnlyValue(); + long actualCount = (long) queryRunner.execute( + "SELECT count(*) FROM " + format( + "%s.%s.%s", + session.getCatalog().orElseThrow(), + session.getSchema().orElseThrow(), + tpchTable.getTableName())).getOnlyValue(); + + if (expectedCount != actualCount) { + throw new RuntimeException(format("Table %s is not loaded correctly. Expected %s rows got %s", tpchTable.getTableName(), expectedCount, actualCount)); + } + + log.info("Checking column types on table %s", tpchTable.getTableName()); + MaterializedResult expectedColumns = queryRunner.execute(format("DESCRIBE %s.%s.%s", TPCH_CATALOG, TINY_SCHEMA_NAME, tpchTable.getTableName())); + MaterializedResult actualColumns = queryRunner.execute("DESCRIBE " + tpchTable.getTableName()); + assertEquals(actualColumns, expectedColumns); + } + catch (Exception e) { + throw new RuntimeException("Failed to assert columns for TPC-H table " + tpchTable.getTableName(), e); + } + } + + /** + * Get the named system property, throwing an exception if it is not set. + */ + private static String requireSystemProperty(String property) + { + return requireNonNull(System.getProperty(property), property + " is not set"); + } + + public static void main(String[] args) + throws Exception + { + Logging.initialize(); + + DistributedQueryRunner queryRunner = createRedshiftQueryRunner( + ImmutableMap.of("http-server.http.port", "8080"), + ImmutableMap.of(), + ImmutableList.of()); + + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java new file mode 100644 index 000000000000..3509f8dd8b9c --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest; +import io.trino.testing.QueryRunner; +import org.testng.SkipException; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; + +public class TestRedshiftAutomaticJoinPushdown + extends BaseAutomaticJoinPushdownTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner( + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableList.of()); + } + + @Override + public void testJoinPushdownWithEmptyStatsInitially() + { + throw new SkipException("Redshift table statistics are automatically populated"); + } + + @Override + protected void gatherStats(String tableName) + { + executeInRedshift(handle -> { + handle.execute(format("ANALYZE VERBOSE %s.%s", TEST_SCHEMA, tableName)); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery(format("SELECT count(*) FROM %s.%s", TEST_SCHEMA, tableName)) + .mapTo(Long.class) + .one(); + long estimatedCount = handle.createQuery( + "SELECT reltuples FROM pg_class " + + "WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) " + + "AND relname = :table_name") + .bind("schema", TEST_SCHEMA) + .bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", "")) + .mapTo(Long.class) + .one(); + if (actualCount == estimatedCount) { + return; + } + handle.execute(format("ANALYZE VERBOSE %s.%s", TEST_SCHEMA, tableName)); + } + throw new IllegalStateException("Stats not gathered"); + }); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java new file mode 100644 index 000000000000..1b16e335bd82 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -0,0 +1,639 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.plugin.jdbc.BaseJdbcConnectorTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.tpch.TpchTable; +import org.testng.SkipException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeWithRedshift; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRedshiftConnectorTest + extends BaseJdbcConnectorTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner( + ImmutableMap.of(), + ImmutableMap.of(), + // NOTE this can cause tests to time-out if larger tables like + // lineitem and orders need to be re-created. + TpchTable.getTables()); + } + + @Override + @SuppressWarnings("DuplicateBranchesInSwitch") // options here are grouped per-feature + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_COMMENT_ON_TABLE: + case SUPPORTS_ADD_COLUMN_WITH_COMMENT: + case SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT: + case SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT: + return false; + + case SUPPORTS_ARRAY: + case SUPPORTS_ROW_TYPE: + return false; + + case SUPPORTS_RENAME_TABLE_ACROSS_SCHEMAS: + return false; + + case SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV: + case SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE: + case SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT: + return true; + + case SUPPORTS_JOIN_PUSHDOWN: + case SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY: + return true; + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + case SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN: + return false; + + default: + return super.hasBehavior(connectorBehavior); + } + } + + @Override + protected TestTable createTableWithDefaultColumns() + { + return new TestTable( + onRemoteDatabase(), + format("%s.test_table_with_default_columns", TEST_SCHEMA), + "(col_required BIGINT NOT NULL," + + "col_nullable BIGINT," + + "col_default BIGINT DEFAULT 43," + + "col_nonnull_default BIGINT NOT NULL DEFAULT 42," + + "col_required2 BIGINT NOT NULL)"); + } + + @Override + protected Optional filterDataMappingSmokeTestData(DataMappingTestSetup dataMappingTestSetup) + { + String typeName = dataMappingTestSetup.getTrinoTypeName(); + if ("date".equals(typeName)) { + if (dataMappingTestSetup.getSampleValueLiteral().equals("DATE '1582-10-05'")) { + return Optional.empty(); + } + } + return Optional.of(dataMappingTestSetup); + } + + /** + * Overridden due to Redshift not supporting non-ASCII characters in CHAR. + */ + @Override + public void testCreateTableAsSelectWithUnicode() + { + assertThatThrownBy(super::testCreateTableAsSelectWithUnicode) + .hasStackTraceContaining("Value too long for character type"); + // NOTE we add a copy of the above using VARCHAR which supports non-ASCII characters + assertCreateTableAsSelect( + "SELECT CAST('\u2603' AS VARCHAR) unicode", + "SELECT 1"); + } + + @Test(dataProvider = "redshiftTypeToTrinoTypes") + public void testReadFromLateBindingView(String redshiftType, String trinoType) + { + try (TestView view = new TestView(onRemoteDatabase(), TEST_SCHEMA + ".late_schema_binding", "SELECT CAST(NULL AS %s) AS value WITH NO SCHEMA BINDING".formatted(redshiftType))) { + assertThat(query("SELECT value, true FROM %s WHERE value IS NULL".formatted(view.getName()))) + .projected(1) + .containsAll("VALUES (true)"); + + assertThat(query("SHOW COLUMNS FROM %s LIKE 'value'".formatted(view.getName()))) + .projected(1) + .skippingTypesCheck() + .containsAll("VALUES ('%s')".formatted(trinoType)); + } + } + + @DataProvider + public Object[][] redshiftTypeToTrinoTypes() + { + return new Object[][] { + {"SMALLINT", "smallint"}, + {"INTEGER", "integer"}, + {"BIGINT", "bigint"}, + {"DECIMAL", "decimal(18,0)"}, + {"REAL", "real"}, + {"DOUBLE PRECISION", "double"}, + {"BOOLEAN", "boolean"}, + {"CHAR(1)", "char(1)"}, + {"VARCHAR(1)", "varchar(1)"}, + {"TIME", "time(6)"}, + {"TIMESTAMP", "timestamp(6)"}, + {"TIMESTAMPTZ", "timestamp(6) with time zone"}}; + } + + @Override + public void testDelete() + { + // The base tests is very slow because Redshift CTAS is really slow, so use a smaller test + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_", "AS SELECT * FROM nation")) { + // delete without matching any rows + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey < 0", 0); + + // delete with a predicate that optimizes to false + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey > 5 AND nationkey < 4", 0); + + // delete successive parts of the table + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 5", "SELECT count(*) FROM nation WHERE nationkey <= 5"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 5"); + + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 10", "SELECT count(*) FROM nation WHERE nationkey > 5 AND nationkey <= 10"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 10"); + + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 15", "SELECT count(*) FROM nation WHERE nationkey > 10 AND nationkey <= 15"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 15"); + + // delete remaining + assertUpdate("DELETE FROM " + table.getName(), "SELECT count(*) FROM nation WHERE nationkey > 15"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE false"); + } + } + + @Test(dataProvider = "testCaseColumnNamesDataProvider") + public void testCaseColumnNames(String tableName) + { + try { + assertUpdate( + "CREATE TABLE " + TEST_SCHEMA + "." + tableName + + " AS SELECT " + + " custkey AS CASE_UNQUOTED_UPPER, " + + " name AS case_unquoted_lower, " + + " address AS cASe_uNQuoTeD_miXED, " + + " nationkey AS \"CASE_QUOTED_UPPER\", " + + " phone AS \"case_quoted_lower\"," + + " acctbal AS \"CasE_QuoTeD_miXED\" " + + "FROM customer", + 1500); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + TEST_SCHEMA + "." + tableName, + "VALUES " + + "('case_unquoted_upper', NULL, 1485, 0, null, null, null)," + + "('case_unquoted_lower', 33000, 1470, 0, null, null, null)," + + "('case_unquoted_mixed', 42000, 1500, 0, null, null, null)," + + "('case_quoted_upper', NULL, 25, 0, null, null, null)," + + "('case_quoted_lower', 28500, 1483, 0, null, null, null)," + + "('case_quoted_mixed', NULL, 1483, 0, null, null, null)," + + "(null, null, null, null, 1500, null, null)"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + /** + * Tries to create situation where Redshift would decide to materialize a temporary table for query sent to it by us. + * Such temporary table requires that our Connection is not read-only. + */ + @Test + public void testComplexPushdownThatMayElicitTemporaryTable() + { + int subqueries = 10; + String subquery = "SELECT custkey, count(*) c FROM orders GROUP BY custkey"; + StringBuilder sql = new StringBuilder(); + sql.append(format( + "SELECT t0.custkey, %s c_sum ", + IntStream.range(0, subqueries) + .mapToObj(i -> format("t%s.c", i)) + .collect(Collectors.joining("+")))); + sql.append(format("FROM (%s) t0 ", subquery)); + for (int i = 1; i < subqueries; i++) { + sql.append(format("JOIN (%s) t%s ON t0.custkey = t%s.custkey ", subquery, i, i)); + } + sql.append("WHERE t0.custkey = 1045 OR rand() = 42"); + + Session forceJoinPushdown = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + + assertThat(query(forceJoinPushdown, sql.toString())) + .matches(format("SELECT max(custkey), count(*) * %s FROM tpch.tiny.orders WHERE custkey = 1045", subqueries)); + } + + private static void gatherStats(String tableName) + { + executeInRedshift(handle -> { + handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery("SELECT count(*) FROM " + TEST_SCHEMA + "." + tableName) + .mapTo(Long.class) + .one(); + long estimatedCount = handle.createQuery(""" + SELECT reltuples FROM pg_class + WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) + AND relname = :table_name + """) + .bind("schema", TEST_SCHEMA) + .bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", "")) + .mapTo(Long.class) + .one(); + if (actualCount == estimatedCount) { + return; + } + handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName); + } + throw new IllegalStateException("Stats not gathered"); // for small test tables reltuples should be exact + }); + } + + @DataProvider + public Object[][] testCaseColumnNamesDataProvider() + { + return new Object[][] { + {"TEST_STATS_MIXED_UNQUOTED_UPPER_" + randomNameSuffix()}, + {"test_stats_mixed_unquoted_lower_" + randomNameSuffix()}, + {"test_stats_mixed_uNQuoTeD_miXED_" + randomNameSuffix()}, + {"\"TEST_STATS_MIXED_QUOTED_UPPER_" + randomNameSuffix() + "\""}, + {"\"test_stats_mixed_quoted_lower_" + randomNameSuffix() + "\""}, + {"\"test_stats_mixed_QuoTeD_miXED_" + randomNameSuffix() + "\""} + }; + } + + @Override + public void testCountDistinctWithStringTypes() + { + // cannot test using generic method as Redshift does not allow non-ASCII characters in CHAR values. + assertThatThrownBy(super::testCountDistinctWithStringTypes).hasMessageContaining("Value for Redshift CHAR must be ASCII, but found 'ą'"); + + List rows = Stream.of("a", "b", "A", "B", " a ", "a", "b", " b ") + .map(value -> format("'%1$s', '%1$s'", value)) + .collect(toImmutableList()); + String tableName = "distinct_strings" + randomNameSuffix(); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(t_char CHAR(5), t_varchar VARCHAR(5))", rows)) { + // Single count(DISTINCT ...) can be pushed even down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES BIGINT '6'") + .isFullyPushedDown(); + + // Single count(DISTINCT ...) can be pushed down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_char) FROM " + testTable.getName())) + .matches("VALUES BIGINT '6'") + .isFullyPushedDown(); + + assertThat(query("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES (BIGINT '6', BIGINT '6')") + .isFullyPushedDown(); + } + } + + @Override + public void testAggregationPushdown() + { + throw new SkipException("tested in testAggregationPushdown(String)"); + } + + @Test(dataProvider = "testAggregationPushdownDistStylesDataProvider") + public void testAggregationPushdown(String distStyle) + { + String nation = format("%s.nation_%s_%s", TEST_SCHEMA, distStyle, randomNameSuffix()); + String customer = format("%s.customer_%s_%s", TEST_SCHEMA, distStyle, randomNameSuffix()); + try { + copyWithDistStyle(TEST_SCHEMA + ".nation", nation, distStyle, Optional.of("regionkey")); + copyWithDistStyle(TEST_SCHEMA + ".customer", customer, distStyle, Optional.of("nationkey")); + + // TODO support aggregation pushdown with GROUPING SETS + // TODO support aggregation over expressions + + // count() + assertThat(query("SELECT count(*) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count(nationkey) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count() FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT regionkey, count(1) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + try (TestTable emptyTable = createAggregationTestTable(getSession().getSchema().orElseThrow() + ".empty_table", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT count(*) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count(a_bigint) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count() FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT a_bigint, count(1) FROM " + emptyTableName + " GROUP BY a_bigint")).isFullyPushedDown(); + } + + // GROUP BY + assertThat(query("SELECT regionkey, min(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, max(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, avg(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + try (TestTable emptyTable = createAggregationTestTable(getSession().getSchema().orElseThrow() + ".empty_table", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT t_double, min(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, max(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, sum(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, avg(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + } + + // GROUP BY and WHERE on bigint column + // GROUP BY and WHERE on aggregation key + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " WHERE regionkey < 4 GROUP BY regionkey")).isFullyPushedDown(); + + // GROUP BY and WHERE on varchar column + // GROUP BY and WHERE on "other" (not aggregation key, not aggregation input) + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY above WHERE and LIMIT + assertThat(query("SELECT regionkey, sum(nationkey) FROM (SELECT * FROM " + nation + " WHERE regionkey < 2 LIMIT 11) GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY above TopN + assertThat(query("SELECT regionkey, sum(nationkey) FROM (SELECT regionkey, nationkey FROM " + nation + " ORDER BY nationkey ASC LIMIT 10) GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY with JOIN + assertThat(query( + joinPushdownEnabled(getSession()), + "SELECT n.regionkey, sum(c.acctbal) acctbals FROM " + nation + " n LEFT JOIN " + customer + " c USING (nationkey) GROUP BY 1")) + .isFullyPushedDown(); + // GROUP BY with WHERE on neither grouping nor aggregation column + assertThat(query("SELECT nationkey, min(regionkey) FROM " + nation + " WHERE name = 'ARGENTINA' GROUP BY nationkey")).isFullyPushedDown(); + // aggregation on varchar column + assertThat(query("SELECT count(name) FROM " + nation)).isFullyPushedDown(); + // aggregation on varchar column with GROUPING + assertThat(query("SELECT nationkey, count(name) FROM " + nation + " GROUP BY nationkey")).isFullyPushedDown(); + // aggregation on varchar column with WHERE + assertThat(query("SELECT count(name) FROM " + nation + " WHERE name = 'ARGENTINA'")).isFullyPushedDown(); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + nation); + executeInRedshift("DROP TABLE IF EXISTS " + customer); + } + } + + @Override + public void testNumericAggregationPushdown() + { + throw new SkipException("tested in testNumericAggregationPushdown(String)"); + } + + @Test(dataProvider = "testAggregationPushdownDistStylesDataProvider") + public void testNumericAggregationPushdown(String distStyle) + { + String schemaName = getSession().getSchema().orElseThrow(); + // empty table + try (TestTable emptyTable = createAggregationTestTable(schemaName + ".test_aggregation_pushdown", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + } + + try (TestTable testTable = createAggregationTestTable(schemaName + ".test_aggregation_pushdown", + ImmutableList.of("100.000, 100000000.000000000, 100.000, 100000000", "123.321, 123456789.987654321, 123.321, 123456789"))) { + String testTableName = testTable.getName() + "_" + distStyle; + copyWithDistStyle(testTable.getName(), testTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTableName)).isFullyPushedDown(); + + // smoke testing of more complex cases + // WHERE on aggregation column + assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 AND long_decimal < 124")).isFullyPushedDown(); + // WHERE on non-aggregation column + assertThat(query("SELECT min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110")).isFullyPushedDown(); + // GROUP BY + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on both grouping and aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 AND long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on grouping column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + } + } + + private static void copyWithDistStyle(String sourceTableName, String destTableName, String distStyle, Optional distKey) + { + if (distStyle.equals("AUTO")) { + // NOTE: Redshift doesn't support setting diststyle AUTO in CTAS statements + executeInRedshift("CREATE TABLE " + destTableName + " AS SELECT * FROM " + sourceTableName); + // Redshift doesn't allow ALTER DISTSTYLE if original and new style are same, so we need to check current diststyle of table + boolean isDistStyleAuto = executeWithRedshift(handle -> { + Optional currentDistStyle = handle.createQuery("" + + "SELECT releffectivediststyle " + + "FROM pg_class_info AS a LEFT JOIN pg_namespace AS b ON a.relnamespace = b.oid " + + "WHERE lower(nspname) = lower(:schema_name) AND lower(relname) = lower(:table_name)") + .bind("schema_name", TEST_SCHEMA) + // destTableName = TEST_SCHEMA + "." + tableName + .bind("table_name", destTableName.substring(destTableName.indexOf(".") + 1)) + .mapTo(Long.class) + .findOne(); + + // 10 means AUTO(ALL) and 11 means AUTO(EVEN). See https://docs.aws.amazon.com/redshift/latest/dg/r_PG_CLASS_INFO.html. + return currentDistStyle.isPresent() && (currentDistStyle.get() == 10 || currentDistStyle.get() == 11); + }); + if (!isDistStyleAuto) { + executeInRedshift("ALTER TABLE " + destTableName + " ALTER DISTSTYLE " + distStyle); + } + } + else { + String copyWithDistStyleSql = "CREATE TABLE " + destTableName + " DISTSTYLE " + distStyle; + if (distStyle.equals("KEY")) { + copyWithDistStyleSql += format(" DISTKEY(%s)", distKey.orElseThrow()); + } + copyWithDistStyleSql += " AS SELECT * FROM " + sourceTableName; + executeInRedshift(copyWithDistStyleSql); + } + } + + @DataProvider + public Object[][] testAggregationPushdownDistStylesDataProvider() + { + return new Object[][] { + {"EVEN"}, + {"KEY"}, + {"ALL"}, + {"AUTO"}, + }; + } + + @Test + public void testDecimalAvgPushdownForMaximumDecimalScale() + { + List rows = ImmutableList.of( + "12345789.9876543210", + format("%s.%s", "1".repeat(28), "9".repeat(10))); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, TEST_SCHEMA + ".test_agg_pushdown_avg_max_decimal", + "(t_decimal DECIMAL(38, 10))", rows)) { + // Redshift avg rounds down decimal result which doesn't match Presto semantics + assertThatThrownBy(() -> assertThat(query("SELECT avg(t_decimal) FROM " + testTable.getName())).isFullyPushedDown()) + .isInstanceOf(AssertionError.class) + .hasMessageContaining(""" + elements not found: + <(555555555555555555561728450.9938271605)> + and elements not expected: + <(555555555555555555561728450.9938271604)> + """); + } + } + + @Test + public void testDecimalAvgPushdownFoShortDecimalScale() + { + List rows = ImmutableList.of( + "0.987654321234567890", + format("0.%s", "1".repeat(18))); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, TEST_SCHEMA + ".test_agg_pushdown_avg_max_decimal", + "(t_decimal DECIMAL(18, 18))", rows)) { + assertThat(query("SELECT avg(t_decimal) FROM " + testTable.getName())).isFullyPushedDown(); + } + } + + @Override + @Test + public void testReadMetadataWithRelationsConcurrentModifications() + { + throw new SkipException("Test fails with a timeout sometimes and is flaky"); + } + + @Override + public void testInsertRowConcurrently() + { + throw new SkipException("Test fails with a timeout sometimes and is flaky"); + } + + @Override + protected Session joinPushdownEnabled(Session session) + { + return Session.builder(super.joinPushdownEnabled(session)) + // strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + } + + @Override + protected String errorMessageForInsertIntoNotNullColumn(String columnName) + { + return format("(?s).*Cannot insert a NULL value into column %s.*", columnName); + } + + @Override + protected OptionalInt maxSchemaNameLength() + { + return OptionalInt.of(127); + } + + @Override + protected void verifySchemaNameLengthFailurePermissible(Throwable e) + { + assertThat(e).hasMessage("Schema name must be shorter than or equal to '127' characters but got '128'"); + } + + @Override + protected OptionalInt maxTableNameLength() + { + return OptionalInt.of(127); + } + + @Override + protected void verifyTableNameLengthFailurePermissible(Throwable e) + { + assertThat(e).hasMessage("Table name must be shorter than or equal to '127' characters but got '128'"); + } + + @Override + protected OptionalInt maxColumnNameLength() + { + return OptionalInt.of(127); + } + + @Override + protected void verifyColumnNameLengthFailurePermissible(Throwable e) + { + assertThat(e).hasMessage("Column name must be shorter than or equal to '127' characters but got '128'"); + } + + @Override + protected SqlExecutor onRemoteDatabase() + { + return RedshiftQueryRunner::executeInRedshift; + } + + @Override + public void testDeleteWithLike() + { + assertThatThrownBy(super::testDeleteWithLike) + .hasStackTraceContaining("TrinoException: This connector does not support modifying table rows"); + } + + @Test + @Override + public void testAddNotNullColumnToNonEmptyTable() + { + throw new SkipException("Redshift ALTER TABLE ADD COLUMN defined as NOT NULL must have a non-null default expression"); + } + + private static class TestView + implements AutoCloseable + { + private final String name; + private final SqlExecutor executor; + + public TestView(SqlExecutor executor, String namePrefix, String viewDefinition) + { + this.executor = executor; + this.name = namePrefix + "_" + randomNameSuffix(); + executor.execute("CREATE OR REPLACE VIEW " + name + " AS " + viewDefinition); + } + + @Override + public void close() + { + executor.execute("DROP VIEW " + name); + } + + public String getName() + { + return name; + } + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java new file mode 100644 index 000000000000..ff713337ea53 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java @@ -0,0 +1,349 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.amazon.redshift.Driver; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcConfig; +import io.trino.plugin.jdbc.DriverConnectionFactory; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.plugin.jdbc.credential.StaticCredentialProvider; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import io.trino.spi.type.VarcharType; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.SoftAssertions; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.sql.Types; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_PASSWORD; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_URL; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_USER; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.testing.TestingConnectorSession.SESSION; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.sql.TestTable.fromColumns; +import static io.trino.tpch.TpchTable.CUSTOMER; +import static java.util.Collections.emptyMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.from; +import static org.assertj.core.api.Assertions.withinPercentage; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +public class TestRedshiftTableStatisticsReader + extends AbstractTestQueryFramework +{ + private static final JdbcTypeHandle BIGINT_TYPE_HANDLE = new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + private static final JdbcTypeHandle DOUBLE_TYPE_HANDLE = new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + + private static final List CUSTOMER_COLUMNS = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("name", 25), + createVarcharJdbcColumnHandle("address", 48), + new JdbcColumnHandle("nationkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("phone", 15), + new JdbcColumnHandle("acctbal", DOUBLE_TYPE_HANDLE, DOUBLE), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + private RedshiftTableStatisticsReader statsReader; + + @BeforeClass + public void setup() + { + DriverConnectionFactory connectionFactory = new DriverConnectionFactory( + new Driver(), + new BaseJdbcConfig().setConnectionUrl(JDBC_URL), + new StaticCredentialProvider(Optional.of(JDBC_USER), Optional.of(JDBC_PASSWORD))); + statsReader = new RedshiftTableStatisticsReader(connectionFactory); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner(Map.of(), Map.of(), ImmutableList.of(CUSTOMER)); + } + + @Test + public void testCustomerTable() + throws Exception + { + assertThat(collectStats("SELECT * FROM " + TEST_SCHEMA + ".customer", CUSTOMER_COLUMNS)) + .returns(Estimate.of(1500), from(TableStatistics::getRowCount)) + .extracting(TableStatistics::getColumnStatistics, InstanceOfAssertFactories.map(ColumnHandle.class, ColumnStatistics.class)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(0), statsCloseTo(1500.0, 0.0, 8.0 * 1500)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(1), statsCloseTo(1500.0, 0.0, 33000.0)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(3), statsCloseTo(25.000, 0.0, 8.0 * 1500)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(5), statsCloseTo(1499.0, 0.0, 8.0 * 1500)); + } + + @Test + public void testEmptyTable() + throws Exception + { + TableStatistics tableStatistics = collectStats("SELECT * FROM " + TEST_SCHEMA + ".customer WHERE false", CUSTOMER_COLUMNS); + assertThat(tableStatistics) + .returns(Estimate.of(0.0), from(TableStatistics::getRowCount)) + .returns(emptyMap(), from(TableStatistics::getColumnStatistics)); + } + + @Test + public void testAllNulls() + throws Exception + { + String tableName = "testallnulls_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + try { + executeInRedshift("CREATE TABLE " + schemaAndTable + " (i BIGINT)"); + executeInRedshift("INSERT INTO " + schemaAndTable + " (i) VALUES (NULL)"); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + + TableStatistics stats = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> ImmutableList.of(new JdbcColumnHandle("i", BIGINT_TYPE_HANDLE, BIGINT))); + assertThat(stats) + .returns(Estimate.of(1.0), from(TableStatistics::getRowCount)) + .returns(emptyMap(), from(TableStatistics::getColumnStatistics)); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + schemaAndTable); + } + } + + @Test + public void testNullsFraction() + throws Exception + { + JdbcColumnHandle custkeyColumnHandle = CUSTOMER_COLUMNS.get(0); + TableStatistics stats = collectStats( + "SELECT CASE custkey % 3 WHEN 0 THEN NULL ELSE custkey END FROM " + TEST_SCHEMA + ".customer", + ImmutableList.of(custkeyColumnHandle)); + assertEquals(stats.getRowCount(), Estimate.of(1500)); + + ColumnStatistics columnStatistics = stats.getColumnStatistics().get(custkeyColumnHandle); + assertThat(columnStatistics.getNullsFraction().getValue()).isCloseTo(1.0 / 3, withinPercentage(1)); + } + + @Test + public void testAverageColumnLength() + throws Exception + { + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("v3_in_3", 3), + createVarcharJdbcColumnHandle("v3_in_42", 42), + createVarcharJdbcColumnHandle("single_10v_value", 10), + createVarcharJdbcColumnHandle("half_10v_value", 10), + createVarcharJdbcColumnHandle("half_distinct_20v_value", 20), + createVarcharJdbcColumnHandle("all_nulls", 10)); + + assertThat( + collectStats( + "SELECT " + + " custkey, " + + " 'abc' v3_in_3, " + + " CAST('abc' AS varchar(42)) v3_in_42, " + + " CASE custkey WHEN 1 THEN '0123456789' ELSE NULL END single_10v_value, " + + " CASE custkey % 2 WHEN 0 THEN '0123456789' ELSE NULL END half_10v_value, " + + " CASE custkey % 2 WHEN 0 THEN CAST((1000000 - custkey) * (1000000 - custkey) AS varchar(20)) ELSE NULL END half_distinct_20v_value, " + // 12 chars each + " CAST(NULL AS varchar(10)) all_nulls " + + "FROM " + TEST_SCHEMA + ".customer " + + "ORDER BY custkey LIMIT 100", + columns)) + .returns(Estimate.of(100), from(TableStatistics::getRowCount)) + .extracting(TableStatistics::getColumnStatistics, InstanceOfAssertFactories.map(ColumnHandle.class, ColumnStatistics.class)) + .hasEntrySatisfying(columns.get(0), statsCloseTo(100.0, 0.0, 800)) + .hasEntrySatisfying(columns.get(1), statsCloseTo(1.0, 0.0, 700.0)) + .hasEntrySatisfying(columns.get(2), statsCloseTo(1.0, 0.0, 700)) + .hasEntrySatisfying(columns.get(3), statsCloseTo(1.0, 0.99, 14)) + .hasEntrySatisfying(columns.get(4), statsCloseTo(1.0, 0.5, 700)) + .hasEntrySatisfying(columns.get(5), statsCloseTo(51, 0.5, 800)) + .satisfies(stats -> assertNull(stats.get(columns.get(6)))); + } + + @Test + public void testView() + throws Exception + { + String tableName = "test_stats_view_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + try { + executeInRedshift("CREATE OR REPLACE VIEW " + schemaAndTable + " AS SELECT custkey, mktsegment, comment FROM " + TEST_SCHEMA + ".customer"); + TableStatistics tableStatistics = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columns); + assertThat(tableStatistics).isEqualTo(TableStatistics.empty()); + } + finally { + executeInRedshift("DROP VIEW IF EXISTS " + schemaAndTable); + } + } + + @Test + public void testMaterializedView() + throws Exception + { + String tableName = "test_stats_materialized_view_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + try { + executeInRedshift("CREATE MATERIALIZED VIEW " + schemaAndTable + + " AS SELECT custkey, mktsegment, comment FROM " + TEST_SCHEMA + ".customer"); + executeInRedshift("REFRESH MATERIALIZED VIEW " + schemaAndTable); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + TableStatistics tableStatistics = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columns); + assertThat(tableStatistics).isEqualTo(TableStatistics.empty()); + } + finally { + executeInRedshift("DROP MATERIALIZED VIEW " + schemaAndTable); + } + } + + @Test + public void testNumericCornerCases() + { + try (TestTable table = fromColumns( + getQueryRunner()::execute, + "test_numeric_corner_cases_", + ImmutableMap.>builder() + .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()")) + .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()")) + .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()")) + .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0")) + .put("nans_only double", List.of("nan()", "nan()")) + .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0")) + .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3 + .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456")) + .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6")) + .put("long_decimals_big_fraction decimal(38,37)", List.of("-1.2345678901234567890123456789012345678", "1.2345678901234567890123456789012345678")) + .put("long_decimals_middle decimal(38,16)", List.of("-1234567890123456.7890123456789012345678", "1234567890123456.7890123456789012345678")) + .put("long_decimals_big_integral decimal(38,1)", List.of("-1234567890123456789012345678901234567.8", "1234567890123456789012345678901234567.8")) + .buildOrThrow(), + "null")) { + executeInRedshift("ANALYZE VERBOSE " + TEST_SCHEMA + "." + table.getName()); + assertQuery( + "SHOW STATS FOR " + table.getName(), + "VALUES " + + "('only_negative_infinity', null, 1, 0, null, null, null)," + + "('only_positive_infinity', null, 1, 0, null, null, null)," + + "('mixed_infinities', null, 2, 0, null, null, null)," + + "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," + + "('nans_only', null, 1.0, 0.5, null, null, null)," + + "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," + + "('large_doubles', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_middle', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "(null, null, null, null, 4, null, null)"); + } + } + + /** + * Assert that the given column is within 5% of each statistic in the parameters, and that it has no range + */ + private static Consumer statsCloseTo(double distinctValues, double nullsFraction, double dataSize) + { + return stats -> { + SoftAssertions softly = new SoftAssertions(); + + softly.assertThat(stats.getDistinctValuesCount().getValue()) + .isCloseTo(distinctValues, withinPercentage(5.0)); + + softly.assertThat(stats.getNullsFraction().getValue()) + .isCloseTo(nullsFraction, withinPercentage(5.0)); + + softly.assertThat(stats.getDataSize().getValue()) + .isCloseTo(dataSize, withinPercentage(5.0)); + + softly.assertThat(stats.getRange()).isEmpty(); + softly.assertAll(); + }; + } + + private TableStatistics collectStats(String values, List columnHandles) + throws Exception + { + String tableName = "testredshiftstatisticsreader_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + try { + executeInRedshift("CREATE TABLE " + schemaAndTable + " AS " + values); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + return statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columnHandles); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + schemaAndTable); + } + } + + private static JdbcColumnHandle createVarcharJdbcColumnHandle(String name, int length) + { + return new JdbcColumnHandle( + name, + new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(length), Optional.empty(), Optional.empty(), Optional.empty()), + VarcharType.createVarcharType(length)); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java new file mode 100644 index 000000000000..26938c3b6532 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java @@ -0,0 +1,994 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.base.Utf8; +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingSession; +import io.trino.testing.datatype.CreateAndInsertDataSetup; +import io.trino.testing.datatype.CreateAsSelectDataSetup; +import io.trino.testing.datatype.DataSetup; +import io.trino.testing.datatype.SqlDataTypeTest; +import io.trino.testing.sql.JdbcSqlExecutor; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.testing.sql.TrinoSqlExecutor; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.sql.SQLException; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneId; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +import static com.google.common.base.Verify.verify; +import static com.google.common.io.BaseEncoding.base16; +import static io.trino.plugin.redshift.RedshiftClient.REDSHIFT_MAX_VARCHAR; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_PASSWORD; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_URL; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_USER; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.createTimeType; +import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.time.ZoneOffset.UTC; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRedshiftTypeMapping + extends AbstractTestQueryFramework +{ + private static final ZoneId testZone = TestingSession.DEFAULT_TIME_ZONE_KEY.getZoneId(); + + private final ZoneId jvmZone = ZoneId.systemDefault(); + private final LocalDateTime timeGapInJvmZone = LocalDate.EPOCH.atStartOfDay(); + private final LocalDateTime timeDoubledInJvmZone = LocalDateTime.of(2018, 10, 28, 1, 33, 17, 456_789_000); + + // using two non-JVM zones so that we don't need to worry what the backend's system zone is + + // no DST in 1970, but has DST in later years (e.g. 2018) + private final ZoneId vilnius = ZoneId.of("Europe/Vilnius"); + private final LocalDateTime timeGapInVilnius = LocalDateTime.of(2018, 3, 25, 3, 17, 17); + private final LocalDateTime timeDoubledInVilnius = LocalDateTime.of(2018, 10, 28, 3, 33, 33, 333_333_000); + + // Size of offset changed since 1970-01-01, no DST + private final ZoneId kathmandu = ZoneId.of("Asia/Kathmandu"); + private final LocalDateTime timeGapInKathmandu = LocalDateTime.of(1986, 1, 1, 0, 13, 7); + + private final LocalDate dayOfMidnightGapInJvmZone = LocalDate.EPOCH; + private final LocalDate dayOfMidnightGapInVilnius = LocalDate.of(1983, 4, 1); + private final LocalDate dayAfterMidnightSetBackInVilnius = LocalDate.of(1983, 10, 1); + + @BeforeClass + public void checkRanges() + { + // Timestamps + checkIsGap(jvmZone, timeGapInJvmZone); + checkIsDoubled(jvmZone, timeDoubledInJvmZone); + checkIsGap(vilnius, timeGapInVilnius); + checkIsDoubled(vilnius, timeDoubledInVilnius); + checkIsGap(kathmandu, timeGapInKathmandu); + + // Times + checkIsGap(jvmZone, LocalTime.of(0, 0, 0).atDate(LocalDate.EPOCH)); + + // Dates + checkIsGap(jvmZone, dayOfMidnightGapInJvmZone.atStartOfDay()); + checkIsGap(vilnius, dayOfMidnightGapInVilnius.atStartOfDay()); + checkIsDoubled(vilnius, dayAfterMidnightSetBackInVilnius.atStartOfDay().minusNanos(1)); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner(Map.of(), Map.of(), List.of()); + } + + @Test + public void testBasicTypes() + { + // Assume that if these types work at all, they have standard semantics. + SqlDataTypeTest.create() + .addRoundTrip("boolean", "true", BOOLEAN, "true") + .addRoundTrip("boolean", "false", BOOLEAN, "false") + .addRoundTrip("bigint", "123456789012", BIGINT, "123456789012") + .addRoundTrip("integer", "1234567890", INTEGER, "1234567890") + .addRoundTrip("smallint", "32456", SMALLINT, "SMALLINT '32456'") + .addRoundTrip("double", "123.45", DOUBLE, "DOUBLE '123.45'") + .addRoundTrip("real", "123.45", REAL, "REAL '123.45'") + // If we map tinyint to smallint: + .addRoundTrip("tinyint", "5", SMALLINT, "SMALLINT '5'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_basic_types")); + } + + @Test + public void testVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("varchar(65535)", "'varchar max'", createVarcharType(65535), "CAST('varchar max' AS varchar(65535))") + .addRoundTrip("varchar(40)", "'攻殻機動隊'", createVarcharType(40), "CAST('攻殻機動隊' AS varchar(40))") + .addRoundTrip("varchar(8)", "'隊'", createVarcharType(8), "CAST('隊' AS varchar(8))") + .addRoundTrip("varchar(16)", "'😂'", createVarcharType(16), "CAST('😂' AS varchar(16))") + .addRoundTrip("varchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("varchar(10)", "'text_a'", createVarcharType(10), "CAST('text_a' AS varchar(10))") + .addRoundTrip("varchar(255)", "'text_b'", createVarcharType(255), "CAST('text_b' AS varchar(255))") + .addRoundTrip("varchar(4096)", "'char max'", createVarcharType(4096), "CAST('char max' AS varchar(4096))") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_varchar")) + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_varchar")); + } + + @Test + public void testChar() + { + SqlDataTypeTest.create() + .addRoundTrip("char(10)", "'text_a'", createCharType(10), "CAST('text_a' AS char(10))") + .addRoundTrip("char(255)", "'text_b'", createCharType(255), "CAST('text_b' AS char(255))") + .addRoundTrip("char(4096)", "'char max'", createCharType(4096), "CAST('char max' AS char(4096))") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_char")) + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_char")); + + // Test with types larger than Redshift's char(max) + SqlDataTypeTest.create() + .addRoundTrip("char(65535)", "'varchar max'", createVarcharType(65535), format("CAST('varchar max%s' AS varchar(65535))", " ".repeat(65535 - "varchar max".length()))) + .addRoundTrip("char(4136)", "'攻殻機動隊'", createVarcharType(4136), format("CAST('%s' AS varchar(4136))", padVarchar(4136).apply("攻殻機動隊"))) + .addRoundTrip("char(4104)", "'隊'", createVarcharType(4104), format("CAST('%s' AS varchar(4104))", padVarchar(4104).apply("隊"))) + .addRoundTrip("char(4112)", "'😂'", createVarcharType(4112), format("CAST('%s' AS varchar(4112))", padVarchar(4112).apply("😂"))) + .addRoundTrip("varchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("char(4106)", "'text_a'", createVarcharType(4106), format("CAST('%s' AS varchar(4106))", padVarchar(4106).apply("text_a"))) + .addRoundTrip("char(4351)", "'text_b'", createVarcharType(4351), format("CAST('%s' AS varchar(4351))", padVarchar(4351).apply("text_b"))) + .addRoundTrip("char(8192)", "'char max'", createVarcharType(8192), format("CAST('%s' AS varchar(8192))", padVarchar(8192).apply("char max"))) + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_large_char")); + } + + /** + * Test handling of data outside Redshift's normal bounds. + * + *

Redshift sometimes returns unbounded {@code VARCHAR} data, apparently + * when it returns directly from a Postgres function. + */ + @Test + public void testPostgresText() + { + try (TestView view1 = new TestView("postgres_text_view", "SELECT lpad('x', 1)"); + TestView view2 = new TestView("pg_catalog_view", "SELECT relname FROM pg_class")) { + // Test data and type from a function + assertThat(query(format("SELECT * FROM %s", view1.name))) + .matches("VALUES CAST('x' AS varchar)"); + + // Test the type of an internal table + assertThat(query(format("SELECT * FROM %s LIMIT 1", view2.name))) + .hasOutputTypes(List.of(createUnboundedVarcharType())); + } + } + + // Make sure that Redshift still maps NCHAR and NVARCHAR to CHAR and VARCHAR. + @Test + public void checkNCharAndNVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("nvarchar(65535)", "'varchar max'", createVarcharType(65535), "CAST('varchar max' AS varchar(65535))") + .addRoundTrip("nvarchar(40)", "'攻殻機動隊'", createVarcharType(40), "CAST('攻殻機動隊' AS varchar(40))") + .addRoundTrip("nvarchar(8)", "'隊'", createVarcharType(8), "CAST('隊' AS varchar(8))") + .addRoundTrip("nvarchar(16)", "'😂'", createVarcharType(16), "CAST('😂' AS varchar(16))") + .addRoundTrip("nvarchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("nvarchar(10)", "'text_a'", createVarcharType(10), "CAST('text_a' AS varchar(10))") + .addRoundTrip("nvarchar(255)", "'text_b'", createVarcharType(255), "CAST('text_b' AS varchar(255))") + .addRoundTrip("nvarchar(4096)", "'char max'", createVarcharType(4096), "CAST('char max' AS varchar(4096))") + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_nvarchar")); + + SqlDataTypeTest.create() + .addRoundTrip("nchar(10)", "'text_a'", createCharType(10), "CAST('text_a' AS char(10))") + .addRoundTrip("nchar(255)", "'text_b'", createCharType(255), "CAST('text_b' AS char(255))") + .addRoundTrip("nchar(4096)", "'char max'", createCharType(4096), "CAST('char max' AS char(4096))") + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_nchar")); + } + + @Test + public void testUnicodeChar() // Redshift doesn't allow multibyte chars in CHAR + { + try (TestTable table = testTable("test_multibyte_char", "(c char(32))")) { + assertQueryFails( + format("INSERT INTO %s VALUES ('\u968A')", table.getName()), + "^Value for Redshift CHAR must be ASCII, but found '\u968A'$"); + } + + assertCreateFails( + "test_multibyte_char_ctas", + "AS SELECT CAST('\u968A' AS char(32)) c", + "^Value for Redshift CHAR must be ASCII, but found '\u968A'$"); + } + + // Make sure Redshift really doesn't allow multibyte characters in CHAR + @Test + public void checkUnicodeCharInRedshift() + { + try (TestTable table = testTable("check_multibyte_char", "(c char(32))")) { + assertThatThrownBy(() -> getRedshiftExecutor() + .execute(format("INSERT INTO %s VALUES ('\u968a')", table.getName()))) + .getCause() + .isInstanceOf(SQLException.class) + .hasMessageContaining("CHAR string contains invalid ASCII character"); + } + } + + @Test + public void testOversizedCharacterTypes() + { + // Test that character types too large for Redshift map to the maximum size + SqlDataTypeTest.create() + .addRoundTrip("varchar", "'unbounded'", createVarcharType(65535), "CAST('unbounded' AS varchar(65535))") + .addRoundTrip(format("varchar(%s)", REDSHIFT_MAX_VARCHAR + 1), "'oversized varchar'", createVarcharType(65535), "CAST('oversized varchar' AS varchar(65535))") + .addRoundTrip(format("char(%s)", REDSHIFT_MAX_VARCHAR + 1), "'oversized char'", createVarcharType(65535), format("CAST('%s' AS varchar(65535))", padVarchar(65535).apply("oversized char"))) + .execute(getQueryRunner(), trinoCreateAsSelect("oversized_character_types")); + } + + @Test + public void testVarbinary() + { + // Redshift's VARBYTE is mapped to Trino VARBINARY. Redshift does not have VARBINARY type. + SqlDataTypeTest.create() + // varbyte + .addRoundTrip("varbyte", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbyte", "to_varbyte('', 'hex')", VARBINARY, "X''") + .addRoundTrip("varbyte", utf8VarbyteLiteral("hello"), VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbyte", utf8VarbyteLiteral("Piękna łąka w 東京都"), VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbyte", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbyte", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbyte", "to_varbyte('000000000000', 'hex')", VARBINARY, "X'000000000000'") + .addRoundTrip("varbyte(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("varbyte(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + // varbinary + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("varbinary(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + // binary varying + .addRoundTrip("binary varying", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("binary varying", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("binary varying", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("binary varying(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("binary varying(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + .execute(getQueryRunner(), redshiftCreateAndInsert("test_varbinary")); + + SqlDataTypeTest.create() + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", "X''", VARBINARY, "X''") + .addRoundTrip("varbinary", "X'68656C6C6F'", VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbinary", "X'5069C4996B6E6120C582C4856B61207720E69DB1E4BAACE983BD'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbinary", "X'4261672066756C6C206F6620F09F92B0'", VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "X'0001020304050607080DF9367AA7000000'", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary", "X'000000000000'", VARBINARY, "X'000000000000'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_varbinary")); + } + + private static String utf8VarbyteLiteral(String string) + { + return format("to_varbyte('%s', 'hex')", base16().encode(string.getBytes(UTF_8))); + } + + @Test + public void testDecimal() + { + SqlDataTypeTest.create() + .addRoundTrip("decimal(3, 0)", "CAST('193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('19' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('19' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('-193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('-193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 1)", "CAST('10.0' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.0' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('-10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('-10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(4, 2)", "CAST('2' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2' AS decimal(4, 2))") + .addRoundTrip("decimal(4, 2)", "CAST('2.3' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2.3' AS decimal(4, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('123456789.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('123456789.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 4)", "CAST('12345678901234567890.31' AS decimal(24, 4))", createDecimalType(24, 4), "CAST('12345678901234567890.31' AS decimal(24, 4))") + .addRoundTrip("decimal(30, 5)", "CAST('3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(30, 5)", "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(31, 0)", "CAST('2718281828459045235360287471352' AS decimal(31, 0))", createDecimalType(31, 0), "CAST('2718281828459045235360287471352' AS decimal(31, 0))") + .addRoundTrip("decimal(31, 0)", "CAST('-2718281828459045235360287471352' AS decimal(31, 0))", createDecimalType(31, 0), "CAST('-2718281828459045235360287471352' AS decimal(31, 0))") + .addRoundTrip("decimal(3, 0)", "NULL", createDecimalType(3, 0), "CAST(NULL AS decimal(3, 0))") + .addRoundTrip("decimal(31, 0)", "NULL", createDecimalType(31, 0), "CAST(NULL AS decimal(31, 0))") + .execute(getQueryRunner(), redshiftCreateAndInsert("test_decimal")) + .execute(getQueryRunner(), trinoCreateAsSelect("test_decimal")); + } + + @Test + public void testRedshiftDecimalCutoff() + { + String columns = "(d19 decimal(19, 0), d18 decimal(19, 18), d0 decimal(19, 19))"; + try (TestTable table = testTable("test_decimal_range", columns)) { + assertQueryFails( + format("INSERT INTO %s (d19) VALUES (DECIMAL'9991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 0\\)$"); + assertQueryFails( + format("INSERT INTO %s (d18) VALUES (DECIMAL'9.991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 18\\)$"); + assertQueryFails( + format("INSERT INTO %s (d0) VALUES (DECIMAL'.9991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 19\\)$"); + } + } + + @Test + public void testRedshiftDecimalScaleLimit() + { + assertCreateFails( + "test_overlarge_decimal_scale", + "(d DECIMAL(38, 38))", + "^ERROR: DECIMAL scale 38 must be between 0 and 37$"); + } + + @Test + public void testUnsupportedTrinoDataTypes() + { + assertCreateFails( + "test_unsupported_type", + "(col json)", + "Unsupported column type: json"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testDate(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '0001-01-01'", DATE, "DATE '0001-01-01'") // first day of AD + .addRoundTrip("date", "DATE '1500-01-01'", DATE, "DATE '1500-01-01'") // sometime before julian->gregorian switch + .addRoundTrip("date", "DATE '1600-01-01'", DATE, "DATE '1600-01-01'") // long ago but after julian->gregorian switch + .addRoundTrip("date", "DATE '1952-04-03'", DATE, "DATE '1952-04-03'") // before epoch + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") + .addRoundTrip("date", "DATE '1970-02-03'", DATE, "DATE '1970-02-03'") // after epoch + .addRoundTrip("date", "DATE '2017-07-01'", DATE, "DATE '2017-07-01'") // summer in northern hemisphere (possible DST) + .addRoundTrip("date", "DATE '2017-01-01'", DATE, "DATE '2017-01-01'") // winter in northern hemisphere (possible DST in southern hemisphere) + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") // day of midnight gap in JVM + .addRoundTrip("date", "DATE '1983-04-01'", DATE, "DATE '1983-04-01'") // day of midnight gap in Vilnius + .addRoundTrip("date", "DATE '1983-10-01'", DATE, "DATE '1983-10-01'") // day after midnight setback in Vilnius + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + + // some time BC + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '-0100-01-01'", DATE, "DATE '-0100-01-01'") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '0101-01-01 BC'", DATE, "DATE '-0100-01-01'") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTime(ZoneId sessionZone) + { + // Redshift gets bizarre errors if you try to insert after + // specifying precision for a time column. + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + timeTypeTests("time(6)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "time_from_trino")); + timeTypeTests("time") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("time_from_jdbc")); + } + + private static SqlDataTypeTest timeTypeTests(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "TIME '00:00:00.000000'", createTimeType(6), "TIME '00:00:00.000000'") // gap in JVM zone on Epoch day + .addRoundTrip(inputType, "TIME '00:13:42.000000'", createTimeType(6), "TIME '00:13:42.000000'") // gap in JVM zone on Epoch day + .addRoundTrip(inputType, "TIME '01:33:17.000000'", createTimeType(6), "TIME '01:33:17.000000'") + .addRoundTrip(inputType, "TIME '03:17:17.000000'", createTimeType(6), "TIME '03:17:17.000000'") + .addRoundTrip(inputType, "TIME '10:01:17.100000'", createTimeType(6), "TIME '10:01:17.100000'") + .addRoundTrip(inputType, "TIME '13:18:03.000000'", createTimeType(6), "TIME '13:18:03.000000'") + .addRoundTrip(inputType, "TIME '14:18:03.000000'", createTimeType(6), "TIME '14:18:03.000000'") + .addRoundTrip(inputType, "TIME '15:18:03.000000'", createTimeType(6), "TIME '15:18:03.000000'") + .addRoundTrip(inputType, "TIME '16:18:03.123456'", createTimeType(6), "TIME '16:18:03.123456'") + .addRoundTrip(inputType, "TIME '19:01:17.000000'", createTimeType(6), "TIME '19:01:17.000000'") + .addRoundTrip(inputType, "TIME '20:01:17.000000'", createTimeType(6), "TIME '20:01:17.000000'") + .addRoundTrip(inputType, "TIME '21:01:17.000001'", createTimeType(6), "TIME '21:01:17.000001'") + .addRoundTrip(inputType, "TIME '22:59:59.000000'", createTimeType(6), "TIME '22:59:59.000000'") + .addRoundTrip(inputType, "TIME '23:59:59.000000'", createTimeType(6), "TIME '23:59:59.000000'") + .addRoundTrip(inputType, "TIME '23:59:59.999999'", createTimeType(6), "TIME '23:59:59.999999'"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTimestamp(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + // Redshift doesn't allow timestamp precision to be specified + timestampTypeTests("timestamp(6)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "timestamp_from_trino")); + timestampTypeTests("timestamp") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("timestamp_from_jdbc")); + + // some time BC + SqlDataTypeTest.create() + .addRoundTrip("timestamp(6)", "TIMESTAMP '-0100-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '-0100-01-01 00:00:00.000000'") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")); + SqlDataTypeTest.create() + .addRoundTrip("timestamp", "TIMESTAMP '0101-01-01 00:00:00 BC'", createTimestampType(6), "TIMESTAMP '-0100-01-01 00:00:00.000000'") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + } + + private static SqlDataTypeTest timestampTypeTests(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "TIMESTAMP '0001-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '0001-01-01 00:00:00.000000'") // first day of AD + .addRoundTrip(inputType, "TIMESTAMP '1500-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1500-01-01 00:00:00.000000'") // sometime before julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1600-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1600-01-01 00:00:00.000000'") // long ago but after julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1958-01-01 13:18:03.123456'", createTimestampType(6), "TIMESTAMP '1958-01-01 13:18:03.123456'") // before epoch + .addRoundTrip(inputType, "TIMESTAMP '2019-03-18 10:09:17.987654'", createTimestampType(6), "TIMESTAMP '2019-03-18 10:09:17.987654'") // after epoch + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 01:33:17.456789'", createTimestampType(6), "TIMESTAMP '2018-10-28 01:33:17.456789'") // time doubled in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 03:33:33.333333'", createTimestampType(6), "TIMESTAMP '2018-10-28 03:33:33.333333'") // time doubled in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1970-01-01 00:00:00.000000'") // time gap in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-03-25 03:17:17.000000'", createTimestampType(6), "TIMESTAMP '2018-03-25 03:17:17.000000'") // time gap in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '1986-01-01 00:13:07.000000'", createTimestampType(6), "TIMESTAMP '1986-01-01 00:13:07.000000'") // time gap in Kathmandu + // Full time precision + .addRoundTrip(inputType, "TIMESTAMP '1969-12-31 23:59:59.999999'", createTimestampType(6), "TIMESTAMP '1969-12-31 23:59:59.999999'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.999999'", createTimestampType(6), "TIMESTAMP '1970-01-01 00:00:00.999999'"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTimestampWithTimeZone(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + + SqlDataTypeTest.create() + // test arbitrary time for all supported precisions + .addRoundTrip("timestamp(0) with time zone", "TIMESTAMP '2022-09-27 12:34:56 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.000000 UTC'") + .addRoundTrip("timestamp(1) with time zone", "TIMESTAMP '2022-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(2) with time zone", "TIMESTAMP '2022-09-27 12:34:56.12 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.120000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2022-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(4) with time zone", "TIMESTAMP '2022-09-27 12:34:56.1234 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123400 UTC'") + .addRoundTrip("timestamp(5) with time zone", "TIMESTAMP '2022-09-27 12:34:56.12345 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123450 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2022-09-27 12:34:56.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123456 UTC'") + + // short timestamp with time zone + // .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '-4712-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '-4712-01-01 00:00:00.000000 UTC'") // min value in Redshift + // .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '0001-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-04 23:59:59.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-05 00:00:00.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-14 23:59:59.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999000 UTC'") // end julian->gregorian switch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-15 00:00:00.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1986-01-01 00:13:07 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-10-28 01:33:17.456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456000 UTC'") // time doubled in JVM + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-10-28 03:33:33.333 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333000 UTC'") // time doubled in Vilnius + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-03-25 03:17:17.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '73326-09-11 20:14:45.247 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247000 UTC'") // max value in Trino + .addRoundTrip("timestamp(3) with time zone", "NULL", TIMESTAMP_TZ_MICROS, "CAST(NULL AS timestamp(6) with time zone)") + + // long timestamp with time zone + // .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '0001-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'") // end julian->gregorian switch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu (long timestamp_tz) + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'") // time doubled in JVM + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'") // time doubled in Vilnius + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'") // max value in Trino + .addRoundTrip("timestamp(6) with time zone", "NULL", TIMESTAMP_TZ_MICROS, "CAST(NULL AS timestamp(6) with time zone)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(getSession(), "test_timestamp_tz")); + + redshiftTimestampWithTimeZoneTests("timestamptz") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_timestamp_tz")); + redshiftTimestampWithTimeZoneTests("timestamp with time zone") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_timestamp_tz")); + } + + private static SqlDataTypeTest redshiftTimestampWithTimeZoneTests(String inputType) + { + return SqlDataTypeTest.create() + // .addRoundTrip(inputType, "TIMESTAMP '4713-01-01 00:00:00 BC' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '-4712-01-01 00:00:00.000000 UTC'") // min value in Redshift + // .addRoundTrip(inputType, "TIMESTAMP '0001-01-01 00:00:00' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip(inputType, "TIMESTAMP '1582-10-04 23:59:59.999999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1582-10-05 00:00:00.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1582-10-14 23:59:59.999999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'") // end julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1582-10-15 00:00:00.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.1' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.9' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123456' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1986-01-01 00:13:07.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 01:33:17.456789' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'") // time doubled in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 03:33:33.333333' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'") // time doubled in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '2018-03-25 03:17:17.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.1' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.9' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123456' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '73326-09-11 20:14:45.247999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'"); // max value in Trino + } + + @Test + public void testTimestampWithTimeZoneCoercion() + { + SqlDataTypeTest.create() + // short timestamp with time zone + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.12341 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") // round down + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123499 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") // round up, end result rounds down + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1235 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.124000 UTC'") // round up + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.111222333444 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.111000 UTC'") // max precision + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:01.000000 UTC'") // round up to next second + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-02 00:00:00.000000 UTC'") // round up to next day + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") // negative epoch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.999499999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999000 UTC'") // negative epoch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9994 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999000 UTC'") // negative epoch + + // long timestamp with time zone + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1234561 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") // round down + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123456499 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") // nanoc round up, end result rounds down + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1234565 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123457 UTC'") // round up + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.111222333444 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.111222 UTC'") // max precision + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:01.000000 UTC'") // round up to next second + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-02 00:00:00.000000 UTC'") // round up to next day + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") // negative epoch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.999999499999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999999 UTC'") // negative epoch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999994 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999999 UTC'") // negative epoch + .execute(getQueryRunner(), trinoCreateAsSelect(getSession(), "test_timestamp_tz")); + } + + @Test + public void testTimestampWithTimeZoneOverflow() + { + // The min timestamp with time zone value in Trino is smaller than Redshift + try (TestTable table = new TestTable(getTrinoExecutor(), "timestamp_tz_min", "(ts timestamp(3) with time zone)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '-69387-04-22 03:45:14.752 UTC')", table.getName()), + "\\QMinimum timestamp with time zone in Redshift is -4712-01-01 00:00:00.000000: -69387-04-22 03:45:14.752000"); + } + try (TestTable table = new TestTable(getTrinoExecutor(), "timestamp_tz_min", "(ts timestamp(6) with time zone)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '-69387-04-22 03:45:14.752000 UTC')", table.getName()), + "\\QMinimum timestamp with time zone in Redshift is -4712-01-01 00:00:00.000000: -69387-04-22 03:45:14.752000"); + } + + // The max timestamp with time zone value in Redshift is larger than Trino + try (TestTable table = new TestTable(getRedshiftExecutor(), TEST_SCHEMA + ".timestamp_tz_max", "(ts timestamptz)", ImmutableList.of("TIMESTAMP '294276-12-31 23:59:59' AT TIME ZONE 'UTC'"))) { + assertThatThrownBy(() -> query("SELECT * FROM " + table.getName())) + .hasMessage("Millis overflow: 9224318015999000"); + } + } + + @DataProvider(name = "datetime_test_parameters") + public Object[][] dataProviderForDatetimeTests() + { + return new Object[][] { + {UTC}, + {jvmZone}, + {vilnius}, + {kathmandu}, + {testZone}, + }; + } + + @Test + public void testUnsupportedDateTimeTypes() + { + assertCreateFails( + "test_time_with_time_zone", + "(value TIME WITH TIME ZONE)", + "Unsupported column type: (?i)time.* with time zone"); + } + + @Test + public void testDateLimits() + { + // We can't test the exact date limits because Redshift doesn't say + // what they are, so we test one date on either side. + try (TestTable table = testTable("test_date_limits", "(d date)")) { + // First day of smallest year that Redshift supports (based on its documentation) + assertUpdate(format("INSERT INTO %s VALUES (DATE '-4712-01-01')", table.getName()), 1); + // Small date observed to not work + assertThatThrownBy(() -> computeActual(format("INSERT INTO %s VALUES (DATE '-4713-06-01')", table.getName()))) + .hasStackTraceContaining("ERROR: date out of range: \"4714-06-01 BC\""); + + // Last day of the largest year that Redshift supports (based on in its documentation) + assertUpdate(format("INSERT INTO %s VALUES (DATE '294275-12-31')", table.getName()), 1); + // Large date observed to not work + assertThatThrownBy(() -> computeActual(format("INSERT INTO %s VALUES (DATE '5875000-01-01')", table.getName()))) + .hasStackTraceContaining("ERROR: date out of range: \"5875000-01-01 AD\""); + } + } + + @Test + public void testLimitedTimePrecision() + { + Map> testCasesByPrecision = groupTestCasesByInput( + "TIME '\\d{2}:\\d{2}:\\d{2}(\\.\\d{1,12})?'", + input -> input.length() - "TIME '00:00:00'".length() - (input.contains(".") ? 1 : 0), + List.of( + // No rounding + new TestCase("TIME '00:00:00'", "TIME '00:00:00'"), + new TestCase("TIME '00:00:00.000000'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.123456'", "TIME '00:00:00.123456'"), + new TestCase("TIME '12:34:56'", "TIME '12:34:56'"), + new TestCase("TIME '12:34:56.123456'", "TIME '12:34:56.123456'"), + new TestCase("TIME '23:59:59'", "TIME '23:59:59'"), + new TestCase("TIME '23:59:59.9'", "TIME '23:59:59.9'"), + new TestCase("TIME '23:59:59.999'", "TIME '23:59:59.999'"), + new TestCase("TIME '23:59:59.999999'", "TIME '23:59:59.999999'"), + // round down + new TestCase("TIME '00:00:00.0000001'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000000001'", "TIME '00:00:00.000000'"), + new TestCase("TIME '12:34:56.1234561'", "TIME '12:34:56.123456'"), + // round down, maximal value + new TestCase("TIME '00:00:00.0000004'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000449'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000444449'", "TIME '00:00:00.000000'"), + // round up, minimal value + new TestCase("TIME '00:00:00.0000005'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000500'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000500000'", "TIME '00:00:00.000001'"), + // round up, maximal value + new TestCase("TIME '00:00:00.0000009'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000999'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000999999'", "TIME '00:00:00.000001'"), + // round up to next day, minimal value + new TestCase("TIME '23:59:59.9999995'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999500'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999500000'", "TIME '00:00:00.000000'"), + // round up to next day, maximal value + new TestCase("TIME '23:59:59.9999999'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999999'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999999999'", "TIME '00:00:00.000000'"), + // don't round to next day (round down near upper bound) + new TestCase("TIME '23:59:59.9999994'", "TIME '23:59:59.999999'"), + new TestCase("TIME '23:59:59.999999499'", "TIME '23:59:59.999999'"), + new TestCase("TIME '23:59:59.999999499999'", "TIME '23:59:59.999999'"))); + + for (Entry> entry : testCasesByPrecision.entrySet()) { + String tableName = format("test_time_precision_%d_%s", entry.getKey(), randomNameSuffix()); + runTestCases(tableName, entry.getValue()); + } + } + + @Test + public void testLimitedTimestampPrecision() + { + Map> testCasesByPrecision = groupTestCasesByInput( + "TIMESTAMP '\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}(\\.\\d{1,12})?'", + input -> input.length() - "TIMESTAMP '0000-00-00 00:00:00'".length() - (input.contains(".") ? 1 : 0), + // No rounding + new TestCase("TIMESTAMP '1970-01-01 00:00:00'", "TIMESTAMP '1970-01-01 00:00:00'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56'", "TIMESTAMP '2020-11-03 12:34:56'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + + new TestCase("TIMESTAMP '1970-01-01 00:00:00.123456'", "TIMESTAMP '1970-01-01 00:00:00.123456'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56.123456'", "TIMESTAMP '2020-11-03 12:34:56.123456'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59'", "TIMESTAMP '1969-12-31 23:59:59'"), + + new TestCase("TIMESTAMP '1970-01-01 23:59:59.9'", "TIMESTAMP '1970-01-01 23:59:59.9'"), + new TestCase("TIMESTAMP '2020-11-03 23:59:59.999'", "TIMESTAMP '2020-11-03 23:59:59.999'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999'", "TIMESTAMP '1969-12-31 23:59:59.999999'"), + // round down + new TestCase("TIMESTAMP '1969-12-31 00:00:00.0000001'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000000001'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56.1234561'", "TIMESTAMP '2020-11-03 12:34:56.123456'"), + // round down, maximal value + new TestCase("TIMESTAMP '2020-11-03 00:00:00.0000004'", "TIMESTAMP '2020-11-03 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000449'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000444449'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + // round up, minimal value + new TestCase("TIMESTAMP '1970-01-01 00:00:00.0000005'", "TIMESTAMP '1970-01-01 00:00:00.000001'"), + new TestCase("TIMESTAMP '2020-11-03 00:00:00.000000500'", "TIMESTAMP '2020-11-03 00:00:00.000001'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000500000'", "TIMESTAMP '1969-12-31 00:00:00.000001'"), + // round up, maximal value + new TestCase("TIMESTAMP '1969-12-31 00:00:00.0000009'", "TIMESTAMP '1969-12-31 00:00:00.000001'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000999'", "TIMESTAMP '1970-01-01 00:00:00.000001'"), + new TestCase("TIMESTAMP '2020-11-03 00:00:00.000000999999'", "TIMESTAMP '2020-11-03 00:00:00.000001'"), + // round up to next year, minimal value + new TestCase("TIMESTAMP '2020-12-31 23:59:59.9999995'", "TIMESTAMP '2021-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999500'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 23:59:59.999999500000'", "TIMESTAMP '1970-01-02 00:00:00.000000'"), + // round up to next day/year, maximal value + new TestCase("TIMESTAMP '1970-01-01 23:59:59.9999999'", "TIMESTAMP '1970-01-02 00:00:00.000000'"), + new TestCase("TIMESTAMP '2020-12-31 23:59:59.999999999'", "TIMESTAMP '2021-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999999999'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + // don't round to next year (round down near upper bound) + new TestCase("TIMESTAMP '1969-12-31 23:59:59.9999994'", "TIMESTAMP '1969-12-31 23:59:59.999999'"), + new TestCase("TIMESTAMP '1970-01-01 23:59:59.999999499'", "TIMESTAMP '1970-01-01 23:59:59.999999'"), + new TestCase("TIMESTAMP '2020-12-31 23:59:59.999999499999'", "TIMESTAMP '2020-12-31 23:59:59.999999'")); + + for (Entry> entry : testCasesByPrecision.entrySet()) { + String tableName = format("test_timestamp_precision_%d_%s", entry.getKey(), randomNameSuffix()); + runTestCases(tableName, entry.getValue()); + } + } + + private static Map> groupTestCasesByInput(String inputRegex, Function classifier, TestCase... testCases) + { + return groupTestCasesByInput(inputRegex, classifier, Arrays.asList(testCases)); + } + + private static Map> groupTestCasesByInput(String inputRegex, Function classifier, List testCases) + { + return testCases.stream() + .peek(test -> { + if (!test.input().matches(inputRegex)) { + throw new RuntimeException("Bad test case input format: " + test.input()); + } + }) + .collect(groupingBy(classifier.compose(TestCase::input))); + } + + private void runTestCases(String tableName, List testCases) + { + // Must use CTAS instead of TestTable because if the table is created before the insert, + // the type mapping will treat it as TIME(6) no matter what it was created as. + getTrinoExecutor().execute(format( + "CREATE TABLE %s AS SELECT * FROM (VALUES %s) AS t (id, value)", + tableName, + testCases.stream() + .map(testCase -> format("(%d, %s)", testCase.id(), testCase.input())) + .collect(joining("), (", "(", ")")))); + try { + assertQuery( + format("SELECT value FROM %s ORDER BY id", tableName), + testCases.stream() + .map(TestCase::expected) + .collect(joining("), (", "VALUES (", ")"))); + } + finally { + getTrinoExecutor().execute("DROP TABLE " + tableName); + } + } + + @Test + public static void checkIllegalRedshiftTimePrecision() + { + assertRedshiftCreateFails( + "check_redshift_time_precision_error", + "(t TIME(6))", + "ERROR: time column does not support precision."); + } + + @Test + public static void checkIllegalRedshiftTimestampPrecision() + { + assertRedshiftCreateFails( + "check_redshift_timestamp_precision_error", + "(t TIMESTAMP(6))", + "ERROR: timestamp column does not support precision."); + } + + /** + * Assert that a {@code CREATE TABLE} statement made from Redshift fails, + * and drop the table if it doesn't fail. + */ + private static void assertRedshiftCreateFails(String tableNamePrefix, String tableBody, String message) + { + String tableName = tableNamePrefix + "_" + randomNameSuffix(); + try { + assertThatThrownBy(() -> getRedshiftExecutor() + .execute(format("CREATE TABLE %s %s", tableName, tableBody))) + .getCause() + .as("Redshift create fails for %s %s", tableName, tableBody) + .isInstanceOf(SQLException.class) + .hasMessage(message); + } + catch (AssertionError failure) { + // If the table was created, clean it up because the tests run on a shared Redshift instance + try { + getRedshiftExecutor().execute("DROP TABLE IF EXISTS " + tableName); + } + catch (Throwable e) { + failure.addSuppressed(e); + } + throw failure; + } + } + + /** + * Assert that a {@code CREATE TABLE} statement fails, and drop the table + * if it doesn't fail. + */ + private void assertCreateFails(String tableNamePrefix, String tableBody, String expectedMessageRegExp) + { + String tableName = tableNamePrefix + "_" + randomNameSuffix(); + try { + assertQueryFails(format("CREATE TABLE %s %s", tableName, tableBody), expectedMessageRegExp); + } + catch (AssertionError failure) { + // If the table was created, clean it up because the tests run on a shared Redshift instance + try { + getRedshiftExecutor().execute("DROP TABLE " + tableName); + } + catch (Throwable e) { + failure.addSuppressed(e); + } + throw failure; + } + } + + private DataSetup trinoCreateAsSelect(String tableNamePrefix) + { + return trinoCreateAsSelect(getQueryRunner().getDefaultSession(), tableNamePrefix); + } + + private DataSetup trinoCreateAsSelect(Session session, String tableNamePrefix) + { + return new CreateAsSelectDataSetup(new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); + } + + private static DataSetup redshiftCreateAndInsert(String tableNamePrefix) + { + return new CreateAndInsertDataSetup(getRedshiftExecutor(), TEST_SCHEMA + "." + tableNamePrefix); + } + + /** + * Create a table in the test schema using the JDBC. + * + *

Creating a test table normally doesn't use the correct schema. + */ + private static TestTable testTable(String namePrefix, String body) + { + return new TestTable(getRedshiftExecutor(), TEST_SCHEMA + "." + namePrefix, body); + } + + private SqlExecutor getTrinoExecutor() + { + return new TrinoSqlExecutor(getQueryRunner()); + } + + private static SqlExecutor getRedshiftExecutor() + { + Properties properties = new Properties(); + properties.setProperty("user", JDBC_USER); + properties.setProperty("password", JDBC_PASSWORD); + return new JdbcSqlExecutor(JDBC_URL, properties); + } + + private static void checkIsGap(ZoneId zone, LocalDateTime dateTime) + { + verify( + zone.getRules().getValidOffsets(dateTime).isEmpty(), + "Expected %s to be a gap in %s", dateTime, zone); + } + + private static void checkIsDoubled(ZoneId zone, LocalDateTime dateTime) + { + verify( + zone.getRules().getValidOffsets(dateTime).size() == 2, + "Expected %s to be doubled in %s", dateTime, zone); + } + + private static Function padVarchar(int length) + { + // Add the same padding as RedshiftClient.writeCharAsVarchar, but start from String, not Slice + return (input) -> input + " ".repeat(length - Utf8.encodedLength(input)); + } + + /** + * A pair of input and expected output from a test. + * Each instance has a unique ID. + */ + private static class TestCase + { + private static final AtomicInteger LAST_ID = new AtomicInteger(); + + private final int id; + private final String input; + private final String expected; + + private TestCase(String input, String expected) + { + this.id = LAST_ID.incrementAndGet(); + this.input = input; + this.expected = expected; + } + + public int id() + { + return this.id; + } + + public String input() + { + return this.input; + } + + public String expected() + { + return this.expected; + } + } + + private static class TestView + implements AutoCloseable + { + final String name; + + TestView(String namePrefix, String definition) + { + name = requireNonNull(namePrefix) + "_" + randomNameSuffix(); + executeInRedshift(format("CREATE VIEW %s.%s AS %s", TEST_SCHEMA, name, definition)); + } + + @Override + public void close() + { + executeInRedshift(format("DROP VIEW IF EXISTS %s.%s", TEST_SCHEMA, name)); + } + } +} diff --git a/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/DbResourceGroupConfig.java b/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/DbResourceGroupConfig.java index 1ed9c6ce5c02..965424bb9c2e 100644 --- a/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/DbResourceGroupConfig.java +++ b/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/DbResourceGroupConfig.java @@ -19,7 +19,10 @@ import io.airlift.units.Duration; import io.airlift.units.MinDuration; +import javax.validation.constraints.AssertTrue; + import static java.util.concurrent.TimeUnit.HOURS; +import static java.util.concurrent.TimeUnit.SECONDS; public class DbResourceGroupConfig { @@ -28,6 +31,7 @@ public class DbResourceGroupConfig private String password; private boolean exactMatchSelectorEnabled; private Duration maxRefreshInterval = new Duration(1, HOURS); + private Duration refreshInterval = new Duration(1, SECONDS); public String getConfigDbUrl() { @@ -82,6 +86,20 @@ public DbResourceGroupConfig setMaxRefreshInterval(Duration maxRefreshInterval) return this; } + @MinDuration("1s") + public Duration getRefreshInterval() + { + return refreshInterval; + } + + @Config("resource-groups.refresh-interval") + @ConfigDescription("How often the cluster reloads from the database") + public DbResourceGroupConfig setRefreshInterval(Duration refreshInterval) + { + this.refreshInterval = refreshInterval; + return this; + } + public boolean getExactMatchSelectorEnabled() { return exactMatchSelectorEnabled; @@ -93,4 +111,10 @@ public DbResourceGroupConfig setExactMatchSelectorEnabled(boolean exactMatchSele this.exactMatchSelectorEnabled = exactMatchSelectorEnabled; return this; } + + @AssertTrue(message = "maxRefreshInterval must be greater than refreshInterval") + public boolean isRefreshIntervalValid() + { + return maxRefreshInterval.compareTo(refreshInterval) > 0; + } } diff --git a/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/DbResourceGroupConfigurationManager.java b/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/DbResourceGroupConfigurationManager.java index cd53571cd40d..a13ebe668ff5 100644 --- a/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/DbResourceGroupConfigurationManager.java +++ b/plugin/trino-resource-group-managers/src/main/java/io/trino/plugin/resourcegroups/db/DbResourceGroupConfigurationManager.java @@ -84,6 +84,7 @@ public class DbResourceGroupConfigurationManager private final AtomicLong lastRefresh = new AtomicLong(); private final String environment; private final Duration maxRefreshInterval; + private final Duration refreshInterval; private final boolean exactMatchSelectorEnabled; private final CounterStat refreshFailures = new CounterStat(); @@ -95,6 +96,7 @@ public DbResourceGroupConfigurationManager(ClusterMemoryPoolManager memoryPoolMa requireNonNull(dao, "daoProvider is null"); this.environment = requireNonNull(environment, "environment is null"); this.maxRefreshInterval = config.getMaxRefreshInterval(); + this.refreshInterval = config.getRefreshInterval(); this.exactMatchSelectorEnabled = config.getExactMatchSelectorEnabled(); this.dao = dao; load(); @@ -129,7 +131,7 @@ public void destroy() public void start() { if (started.compareAndSet(false, true)) { - configExecutor.scheduleWithFixedDelay(this::load, 1, 1, TimeUnit.SECONDS); + configExecutor.scheduleWithFixedDelay(this::load, 1000, refreshInterval.toMillis(), TimeUnit.MILLISECONDS); } } diff --git a/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestDbResourceGroupConfig.java b/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestDbResourceGroupConfig.java index de86e8e07d79..6bf518e8ecb3 100644 --- a/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestDbResourceGroupConfig.java +++ b/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestDbResourceGroupConfig.java @@ -17,13 +17,18 @@ import io.airlift.units.Duration; import org.testng.annotations.Test; +import javax.validation.constraints.AssertTrue; + import java.util.Map; import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static io.airlift.testing.ValidationAssertions.assertFailsValidation; import static java.util.concurrent.TimeUnit.HOURS; import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.testng.Assert.assertTrue; public class TestDbResourceGroupConfig { @@ -35,6 +40,7 @@ public void testDefaults() .setConfigDbUser(null) .setConfigDbPassword(null) .setMaxRefreshInterval(new Duration(1, HOURS)) + .setRefreshInterval(new Duration(1, SECONDS)) .setExactMatchSelectorEnabled(false)); } @@ -46,6 +52,7 @@ public void testExplicitPropertyMappings() .put("resource-groups.config-db-user", "trino_admin") .put("resource-groups.config-db-password", "trino_admin_pass") .put("resource-groups.max-refresh-interval", "1m") + .put("resource-groups.refresh-interval", "2s") .put("resource-groups.exact-match-selector-enabled", "true") .buildOrThrow(); DbResourceGroupConfig expected = new DbResourceGroupConfig() @@ -53,8 +60,20 @@ public void testExplicitPropertyMappings() .setConfigDbUser("trino_admin") .setConfigDbPassword("trino_admin_pass") .setMaxRefreshInterval(new Duration(1, MINUTES)) + .setRefreshInterval(new Duration(2, SECONDS)) .setExactMatchSelectorEnabled(true); assertFullMapping(properties, expected); + assertTrue(expected.isRefreshIntervalValid()); + } + + @Test + public void testValidation() + { + assertFailsValidation( + new DbResourceGroupConfig().setRefreshInterval(new Duration(2, HOURS)), + "refreshIntervalValid", + "maxRefreshInterval must be greater than refreshInterval", + AssertTrue.class); } } diff --git a/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestDbResourceGroupConfigurationManager.java b/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestDbResourceGroupConfigurationManager.java index 6d35b321794f..ece552542212 100644 --- a/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestDbResourceGroupConfigurationManager.java +++ b/plugin/trino-resource-group-managers/src/test/java/io/trino/plugin/resourcegroups/db/TestDbResourceGroupConfigurationManager.java @@ -296,7 +296,7 @@ public void testInvalidConfiguration() DbResourceGroupConfigurationManager manager = new DbResourceGroupConfigurationManager( listener -> {}, - new DbResourceGroupConfig().setMaxRefreshInterval(new io.airlift.units.Duration(1, MILLISECONDS)), + new DbResourceGroupConfig().setMaxRefreshInterval(new io.airlift.units.Duration(2, MILLISECONDS)).setRefreshInterval(new io.airlift.units.Duration(1, MILLISECONDS)), daoProvider.get(), ENVIRONMENT); @@ -316,7 +316,7 @@ public void testRefreshInterval() DbResourceGroupConfigurationManager manager = new DbResourceGroupConfigurationManager( listener -> {}, - new DbResourceGroupConfig().setMaxRefreshInterval(new io.airlift.units.Duration(1, MILLISECONDS)), + new DbResourceGroupConfig().setMaxRefreshInterval(new io.airlift.units.Duration(2, MILLISECONDS)).setRefreshInterval(new io.airlift.units.Duration(1, MILLISECONDS)), daoProvider.get(), ENVIRONMENT); @@ -345,7 +345,7 @@ public void testMatchByUserGroups() DbResourceGroupConfigurationManager manager = new DbResourceGroupConfigurationManager( listener -> {}, - new DbResourceGroupConfig().setMaxRefreshInterval(new io.airlift.units.Duration(1, MILLISECONDS)), + new DbResourceGroupConfig().setMaxRefreshInterval(new io.airlift.units.Duration(2, MILLISECONDS)).setRefreshInterval(new io.airlift.units.Duration(1, MILLISECONDS)), daoProvider.get(), ENVIRONMENT); @@ -368,7 +368,7 @@ public void testMatchByUsersAndGroups() DbResourceGroupConfigurationManager manager = new DbResourceGroupConfigurationManager( listener -> {}, - new DbResourceGroupConfig().setMaxRefreshInterval(new io.airlift.units.Duration(1, MILLISECONDS)), + new DbResourceGroupConfig().setMaxRefreshInterval(new io.airlift.units.Duration(2, MILLISECONDS)).setRefreshInterval(new io.airlift.units.Duration(1, MILLISECONDS)), daoProvider.get(), ENVIRONMENT); diff --git a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java index 56f96a5e9cd5..1ab22188272e 100644 --- a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java +++ b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java @@ -36,7 +36,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.ProjectionApplicationResult; @@ -164,12 +163,6 @@ public Optional resolveIndex(ConnectorSession session, C return Optional.empty(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java b/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java index a917bad9b8ff..51c5fb6ca083 100644 --- a/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java +++ b/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.statistics.TableStatistics; @@ -98,12 +97,6 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable return new TpcdsTableHandle(tableName.getTableName(), scaleFactor); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { diff --git a/pom.xml b/pom.xml index c9079b034c2b..96d1b57f3052 100644 --- a/pom.xml +++ b/pom.xml @@ -66,7 +66,7 @@ 5.5.2 4.14.0 7.1.4 - 1.0.0 + 1.1.0 4.7.2 3.21.6 3.2.2 diff --git a/testing/trino-faulttolerant-tests/pom.xml b/testing/trino-faulttolerant-tests/pom.xml index fcba727b5010..95e8c9167785 100644 --- a/testing/trino-faulttolerant-tests/pom.xml +++ b/testing/trino-faulttolerant-tests/pom.xml @@ -465,6 +465,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeSelectCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeSelectCompatibility.java new file mode 100644 index 000000000000..c44a9ad1d07c --- /dev/null +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeSelectCompatibility.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests.product.deltalake; + +import com.google.common.collect.ImmutableList; +import io.trino.tempto.assertions.QueryAssert.Row; +import io.trino.testng.services.Flaky; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.List; + +import static io.trino.tempto.assertions.QueryAssert.Row.row; +import static io.trino.tempto.assertions.QueryAssert.assertThat; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.tests.product.TestGroups.DELTA_LAKE_DATABRICKS; +import static io.trino.tests.product.TestGroups.DELTA_LAKE_OSS; +import static io.trino.tests.product.TestGroups.PROFILE_SPECIFIC_TESTS; +import static io.trino.tests.product.deltalake.util.DeltaLakeTestUtils.DATABRICKS_COMMUNICATION_FAILURE_ISSUE; +import static io.trino.tests.product.deltalake.util.DeltaLakeTestUtils.DATABRICKS_COMMUNICATION_FAILURE_MATCH; +import static io.trino.tests.product.utils.QueryExecutors.onDelta; +import static io.trino.tests.product.utils.QueryExecutors.onTrino; +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +public class TestDeltaLakeSelectCompatibility + extends BaseTestDeltaLakeS3Storage +{ + @Test(groups = {DELTA_LAKE_DATABRICKS, DELTA_LAKE_OSS, PROFILE_SPECIFIC_TESTS}) + @Flaky(issue = DATABRICKS_COMMUNICATION_FAILURE_ISSUE, match = DATABRICKS_COMMUNICATION_FAILURE_MATCH) + public void testPartitionedSelectSpecialCharacters() + { + String tableName = "test_dl_partitioned_select_special" + randomNameSuffix(); + + onDelta().executeQuery("" + + "CREATE TABLE default." + tableName + + " (a_number INT, a_string STRING)" + + " USING delta " + + " PARTITIONED BY (a_string)" + + " LOCATION 's3://" + bucketName + "/databricks-compatibility-test-" + tableName + "'"); + + try { + onDelta().executeQuery("INSERT INTO default." + tableName + " VALUES (1,'spark=equal'), (2, 'spark+plus'), (3, 'spark space')"); + onTrino().executeQuery("INSERT INTO delta.default." + tableName + " VALUES (10, 'trino=equal'), (20, 'trino+plus'), (30, 'trino space')"); + + List expectedRows = ImmutableList.of( + row(1, "spark=equal"), + row(2, "spark+plus"), + row(3, "spark space"), + row(10, "trino=equal"), + row(20, "trino+plus"), + row(30, "trino space")); + + assertThat(onDelta().executeQuery("SELECT * FROM default." + tableName)) + .containsOnly(expectedRows); + assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName)) + .containsOnly(expectedRows); + + String deltaFilePath = (String) onDelta().executeQuery("SELECT input_file_name() FROM default." + tableName + " WHERE a_number = 1").getOnlyValue(); + String trinoFilePath = (String) onTrino().executeQuery("SELECT \"$path\" FROM delta.default." + tableName + " WHERE a_number = 1").getOnlyValue(); + // File paths returned by the input_file_name function are URI encoded https://github.com/delta-io/delta/issues/1517 while the $path of Trino is not + assertNotEquals(deltaFilePath, trinoFilePath); + assertEquals(format("s3://%s%s", bucketName, URI.create(deltaFilePath).getPath()), trinoFilePath); + + assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName + " WHERE \"$path\" = '" + trinoFilePath + "'")) + .containsOnly(row(1, "spark=equal")); + } + finally { + onDelta().executeQuery("DROP TABLE default." + tableName); + } + } +} diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java index a7ec94a51eb3..062ae91d9c3b 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.List; +import static io.trino.testing.TestingNames.randomNameSuffix; import static io.trino.tests.product.TestGroups.SMOKE; import static io.trino.tests.product.utils.QueryExecutors.onHive; import static io.trino.tests.product.utils.QueryExecutors.onTrino; @@ -51,6 +52,87 @@ public static Object[][] storageFormats() return new String[][] {{"ORC"}, {"AVRO"}}; } + @DataProvider(name = "union_dereference_test_cases") + public static Object[][] unionDereferenceTestCases() + { + String tableUnionDereference = "test_union_dereference" + randomNameSuffix(); + // Hive insertion for union type in AVRO format has bugs, so we test on different table schemas for AVRO than ORC. + return new Object[][] {{ + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE<" + + "INT, STRING>)" + + "STORED AS %s", + tableUnionDereference, + "AVRO"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(0, 321, 'row1') " + + "UNION ALL " + + "SELECT create_union(1, 55, 'row2') ", + tableUnionDereference), + format("SELECT unionLevel0.field0 FROM %s WHERE unionLevel0.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList(321), + format("SELECT unionLevel0.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + // there is an internal issue in Hive 1.2: + // unionLevel1 is declared as unionType, but has to be inserted by create_union(tagId, Int, String) + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE>>, intLevel0 INT )" + + "STORED AS %s", + tableUnionDereference, + "AVRO"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 5, 'testString'))), 8 " + + "UNION ALL " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 5, 'testString'))), 8 ", + tableUnionDereference), + format("SELECT unionLevel0.field2.unionLevel1.field1 FROM %s WHERE unionLevel0.field2.unionLevel1.field1 IS NOT NULL", tableUnionDereference), + Arrays.asList(5), + format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE<" + + "STRUCT>>)" + + "STORED AS %s", + tableUnionDereference, + "ORC"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(0, named_struct('unionLevel1', create_union(0, 'testString1', 23))) " + + "UNION ALL " + + "SELECT create_union(0, named_struct('unionLevel1', create_union(1, 'testString2', 45))) ", + tableUnionDereference), + format("SELECT unionLevel0.field0.unionLevel1.field0 FROM %s WHERE unionLevel0.field0.unionLevel1.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList("testString1"), + format("SELECT unionLevel0.field0.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE>>, intLevel0 INT )" + + "STORED AS %s", + tableUnionDereference, + "ORC"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 'testString', 5))), 8 " + + "UNION ALL " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 'testString', 5))), 8 ", + tableUnionDereference), + format("SELECT unionLevel0.field2.unionLevel1.field0 FROM %s WHERE unionLevel0.field2.unionLevel1.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList("testString"), + format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}}; + } + @Test(dataProvider = "storage_formats", groups = SMOKE) public void testReadUniontype(String storageFormat) { @@ -137,6 +219,25 @@ public void testReadUniontype(String storageFormat) } } + @Test(dataProvider = "union_dereference_test_cases", groups = SMOKE) + public void testReadUniontypeWithDereference(String createTableSql, String insertSql, String selectSql, List expectedResult, String selectTagSql, List expectedTagResult, String dropTableSql) + { + // According to testing results, the Hive INSERT queries here only work in Hive 1.2 + if (getHiveVersionMajor() != 1 || getHiveVersionMinor() != 2) { + throw new SkipException("This test can only be run with Hive 1.2 (default config)"); + } + + onHive().executeQuery(createTableSql); + onHive().executeQuery(insertSql); + + QueryResult result = onTrino().executeQuery(selectSql); + assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedResult); + result = onTrino().executeQuery(selectTagSql); + assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedTagResult); + + onTrino().executeQuery(dropTableSql); + } + @Test(dataProvider = "storage_formats", groups = SMOKE) public void testUnionTypeSchemaEvolution(String storageFormat) { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractDistributedEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractDistributedEngineOnlyQueries.java index 49e891aa2150..41fcf48f62c5 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractDistributedEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractDistributedEngineOnlyQueries.java @@ -233,6 +233,8 @@ public void testExplainAnalyzeVerbose() "'CPU time distribution \\(s\\)' = \\{count=.*, p01=.*, p05=.*, p10=.*, p25=.*, p50=.*, p75=.*, p90=.*, p95=.*, p99=.*, min=.*, max=.*}", "'Scheduled time distribution \\(s\\)' = \\{count=.*, p01=.*, p05=.*, p10=.*, p25=.*, p50=.*, p75=.*, p90=.*, p95=.*, p99=.*, min=.*, max=.*}", "Output buffer active time: .*, buffer utilization distribution \\(%\\): \\{p01=.*, p05=.*, p10=.*, p25=.*, p50=.*, p75=.*, p90=.*, p95=.*, p99=.*, max=.*}", + "Task output distribution: \\{count=.*, p01=.*, p05=.*, p10=.*, p25=.*, p50=.*, p75=.*, p90=.*, p95=.*, p99=.*, max=.*}", + "Task input distribution: \\{count=.*, p01=.*, p05=.*, p10=.*, p25=.*, p50=.*, p75=.*, p90=.*, p95=.*, p99=.*, max=.*}", "Trino version: .*"); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java index 0224103e7530..2679256700cf 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java @@ -92,7 +92,7 @@ public ResultWithQueryId execute(Session session, @Language("SQL") String sql ClientSession clientSession = toClientSession(session, trinoServer.getBaseUrl(), new Duration(2, TimeUnit.MINUTES)); - try (StatementClient client = newStatementClient(httpClient, clientSession, sql)) { + try (StatementClient client = newStatementClient(httpClient, clientSession, sql, Optional.of(session.getClientCapabilities()))) { while (client.isRunning()) { resultsSession.addResults(client.currentStatusInfo(), client.currentData()); client.advance(); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 602dbae260b7..40cf6af6dcbb 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -3827,6 +3827,18 @@ public void testUpdateRowType() } } + @Test + public void testPredicateOnRowTypeField() + { + skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_INSERT) && hasBehavior(SUPPORTS_ROW_TYPE)); + + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_predicate_on_row_type_field", "(int_t INT, row_t row(varchar_t VARCHAR, int_t INT))")) { + assertUpdate("INSERT INTO " + table.getName() + " VALUES (2, row('first', 1)), (20, row('second', 10)), (200, row('third', 100))", 3); + assertQuery("SELECT int_t FROM " + table.getName() + " WHERE row_t.int_t = 1", "VALUES 2"); + assertQuery("SELECT int_t FROM " + table.getName() + " WHERE row_t.int_t > 1", "VALUES 20, 200"); + } + } + @Test public void testUpdateAllValues() { diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 02d6ab48ffaf..5eef67d95db0 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -391,6 +391,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java index 561c18e474bb..99b162657aae 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java @@ -42,6 +42,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collector; import java.util.stream.Collectors; @@ -66,6 +67,7 @@ import static io.trino.SystemSessionProperties.HASH_PARTITION_COUNT; import static io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static io.trino.SystemSessionProperties.QUERY_MAX_MEMORY; +import static io.trino.client.ClientCapabilities.PATH; import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; import static io.trino.spi.StandardErrorCode.INCOMPATIBLE_CLIENT; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -77,6 +79,7 @@ import static javax.ws.rs.core.Response.Status.OK; import static javax.ws.rs.core.Response.Status.SEE_OTHER; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; @@ -275,6 +278,19 @@ public void testVersionOnCompilerFailedError() } } + @Test + public void testSetPathSupportByClient() + { + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of()).build())) { + assertThatThrownBy(() -> testingClient.execute("SET PATH foo")) + .hasMessage("SET PATH not supported by client"); + } + + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of(PATH.name())).build())) { + testingClient.execute("SET PATH foo"); + } + } + private void checkVersionOnError(String query, @Language("RegExp") String proofOfOrigin) { QueryResults queryResults = postQuery(request -> request