diff --git a/core/trino-main/src/main/java/io/trino/connector/ConnectorManager.java b/core/trino-main/src/main/java/io/trino/connector/ConnectorManager.java index 9f43711dc5a2..ebb118ebf5c0 100644 --- a/core/trino-main/src/main/java/io/trino/connector/ConnectorManager.java +++ b/core/trino-main/src/main/java/io/trino/connector/ConnectorManager.java @@ -48,6 +48,7 @@ import io.trino.spi.connector.ConnectorRecordSetProvider; import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.SystemTable; +import io.trino.spi.connector.TableProcedureMetadata; import io.trino.spi.eventlistener.EventListener; import io.trino.spi.procedure.Procedure; import io.trino.spi.session.PropertyMetadata; @@ -300,6 +301,8 @@ private synchronized void addConnectorInternal(MaterializedConnector connector) .ifPresent(partitioningProvider -> nodePartitioningManager.addPartitioningProvider(catalogName, partitioningProvider)); metadataManager.getProcedureRegistry().addProcedures(catalogName, connector.getProcedures()); + Set<TableProcedureMetadata> tableProcedures = connector.getTableProcedures(); + metadataManager.getTableProcedureRegistry().addTableProcedures(catalogName, tableProcedures); connector.getAccessControl() .ifPresent(accessControl -> accessControlManager.addCatalogAccessControl(catalogName, accessControl)); @@ -309,6 +312,9 @@ private synchronized void addConnectorInternal(MaterializedConnector connector) metadataManager.getColumnPropertyManager().addProperties(catalogName, connector.getColumnProperties()); metadataManager.getSchemaPropertyManager().addProperties(catalogName, connector.getSchemaProperties()); metadataManager.getAnalyzePropertyManager().addProperties(catalogName, connector.getAnalyzeProperties()); + for (TableProcedureMetadata tableProcedure : tableProcedures) { + metadataManager.getTableProceduresPropertyManager().addProperties(catalogName, tableProcedure.getName(), tableProcedure.getProperties()); + } metadataManager.getSessionPropertyManager().addConnectorSessionProperties(catalogName, connector.getSessionProperties()); } @@ -333,12 +339,14 @@ private synchronized void removeConnectorInternal(CatalogName catalogName) indexManager.removeIndexProvider(catalogName); nodePartitioningManager.removePartitioningProvider(catalogName); metadataManager.getProcedureRegistry().removeProcedures(catalogName); + metadataManager.getTableProcedureRegistry().removeProcedures(catalogName); accessControlManager.removeCatalogAccessControl(catalogName); metadataManager.getTablePropertyManager().removeProperties(catalogName); metadataManager.getMaterializedViewPropertyManager().removeProperties(catalogName); metadataManager.getColumnPropertyManager().removeProperties(catalogName); metadataManager.getSchemaPropertyManager().removeProperties(catalogName); metadataManager.getAnalyzePropertyManager().removeProperties(catalogName); + metadataManager.getTableProceduresPropertyManager().removeProperties(catalogName); metadataManager.getSessionPropertyManager().removeConnectorSessionProperties(catalogName); MaterializedConnector materializedConnector = connectors.remove(catalogName); @@ -402,6 +410,7 @@ private static class MaterializedConnector private final Connector connector; private final Set<SystemTable> systemTables; private final Set<Procedure> procedures; + private final Set<TableProcedureMetadata> tableProcedures; private final Optional<ConnectorSplitManager> splitManager; private final Optional<ConnectorPageSourceProvider> pageSourceProvider; private final Optional<ConnectorPageSinkProvider> pageSinkProvider; @@ -429,6 +438,10 @@ public MaterializedConnector(CatalogName catalogName, Connector connector) requireNonNull(procedures, format("Connector '%s' returned a null procedures set", catalogName)); this.procedures = ImmutableSet.copyOf(procedures); + Set<TableProcedureMetadata> tableProcedures = connector.getTableProcedures(); + requireNonNull(procedures, format("Connector '%s' returned a null table procedures set", catalogName)); + this.tableProcedures = ImmutableSet.copyOf(tableProcedures); + ConnectorSplitManager splitManager = null; try { splitManager = connector.getSplitManager(); @@ -539,6 +552,11 @@ public Set<Procedure> getProcedures() return procedures; } + public Set<TableProcedureMetadata> getTableProcedures() + { + return tableProcedures; + } + public Optional<ConnectorSplitManager> getSplitManager() { return splitManager; diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index 9f478b2fe4fe..331cdf4c3a59 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -132,6 +132,7 @@ public class SqlQueryExecution private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; private final DynamicFilterService dynamicFilterService; + private final TableExecuteContextManager tableExecuteContextManager; private SqlQueryExecution( PreparedQuery preparedQuery, @@ -159,7 +160,8 @@ private SqlQueryExecution( StatsCalculator statsCalculator, CostCalculator costCalculator, DynamicFilterService dynamicFilterService, - WarningCollector warningCollector) + WarningCollector warningCollector, + TableExecuteContextManager tableExecuteContextManager) { try (SetThreadName ignored = new SetThreadName("Query-%s", stateMachine.getQueryId())) { this.slug = requireNonNull(slug, "slug is null"); @@ -180,6 +182,7 @@ private SqlQueryExecution( this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); checkArgument(scheduleSplitBatchSize > 0, "scheduleSplitBatchSize must be greater than 0"); this.scheduleSplitBatchSize = scheduleSplitBatchSize; @@ -195,6 +198,8 @@ private SqlQueryExecution( } unregisterDynamicFilteringQuery( dynamicFilterService.getDynamicFilteringStats(stateMachine.getQueryId(), stateMachine.getSession())); + + tableExecuteContextManager.unregisterTableExecuteContextForQuery(stateMachine.getQueryId()); }); // when the query finishes cache the final query info, and clear the reference to the output stage @@ -423,6 +428,8 @@ public void start() } } + tableExecuteContextManager.registerTableExecuteContextForQuery(getQueryId()); + if (!stateMachine.transitionToStarting()) { // query already started or finished return; @@ -544,7 +551,8 @@ private void planDistribution(PlanRoot plan) nodeTaskMap, executionPolicy, schedulerStats, - dynamicFilterService); + dynamicFilterService, + tableExecuteContextManager); queryScheduler.set(scheduler); @@ -741,6 +749,7 @@ public static class SqlQueryExecutionFactory private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; private final DynamicFilterService dynamicFilterService; + private final TableExecuteContextManager tableExecuteContextManager; @Inject SqlQueryExecutionFactory( @@ -765,7 +774,8 @@ public static class SqlQueryExecutionFactory SplitSchedulerStats schedulerStats, StatsCalculator statsCalculator, CostCalculator costCalculator, - DynamicFilterService dynamicFilterService) + DynamicFilterService dynamicFilterService, + TableExecuteContextManager tableExecuteContextManager) { requireNonNull(config, "config is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); @@ -790,6 +800,7 @@ public static class SqlQueryExecutionFactory this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); } @Override @@ -829,7 +840,8 @@ public QueryExecution createQueryExecution( statsCalculator, costCalculator, dynamicFilterService, - warningCollector); + warningCollector, + tableExecuteContextManager); } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/TableExecuteContext.java b/core/trino-main/src/main/java/io/trino/execution/TableExecuteContext.java new file mode 100644 index 000000000000..ff0b796c8018 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/TableExecuteContext.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import com.google.common.collect.ImmutableList; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class TableExecuteContext +{ + @GuardedBy("this") + private List<Object> splitsInfo; + + public synchronized void setSplitsInfo(List<Object> splitsInfo) + { + requireNonNull(splitsInfo, "splitsInfo is null"); + if (this.splitsInfo != null) { + throw new IllegalStateException("splitsInfo already set to " + this.splitsInfo); + } + this.splitsInfo = ImmutableList.copyOf(splitsInfo); + } + + public synchronized List<Object> getSplitsInfo() + { + if (splitsInfo == null) { + throw new IllegalStateException("splitsInfo not set yet"); + } + return splitsInfo; + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/TableExecuteContextManager.java b/core/trino-main/src/main/java/io/trino/execution/TableExecuteContextManager.java new file mode 100644 index 000000000000..aa85c44f52de --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/TableExecuteContextManager.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution; + +import io.trino.spi.QueryId; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +@ThreadSafe +public class TableExecuteContextManager +{ + private final ConcurrentMap<QueryId, TableExecuteContext> contexts = new ConcurrentHashMap<>(); + + public void registerTableExecuteContextForQuery(QueryId queryId) + { + TableExecuteContext newContext = new TableExecuteContext(); + if (contexts.putIfAbsent(queryId, newContext) != null) { + throw new IllegalStateException("TableExecuteContext already registered for query " + queryId); + } + } + + public void unregisterTableExecuteContextForQuery(QueryId queryId) + { + contexts.remove(queryId); + } + + public TableExecuteContext getTableExecuteContextForQuery(QueryId queryId) + { + TableExecuteContext context = contexts.get(queryId); + if (context == null) { + throw new IllegalStateException("TableExecuteContext not registered for query " + queryId); + } + return context; + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java index 8424f7c9a44d..6cc76e0d83a6 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -21,6 +21,7 @@ import io.trino.execution.Lifespan; import io.trino.execution.RemoteTask; import io.trino.execution.SqlStageExecution; +import io.trino.execution.TableExecuteContextManager; import io.trino.execution.scheduler.ScheduleResult.BlockedReason; import io.trino.execution.scheduler.group.DynamicLifespanScheduler; import io.trino.execution.scheduler.group.FixedLifespanScheduler; @@ -75,13 +76,15 @@ public FixedSourcePartitionedScheduler( OptionalInt concurrentLifespansPerTask, NodeSelector nodeSelector, List<ConnectorPartitionHandle> partitionHandles, - DynamicFilterService dynamicFilterService) + DynamicFilterService dynamicFilterService, + TableExecuteContextManager tableExecuteContextManager) { requireNonNull(stage, "stage is null"); requireNonNull(splitSources, "splitSources is null"); requireNonNull(bucketNodeMap, "bucketNodeMap is null"); checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty"); requireNonNull(partitionHandles, "partitionHandles is null"); + requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); this.stage = stage; this.nodes = ImmutableList.copyOf(nodes); @@ -119,6 +122,7 @@ public FixedSourcePartitionedScheduler( Math.max(splitBatchSize / concurrentLifespans, 1), groupedExecutionForScanNode, dynamicFilterService, + tableExecuteContextManager, () -> true); if (stageExecutionDescriptor.isStageGroupedExecution() && !groupedExecutionForScanNode) { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java index 1697ade417b4..45d628b4764d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java @@ -23,6 +23,8 @@ import io.trino.execution.Lifespan; import io.trino.execution.RemoteTask; import io.trino.execution.SqlStageExecution; +import io.trino.execution.TableExecuteContext; +import io.trino.execution.TableExecuteContextManager; import io.trino.execution.scheduler.FixedSourcePartitionedScheduler.BucketedSplitPlacementPolicy; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; @@ -40,6 +42,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; import java.util.function.BooleanSupplier; @@ -96,6 +99,7 @@ private enum State private final PlanNodeId partitionedNode; private final boolean groupedExecution; private final DynamicFilterService dynamicFilterService; + private final TableExecuteContextManager tableExecuteContextManager; private final BooleanSupplier anySourceTaskBlocked; private final Map<Lifespan, ScheduleGroup> scheduleGroups = new HashMap<>(); @@ -112,6 +116,7 @@ private SourcePartitionedScheduler( int splitBatchSize, boolean groupedExecution, DynamicFilterService dynamicFilterService, + TableExecuteContextManager tableExecuteContextManager, BooleanSupplier anySourceTaskBlocked) { this.stage = requireNonNull(stage, "stage is null"); @@ -119,6 +124,7 @@ private SourcePartitionedScheduler( this.splitSource = requireNonNull(splitSource, "splitSource is null"); this.splitPlacementPolicy = requireNonNull(splitPlacementPolicy, "splitPlacementPolicy is null"); this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); this.anySourceTaskBlocked = requireNonNull(anySourceTaskBlocked, "anySourceTaskBlocked is null"); checkArgument(splitBatchSize > 0, "splitBatchSize must be at least one"); @@ -146,6 +152,7 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( SplitPlacementPolicy splitPlacementPolicy, int splitBatchSize, DynamicFilterService dynamicFilterService, + TableExecuteContextManager tableExecuteContextManager, BooleanSupplier anySourceTaskBlocked) { SourcePartitionedScheduler sourcePartitionedScheduler = new SourcePartitionedScheduler( @@ -156,6 +163,7 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( splitBatchSize, false, dynamicFilterService, + tableExecuteContextManager, anySourceTaskBlocked); sourcePartitionedScheduler.startLifespan(Lifespan.taskWide(), NOT_PARTITIONED); sourcePartitionedScheduler.noMoreLifespans(); @@ -197,6 +205,7 @@ public static SourceScheduler newSourcePartitionedSchedulerAsSourceScheduler( int splitBatchSize, boolean groupedExecution, DynamicFilterService dynamicFilterService, + TableExecuteContextManager tableExecuteContextManager, BooleanSupplier anySourceTaskBlocked) { return new SourcePartitionedScheduler( @@ -207,6 +216,7 @@ public static SourceScheduler newSourcePartitionedSchedulerAsSourceScheduler( splitBatchSize, groupedExecution, dynamicFilterService, + tableExecuteContextManager, anySourceTaskBlocked); } @@ -357,6 +367,16 @@ else if (pendingSplits.isEmpty()) { throw new IllegalStateException("At least 1 split should have been scheduled for this plan node"); case SPLITS_ADDED: state = State.NO_MORE_SPLITS; + + Optional<List<Object>> tableExecuteSplitsInfo = splitSource.getTableExecuteSplitsInfo(); + + // Here we assume that we can get non-empty tableExecuteSplitsInfo only for queries which facilitate single split source. + // TODO support grouped execution + tableExecuteSplitsInfo.ifPresent(info -> { + TableExecuteContext tableExecuteContext = tableExecuteContextManager.getTableExecuteContextForQuery(stage.getStageId().getQueryId()); + tableExecuteContext.setSplitsInfo(info); + }); + splitSource.close(); // fall through case NO_MORE_SPLITS: diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index d20407c34745..e800b98fbe73 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -36,6 +36,7 @@ import io.trino.execution.StageId; import io.trino.execution.StageInfo; import io.trino.execution.StageState; +import io.trino.execution.TableExecuteContextManager; import io.trino.execution.TaskStatus; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; @@ -121,6 +122,7 @@ public class SqlQueryScheduler private final SplitSchedulerStats schedulerStats; private final boolean summarizeTaskInfo; private final DynamicFilterService dynamicFilterService; + private final TableExecuteContextManager tableExecuteContextManager; private final AtomicBoolean started = new AtomicBoolean(); public static SqlQueryScheduler createSqlQueryScheduler( @@ -139,7 +141,8 @@ public static SqlQueryScheduler createSqlQueryScheduler( NodeTaskMap nodeTaskMap, ExecutionPolicy executionPolicy, SplitSchedulerStats schedulerStats, - DynamicFilterService dynamicFilterService) + DynamicFilterService dynamicFilterService, + TableExecuteContextManager tableExecuteContextManager) { SqlQueryScheduler sqlQueryScheduler = new SqlQueryScheduler( queryStateMachine, @@ -157,7 +160,8 @@ public static SqlQueryScheduler createSqlQueryScheduler( nodeTaskMap, executionPolicy, schedulerStats, - dynamicFilterService); + dynamicFilterService, + tableExecuteContextManager); sqlQueryScheduler.initialize(); return sqlQueryScheduler; } @@ -178,13 +182,15 @@ private SqlQueryScheduler( NodeTaskMap nodeTaskMap, ExecutionPolicy executionPolicy, SplitSchedulerStats schedulerStats, - DynamicFilterService dynamicFilterService) + DynamicFilterService dynamicFilterService, + TableExecuteContextManager tableExecuteContextManager) { this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); this.executionPolicy = requireNonNull(executionPolicy, "executionPolicy is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); this.summarizeTaskInfo = summarizeTaskInfo; this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); // todo come up with a better way to build this, or eliminate this map ImmutableMap.Builder<StageId, StageScheduler> stageSchedulers = ImmutableMap.builder(); @@ -363,6 +369,7 @@ private List<SqlStageExecution> createStages( placementPolicy, splitBatchSize, dynamicFilterService, + tableExecuteContextManager, () -> childStages.stream().anyMatch(SqlStageExecution::isAnyTaskBlocked))); } else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { @@ -441,7 +448,8 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { getConcurrentLifespansPerNode(session), nodeScheduler.createNodeSelector(session, catalogName), connectorPartitionHandles, - dynamicFilterService)); + dynamicFilterService, + tableExecuteContextManager)); } else { // all sources are remote diff --git a/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java b/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java index 8ffd3f94a8e8..8e8c3b405b76 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java +++ b/core/trino-main/src/main/java/io/trino/metadata/HandleJsonModule.java @@ -23,6 +23,7 @@ import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.connector.ConnectorTableExecuteHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayoutHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -72,6 +73,12 @@ public static com.fasterxml.jackson.databind.Module insertTableHandleModule(Hand return new AbstractTypedJacksonModule<>(ConnectorInsertTableHandle.class, resolver::getId, resolver::getInsertTableHandleClass) {}; } + @ProvidesIntoSet + public static com.fasterxml.jackson.databind.Module tableExecuteHandleModule(HandleResolver resolver) + { + return new AbstractTypedJacksonModule<>(ConnectorTableExecuteHandle.class, resolver::getId, resolver::getTableExecuteHandleClass) {}; + } + @ProvidesIntoSet public static com.fasterxml.jackson.databind.Module indexHandleModule(HandleResolver resolver) { diff --git a/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java b/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java index b2b61a372192..bd8e834f2a27 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java +++ b/core/trino-main/src/main/java/io/trino/metadata/HandleResolver.java @@ -22,6 +22,7 @@ import io.trino.spi.connector.ConnectorOutputTableHandle; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSplit; +import io.trino.spi.connector.ConnectorTableExecuteHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayoutHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -103,6 +104,11 @@ public String getId(ConnectorInsertTableHandle insertHandle) return getId(insertHandle, MaterializedHandleResolver::getInsertTableHandleClass); } + public String getId(ConnectorTableExecuteHandle tableExecuteHandle) + { + return getId(tableExecuteHandle, MaterializedHandleResolver::getTableExecuteHandleClass); + } + public String getId(ConnectorPartitioningHandle partitioningHandle) { return getId(partitioningHandle, MaterializedHandleResolver::getPartitioningHandleClass); @@ -148,6 +154,11 @@ public Class<? extends ConnectorInsertTableHandle> getInsertTableHandleClass(Str return resolverFor(id).getInsertTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } + public Class<? extends ConnectorTableExecuteHandle> getTableExecuteHandleClass(String id) + { + return resolverFor(id).getTableExecuteHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + } + public Class<? extends ConnectorPartitioningHandle> getPartitioningHandleClass(String id) { return resolverFor(id).getPartitioningHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); @@ -188,6 +199,7 @@ private static class MaterializedHandleResolver private final Optional<Class<? extends ConnectorIndexHandle>> indexHandle; private final Optional<Class<? extends ConnectorOutputTableHandle>> outputTableHandle; private final Optional<Class<? extends ConnectorInsertTableHandle>> insertTableHandle; + private final Optional<Class<? extends ConnectorTableExecuteHandle>> tableExecuteHandle; private final Optional<Class<? extends ConnectorPartitioningHandle>> partitioningHandle; private final Optional<Class<? extends ConnectorTransactionHandle>> transactionHandle; @@ -200,6 +212,7 @@ public MaterializedHandleResolver(ConnectorHandleResolver resolver) indexHandle = getHandleClass(resolver::getIndexHandleClass); outputTableHandle = getHandleClass(resolver::getOutputTableHandleClass); insertTableHandle = getHandleClass(resolver::getInsertTableHandleClass); + tableExecuteHandle = getHandleClass(resolver::getTableExecuteHandleClass); partitioningHandle = getHandleClass(resolver::getPartitioningHandleClass); transactionHandle = getHandleClass(resolver::getTransactionHandleClass); } @@ -249,6 +262,11 @@ public Optional<Class<? extends ConnectorInsertTableHandle>> getInsertTableHandl return insertTableHandle; } + public Optional<Class<? extends ConnectorTableExecuteHandle>> getTableExecuteHandleClass() + { + return tableExecuteHandle; + } + public Optional<Class<? extends ConnectorPartitioningHandle>> getPartitioningHandleClass() { return partitioningHandle; @@ -276,6 +294,7 @@ public boolean equals(Object o) Objects.equals(indexHandle, that.indexHandle) && Objects.equals(outputTableHandle, that.outputTableHandle) && Objects.equals(insertTableHandle, that.insertTableHandle) && + Objects.equals(tableExecuteHandle, that.tableExecuteHandle) && Objects.equals(partitioningHandle, that.partitioningHandle) && Objects.equals(transactionHandle, that.transactionHandle); } @@ -283,7 +302,7 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(tableHandle, layoutHandle, columnHandle, split, indexHandle, outputTableHandle, insertTableHandle, partitioningHandle, transactionHandle); + return Objects.hash(tableHandle, layoutHandle, columnHandle, split, indexHandle, outputTableHandle, insertTableHandle, tableExecuteHandle, partitioningHandle, transactionHandle); } } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index 6ddc476e1048..0bd6fe8fdcef 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -23,6 +23,7 @@ import io.trino.spi.block.BlockEncodingSerde; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.AggregationApplicationResult; +import io.trino.spi.connector.BeginTableExecuteResult; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.connector.ColumnHandle; @@ -97,6 +98,18 @@ public interface Metadata Optional<TableHandle> getTableHandleForStatisticsCollection(Session session, QualifiedObjectName tableName, Map<String, Object> analyzeProperties); + Optional<TableExecuteHandle> getTableHandleForExecute( + Session session, + TableHandle tableHandle, + String procedureName, + Map<String, Object> executeProperties); + + Optional<NewTableLayout> getLayoutForTableExecute(Session session, TableExecuteHandle tableExecuteHandle); + + BeginTableExecuteResult<TableExecuteHandle, TableHandle> beginTableExecute(Session session, TableExecuteHandle handle, TableHandle updatedSourceTableHandle); + + void finishTableExecute(Session session, TableExecuteHandle handle, Collection<Slice> fragments, List<Object> tableExecuteState); + @Deprecated Optional<TableLayoutResult> getLayout(Session session, TableHandle tableHandle, Constraint constraint, Optional<Set<ColumnHandle>> desiredColumns); @@ -643,6 +656,8 @@ default ResolvedFunction getCoercion(Session session, Type fromType, Type toType ProcedureRegistry getProcedureRegistry(); + TableProceduresRegistry getTableProcedureRegistry(); + // // Blocks // @@ -665,6 +680,8 @@ default ResolvedFunction getCoercion(Session session, Type fromType, Type toType AnalyzePropertyManager getAnalyzePropertyManager(); + TableProceduresPropertyManager getTableProceduresPropertyManager(); + /** * Creates the specified materialized view with the specified view definition. */ diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 3602cbb199e1..f3f34f4d288a 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -57,6 +57,7 @@ import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.AggregationApplicationResult; import io.trino.spi.connector.Assignment; +import io.trino.spi.connector.BeginTableExecuteResult; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.connector.ColumnHandle; @@ -70,6 +71,7 @@ import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorResolvedIndex; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableExecuteHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableLayoutHandle; @@ -210,12 +212,14 @@ public final class MetadataManager private final TypeOperators typeOperators; private final FunctionResolver functionResolver; private final ProcedureRegistry procedures; + private final TableProceduresRegistry tableProcedures; private final SessionPropertyManager sessionPropertyManager; private final SchemaPropertyManager schemaPropertyManager; private final TablePropertyManager tablePropertyManager; private final MaterializedViewPropertyManager materializedViewPropertyManager; private final ColumnPropertyManager columnPropertyManager; private final AnalyzePropertyManager analyzePropertyManager; + private final TableProceduresPropertyManager tableProceduresPropertyManager; private final SystemSecurityMetadata systemSecurityMetadata; private final TransactionManager transactionManager; private final TypeRegistry typeRegistry; @@ -237,6 +241,7 @@ public MetadataManager( MaterializedViewPropertyManager materializedViewPropertyManager, ColumnPropertyManager columnPropertyManager, AnalyzePropertyManager analyzePropertyManager, + TableProceduresPropertyManager tableProceduresPropertyManager, SystemSecurityMetadata systemSecurityMetadata, TransactionManager transactionManager, TypeOperators typeOperators, @@ -249,12 +254,14 @@ public MetadataManager( functionResolver = new FunctionResolver(this); this.procedures = new ProcedureRegistry(); + this.tableProcedures = new TableProceduresRegistry(); this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); this.schemaPropertyManager = requireNonNull(schemaPropertyManager, "schemaPropertyManager is null"); this.tablePropertyManager = requireNonNull(tablePropertyManager, "tablePropertyManager is null"); this.materializedViewPropertyManager = requireNonNull(materializedViewPropertyManager, "materializedViewPropertyManager is null"); this.columnPropertyManager = requireNonNull(columnPropertyManager, "columnPropertyManager is null"); this.analyzePropertyManager = requireNonNull(analyzePropertyManager, "analyzePropertyManager is null"); + this.tableProceduresPropertyManager = requireNonNull(tableProceduresPropertyManager, "tableProceduresPropertyManager is null"); this.systemSecurityMetadata = requireNonNull(systemSecurityMetadata, "systemSecurityMetadata is null"); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); @@ -320,6 +327,7 @@ public static MetadataManager createTestMetadataManager(TransactionManager trans new MaterializedViewPropertyManager(), new ColumnPropertyManager(), new AnalyzePropertyManager(), + new TableProceduresPropertyManager(), new DisabledSystemSecurityMetadata(), transactionManager, typeOperators, @@ -420,6 +428,62 @@ public Optional<TableHandle> getTableHandleForStatisticsCollection(Session sessi return Optional.empty(); } + @Override + public Optional<TableExecuteHandle> getTableHandleForExecute(Session session, TableHandle tableHandle, String procedure, Map<String, Object> executeProperties) + { + requireNonNull(session, "session is null"); + requireNonNull(tableHandle, "tableHandle is null"); + requireNonNull(procedure, "procedure is null"); + requireNonNull(executeProperties, "executeProperties is null"); + + CatalogName catalogName = tableHandle.getCatalogName(); + CatalogMetadata catalogMetadata = getCatalogMetadata(session, catalogName); + ConnectorMetadata metadata = catalogMetadata.getMetadataFor(catalogName); + + Optional<ConnectorTableExecuteHandle> executeHandle = metadata.getTableHandleForExecute( + session.toConnectorSession(catalogName), + tableHandle.getConnectorHandle(), + procedure, + executeProperties); + + return executeHandle.map(handle -> new TableExecuteHandle( + catalogName, + tableHandle.getTransaction(), + handle)); + } + + @Override + public Optional<NewTableLayout> getLayoutForTableExecute(Session session, TableExecuteHandle tableExecuteHandle) + { + CatalogName catalogName = tableExecuteHandle.getCatalogName(); + CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogName); + ConnectorMetadata metadata = catalogMetadata.getMetadata(); + + return metadata.getLayoutForTableExecute(session.toConnectorSession(catalogName), tableExecuteHandle.getConnectorHandle()) + .map(layout -> new NewTableLayout(catalogName, catalogMetadata.getTransactionHandleFor(catalogName), layout)); + } + + @Override + public BeginTableExecuteResult<TableExecuteHandle, TableHandle> beginTableExecute(Session session, TableExecuteHandle tableExecuteHandle, TableHandle sourceHandle) + { + CatalogName catalogName = tableExecuteHandle.getCatalogName(); + CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogName); + ConnectorMetadata metadata = catalogMetadata.getMetadata(); + BeginTableExecuteResult<ConnectorTableExecuteHandle, ConnectorTableHandle> connectorBeginResult = metadata.beginTableExecute(session.toConnectorSession(), tableExecuteHandle.getConnectorHandle(), sourceHandle.getConnectorHandle()); + + return new BeginTableExecuteResult<>( + tableExecuteHandle.withConnectorHandle(connectorBeginResult.getTableExecuteHandle()), + sourceHandle.withConnectorHandle(connectorBeginResult.getSourceHandle())); + } + + @Override + public void finishTableExecute(Session session, TableExecuteHandle tableExecuteHandle, Collection<Slice> fragments, List<Object> tableExecuteState) + { + CatalogName catalogName = tableExecuteHandle.getCatalogName(); + ConnectorMetadata metadata = getMetadata(session, catalogName); + metadata.finishTableExecute(session.toConnectorSession(catalogName), tableExecuteHandle.getConnectorHandle(), fragments, tableExecuteState); + } + @Override public Optional<SystemTable> getSystemTable(Session session, QualifiedObjectName tableName) { @@ -2573,6 +2637,12 @@ public ProcedureRegistry getProcedureRegistry() return procedures; } + @Override + public TableProceduresRegistry getTableProcedureRegistry() + { + return tableProcedures; + } + // // Blocks // @@ -2637,6 +2707,12 @@ public AnalyzePropertyManager getAnalyzePropertyManager() return analyzePropertyManager; } + @Override + public TableProceduresPropertyManager getTableProceduresPropertyManager() + { + return tableProceduresPropertyManager; + } + // // Helpers // diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableExecuteHandle.java b/core/trino-main/src/main/java/io/trino/metadata/TableExecuteHandle.java new file mode 100644 index 000000000000..bd3b30683b64 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/TableExecuteHandle.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import io.trino.connector.CatalogName; +import io.trino.spi.connector.ConnectorTableExecuteHandle; +import io.trino.spi.connector.ConnectorTransactionHandle; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +/** + * TableExecuteHandle wraps connectors ConnectorTableExecuteHandle which identifies instance of executing + * specific table procedure o specific table. See {#link {@link ConnectorTableExecuteHandle}} for more details. + */ +public final class TableExecuteHandle +{ + private final CatalogName catalogName; + private final ConnectorTransactionHandle transactionHandle; + private final ConnectorTableExecuteHandle connectorHandle; + + @JsonCreator + public TableExecuteHandle( + @JsonProperty("catalogName") CatalogName catalogName, + @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, + @JsonProperty("connectorHandle") ConnectorTableExecuteHandle connectorHandle) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + this.connectorHandle = requireNonNull(connectorHandle, "connectorHandle is null"); + } + + @JsonProperty + public CatalogName getCatalogName() + { + return catalogName; + } + + @JsonProperty + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } + + @JsonProperty + public ConnectorTableExecuteHandle getConnectorHandle() + { + return connectorHandle; + } + + public TableExecuteHandle withConnectorHandle(ConnectorTableExecuteHandle connectorHandle) + { + return new TableExecuteHandle(catalogName, transactionHandle, connectorHandle); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + TableExecuteHandle o = (TableExecuteHandle) obj; + return Objects.equals(this.catalogName, o.catalogName) && + Objects.equals(this.transactionHandle, o.transactionHandle) && + Objects.equals(this.connectorHandle, o.connectorHandle); + } + + @Override + public int hashCode() + { + return Objects.hash(catalogName, transactionHandle, connectorHandle); + } + + @Override + public String toString() + { + return "Execute[" + catalogName + ":" + connectorHandle + "]"; + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableHandle.java b/core/trino-main/src/main/java/io/trino/metadata/TableHandle.java index d6918c2e7c79..db7ce559dfff 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TableHandle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/TableHandle.java @@ -72,6 +72,15 @@ public ConnectorTransactionHandle getTransaction() return transaction; } + public TableHandle withConnectorHandle(ConnectorTableHandle connectorHandle) + { + return new TableHandle( + catalogName, + connectorHandle, + transaction, + layout); + } + @Override public String toString() { diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableProceduresPropertyManager.java b/core/trino-main/src/main/java/io/trino/metadata/TableProceduresPropertyManager.java new file mode 100644 index 000000000000..43ccae91ec72 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/TableProceduresPropertyManager.java @@ -0,0 +1,137 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.security.AccessControl; +import io.trino.spi.session.PropertyMetadata; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.NodeRef; +import io.trino.sql.tree.Parameter; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.spi.StandardErrorCode.INVALID_PROCEDURE_ARGUMENT; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class TableProceduresPropertyManager + extends AbstractPropertyManager<TableProceduresPropertyManager.Key> +{ + public TableProceduresPropertyManager() + { + super("procedure", INVALID_PROCEDURE_ARGUMENT); + } + + public void addProperties(CatalogName catalogName, String procedureName, List<PropertyMetadata<?>> properties) + { + doAddProperties(new Key(catalogName, procedureName), properties); + } + + public void removeProperties(CatalogName catalogName) + { + Set<Key> keysToRemove = connectorProperties.keySet().stream() + .filter(key -> catalogName.equals(key.getCatalogName())) + .collect(toImmutableSet()); + for (Key key : keysToRemove) { + doRemoveProperties(key); + } + } + + public Map<String, Object> getProperties( + CatalogName catalog, + String procedureName, + String catalogNameForDiagnostics, + Map<String, Expression> sqlPropertyValues, + Session session, + Metadata metadata, + AccessControl accessControl, + Map<NodeRef<Parameter>, Expression> parameters, + boolean setDefaultProperties) + { + return doGetProperties( + new Key(catalog, procedureName), + catalogNameForDiagnostics, + sqlPropertyValues, + session, + metadata, + accessControl, + parameters, + setDefaultProperties); + } + + public Map<Key, Map<String, PropertyMetadata<?>>> getAllProperties() + { + return doGetAllProperties(); + } + + @Override + protected String formatPropertiesKeyForMessage(String catalogName, Key propertiesKey) + { + return format("Catalog %s table procedure %s", catalogName, propertiesKey.procedureName); + } + + static final class Key + { + private final CatalogName catalogName; + private final String procedureName; + + private Key(CatalogName catalogName, String procedureName) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + this.procedureName = requireNonNull(procedureName, "procedureName is null"); + } + + public CatalogName getCatalogName() + { + return catalogName; + } + + public String getProcedureName() + { + return procedureName; + } + + @Override + public String toString() + { + return catalogName + ":" + procedureName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Key key = (Key) o; + return Objects.equals(catalogName, key.catalogName) + && Objects.equals(procedureName, key.procedureName); + } + + @Override + public int hashCode() + { + return Objects.hash(catalogName, procedureName); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableProceduresRegistry.java b/core/trino-main/src/main/java/io/trino/metadata/TableProceduresRegistry.java new file mode 100644 index 000000000000..a01d512d9792 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/metadata/TableProceduresRegistry.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.metadata; + +import com.google.common.collect.Maps; +import io.trino.connector.CatalogName; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.TableProcedureMetadata; + +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.spi.StandardErrorCode.PROCEDURE_NOT_FOUND; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class TableProceduresRegistry +{ + private final Map<CatalogName, Map<String, TableProcedureMetadata>> tableProcedures = new ConcurrentHashMap<>(); + + public TableProceduresRegistry() + { + } + + public void addTableProcedures(CatalogName catalogName, Collection<TableProcedureMetadata> procedures) + { + requireNonNull(catalogName, "catalogName is null"); + requireNonNull(procedures, "procedures is null"); + + Map<String, TableProcedureMetadata> proceduresByName = Maps.uniqueIndex(procedures, TableProcedureMetadata::getName); + + checkState(tableProcedures.putIfAbsent(catalogName, proceduresByName) == null, "Table procedures already registered for connector: %s", catalogName); + } + + public void removeProcedures(CatalogName catalogName) + { + tableProcedures.remove(catalogName); + } + + public TableProcedureMetadata resolve(CatalogName catalogName, String name) + { + Map<String, TableProcedureMetadata> procedures = tableProcedures.get(catalogName); + if (procedures == null) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Catalog %s not registered", catalogName)); + } + + TableProcedureMetadata procedure = procedures.get(name); + if (procedure == null) { + throw new TrinoException(PROCEDURE_NOT_FOUND, format("Procedure %s not registered for catalog %s", name, catalogName)); + } + return procedure; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/TableFinishOperator.java b/core/trino-main/src/main/java/io/trino/operator/TableFinishOperator.java index 91fc5b539f8d..ab4a8a232670 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TableFinishOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/TableFinishOperator.java @@ -18,9 +18,12 @@ import io.airlift.slice.Slice; import io.airlift.units.Duration; import io.trino.Session; +import io.trino.execution.TableExecuteContext; +import io.trino.execution.TableExecuteContextManager; import io.trino.operator.OperationTimer.OperationTiming; import io.trino.spi.Page; import io.trino.spi.PageBuilder; +import io.trino.spi.QueryId; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorOutputMetadata; import io.trino.spi.statistics.ComputedStatistics; @@ -57,6 +60,8 @@ public static class TableFinishOperatorFactory private final TableFinisher tableFinisher; private final OperatorFactory statisticsAggregationOperatorFactory; private final StatisticAggregationsDescriptor<Integer> descriptor; + private final TableExecuteContextManager tableExecuteContextManager; + private final boolean outputRowCount; private final Session session; private boolean closed; @@ -66,6 +71,8 @@ public TableFinishOperatorFactory( TableFinisher tableFinisher, OperatorFactory statisticsAggregationOperatorFactory, StatisticAggregationsDescriptor<Integer> descriptor, + TableExecuteContextManager tableExecuteContextManager, + boolean outputRowCount, Session session) { this.operatorId = operatorId; @@ -74,6 +81,8 @@ public TableFinishOperatorFactory( this.statisticsAggregationOperatorFactory = requireNonNull(statisticsAggregationOperatorFactory, "statisticsAggregationOperatorFactory is null"); this.descriptor = requireNonNull(descriptor, "descriptor is null"); this.session = requireNonNull(session, "session is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); + this.outputRowCount = outputRowCount; } @Override @@ -83,7 +92,9 @@ public Operator createOperator(DriverContext driverContext) OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, TableFinishOperator.class.getSimpleName()); Operator statisticsAggregationOperator = statisticsAggregationOperatorFactory.createOperator(driverContext); boolean statisticsCpuTimerEnabled = !(statisticsAggregationOperator instanceof DevNullOperator) && isStatisticsCpuTimerEnabled(session); - return new TableFinishOperator(context, tableFinisher, statisticsAggregationOperator, descriptor, statisticsCpuTimerEnabled); + QueryId queryId = driverContext.getPipelineContext().getTaskContext().getQueryContext().getQueryId(); + TableExecuteContext tableExecuteContext = tableExecuteContextManager.getTableExecuteContextForQuery(queryId); + return new TableFinishOperator(context, tableFinisher, statisticsAggregationOperator, descriptor, statisticsCpuTimerEnabled, tableExecuteContext, outputRowCount); } @Override @@ -95,7 +106,7 @@ public void noMoreOperators() @Override public OperatorFactory duplicate() { - return new TableFinishOperatorFactory(operatorId, planNodeId, tableFinisher, statisticsAggregationOperatorFactory, descriptor, session); + return new TableFinishOperatorFactory(operatorId, planNodeId, tableFinisher, statisticsAggregationOperatorFactory, descriptor, tableExecuteContextManager, outputRowCount, session); } } @@ -118,6 +129,9 @@ private enum State private final OperationTiming statisticsTiming = new OperationTiming(); private final boolean statisticsCpuTimerEnabled; + private final TableExecuteContext tableExecuteContext; + private final boolean outputRowCount; + private final Supplier<TableFinishInfo> tableFinishInfoSupplier; public TableFinishOperator( @@ -125,14 +139,18 @@ public TableFinishOperator( TableFinisher tableFinisher, Operator statisticsAggregationOperator, StatisticAggregationsDescriptor<Integer> descriptor, - boolean statisticsCpuTimerEnabled) + boolean statisticsCpuTimerEnabled, + TableExecuteContext tableExecuteContext, + boolean outputRowCount) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.tableFinisher = requireNonNull(tableFinisher, "tableFinisher is null"); this.statisticsAggregationOperator = requireNonNull(statisticsAggregationOperator, "statisticsAggregationOperator is null"); this.descriptor = requireNonNull(descriptor, "descriptor is null"); this.statisticsCpuTimerEnabled = statisticsCpuTimerEnabled; + this.tableExecuteContext = requireNonNull(tableExecuteContext, "tableExecuteContext is null"); this.tableFinishInfoSupplier = createTableFinishInfoSupplier(outputMetadata, statisticsTiming); + this.outputRowCount = outputRowCount; operatorContext.setInfoSupplier(tableFinishInfoSupplier); } @@ -297,13 +315,15 @@ public Page getOutput() } state = State.FINISHED; - this.outputMetadata.set(tableFinisher.finishTable(fragmentBuilder.build(), computedStatisticsBuilder.build())); + this.outputMetadata.set(tableFinisher.finishTable(fragmentBuilder.build(), computedStatisticsBuilder.build(), tableExecuteContext)); // output page will only be constructed once, // so a new PageBuilder is constructed (instead of using PageBuilder.reset) PageBuilder page = new PageBuilder(1, TYPES); - page.declarePosition(); - BIGINT.writeLong(page.getBlockBuilder(0), rowCount); + if (outputRowCount) { + page.declarePosition(); + BIGINT.writeLong(page.getBlockBuilder(0), rowCount); + } return page.build(); } @@ -345,6 +365,9 @@ public void close() public interface TableFinisher { - Optional<ConnectorOutputMetadata> finishTable(Collection<Slice> fragments, Collection<ComputedStatistics> computedStatistics); + Optional<ConnectorOutputMetadata> finishTable( + Collection<Slice> fragments, + Collection<ComputedStatistics> computedStatistics, + TableExecuteContext tableExecuteContext); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/TableWriterOperator.java b/core/trino-main/src/main/java/io/trino/operator/TableWriterOperator.java index b1034d8d6c2f..bf84936f95b4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TableWriterOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/TableWriterOperator.java @@ -98,8 +98,12 @@ public TableWriterOperatorFactory( this.columnChannels = requireNonNull(columnChannels, "columnChannels is null"); this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); - checkArgument(writerTarget instanceof CreateTarget || writerTarget instanceof InsertTarget || writerTarget instanceof TableWriterNode.RefreshMaterializedViewTarget, - "writerTarget must be CreateTarget, InsertTarget or RefreshMaterializedViewTarget"); + checkArgument( + writerTarget instanceof CreateTarget + || writerTarget instanceof InsertTarget + || writerTarget instanceof TableWriterNode.RefreshMaterializedViewTarget + || writerTarget instanceof TableWriterNode.TableExecuteTarget, + "writerTarget must be CreateTarget, InsertTarget, RefreshMaterializedViewTarget or TableExecuteTarget"); this.target = requireNonNull(writerTarget, "writerTarget is null"); this.session = session; this.statisticsAggregationOperatorFactory = requireNonNull(statisticsAggregationOperatorFactory, "statisticsAggregationOperatorFactory is null"); @@ -127,6 +131,9 @@ private ConnectorPageSink createPageSink() if (target instanceof TableWriterNode.RefreshMaterializedViewTarget) { return pageSinkManager.createPageSink(session, ((TableWriterNode.RefreshMaterializedViewTarget) target).getInsertHandle()); } + if (target instanceof TableWriterNode.TableExecuteTarget) { + return pageSinkManager.createPageSink(session, ((TableWriterNode.TableExecuteTarget) target).getExecuteHandle()); + } throw new UnsupportedOperationException("Unhandled target type: " + target.getClass().getName()); } diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControl.java b/core/trino-main/src/main/java/io/trino/security/AccessControl.java index a6b58d61c5c6..811baab41baa 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControl.java @@ -505,6 +505,13 @@ void checkCanRevokeRoles(SecurityContext context, */ void checkCanExecuteFunction(SecurityContext context, String functionName); + /** + * Check if identity is allowed to execute given table procedure on given table + * + * @throws AccessDeniedException if not allowed + */ + void checkCanExecuteTableProcedure(SecurityContext context, QualifiedObjectName tableName, String procedureName); + default List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObjectName tableName) { return ImmutableList.of(); diff --git a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java index 324447885e26..a134f5d1b687 100644 --- a/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java +++ b/core/trino-main/src/main/java/io/trino/security/AccessControlManager.java @@ -1078,6 +1078,27 @@ public void checkCanExecuteFunction(SecurityContext context, String functionName systemAuthorizationCheck(control -> control.checkCanExecuteFunction(context.toSystemSecurityContext(), functionName)); } + @Override + public void checkCanExecuteTableProcedure(SecurityContext securityContext, QualifiedObjectName tableName, String procedureName) + { + requireNonNull(securityContext, "securityContext is null"); + requireNonNull(procedureName, "procedureName is null"); + requireNonNull(tableName, "tableName is null"); + + systemAuthorizationCheck(control -> control.checkCanExecuteTableProcedure( + securityContext.toSystemSecurityContext(), + tableName.asCatalogSchemaTableName(), + procedureName)); + + catalogAuthorizationCheck( + tableName.getCatalogName(), + securityContext, + (control, context) -> control.checkCanExecuteTableProcedure( + context, + tableName.asSchemaTableName(), + procedureName)); + } + @Override public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObjectName tableName) { diff --git a/core/trino-main/src/main/java/io/trino/security/AllowAllAccessControl.java b/core/trino-main/src/main/java/io/trino/security/AllowAllAccessControl.java index 5f0327c90b49..f7dde2c1c022 100644 --- a/core/trino-main/src/main/java/io/trino/security/AllowAllAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/AllowAllAccessControl.java @@ -348,4 +348,9 @@ public void checkCanExecuteProcedure(SecurityContext context, QualifiedObjectNam public void checkCanExecuteFunction(SecurityContext context, String functionName) { } + + @Override + public void checkCanExecuteTableProcedure(SecurityContext context, QualifiedObjectName tableName, String procedureName) + { + } } diff --git a/core/trino-main/src/main/java/io/trino/security/DenyAllAccessControl.java b/core/trino-main/src/main/java/io/trino/security/DenyAllAccessControl.java index 1a1741e971f8..dd2a893343b6 100644 --- a/core/trino-main/src/main/java/io/trino/security/DenyAllAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/DenyAllAccessControl.java @@ -46,6 +46,7 @@ import static io.trino.spi.security.AccessDeniedException.denyExecuteFunction; import static io.trino.spi.security.AccessDeniedException.denyExecuteProcedure; import static io.trino.spi.security.AccessDeniedException.denyExecuteQuery; +import static io.trino.spi.security.AccessDeniedException.denyExecuteTableProcedure; import static io.trino.spi.security.AccessDeniedException.denyGrantExecuteFunctionPrivilege; import static io.trino.spi.security.AccessDeniedException.denyGrantRoles; import static io.trino.spi.security.AccessDeniedException.denyGrantSchemaPrivilege; @@ -465,4 +466,10 @@ public void checkCanExecuteFunction(SecurityContext context, String functionName { denyExecuteFunction(functionName); } + + @Override + public void checkCanExecuteTableProcedure(SecurityContext context, QualifiedObjectName tableName, String procedureName) + { + denyExecuteTableProcedure(tableName.toString(), procedureName.toString()); + } } diff --git a/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java b/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java index 2d3d952888d0..216b63a6637d 100644 --- a/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/ForwardingAccessControl.java @@ -429,6 +429,12 @@ public void checkCanExecuteFunction(SecurityContext context, String functionName delegate().checkCanExecuteFunction(context, functionName); } + @Override + public void checkCanExecuteTableProcedure(SecurityContext context, QualifiedObjectName tableName, String procedureName) + { + delegate().checkCanExecuteTableProcedure(context, tableName, procedureName); + } + @Override public List<ViewExpression> getRowFilters(SecurityContext context, QualifiedObjectName tableName) { diff --git a/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java b/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java index 556f8bca4d51..85bd4405613a 100644 --- a/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java +++ b/core/trino-main/src/main/java/io/trino/security/InjectedConnectorAccessControl.java @@ -413,6 +413,16 @@ public void checkCanExecuteProcedure(ConnectorSecurityContext context, SchemaRou accessControl.checkCanExecuteProcedure(securityContext, new QualifiedObjectName(catalogName, procedure.getSchemaName(), procedure.getRoutineName())); } + @Override + public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, SchemaTableName tableName, String procedure) + { + checkArgument(context == null, "context must be null"); + accessControl.checkCanExecuteTableProcedure( + securityContext, + getQualifiedObjectName(tableName), + procedure); + } + @Override public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName) { diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index fd7e52580e81..b768c9c62e15 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -44,6 +44,7 @@ import io.trino.execution.NodeTaskMap; import io.trino.execution.QueryManagerConfig; import io.trino.execution.SqlTaskManager; +import io.trino.execution.TableExecuteContextManager; import io.trino.execution.TaskManagementExecutor; import io.trino.execution.TaskManager; import io.trino.execution.TaskManagerConfig; @@ -77,6 +78,7 @@ import io.trino.metadata.StaticCatalogStore; import io.trino.metadata.StaticCatalogStoreConfig; import io.trino.metadata.SystemSecurityMetadata; +import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TablePropertyManager; import io.trino.operator.ExchangeClientConfig; import io.trino.operator.ExchangeClientFactory; @@ -233,6 +235,9 @@ protected void setup(Binder binder) // analyze properties binder.bind(AnalyzePropertyManager.class).in(Scopes.SINGLETON); + // table procedures properties + binder.bind(TableProceduresPropertyManager.class).in(Scopes.SINGLETON); + // node manager discoveryBinder(binder).bindSelector("trino"); binder.bind(DiscoveryNodeManager.class).in(Scopes.SINGLETON); @@ -271,6 +276,7 @@ protected void setup(Binder binder) binder.bind(TaskManagementExecutor.class).in(Scopes.SINGLETON); binder.bind(SqlTaskManager.class).in(Scopes.SINGLETON); binder.bind(TaskManager.class).to(Key.get(SqlTaskManager.class)); + binder.bind(TableExecuteContextManager.class).in(Scopes.SINGLETON); // memory revoking scheduler binder.bind(MemoryRevokingScheduler.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/main/java/io/trino/split/BufferingSplitSource.java b/core/trino-main/src/main/java/io/trino/split/BufferingSplitSource.java index d1fa23836286..8bd175b23116 100644 --- a/core/trino-main/src/main/java/io/trino/split/BufferingSplitSource.java +++ b/core/trino-main/src/main/java/io/trino/split/BufferingSplitSource.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; @@ -65,6 +66,12 @@ public boolean isFinished() return source.isFinished(); } + @Override + public Optional<List<Object>> getTableExecuteSplitsInfo() + { + return source.getTableExecuteSplitsInfo(); + } + private static class GetNextBatch { private final SplitSource splitSource; diff --git a/core/trino-main/src/main/java/io/trino/split/ConnectorAwareSplitSource.java b/core/trino-main/src/main/java/io/trino/split/ConnectorAwareSplitSource.java index 4bce504db55b..83792cbf5674 100644 --- a/core/trino-main/src/main/java/io/trino/split/ConnectorAwareSplitSource.java +++ b/core/trino-main/src/main/java/io/trino/split/ConnectorAwareSplitSource.java @@ -24,6 +24,9 @@ import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorSplitSource.ConnectorSplitBatch; +import java.util.List; +import java.util.Optional; + import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.toListenableFuture; import static java.util.Objects.requireNonNull; @@ -71,6 +74,12 @@ public boolean isFinished() return source.isFinished(); } + @Override + public Optional<List<Object>> getTableExecuteSplitsInfo() + { + return source.getTableExecuteSplitsInfo(); + } + @Override public String toString() { diff --git a/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java b/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java index 1c9f7d62f4e8..0e54c712f6de 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java +++ b/core/trino-main/src/main/java/io/trino/split/PageSinkManager.java @@ -17,6 +17,7 @@ import io.trino.connector.CatalogName; import io.trino.metadata.InsertTableHandle; import io.trino.metadata.OutputTableHandle; +import io.trino.metadata.TableExecuteHandle; import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; import io.trino.spi.connector.ConnectorSession; @@ -61,6 +62,14 @@ public ConnectorPageSink createPageSink(Session session, InsertTableHandle table return providerFor(tableHandle.getCatalogName()).createPageSink(tableHandle.getTransactionHandle(), connectorSession, tableHandle.getConnectorHandle()); } + @Override + public ConnectorPageSink createPageSink(Session session, TableExecuteHandle tableHandle) + { + // assumes connectorId and catalog are the same + ConnectorSession connectorSession = session.toConnectorSession(tableHandle.getCatalogName()); + return providerFor(tableHandle.getCatalogName()).createPageSink(tableHandle.getTransactionHandle(), connectorSession, tableHandle.getConnectorHandle()); + } + private ConnectorPageSinkProvider providerFor(CatalogName catalogName) { ConnectorPageSinkProvider provider = pageSinkProviders.get(catalogName); diff --git a/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java b/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java index 8cf84625635a..effcfaef23f3 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java +++ b/core/trino-main/src/main/java/io/trino/split/PageSinkProvider.java @@ -16,6 +16,7 @@ import io.trino.Session; import io.trino.metadata.InsertTableHandle; import io.trino.metadata.OutputTableHandle; +import io.trino.metadata.TableExecuteHandle; import io.trino.spi.connector.ConnectorPageSink; public interface PageSinkProvider @@ -23,4 +24,6 @@ public interface PageSinkProvider ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle); ConnectorPageSink createPageSink(Session session, InsertTableHandle tableHandle); + + ConnectorPageSink createPageSink(Session session, TableExecuteHandle tableHandle); } diff --git a/core/trino-main/src/main/java/io/trino/split/SampledSplitSource.java b/core/trino-main/src/main/java/io/trino/split/SampledSplitSource.java index 3130ebead480..120ff8d6a079 100644 --- a/core/trino-main/src/main/java/io/trino/split/SampledSplitSource.java +++ b/core/trino-main/src/main/java/io/trino/split/SampledSplitSource.java @@ -21,6 +21,8 @@ import javax.annotation.Nullable; +import java.util.List; +import java.util.Optional; import java.util.concurrent.ThreadLocalRandom; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -68,4 +70,13 @@ public boolean isFinished() { return splitSource.isFinished(); } + + @Override + public Optional<List<Object>> getTableExecuteSplitsInfo() + { + splitSource.getTableExecuteSplitsInfo().ifPresent(splitInfo -> { + throw new IllegalStateException("Cannot use SampledSplitSource with SplitSource which returns non-empty TableExecuteSplitsInfo=" + splitInfo); + }); + return Optional.empty(); + } } diff --git a/core/trino-main/src/main/java/io/trino/split/SplitSource.java b/core/trino-main/src/main/java/io/trino/split/SplitSource.java index fff277bec426..b2247b4f9440 100644 --- a/core/trino-main/src/main/java/io/trino/split/SplitSource.java +++ b/core/trino-main/src/main/java/io/trino/split/SplitSource.java @@ -21,6 +21,7 @@ import java.io.Closeable; import java.util.List; +import java.util.Optional; import static java.util.Objects.requireNonNull; @@ -36,6 +37,8 @@ public interface SplitSource boolean isFinished(); + Optional<List<Object>> getTableExecuteSplitsInfo(); + class SplitBatch { private final List<Split> splits; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index d49c33d6097a..85e3331352e6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -26,6 +26,7 @@ import io.trino.metadata.NewTableLayout; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableExecuteHandle; import io.trino.metadata.TableHandle; import io.trino.security.AccessControl; import io.trino.security.SecurityContext; @@ -215,6 +216,8 @@ public class Analysis private final Multimap<Field, SourceColumn> originColumnDetails = ArrayListMultimap.create(); private final Multimap<NodeRef<Expression>, Field> fieldLineage = ArrayListMultimap.create(); + private Optional<TableExecuteHandle> tableExecuteHandle = Optional.empty(); + public Analysis(@Nullable Statement root, Map<NodeRef<Parameter>, Expression> parameters, QueryType queryType) { this.root = root; @@ -1114,6 +1117,18 @@ public PredicateCoercions getPredicateCoercions(Expression expression) return predicateCoercions.get(NodeRef.of(expression)); } + public void setTableExecuteHandle(TableExecuteHandle tableExecuteHandle) + { + requireNonNull(tableExecuteHandle, "tableExecuteHandle is null"); + checkState(this.tableExecuteHandle.isEmpty(), "tableExecuteHandle already set"); + this.tableExecuteHandle = Optional.of(tableExecuteHandle); + } + + public Optional<TableExecuteHandle> getTableExecuteHandle() + { + return tableExecuteHandle; + } + @Immutable public static final class SelectExpression { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index ecd832667db7..ccd4717c1dc8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -33,6 +33,7 @@ import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.RedirectionAwareTableHandle; import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableExecuteHandle; import io.trino.metadata.TableHandle; import io.trino.metadata.TableMetadata; import io.trino.metadata.TableSchema; @@ -51,6 +52,7 @@ import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.ConnectorViewDefinition.ViewColumn; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.connector.TableProcedureMetadata; import io.trino.spi.function.OperatorType; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.GroupProvider; @@ -86,6 +88,7 @@ import io.trino.sql.tree.Analyze; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.Call; +import io.trino.sql.tree.CallArgument; import io.trino.sql.tree.Comment; import io.trino.sql.tree.Commit; import io.trino.sql.tree.CreateMaterializedView; @@ -171,6 +174,7 @@ import io.trino.sql.tree.SubqueryExpression; import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableExecute; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.Union; import io.trino.sql.tree.Unnest; @@ -191,6 +195,7 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -942,6 +947,15 @@ protected Scope visitProperty(Property node, Optional<Scope> scope) return createAndAssignScope(node, scope); } + @Override + protected Scope visitCallArgument(CallArgument node, Optional<Scope> scope) + { + // CallArgument value expressions must be constant + createConstantAnalyzer(metadata, accessControl, session, analysis.getParameters(), WarningCollector.NOOP, analysis.isDescribe()) + .analyze(node.getValue(), createScope(scope)); + return createAndAssignScope(node, scope); + } + @Override protected Scope visitDropTable(DropTable node, Optional<Scope> scope) { @@ -984,6 +998,122 @@ protected Scope visitSetTableAuthorization(SetTableAuthorization node, Optional< return createAndAssignScope(node, scope); } + @Override + protected Scope visitTableExecute(TableExecute node, Optional<Scope> scope) + { + Table table = node.getTable(); + QualifiedObjectName originalName = createQualifiedObjectName(session, table, table.getName()); + String procedureName = node.getProcedureName().getCanonicalValue(); + + if (metadata.getMaterializedView(session, originalName).isPresent()) { + throw semanticException(NOT_SUPPORTED, node, "ALTER TABLE EXECUTE is not supported for materialized views"); + } + + if (metadata.getView(session, originalName).isPresent()) { + throw semanticException(NOT_SUPPORTED, node, "ALTER TABLE EXECUTE is not supported for views"); + } + + RedirectionAwareTableHandle redirection = metadata.getRedirectionAwareTableHandle(session, originalName); + QualifiedObjectName tableName = redirection.getRedirectedTableName().orElse(originalName); + TableHandle tableHandle = redirection.getTableHandle() + .orElseThrow(() -> semanticException(TABLE_NOT_FOUND, table, "Table '%s' does not exist", tableName)); + + accessControl.checkCanExecuteTableProcedure( + session.toSecurityContext(), + tableName, + procedureName); + + if (!accessControl.getRowFilters(session.toSecurityContext(), tableName).isEmpty()) { + throw semanticException(NOT_SUPPORTED, node, "ALTER TABLE EXECUTE is not supported for table with row filter"); + } + + TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); + for (ColumnMetadata tableColumn : tableMetadata.getColumns()) { + if (!accessControl.getColumnMasks(session.toSecurityContext(), tableName, tableColumn.getName(), tableColumn.getType()).isEmpty()) { + throw semanticException(NOT_SUPPORTED, node, "ALTER TABLE EXECUTE is not supported for table with column masks"); + } + } + + Scope tableScope = analyze(table, scope); + + CatalogName catalogName = getRequiredCatalogHandle(metadata, session, node, tableName.getCatalogName()); + TableProcedureMetadata procedureMetadata = metadata.getTableProcedureRegistry().resolve(catalogName, procedureName); + + // analyze WHERE + if (!procedureMetadata.getExecutionMode().supportsFilter() && node.getWhere().isPresent()) { + throw semanticException(NOT_SUPPORTED, node, "WHERE not supported for procedure " + procedureName); + } + node.getWhere().ifPresent(where -> analyzeWhere(node, tableScope, where)); + + // analyze arguments + + Map<String, Expression> propertiesMap = processTableExecuteArguments(node, procedureMetadata, scope); + Map<String, Object> tableProperties = metadata.getTableProceduresPropertyManager().getProperties( + catalogName, + procedureName, + catalogName.getCatalogName(), + propertiesMap, + session, + metadata, + accessControl, + analysis.getParameters(), + true); + + TableExecuteHandle executeHandle = + metadata.getTableHandleForExecute( + session, + tableHandle, + procedureName, + tableProperties) + .orElseThrow(() -> semanticException(NOT_SUPPORTED, node, "Procedure '%s' cannot be executed on table '%s'", procedureName, tableName)); + + analysis.setTableExecuteHandle(executeHandle); + + analysis.setUpdateType("ALTER TABLE EXECUTE"); + analysis.setUpdateTarget(tableName, Optional.of(table), Optional.empty()); + + return createAndAssignScope(node, scope, Field.newUnqualified("rows", BIGINT)); + } + + private Map<String, Expression> processTableExecuteArguments(TableExecute node, TableProcedureMetadata procedureMetadata, Optional<Scope> scope) + { + List<CallArgument> arguments = node.getArguments(); + Predicate<CallArgument> hasName = argument -> argument.getName().isPresent(); + boolean anyNamed = arguments.stream().anyMatch(hasName); + boolean allNamed = arguments.stream().allMatch(hasName); + if (anyNamed && !allNamed) { + throw semanticException(INVALID_ARGUMENTS, node, "Named and positional arguments cannot be mixed"); + } + + if (!anyNamed && arguments.size() > procedureMetadata.getProperties().size()) { + throw semanticException(INVALID_ARGUMENTS, node, "Too many positional arguments"); + } + + for (CallArgument argument : arguments) { + process(argument, scope); + } + + Map<String, Expression> argumentsMap = new HashMap<>(); + + if (anyNamed) { + // all properties named + for (CallArgument argument : arguments) { + if (argumentsMap.put(argument.getName().get(), argument.getValue()) != null) { + throw semanticException(DUPLICATE_PROPERTY, argument, "Duplicate named argument: %s", argument.getName()); + } + } + } + else { + // all properties unnamed + int pos = 0; + for (CallArgument argument : arguments) { + argumentsMap.put(procedureMetadata.getProperties().get(pos).getName(), argument.getValue()); + pos++; + } + } + return ImmutableMap.copyOf(argumentsMap); + } + @Override protected Scope visitRenameView(RenameView node, Optional<Scope> scope) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java index 820cd71edf3f..2519425634ac 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DistributedExecutionPlanner.java @@ -56,6 +56,7 @@ import io.trino.sql.planner.plan.SpatialJoinNode; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -438,6 +439,12 @@ public Map<PlanNodeId, SplitSource> visitTableDelete(TableDeleteNode node, Void return ImmutableMap.of(); } + @Override + public Map<PlanNodeId, SplitSource> visitTableExecute(TableExecuteNode node, Void context) + { + return node.getSource().accept(this, context); + } + @Override public Map<PlanNodeId, SplitSource> visitUnion(UnionNode node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index fed698fc8f8e..ad78ff1d98d0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -34,6 +34,7 @@ import io.trino.execution.DynamicFilterConfig; import io.trino.execution.ExplainAnalyzeContext; import io.trino.execution.StageId; +import io.trino.execution.TableExecuteContextManager; import io.trino.execution.TaskId; import io.trino.execution.TaskManagerConfig; import io.trino.execution.buffer.OutputBuffer; @@ -41,6 +42,7 @@ import io.trino.index.IndexManager; import io.trino.metadata.Metadata; import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableExecuteHandle; import io.trino.metadata.TableHandle; import io.trino.operator.AggregationOperator.AggregationOperatorFactory; import io.trino.operator.AssignUniqueIdOperator; @@ -194,10 +196,12 @@ import io.trino.sql.planner.plan.StatisticAggregationsDescriptor; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; +import io.trino.sql.planner.plan.TableWriterNode.TableExecuteTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TopNNode; import io.trino.sql.planner.plan.TopNRankingNode; @@ -329,6 +333,7 @@ import static io.trino.util.SpatialJoinUtils.ST_WITHIN; import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialComparisons; import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialFunctions; +import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -362,6 +367,7 @@ public class LocalExecutionPlanner private final DynamicFilterConfig dynamicFilterConfig; private final TypeOperators typeOperators; private final BlockTypeOperators blockTypeOperators; + private final TableExecuteContextManager tableExecuteContextManager; @Inject public LocalExecutionPlanner( @@ -387,7 +393,8 @@ public LocalExecutionPlanner( OrderingCompiler orderingCompiler, DynamicFilterConfig dynamicFilterConfig, TypeOperators typeOperators, - BlockTypeOperators blockTypeOperators) + BlockTypeOperators blockTypeOperators, + TableExecuteContextManager tableExecuteContextManager) { this.explainAnalyzeContext = requireNonNull(explainAnalyzeContext, "explainAnalyzeContext is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); @@ -415,6 +422,7 @@ public LocalExecutionPlanner( this.dynamicFilterConfig = requireNonNull(dynamicFilterConfig, "dynamicFilterConfig is null"); this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null"); + this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); } public LocalExecutionPlan plan( @@ -3019,6 +3027,8 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl createTableFinisher(session, node, metadata), statisticsAggregation, descriptor, + tableExecuteContextManager, + shouldOutputRowCount(node), session); Map<Symbol, Integer> layout = ImmutableMap.of(node.getOutputSymbols().get(0), 0); @@ -3068,6 +3078,36 @@ private List<Integer> createColumnValueAndRowIdChannels(List<Symbol> outputSymbo return Arrays.asList(columnValueAndRowIdChannels); } + @Override + public PhysicalOperation visitTableExecute(TableExecuteNode node, LocalExecutionPlanContext context) + { + // Set table writer count + context.setDriverInstanceCount(getTaskWriterCount(session)); + + PhysicalOperation source = node.getSource().accept(this, context); + + ImmutableMap.Builder<Symbol, Integer> outputMapping = ImmutableMap.builder(); + outputMapping.put(node.getOutputSymbols().get(0), ROW_COUNT_CHANNEL); + outputMapping.put(node.getOutputSymbols().get(1), FRAGMENT_CHANNEL); + + List<Integer> inputChannels = node.getColumns().stream() + .map(source::symbolToChannel) + .collect(toImmutableList()); + + OperatorFactory operatorFactory = new TableWriterOperatorFactory( + context.getNextOperatorId(), + node.getId(), + pageSinkManager, + node.getTarget(), + inputChannels, + nCopies(inputChannels.size(), null), // N x null means no not-null checking will be performed. This is ok as in TableExecute flow we are not changing any table data. + session, + new DevNullOperatorFactory(context.getNextOperatorId(), node.getId()), // statistics are not calculated + getSymbolTypes(node.getOutputSymbols(), context.getTypes())); + + return new PhysicalOperation(operatorFactory, outputMapping.build(), context, source); + } + @Override public PhysicalOperation visitTableDelete(TableDeleteNode node, LocalExecutionPlanContext context) { @@ -3586,7 +3626,7 @@ private static List<Type> getTypes(List<Expression> expressions, Map<NodeRef<Exp private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata) { WriterTarget target = node.getTarget(); - return (fragments, statistics) -> { + return (fragments, statistics, tableExecuteContext) -> { if (target instanceof CreateTarget) { return metadata.finishCreateTable(session, ((CreateTarget) target).getHandle(), fragments, statistics); } @@ -3611,12 +3651,23 @@ else if (target instanceof UpdateTarget) { metadata.finishUpdate(session, ((UpdateTarget) target).getHandleOrElseThrow(), fragments); return Optional.empty(); } + else if (target instanceof TableExecuteTarget) { + TableExecuteHandle tableExecuteHandle = ((TableExecuteTarget) target).getExecuteHandle(); + metadata.finishTableExecute(session, tableExecuteHandle, fragments, tableExecuteContext.getSplitsInfo()); + return Optional.empty(); + } else { throw new AssertionError("Unhandled target type: " + target.getClass().getName()); } }; } + private static boolean shouldOutputRowCount(TableFinishNode node) + { + WriterTarget target = node.getTarget(); + return !(target instanceof TableExecuteTarget); + } + private static Function<Page, Page> enforceLoadedLayoutProcessor(List<Symbol> expectedLayout, Map<Symbol, Integer> inputLayout) { int[] channels = expectedLayout.stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index b2dc1125c3b2..4144fad4c2dc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -29,6 +29,7 @@ import io.trino.metadata.NewTableLayout; import io.trino.metadata.QualifiedObjectName; import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableExecuteHandle; import io.trino.metadata.TableHandle; import io.trino.metadata.TableMetadata; import io.trino.spi.TrinoException; @@ -51,6 +52,7 @@ import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; +import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.LimitNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.PlanNode; @@ -58,6 +60,7 @@ import io.trino.sql.planner.plan.RefreshMaterializedViewNode; import io.trino.sql.planner.plan.StatisticAggregations; import io.trino.sql.planner.plan.StatisticsWriterNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -86,6 +89,8 @@ import io.trino.sql.tree.Row; import io.trino.sql.tree.Statement; import io.trino.sql.tree.StringLiteral; +import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableExecute; import io.trino.sql.tree.Update; import io.trino.type.TypeCoercion; import io.trino.type.UnknownType; @@ -108,6 +113,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Streams.zip; import static io.trino.SystemSessionProperties.isCollectPlanStatisticsForAllQueries; +import static io.trino.metadata.MetadataUtil.createQualifiedObjectName; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.statistics.TableStatisticType.ROW_COUNT; import static io.trino.spi.type.BigintType.BIGINT; @@ -117,6 +123,7 @@ import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED; import static io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED; +import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.QueryPlanner.visibleFields; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; @@ -280,6 +287,9 @@ private RelationPlan planStatementWithoutOutput(Analysis analysis, Statement sta if (statement instanceof ExplainAnalyze) { return createExplainAnalyzePlan(analysis, (ExplainAnalyze) statement); } + if (statement instanceof TableExecute) { + return createTableExecutePlan(analysis, (TableExecute) statement); + } throw new TrinoException(NOT_SUPPORTED, "Unsupported statement type " + statement.getClass().getSimpleName()); } @@ -718,8 +728,17 @@ private PlanNode createOutputPlan(RelationPlan plan, Analysis analysis) private RelationPlan createRelationPlan(Analysis analysis, Query query) { - return new RelationPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, Optional.empty(), session, ImmutableMap.of()) - .process(query, null); + return getRelationPlanner(analysis).process(query, null); + } + + private RelationPlan createRelationPlan(Analysis analysis, Table table) + { + return getRelationPlanner(analysis).process(table, null); + } + + private RelationPlanner getRelationPlanner(Analysis analysis) + { + return new RelationPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, Optional.empty(), session, ImmutableMap.of()); } private static Map<NodeRef<LambdaArgumentDeclaration>, Symbol> buildLambdaDeclarationToSymbolMap(Analysis analysis, SymbolAllocator symbolAllocator) @@ -750,6 +769,83 @@ private static Map<NodeRef<LambdaArgumentDeclaration>, Symbol> buildLambdaDeclar return result; } + private RelationPlan createTableExecutePlan(Analysis analysis, TableExecute statement) + { + Table table = statement.getTable(); + TableHandle tableHandle = analysis.getTableHandle(table); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, table.getName()); + TableExecuteHandle executeHandle = analysis.getTableExecuteHandle().orElseThrow(); + + RelationPlan tableScanPlan = createRelationPlan(analysis, table); + PlanBuilder sourcePlanBuilder = newPlanBuilder(tableScanPlan, analysis, ImmutableMap.of(), ImmutableMap.of()); + if (statement.getWhere().isPresent()) { + SubqueryPlanner subqueryPlanner = new SubqueryPlanner(analysis, symbolAllocator, idAllocator, buildLambdaDeclarationToSymbolMap(analysis, symbolAllocator), metadata, typeCoercion, Optional.empty(), session, ImmutableMap.of()); + Expression whereExpression = statement.getWhere().get(); + sourcePlanBuilder = subqueryPlanner.handleSubqueries(sourcePlanBuilder, whereExpression, analysis.getSubqueries(statement)); + sourcePlanBuilder = sourcePlanBuilder.withNewRoot(new FilterNode(idAllocator.getNextId(), sourcePlanBuilder.getRoot(), sourcePlanBuilder.rewrite(whereExpression))); + } + + PlanNode sourcePlanRoot = sourcePlanBuilder.getRoot(); + + TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); + List<String> columnNames = tableMetadata.getColumns().stream() + .filter(column -> !column.isHidden()) // todo this filter is redundant + .map(ColumnMetadata::getName) + .collect(toImmutableList()); + + TableWriterNode.TableExecuteTarget tableExecuteTarget = new TableWriterNode.TableExecuteTarget(executeHandle, Optional.empty(), tableName.asSchemaTableName()); + + Optional<NewTableLayout> layout = metadata.getLayoutForTableExecute(session, executeHandle); + + List<Symbol> symbols = visibleFields(tableScanPlan); + + // todo extract common method to be used here and in createTableWriterPlan() + Optional<PartitioningScheme> partitioningScheme = Optional.empty(); + Optional<PartitioningScheme> preferredPartitioningScheme = Optional.empty(); + if (layout.isPresent()) { + List<Symbol> partitionFunctionArguments = new ArrayList<>(); + layout.get().getPartitionColumns().stream() + .mapToInt(columnNames::indexOf) + .mapToObj(symbols::get) + .forEach(partitionFunctionArguments::add); + + List<Symbol> outputLayout = new ArrayList<>(symbols); + + Optional<PartitioningHandle> partitioningHandle = layout.get().getPartitioning(); + if (partitioningHandle.isPresent()) { + partitioningScheme = Optional.of(new PartitioningScheme( + Partitioning.create(partitioningHandle.get(), partitionFunctionArguments), + outputLayout)); + } + else { + // empty connector partitioning handle means evenly partitioning on partitioning columns + preferredPartitioningScheme = Optional.of(new PartitioningScheme( + Partitioning.create(FIXED_HASH_DISTRIBUTION, partitionFunctionArguments), + outputLayout)); + } + } + + verify(columnNames.size() == symbols.size(), "columnNames.size() != symbols.size(): %s and %s", columnNames, symbols); + TableFinishNode commitNode = new TableFinishNode( + idAllocator.getNextId(), + new TableExecuteNode( + idAllocator.getNextId(), + sourcePlanRoot, + tableExecuteTarget, + symbolAllocator.newSymbol("partialrows", BIGINT), + symbolAllocator.newSymbol("fragment", VARBINARY), + symbols, + columnNames, + partitioningScheme, + preferredPartitioningScheme), + tableExecuteTarget, + symbolAllocator.newSymbol("rows", BIGINT), + Optional.empty(), + Optional.empty()); + + return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputSymbols(), Optional.empty()); + } + private static class Key { private final LambdaArgumentDeclaration argument; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index b149e1c26e7b..9872876cb75b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -48,6 +48,7 @@ import io.trino.sql.planner.iterative.rule.DesugarLike; import io.trino.sql.planner.iterative.rule.DesugarTryExpression; import io.trino.sql.planner.iterative.rule.DetermineJoinDistributionType; +import io.trino.sql.planner.iterative.rule.DeterminePreferredTableExecutePartitioning; import io.trino.sql.planner.iterative.rule.DeterminePreferredWritePartitioning; import io.trino.sql.planner.iterative.rule.DetermineSemiJoinDistributionType; import io.trino.sql.planner.iterative.rule.DetermineTableScanNodePartitioning; @@ -786,7 +787,9 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new DeterminePreferredWritePartitioning())), + ImmutableSet.of( + new DeterminePreferredWritePartitioning(), + new DeterminePreferredTableExecutePartitioning())), // Because ReorderJoins runs only once, // PredicatePushDown, columnPruningOptimizer and RemoveRedundantIdentityProjections // need to run beforehand in order to produce an optimal join order diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DeterminePreferredTableExecutePartitioning.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DeterminePreferredTableExecutePartitioning.java new file mode 100644 index 000000000000..1023aa4a4f8d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DeterminePreferredTableExecutePartitioning.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import io.trino.Session; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.TableExecuteNode; + +import java.util.Optional; + +import static io.trino.SystemSessionProperties.getPreferredWritePartitioningMinNumberOfPartitions; +import static io.trino.SystemSessionProperties.isUsePreferredWritePartitioning; +import static io.trino.cost.AggregationStatsRule.getRowsCount; +import static io.trino.sql.planner.plan.Patterns.tableExecute; +import static java.lang.Double.isNaN; + +/** + * Replaces {@link TableExecuteNode} with {@link TableExecuteNode#getPreferredPartitioningScheme()} + * with a {@link TableExecuteNode} with {@link TableExecuteNode#getPartitioningScheme()} set. + */ +public class DeterminePreferredTableExecutePartitioning + implements Rule<TableExecuteNode> +{ + public static final Pattern<TableExecuteNode> TABLE_EXECUTE_NODE_WITH_PREFERRED_PARTITIONING = tableExecute() + .matching(node -> node.getPreferredPartitioningScheme().isPresent()); + + @Override + public Pattern<TableExecuteNode> getPattern() + { + return TABLE_EXECUTE_NODE_WITH_PREFERRED_PARTITIONING; + } + + @Override + public boolean isEnabled(Session session) + { + return isUsePreferredWritePartitioning(session); + } + + @Override + public Result apply(TableExecuteNode node, Captures captures, Context context) + { + int minimumNumberOfPartitions = getPreferredWritePartitioningMinNumberOfPartitions(context.getSession()); + if (minimumNumberOfPartitions <= 1) { + // Force 'preferred write partitioning' even if stats are missing or broken + return enable(node); + } + + double expectedNumberOfPartitions = getRowsCount( + context.getStatsProvider().getStats(node.getSource()), + node.getPreferredPartitioningScheme().get().getPartitioning().getColumns()); + + if (isNaN(expectedNumberOfPartitions) || expectedNumberOfPartitions < minimumNumberOfPartitions) { + return Result.empty(); + } + + return enable(node); + } + + private static Result enable(TableExecuteNode node) + { + return Result.ofPlanNode(new TableExecuteNode( + node.getId(), + node.getSource(), + node.getTarget(), + node.getRowCountSymbol(), + node.getFragmentSymbol(), + node.getColumns(), + node.getColumnNames(), + node.getPreferredPartitioningScheme(), + Optional.empty())); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 90ca926d8dd8..1673388f0066 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -66,6 +66,7 @@ import io.trino.sql.planner.plan.SpatialJoinNode; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -594,28 +595,38 @@ public PlanWithProperties visitRefreshMaterializedView(RefreshMaterializedViewNo @Override public PlanWithProperties visitTableWriter(TableWriterNode node, PreferredProperties preferredProperties) { - PlanWithProperties source = node.getSource().accept(this, preferredProperties); + return visitTableWriter(node, node.getPartitioningScheme(), node.getSource(), preferredProperties); + } + + @Override + public PlanWithProperties visitTableExecute(TableExecuteNode node, PreferredProperties preferredProperties) + { + return visitTableWriter(node, node.getPartitioningScheme(), node.getSource(), preferredProperties); + } + + private PlanWithProperties visitTableWriter(PlanNode node, Optional<PartitioningScheme> partitioningScheme, PlanNode source, PreferredProperties preferredProperties) + { + PlanWithProperties newSource = source.accept(this, preferredProperties); - Optional<PartitioningScheme> partitioningScheme = node.getPartitioningScheme(); if (partitioningScheme.isEmpty()) { if (scaleWriters) { - partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), source.getNode().getOutputSymbols())); + partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols())); } else if (redistributeWrites) { - partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getNode().getOutputSymbols())); + partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols())); } } - if (partitioningScheme.isPresent() && !source.getProperties().isCompatibleTablePartitioningWith(partitioningScheme.get().getPartitioning(), false, metadata, session)) { - source = withDerivedProperties( + if (partitioningScheme.isPresent() && !newSource.getProperties().isCompatibleTablePartitioningWith(partitioningScheme.get().getPartitioning(), false, metadata, session)) { + newSource = withDerivedProperties( partitionedExchange( idAllocator.getNextId(), REMOTE, - source.getNode(), + newSource.getNode(), partitioningScheme.get()), - source.getProperties()); + newSource.getProperties()); } - return rebaseAndDeriveProperties(node, source); + return rebaseAndDeriveProperties(node, newSource); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java index b138fe75303a..bf29e5df87ad 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java @@ -53,6 +53,7 @@ import io.trino.sql.planner.plan.SortNode; import io.trino.sql.planner.plan.SpatialJoinNode; import io.trino.sql.planner.plan.StatisticsWriterNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -583,20 +584,32 @@ public PlanWithProperties visitTopNRanking(TopNRankingNode node, StreamPreferred } // - // Table Writer + // Table Writer and Table Execute // @Override public PlanWithProperties visitTableWriter(TableWriterNode node, StreamPreferredProperties parentPreferences) + { + return visitTableWriter(node, node.getPartitioningScheme(), node.getSource(), parentPreferences); + } + + @Override + public PlanWithProperties visitTableExecute(TableExecuteNode node, StreamPreferredProperties parentPreferences) + { + return visitTableWriter(node, node.getPartitioningScheme(), node.getSource(), parentPreferences); + } + + private PlanWithProperties visitTableWriter(PlanNode node, Optional<PartitioningScheme> partitioningSchemeOptional, PlanNode source, StreamPreferredProperties parentPreferences) { if (getTaskWriterCount(session) == 1) { return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); } - if (node.getPartitioningScheme().isEmpty()) { + if (partitioningSchemeOptional.isEmpty()) { return planAndEnforceChildren(node, fixedParallelism(), fixedParallelism()); } - PartitioningScheme partitioningScheme = node.getPartitioningScheme().get(); + PartitioningScheme partitioningScheme = partitioningSchemeOptional.get(); + if (partitioningScheme.getPartitioning().getHandle().equals(FIXED_HASH_DISTRIBUTION)) { // arbitrary hash function on predefined set of partition columns StreamPreferredProperties preference = partitionedOn(partitioningScheme.getPartitioning().getColumns()); @@ -608,14 +621,14 @@ public PlanWithProperties visitTableWriter(TableWriterNode node, StreamPreferred verify( partitioningScheme.getPartitioning().getArguments().stream().noneMatch(Partitioning.ArgumentBinding::isConstant), "Table writer partitioning has constant arguments"); - PlanWithProperties source = node.getSource().accept(this, parentPreferences); + PlanWithProperties newSource = source.accept(this, parentPreferences); PlanWithProperties exchange = deriveProperties( partitionedExchange( idAllocator.getNextId(), LOCAL, - source.getNode(), - node.getPartitioningScheme().get()), - source.getProperties()); + newSource.getNode(), + partitioningScheme), + newSource.getProperties()); return rebaseAndDeriveProperties(node, ImmutableList.of(exchange)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java index e0dfa7a820a5..1efaac34fedc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java @@ -17,7 +17,9 @@ import io.trino.cost.StatsAndCosts; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; +import io.trino.metadata.TableExecuteHandle; import io.trino.metadata.TableHandle; +import io.trino.spi.connector.BeginTableExecuteResult; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.TypeProvider; @@ -33,6 +35,7 @@ import io.trino.sql.planner.plan.SimplePlanRewriter; import io.trino.sql.planner.plan.SimplePlanRewriter.RewriteContext; import io.trino.sql.planner.plan.StatisticsWriterNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -41,11 +44,13 @@ import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget; import io.trino.sql.planner.plan.TableWriterNode.InsertReference; import io.trino.sql.planner.plan.TableWriterNode.InsertTarget; +import io.trino.sql.planner.plan.TableWriterNode.TableExecuteTarget; import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget; import io.trino.sql.planner.plan.TableWriterNode.WriterTarget; import io.trino.sql.planner.plan.UnionNode; import io.trino.sql.planner.plan.UpdateNode; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; @@ -149,6 +154,22 @@ public PlanNode visitUpdate(UpdateNode node, RewriteContext<Optional<WriterTarge node.getOutputSymbols()); } + @Override + public PlanNode visitTableExecute(TableExecuteNode node, RewriteContext<Optional<WriterTarget>> context) + { + TableExecuteTarget tableExecuteTarget = (TableExecuteTarget) getContextTarget(context); + return new TableExecuteNode( + node.getId(), + rewriteModifyTableScan(node.getSource(), tableExecuteTarget.getSourceHandle().orElseThrow()), + tableExecuteTarget, + node.getRowCountSymbol(), + node.getFragmentSymbol(), + node.getColumns(), + node.getColumnNames(), + node.getPartitioningScheme(), + node.getPreferredPartitioningScheme()); + } + @Override public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext<Optional<WriterTarget>> context) { @@ -207,6 +228,13 @@ public WriterTarget getWriterTarget(PlanNode node) update.getUpdatedColumns(), update.getUpdatedColumnHandles()); } + if (node instanceof TableExecuteNode) { + TableExecuteTarget target = ((TableExecuteNode) node).getTarget(); + return new TableExecuteTarget( + target.getExecuteHandle(), + findTableScanHandleForTableExecute(((TableExecuteNode) node).getSource()), + target.getSchemaTableName()); + } if (node instanceof ExchangeNode || node instanceof UnionNode) { Set<WriterTarget> writerTargets = node.getSources().stream() .map(this::getWriterTarget) @@ -250,6 +278,12 @@ private WriterTarget createWriterTarget(WriterTarget target) metadata.getTableMetadata(session, refreshMV.getStorageTableHandle()).getTable(), refreshMV.getSourceTableHandles()); } + if (target instanceof TableExecuteTarget) { + TableExecuteTarget tableExecute = (TableExecuteTarget) target; + BeginTableExecuteResult<TableExecuteHandle, TableHandle> result = metadata.beginTableExecute(session, tableExecute.getExecuteHandle(), tableExecute.getMandatorySourceHandle()); + + return new TableExecuteTarget(result.getTableExecuteHandle(), Optional.of(result.getSourceHandle()), tableExecute.getSchemaTableName()); + } throw new IllegalArgumentException("Unhandled target type: " + target.getClass().getSimpleName()); } @@ -282,6 +316,18 @@ private TableHandle findTableScanHandleForDeleteOrUpdate(PlanNode node) throw new IllegalArgumentException("Invalid descendant for DeleteNode or UpdateNode: " + node.getClass().getName()); } + private Optional<TableHandle> findTableScanHandleForTableExecute(PlanNode startNode) + { + List<PlanNode> tableScanNodes = PlanNodeSearcher.searchFrom(startNode) + .where(node -> node instanceof TableScanNode && ((TableScanNode) node).isUpdateTarget()) + .findAll(); + + if (tableScanNodes.size() == 1) { + return Optional.of(((TableScanNode) tableScanNodes.get(0)).getTable()); + } + throw new IllegalArgumentException("Expected to find exactly one update target TableScanNode in plan but found: " + tableScanNodes); + } + private PlanNode rewriteModifyTableScan(PlanNode node, TableHandle handle) { AtomicInteger modifyCount = new AtomicInteger(0); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index 1ee622230347..5501ecfbe3a2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -68,6 +68,7 @@ import io.trino.sql.planner.plan.SpatialJoinNode; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -471,6 +472,21 @@ public ActualProperties visitUpdate(UpdateNode node, List<ActualProperties> inpu return Iterables.getOnlyElement(inputProperties).translate(symbol -> Optional.empty()); } + @Override + public ActualProperties visitTableExecute(TableExecuteNode node, List<ActualProperties> inputProperties) + { + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + + if (properties.isCoordinatorOnly()) { + return ActualProperties.builder() + .global(coordinatorSingleStreamPartition()) + .build(); + } + return ActualProperties.builder() + .global(properties.isSingleNode() ? singleStreamPartition() : arbitraryPartition()) + .build(); + } + @Override public ActualProperties visitJoin(JoinNode node, List<ActualProperties> inputProperties) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PruneUnreferencedOutputs.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PruneUnreferencedOutputs.java index 8436bcaaa9da..4942bd2afc48 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -63,6 +63,7 @@ import io.trino.sql.planner.plan.SpatialJoinNode; import io.trino.sql.planner.plan.StatisticAggregations; import io.trino.sql.planner.plan.StatisticsWriterNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -727,6 +728,29 @@ public PlanNode visitUpdate(UpdateNode node, RewriteContext<Set<Symbol>> context return new UpdateNode(node.getId(), source, node.getTarget(), node.getRowId(), node.getColumnValueAndRowIdSymbols(), node.getOutputSymbols()); } + @Override + public PlanNode visitTableExecute(TableExecuteNode node, RewriteContext<Set<Symbol>> context) + { + ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() + .addAll(node.getColumns()); + if (node.getPartitioningScheme().isPresent()) { + PartitioningScheme partitioningScheme = node.getPartitioningScheme().get(); + partitioningScheme.getPartitioning().getColumns().forEach(expectedInputs::add); + partitioningScheme.getHashColumn().ifPresent(expectedInputs::add); + } + PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); + return new TableExecuteNode( + node.getId(), + source, + node.getTarget(), + node.getRowCountSymbol(), + node.getFragmentSymbol(), + node.getColumns(), + node.getColumnNames(), + node.getPartitioningScheme(), + node.getPreferredPartitioningScheme()); + } + @Override public PlanNode visitUnion(UnionNode node, RewriteContext<Set<Symbol>> context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java index 630a4f2216c2..fb9c0393d686 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java @@ -58,6 +58,7 @@ import io.trino.sql.planner.plan.SpatialJoinNode; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -446,6 +447,14 @@ public StreamProperties visitUpdate(UpdateNode node, List<StreamProperties> inpu return properties.withUnspecifiedPartitioning(); } + @Override + public StreamProperties visitTableExecute(TableExecuteNode node, List<StreamProperties> inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + // table execute only outputs the row count and fragments + return properties.withUnspecifiedPartitioning(); + } + @Override public StreamProperties visitRefreshMaterializedView(RefreshMaterializedViewNode node, List<StreamProperties> inputProperties) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 3c0e7cd3e159..49c36b145def 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -32,6 +32,7 @@ import io.trino.sql.planner.plan.RowNumberNode; import io.trino.sql.planner.plan.StatisticAggregations; import io.trino.sql.planner.plan.StatisticsWriterNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -382,6 +383,26 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new node.getStatisticsAggregationDescriptor().map(descriptor -> descriptor.map(this::map))); } + public TableExecuteNode map(TableExecuteNode node, PlanNode source) + { + return map(node, source, node.getId()); + } + + public TableExecuteNode map(TableExecuteNode node, PlanNode source, PlanNodeId newId) + { + // Intentionally does not use mapAndDistinct on columns as that would remove columns + return new TableExecuteNode( + newId, + source, + node.getTarget(), + map(node.getRowCountSymbol()), + map(node.getFragmentSymbol()), + map(node.getColumns()), + node.getColumnNames(), + node.getPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols())), + node.getPreferredPartitioningScheme().map(partitioningScheme -> map(partitioningScheme, source.getOutputSymbols()))); + } + public PartitioningScheme map(PartitioningScheme scheme, List<Symbol> sourceLayout) { return new PartitioningScheme( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 90c5ea8edb6f..e34fb34263fb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -67,6 +67,7 @@ import io.trino.sql.planner.plan.SpatialJoinNode; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -608,6 +609,18 @@ public PlanAndMappings visitUpdate(UpdateNode node, UnaliasContext context) mapping); } + @Override + public PlanAndMappings visitTableExecute(TableExecuteNode node, UnaliasContext context) + { + PlanAndMappings rewrittenSource = node.getSource().accept(this, context); + Map<Symbol, Symbol> mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); + + TableExecuteNode rewrittenTableExecute = mapper.map(node, rewrittenSource.getRoot()); + + return new PlanAndMappings(rewrittenTableExecute, mapping); + } + @Override public PlanAndMappings visitStatisticsWriterNode(StatisticsWriterNode node, UnaliasContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java index 119e7ca930aa..a1a2af14f6e9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/Patterns.java @@ -64,6 +64,11 @@ public static Pattern<UpdateNode> update() return typeOf(UpdateNode.class); } + public static Pattern<TableExecuteNode> tableExecute() + { + return typeOf(TableExecuteNode.class); + } + public static Pattern<ExchangeNode> exchange() { return typeOf(ExchangeNode.class); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java index cc9fb46fd44f..779ced97f7b9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java @@ -51,6 +51,7 @@ @JsonSubTypes.Type(value = TableWriterNode.class, name = "tablewriter"), @JsonSubTypes.Type(value = DeleteNode.class, name = "delete"), @JsonSubTypes.Type(value = UpdateNode.class, name = "update"), + @JsonSubTypes.Type(value = TableExecuteNode.class, name = "tableExecute"), @JsonSubTypes.Type(value = TableDeleteNode.class, name = "tableDelete"), @JsonSubTypes.Type(value = TableFinishNode.class, name = "tablecommit"), @JsonSubTypes.Type(value = UnnestNode.class, name = "unnest"), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java index 277cdc5d9ce6..742f8eee4b7f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java @@ -139,6 +139,11 @@ public R visitUpdate(UpdateNode node, C context) return visitPlan(node, context); } + public R visitTableExecute(TableExecuteNode node, C context) + { + return visitPlan(node, context); + } + public R visitTableDelete(TableDeleteNode node, C context) { return visitPlan(node, context); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableExecuteNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableExecuteNode.java new file mode 100644 index 000000000000..f6b4877deb17 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableExecuteNode.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.trino.sql.planner.PartitioningScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.TableWriterNode.TableExecuteTarget; + +import javax.annotation.concurrent.Immutable; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +@Immutable +public class TableExecuteNode + extends PlanNode +{ + private final PlanNode source; + private final TableExecuteTarget target; + private final Symbol rowCountSymbol; + private final Symbol fragmentSymbol; + private final List<Symbol> columns; + private final List<String> columnNames; + private final Optional<PartitioningScheme> partitioningScheme; + private final Optional<PartitioningScheme> preferredPartitioningScheme; + private final List<Symbol> outputs; + + @JsonCreator + public TableExecuteNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") TableExecuteTarget target, + @JsonProperty("rowCountSymbol") Symbol rowCountSymbol, + @JsonProperty("fragmentSymbol") Symbol fragmentSymbol, + @JsonProperty("columns") List<Symbol> columns, + @JsonProperty("columnNames") List<String> columnNames, + @JsonProperty("partitioningScheme") Optional<PartitioningScheme> partitioningScheme, + @JsonProperty("preferredPartitioningScheme") Optional<PartitioningScheme> preferredPartitioningScheme) + { + super(id); + + requireNonNull(columns, "columns is null"); + requireNonNull(columnNames, "columnNames is null"); + checkArgument(columns.size() == columnNames.size(), "columns and columnNames sizes don't match"); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.rowCountSymbol = requireNonNull(rowCountSymbol, "rowCountSymbol is null"); + this.fragmentSymbol = requireNonNull(fragmentSymbol, "fragmentSymbol is null"); + this.columns = ImmutableList.copyOf(columns); + this.columnNames = ImmutableList.copyOf(columnNames); + this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null"); + this.preferredPartitioningScheme = requireNonNull(preferredPartitioningScheme, "preferredPartitioningScheme is null"); + checkArgument(partitioningScheme.isEmpty() || preferredPartitioningScheme.isEmpty(), "Both partitioningScheme and preferredPartitioningScheme cannot be present"); + + ImmutableList.Builder<Symbol> outputs = ImmutableList.<Symbol>builder() + .add(rowCountSymbol) + .add(fragmentSymbol); + this.outputs = outputs.build(); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public TableExecuteTarget getTarget() + { + return target; + } + + @JsonProperty + public Symbol getRowCountSymbol() + { + return rowCountSymbol; + } + + @JsonProperty + public Symbol getFragmentSymbol() + { + return fragmentSymbol; + } + + @JsonProperty + public List<Symbol> getColumns() + { + return columns; + } + + @JsonProperty + public List<String> getColumnNames() + { + return columnNames; + } + + @JsonProperty + public Optional<PartitioningScheme> getPartitioningScheme() + { + return partitioningScheme; + } + + @JsonProperty + public Optional<PartitioningScheme> getPreferredPartitioningScheme() + { + return preferredPartitioningScheme; + } + + @Override + public List<PlanNode> getSources() + { + return ImmutableList.of(source); + } + + @Override + public List<Symbol> getOutputSymbols() + { + return outputs; + } + + @Override + public <R, C> R accept(PlanVisitor<R, C> visitor, C context) + { + return visitor.visitTableExecute(this, context); + } + + @Override + public PlanNode replaceChildren(List<PlanNode> newChildren) + { + return new TableExecuteNode( + getId(), + Iterables.getOnlyElement(newChildren), + target, + rowCountSymbol, + fragmentSymbol, + columns, + columnNames, + partitioningScheme, + preferredPartitioningScheme); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableScanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableScanNode.java index 3717e8599e4c..cc6436e69140 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableScanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableScanNode.java @@ -253,4 +253,17 @@ public TableScanNode withUseConnectorNodePartitioning(boolean useConnectorNodePa updateTarget, Optional.of(useConnectorNodePartitioning)); } + + public TableScanNode withTableHandle(TableHandle table) + { + return new TableScanNode( + getId(), + table, + outputSymbols, + assignments, + enforcedConstraint, + statistics, + updateTarget, + useConnectorNodePartitioning); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java index 5c3d0255d6c2..9eed057443b6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableWriterNode.java @@ -25,6 +25,7 @@ import io.trino.metadata.NewTableLayout; import io.trino.metadata.OutputTableHandle; import io.trino.metadata.QualifiedObjectName; +import io.trino.metadata.TableExecuteHandle; import io.trino.metadata.TableHandle; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorTableMetadata; @@ -199,7 +200,9 @@ public PlanNode replaceChildren(List<PlanNode> newChildren) @JsonSubTypes.Type(value = InsertTarget.class, name = "InsertTarget"), @JsonSubTypes.Type(value = DeleteTarget.class, name = "DeleteTarget"), @JsonSubTypes.Type(value = UpdateTarget.class, name = "UpdateTarget"), - @JsonSubTypes.Type(value = RefreshMaterializedViewTarget.class, name = "RefreshMaterializedViewTarget")}) + @JsonSubTypes.Type(value = RefreshMaterializedViewTarget.class, name = "RefreshMaterializedViewTarget"), + @JsonSubTypes.Type(value = TableExecuteTarget.class, name = "TableExecuteTarget"), + }) @SuppressWarnings({"EmptyClass", "ClassMayBeInterface"}) public abstract static class WriterTarget { @@ -528,4 +531,52 @@ public String toString() return handle.map(Object::toString).orElse("[]"); } } + + public static class TableExecuteTarget + extends WriterTarget + { + private final TableExecuteHandle executeHandle; + private final Optional<TableHandle> sourceHandle; + private final SchemaTableName schemaTableName; + + @JsonCreator + public TableExecuteTarget( + @JsonProperty("executeHandle") TableExecuteHandle executeHandle, + @JsonProperty("sourceHandle") Optional<TableHandle> sourceHandle, + @JsonProperty("schemaTableName") SchemaTableName schemaTableName) + { + this.executeHandle = requireNonNull(executeHandle, "handle is null"); + this.sourceHandle = requireNonNull(sourceHandle, "sourceHandle is null"); + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + } + + @JsonProperty + public TableExecuteHandle getExecuteHandle() + { + return executeHandle; + } + + @JsonProperty + public Optional<TableHandle> getSourceHandle() + { + return sourceHandle; + } + + public TableHandle getMandatorySourceHandle() + { + return sourceHandle.orElseThrow(); + } + + @JsonProperty + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @Override + public String toString() + { + return executeHandle.toString(); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index acee7fa8053e..dd3b2bf83ba1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -91,6 +91,7 @@ import io.trino.sql.planner.plan.StatisticAggregationsDescriptor; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -1347,6 +1348,19 @@ public Void visitUpdate(UpdateNode node, Void context) return processChildren(node, context); } + @Override + public Void visitTableExecute(TableExecuteNode node, Void context) + { + NodeRepresentation nodeOutput = addNode(node, "TableExecute"); + for (int i = 0; i < node.getColumnNames().size(); i++) { + String name = node.getColumnNames().get(i); + Symbol symbol = node.getColumns().get(i); + nodeOutput.appendDetailsLine("%s := %s", name, symbol); + } + + return processChildren(node, context); + } + @Override public Void visitTableDelete(TableDeleteNode node, Void context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java index 756975ef603e..44b5cd00374b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/PlanSanityChecker.java @@ -60,7 +60,9 @@ public PlanSanityChecker(boolean forceSingleNode) new ValidateStreamingAggregations(), new ValidateLimitWithPresortedInput(), new DynamicFiltersChecker(), - new TableScanValidator()) + new TableScanValidator(), + new TableExecuteStructureValidator()) + .build(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TableExecuteStructureValidator.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TableExecuteStructureValidator.java new file mode 100644 index 000000000000..a7bc9213a3e9 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/TableExecuteStructureValidator.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.sanity; + +import io.trino.Session; +import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.Metadata; +import io.trino.spi.type.TypeOperators; +import io.trino.sql.planner.TypeAnalyzer; +import io.trino.sql.planner.TypeProvider; +import io.trino.sql.planner.plan.ExchangeNode; +import io.trino.sql.planner.plan.OutputNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableExecuteNode; +import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableScanNode; + +import java.util.Optional; + +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; + +public class TableExecuteStructureValidator + implements PlanSanityChecker.Checker +{ + @Override + public void validate(PlanNode planNode, Session session, Metadata metadata, TypeOperators typeOperators, TypeAnalyzer typeAnalyzer, TypeProvider types, WarningCollector warningCollector) + { + Optional<PlanNode> tableExecuteNode = searchFrom(planNode) + .where(node -> node instanceof TableExecuteNode) + .findFirst(); + + if (tableExecuteNode.isEmpty()) { + // we are good; not a TableExecute plan + return; + } + + searchFrom(planNode) + .findAll() + .forEach(node -> { + if (!isAllowedNode(node)) { + throw new IllegalStateException("Unexpected " + node.getClass().getSimpleName() + " found in plan; probably connector was not able to handle provided WHERE expression"); + } + }); + } + + private boolean isAllowedNode(PlanNode node) + { + return node instanceof TableScanNode + || node instanceof ProjectNode + || node instanceof TableExecuteNode + || node instanceof OutputNode + || node instanceof ExchangeNode + || node instanceof TableFinishNode; + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 978aa7e2cfe5..0a957d221032 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -61,6 +61,7 @@ import io.trino.sql.planner.plan.StatisticAggregationsDescriptor; import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableDeleteNode; +import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -649,6 +650,15 @@ public Void visitUpdate(UpdateNode node, Set<Symbol> boundSymbols) return null; } + @Override + public Void visitTableExecute(TableExecuteNode node, Set<Symbol> boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); // visit child + + return null; + } + @Override public Void visitTableDelete(TableDeleteNode node, Set<Symbol> boundSymbols) { diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 47f95425d275..cf94ac9cf41a 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -78,6 +78,7 @@ import io.trino.execution.SetSessionTask; import io.trino.execution.SetTimeZoneTask; import io.trino.execution.StartTransactionTask; +import io.trino.execution.TableExecuteContextManager; import io.trino.execution.TaskManagerConfig; import io.trino.execution.TaskSource; import io.trino.execution.resourcegroups.NoOpResourceGroupManager; @@ -105,6 +106,7 @@ import io.trino.metadata.Split; import io.trino.metadata.SqlFunction; import io.trino.metadata.TableHandle; +import io.trino.metadata.TableProceduresPropertyManager; import io.trino.metadata.TablePropertyManager; import io.trino.operator.Driver; import io.trino.operator.DriverContext; @@ -351,6 +353,7 @@ private LocalQueryRunner( new MaterializedViewPropertyManager(), new ColumnPropertyManager(), new AnalyzePropertyManager(), + new TableProceduresPropertyManager(), new DisabledSystemSecurityMetadata(), transactionManager, typeOperators, @@ -807,6 +810,8 @@ private List<Driver> createDrivers(Session session, Plan plan, OutputFactory out throw new AssertionError("Expected subplan to have no children"); } + TableExecuteContextManager tableExecuteContextManager = new TableExecuteContextManager(); + tableExecuteContextManager.registerTableExecuteContextForQuery(taskContext.getQueryContext().getQueryId()); LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( metadata, new TypeAnalyzer(sqlParser, metadata), @@ -830,7 +835,8 @@ private List<Driver> createDrivers(Session session, Plan plan, OutputFactory out new OrderingCompiler(typeOperators), new DynamicFilterConfig(), typeOperators, - blockTypeOperators); + blockTypeOperators, + tableExecuteContextManager); // plan query StageExecutionDescriptor stageExecutionDescriptor = subplan.getFragment().getStageExecutionDescriptor(); diff --git a/core/trino-main/src/main/java/io/trino/util/StatementUtils.java b/core/trino-main/src/main/java/io/trino/util/StatementUtils.java index 05ab819814ab..a9ed269ed004 100644 --- a/core/trino-main/src/main/java/io/trino/util/StatementUtils.java +++ b/core/trino-main/src/main/java/io/trino/util/StatementUtils.java @@ -74,6 +74,7 @@ import io.trino.sql.tree.ShowTables; import io.trino.sql.tree.StartTransaction; import io.trino.sql.tree.Statement; +import io.trino.sql.tree.TableExecute; import io.trino.sql.tree.Update; import io.trino.sql.tree.Use; @@ -83,6 +84,7 @@ import java.util.stream.Stream; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.spi.resourcegroups.QueryType.ALTER_TABLE_EXECUTE; import static io.trino.spi.resourcegroups.QueryType.ANALYZE; import static io.trino.spi.resourcegroups.QueryType.DATA_DEFINITION; import static io.trino.spi.resourcegroups.QueryType.DELETE; @@ -160,6 +162,7 @@ private StatementUtils() {} .put(SetTimeZone.class, DATA_DEFINITION) .put(SetViewAuthorization.class, DATA_DEFINITION) .put(StartTransaction.class, DATA_DEFINITION) + .put(TableExecute.class, ALTER_TABLE_EXECUTE) .put(Use.class, DATA_DEFINITION) .build(); diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index cb2512656de7..5c698ee583cd 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -150,7 +150,8 @@ public static LocalExecutionPlanner createTestingPlanner() new OrderingCompiler(typeOperators), new DynamicFilterConfig(), typeOperators, - blockTypeOperators); + blockTypeOperators, + new TableExecuteContextManager()); } public static TaskInfo updateTask(SqlTask sqlTask, List<TaskSource> taskSources, OutputBuffers outputBuffers) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index a5d6aee9523c..da911dbd7a60 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -30,6 +30,7 @@ import io.trino.execution.RemoteTask; import io.trino.execution.SqlStageExecution; import io.trino.execution.StageId; +import io.trino.execution.TableExecuteContextManager; import io.trino.execution.TableInfo; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.failuredetector.NoOpFailureDetector; @@ -345,6 +346,7 @@ public void testNoNodes() new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 2, new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), + new TableExecuteContextManager(), () -> false); scheduler.schedule(); }).hasErrorCode(NO_NODES_AVAILABLE); @@ -420,6 +422,7 @@ public void testNewTaskScheduledWhenChildStageBufferIsUnderutilized() new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 500, new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), + new TableExecuteContextManager(), () -> false); // the queues of 3 running nodes should be full @@ -463,6 +466,7 @@ public void testNoNewTaskScheduledWhenChildStageBufferIsOverutilized() new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 400, new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), + new TableExecuteContextManager(), () -> true); // the queues of 3 running nodes should be full @@ -504,6 +508,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 2, dynamicFilterService, + new TableExecuteContextManager(), () -> true); SymbolAllocator symbolAllocator = new SymbolAllocator(); @@ -569,6 +574,7 @@ private StageScheduler getSourcePartitionedScheduler( placementPolicy, splitBatchSize, new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), + new TableExecuteContextManager(), () -> false); } diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index 5238f7296b7d..aa3f0770cda5 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -28,6 +28,7 @@ import io.trino.spi.block.BlockEncodingSerde; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.AggregationApplicationResult; +import io.trino.spi.connector.BeginTableExecuteResult; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.connector.ColumnHandle; @@ -132,6 +133,30 @@ public Optional<TableHandle> getTableHandleForStatisticsCollection(Session sessi throw new UnsupportedOperationException(); } + @Override + public Optional<TableExecuteHandle> getTableHandleForExecute(Session session, TableHandle tableHandle, String procedureName, Map<String, Object> executeProperties) + { + throw new UnsupportedOperationException(); + } + + @Override + public Optional<NewTableLayout> getLayoutForTableExecute(Session session, TableExecuteHandle tableExecuteHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public BeginTableExecuteResult<TableExecuteHandle, TableHandle> beginTableExecute(Session session, TableExecuteHandle tableExecuteHandle, TableHandle updatedSourceTableHandle) + { + throw new UnsupportedOperationException(); + } + + @Override + public void finishTableExecute(Session session, TableExecuteHandle handle, Collection<Slice> fragments, List<Object> tableExecuteState) + { + throw new UnsupportedOperationException(); + } + @Override public Optional<SystemTable> getSystemTable(Session session, QualifiedObjectName tableName) { @@ -816,6 +841,12 @@ public ProcedureRegistry getProcedureRegistry() throw new UnsupportedOperationException(); } + @Override + public TableProceduresRegistry getTableProcedureRegistry() + { + throw new UnsupportedOperationException(); + } + // // Blocks // @@ -866,6 +897,12 @@ public AnalyzePropertyManager getAnalyzePropertyManager() throw new UnsupportedOperationException(); } + @Override + public TableProceduresPropertyManager getTableProceduresPropertyManager() + { + throw new UnsupportedOperationException(); + } + @Override public Optional<ProjectionApplicationResult<TableHandle>> applyProjection(Session session, TableHandle table, List<ConnectorExpression> projections, Map<String, ColumnHandle> assignments) { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestTableFinishOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestTableFinishOperator.java index 50b8f3d1cba9..725b15c74d64 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestTableFinishOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestTableFinishOperator.java @@ -18,6 +18,8 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.Session; +import io.trino.execution.TableExecuteContext; +import io.trino.execution.TableExecuteContextManager; import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.TableFinishOperator.TableFinishOperatorFactory; import io.trino.operator.TableFinishOperator.TableFinisher; @@ -92,6 +94,7 @@ public void testStatisticsAggregation() Session session = testSessionBuilder() .setSystemProperty("statistics_cpu_timer_enabled", "true") .build(); + TableExecuteContextManager tableExecuteContextManager = new TableExecuteContextManager(); TableFinishOperatorFactory operatorFactory = new TableFinishOperatorFactory( 0, new PlanNodeId("node"), @@ -103,10 +106,13 @@ public void testStatisticsAggregation() ImmutableList.of(LONG_MAX.bind(ImmutableList.of(2), Optional.empty())), true), descriptor, + tableExecuteContextManager, + true, session); DriverContext driverContext = createTaskContext(scheduledExecutor, scheduledExecutor, session) .addPipelineContext(0, true, true, false) .addDriverContext(); + tableExecuteContextManager.registerTableExecuteContextForQuery(driverContext.getPipelineContext().getTaskContext().getQueryContext().getQueryId()); TableFinishOperator operator = (TableFinishOperator) operatorFactory.createOperator(driverContext); List<Type> inputTypes = ImmutableList.of(BIGINT, VARBINARY, BIGINT); @@ -156,14 +162,16 @@ private static class TestTableFinisher private boolean finished; private Collection<Slice> fragments; private Collection<ComputedStatistics> computedStatistics; + private TableExecuteContext tableExecuteContext; @Override - public Optional<ConnectorOutputMetadata> finishTable(Collection<Slice> fragments, Collection<ComputedStatistics> computedStatistics) + public Optional<ConnectorOutputMetadata> finishTable(Collection<Slice> fragments, Collection<ComputedStatistics> computedStatistics, TableExecuteContext tableExecuteContext) { checkState(!finished, "already finished"); finished = true; this.fragments = fragments; this.computedStatistics = computedStatistics; + this.tableExecuteContext = tableExecuteContext; return Optional.empty(); } @@ -176,5 +184,10 @@ public Collection<ComputedStatistics> getComputedStatistics() { return computedStatistics; } + + public TableExecuteContext getTableExecuteContext() + { + return tableExecuteContext; + } } } diff --git a/core/trino-main/src/test/java/io/trino/split/MockSplitSource.java b/core/trino-main/src/test/java/io/trino/split/MockSplitSource.java index 83fb7587886f..b1424c612b84 100644 --- a/core/trino-main/src/test/java/io/trino/split/MockSplitSource.java +++ b/core/trino-main/src/test/java/io/trino/split/MockSplitSource.java @@ -28,6 +28,7 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -141,6 +142,12 @@ public boolean isFinished() return splitsProduced == totalSplits && atSplitDepletion == FINISH; } + @Override + public Optional<List<Object>> getTableExecuteSplitsInfo() + { + return Optional.empty(); + } + public int getNextBatchInvocationCount() { return nextBatchInvocationCount; diff --git a/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 b/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 index 93ddf7fc7c82..d13b0e5e8ed8 100644 --- a/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 +++ b/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 @@ -71,6 +71,10 @@ statement DROP COLUMN (IF EXISTS)? column=qualifiedName #dropColumn | ALTER TABLE tableName=qualifiedName SET AUTHORIZATION principal #setTableAuthorization | ALTER TABLE tableName=qualifiedName SET PROPERTIES properties #setTableProperties + | ALTER TABLE tableName=qualifiedName + EXECUTE procedureName=identifier + ('(' (callArgument (',' callArgument)*)? ')')? + (WHERE where=booleanExpression)? #tableExecute | ANALYZE qualifiedName (WITH properties)? #analyze | CREATE (OR REPLACE)? MATERIALIZED VIEW (IF NOT EXISTS)? qualifiedName diff --git a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java index d6b2f61ea5ed..bd86f2254eab 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/SqlFormatter.java @@ -118,6 +118,7 @@ import io.trino.sql.tree.SingleColumn; import io.trino.sql.tree.StartTransaction; import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableExecute; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.TransactionAccessMode; import io.trino.sql.tree.TransactionMode; @@ -1397,6 +1398,25 @@ protected Void visitDropColumn(DropColumn node, Integer indent) return null; } + @Override + protected Void visitTableExecute(TableExecute node, Integer indent) + { + builder.append("ALTER TABLE "); + builder.append(formatName(node.getTable().getName())); + builder.append(" EXECUTE "); + builder.append(formatExpression(node.getProcedureName())); + if (!node.getArguments().isEmpty()) { + builder.append("("); + formatCallArguments(indent, node.getArguments()); + builder.append(")"); + } + node.getWhere().ifPresent(where -> + builder.append("\n") + .append(indentString(indent)) + .append("WHERE ").append(formatExpression(where))); + return null; + } + @Override protected Void visitAnalyze(Analyze node, Integer indent) { @@ -1517,18 +1537,21 @@ protected Void visitCall(Call node, Integer indent) builder.append("CALL ") .append(node.getName()) .append("("); + formatCallArguments(indent, node.getArguments()); + builder.append(")"); - Iterator<CallArgument> arguments = node.getArguments().iterator(); - while (arguments.hasNext()) { - process(arguments.next(), indent); - if (arguments.hasNext()) { + return null; + } + + private void formatCallArguments(Integer indent, List<CallArgument> arguments) + { + Iterator<CallArgument> iterator = arguments.iterator(); + while (iterator.hasNext()) { + process(iterator.next(), indent); + if (iterator.hasNext()) { builder.append(", "); } } - - builder.append(")"); - - return null; } @Override diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 41a4d6d826e4..104d50b52c17 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -207,6 +207,7 @@ import io.trino.sql.tree.SubsetDefinition; import io.trino.sql.tree.Table; import io.trino.sql.tree.TableElement; +import io.trino.sql.tree.TableExecute; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.TimeLiteral; import io.trino.sql.tree.TimestampLiteral; @@ -655,6 +656,21 @@ public Node visitDropColumn(SqlBaseParser.DropColumnContext context) context.EXISTS().stream().anyMatch(node -> node.getSymbol().getTokenIndex() > context.COLUMN().getSymbol().getTokenIndex())); } + @Override + public Node visitTableExecute(SqlBaseParser.TableExecuteContext context) + { + List<CallArgument> arguments = ImmutableList.of(); + if (context.callArgument() != null) { + arguments = this.visit(context.callArgument(), CallArgument.class); + } + + return new TableExecute( + new Table(getLocation(context), getQualifiedName(context.tableName)), + (Identifier) visit(context.procedureName), + arguments, + visitIfPresent(context.booleanExpression(), Expression.class)); + } + @Override public Node visitCreateView(SqlBaseParser.CreateViewContext context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index f5cd29efcd3c..4bc144bcaf73 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -667,6 +667,11 @@ protected R visitSetTableAuthorization(SetTableAuthorization node, C context) return visitStatement(node, context); } + protected R visitTableExecute(TableExecute node, C context) + { + return visitStatement(node, context); + } + protected R visitAnalyze(Analyze node, C context) { return visitStatement(node, context); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/TableExecute.java b/core/trino-parser/src/main/java/io/trino/sql/tree/TableExecute.java new file mode 100644 index 000000000000..c776ebfd8987 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/TableExecute.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; +import org.checkerframework.checker.units.qual.C; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class TableExecute + extends Statement +{ + private final Table table; + private final Identifier procedureName; + private final List<CallArgument> arguments; + private final Optional<Expression> where; + + public TableExecute( + Table table, + Identifier procedureName, + List<CallArgument> properties, + Optional<Expression> where) + { + this(Optional.empty(), table, procedureName, properties, where); + } + + public TableExecute( + NodeLocation location, + Table table, + Identifier procedureName, + List<CallArgument> arguments, + Optional<Expression> where) + { + this(Optional.of(location), table, procedureName, arguments, where); + } + + private TableExecute( + Optional<NodeLocation> location, + Table table, + Identifier procedureName, + List<CallArgument> arguments, + Optional<Expression> where) + { + super(location); + this.table = requireNonNull(table, "table is null"); + this.procedureName = requireNonNull(procedureName, "procedureName is null"); + this.arguments = requireNonNull(arguments, "arguments is null"); + this.where = requireNonNull(where, "where is null"); + } + + public Table getTable() + { + return table; + } + + public Identifier getProcedureName() + { + return procedureName; + } + + public List<CallArgument> getArguments() + { + return arguments; + } + + public Optional<Expression> getWhere() + { + return where; + } + + @Override + public <R, C> R accept(AstVisitor<R, C> visitor, C context) + { + return visitor.visitTableExecute(this, context); + } + + @Override + public List<? extends Node> getChildren() + { + ImmutableList.Builder<Node> nodes = ImmutableList.builder(); + nodes.addAll(arguments); + where.ifPresent(nodes::add); + return nodes.build(); + } + + @Override + public int hashCode() + { + return Objects.hash(table, procedureName, arguments, where); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TableExecute that = (TableExecute) o; + return Objects.equals(table, that.table) && + Objects.equals(procedureName, that.procedureName) && + Objects.equals(arguments, that.arguments) && + Objects.equals(where, that.where); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("table", table) + .add("procedureNaem", procedureName) + .add("arguments", arguments) + .add("where", where) + .toString(); + } +} diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index f7aba4ecd714..c8a73bd636a7 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -172,6 +172,7 @@ import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.SubsetDefinition; import io.trino.sql.tree.Table; +import io.trino.sql.tree.TableExecute; import io.trino.sql.tree.TableSubquery; import io.trino.sql.tree.TimeLiteral; import io.trino.sql.tree.TimestampLiteral; @@ -1867,6 +1868,40 @@ public void testAlterViewSetAuthorization() new SetViewAuthorization(QualifiedName.of("foo", "bar", "baz"), new PrincipalSpecification(PrincipalSpecification.Type.ROLE, new Identifier("qux")))); } + @Test + public void testTableExecute() + { + Table table = new Table(QualifiedName.of("foo")); + Identifier procedure = new Identifier("bar"); + + assertStatement("ALTER TABLE foo EXECUTE bar", new TableExecute(table, procedure, ImmutableList.of(), Optional.empty())); + assertStatement( + "ALTER TABLE foo EXECUTE bar(bah => 1, wuh => 'clap') WHERE age > 17", + new TableExecute( + table, + procedure, + ImmutableList.of( + new CallArgument("bah", new LongLiteral("1")), + new CallArgument("wuh", new StringLiteral("clap"))), + Optional.of( + new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, + new Identifier("age"), + new LongLiteral("17"))))); + + assertStatement( + "ALTER TABLE foo EXECUTE bar(1, 'clap') WHERE age > 17", + new TableExecute( + table, + procedure, + ImmutableList.of( + new CallArgument(new LongLiteral("1")), + new CallArgument(new StringLiteral("clap"))), + Optional.of( + new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, + new Identifier("age"), + new LongLiteral("17"))))); + } + @Test public void testAnalyze() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/BeginTableExecuteResult.java b/core/trino-spi/src/main/java/io/trino/spi/connector/BeginTableExecuteResult.java new file mode 100644 index 000000000000..2a2fd34acc93 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/BeginTableExecuteResult.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import java.util.StringJoiner; + +import static java.util.Objects.requireNonNull; + +public final class BeginTableExecuteResult<E, T> +{ + /** + * Updated tableExecuteHandle + */ + private final E tableExecuteHandle; + + /** + * Updated sourceHandle + */ + private final T sourceHandle; + + public BeginTableExecuteResult(E tableExecuteHandle, T sourceHandle) + { + this.tableExecuteHandle = requireNonNull(tableExecuteHandle, "tableExecuteHandle is null"); + this.sourceHandle = requireNonNull(sourceHandle, "sourceHandle is null"); + } + + public E getTableExecuteHandle() + { + return tableExecuteHandle; + } + + public T getSourceHandle() + { + return sourceHandle; + } + + @Override + public String toString() + { + return new StringJoiner(", ", BeginTableExecuteResult.class.getSimpleName() + "[", "]") + .add("tableExecuteHandle=" + tableExecuteHandle) + .add("sourceHandle=" + sourceHandle) + .toString(); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java b/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java index 795f05f77ca5..b52b72c24f97 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/Connector.java @@ -108,6 +108,11 @@ default Set<Procedure> getProcedures() return emptySet(); } + default Set<TableProcedureMetadata> getTableProcedures() + { + return emptySet(); + } + /** * @return the system properties for this connector */ diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java index 9c927826ab21..e4d07444c372 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorAccessControl.java @@ -39,6 +39,7 @@ import static io.trino.spi.security.AccessDeniedException.denyDropTable; import static io.trino.spi.security.AccessDeniedException.denyDropView; import static io.trino.spi.security.AccessDeniedException.denyExecuteProcedure; +import static io.trino.spi.security.AccessDeniedException.denyExecuteTableProcedure; import static io.trino.spi.security.AccessDeniedException.denyGrantRoles; import static io.trino.spi.security.AccessDeniedException.denyGrantSchemaPrivilege; import static io.trino.spi.security.AccessDeniedException.denyGrantTablePrivilege; @@ -556,6 +557,11 @@ default void checkCanExecuteProcedure(ConnectorSecurityContext context, SchemaRo denyExecuteProcedure(procedure.toString()); } + default void checkCanExecuteTableProcedure(ConnectorSecurityContext context, SchemaTableName tableName, String procedure) + { + denyExecuteTableProcedure(tableName.toString(), procedure); + } + /** * Get a row filter associated with the given table and identity. * <p> diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorHandleResolver.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorHandleResolver.java index 8e65748008e2..347e47a4a3ae 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorHandleResolver.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorHandleResolver.java @@ -50,6 +50,11 @@ default Class<? extends ConnectorInsertTableHandle> getInsertTableHandleClass() throw new UnsupportedOperationException(); } + default Class<? extends ConnectorTableExecuteHandle> getTableExecuteHandleClass() + { + throw new UnsupportedOperationException(); + } + default Class<? extends ConnectorPartitioningHandle> getPartitioningHandleClass() { throw new UnsupportedOperationException(); diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index cf4463ff8b1d..16dbfd308c57 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -91,6 +91,42 @@ default ConnectorTableHandle getTableHandleForStatisticsCollection(ConnectorSess throw new TrinoException(NOT_SUPPORTED, "This connector does not support analyze"); } + /** + * Create initial handle for execution of table procedure. The handle will be used through planning process. It will be converted to final + * handle used for execution via @{link {@link ConnectorMetadata#beginTableExecute} + */ + default Optional<ConnectorTableExecuteHandle> getTableHandleForExecute( + ConnectorSession session, + ConnectorTableHandle tableHandle, + String procedureName, + Map<String, Object> executeProperties) + { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support table procedures"); + } + + default Optional<ConnectorNewTableLayout> getLayoutForTableExecute(ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle) + { + return Optional.empty(); + } + + /** + * Begin execution of table procedure + */ + default BeginTableExecuteResult<ConnectorTableExecuteHandle, ConnectorTableHandle> beginTableExecute(ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle, ConnectorTableHandle updatedSourceTableHandle) + { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "ConnectorMetadata getTableHandleForExecute() is implemented without beginTableExecute()"); + } + + /** + * Finish table execute + * + * @param fragments all fragments returned by {@link ConnectorPageSink#finish()} + */ + default void finishTableExecute(ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle, Collection<Slice> fragments, List<Object> tableExecuteState) + { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "ConnectorMetadata getTableHandleForExecute() is implemented without finishTableExecute()"); + } + /** * Returns the system table for the specified table name, if one exists. * The system tables handled via {@link #getSystemTable} differ form those returned by {@link Connector#getSystemTables()}. diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorNewTableLayout.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorNewTableLayout.java index 718e08ea4bcf..c57c428a7c0b 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorNewTableLayout.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorNewTableLayout.java @@ -18,6 +18,7 @@ import static java.util.Objects.requireNonNull; +// TODO ConnectorNewTableLayout is used not only for "new" tables. Rename to be less specific. Preferably to ConnectorTableLayout after https://github.com/trinodb/trino/issues/781 public class ConnectorNewTableLayout { private final Optional<ConnectorPartitioningHandle> partitioning; diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java index 345b3d72cdeb..2b7363621f26 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSinkProvider.java @@ -18,4 +18,9 @@ public interface ConnectorPageSinkProvider ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorOutputTableHandle outputTableHandle); ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorInsertTableHandle insertTableHandle); + + default ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle) + { + throw new IllegalArgumentException("createPageSink not supported for tableExecuteHandle"); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplitSource.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplitSource.java index c459e29b19fb..d585644d6420 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplitSource.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorSplitSource.java @@ -15,6 +15,7 @@ import java.io.Closeable; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import static java.util.Objects.requireNonNull; @@ -37,6 +38,11 @@ public interface ConnectorSplitSource */ boolean isFinished(); + default Optional<List<Object>> getTableExecuteSplitsInfo() + { + return Optional.empty(); + } + class ConnectorSplitBatch { private final List<ConnectorSplit> splits; diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableExecuteHandle.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableExecuteHandle.java new file mode 100644 index 000000000000..19fa9043d694 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorTableExecuteHandle.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +/** + * ConnectorTableExecuteHandle identifies instance of executing a connector provided table procedure o specific table. + * + * ConnectorTableExecuteHandle for planning is obtained by call to {@link ConnectorMetadata#getTableHandleForExecute} for give + * procedure name and table. + * + * Then after planning, just before execution start, ConnectorTableExecuteHandle is refreshed via call to + * {@link ConnectorMetadata#beginTableExecute(ConnectorSession, ConnectorTableExecuteHandle, ConnectorTableHandle)} + * The tableHandle passed to beginTableExecute is one obtained from matching TableScanNode at the of planning. + */ +public interface ConnectorTableExecuteHandle {} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/FixedSplitSource.java b/core/trino-spi/src/main/java/io/trino/spi/connector/FixedSplitSource.java index e6b6db94a404..0979a6319459 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/FixedSplitSource.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/FixedSplitSource.java @@ -14,6 +14,7 @@ package io.trino.spi.connector; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import static io.trino.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; @@ -26,12 +27,25 @@ public class FixedSplitSource implements ConnectorSplitSource { private final List<ConnectorSplit> splits; + private final Optional<List<Object>> tableExecuteSplitsInfo; private int offset; public FixedSplitSource(Iterable<? extends ConnectorSplit> splits) + { + this(splits, Optional.empty()); + } + + public FixedSplitSource(Iterable<? extends ConnectorSplit> splits, List<Object> tableExecuteSplitsInfo) + { + this(splits, Optional.of(tableExecuteSplitsInfo)); + } + + private FixedSplitSource(Iterable<? extends ConnectorSplit> splits, Optional<List<Object>> tableExecuteSplitsInfo) { requireNonNull(splits, "splits is null"); + requireNonNull(tableExecuteSplitsInfo, "tableExecuteSplitsInfo is null"); this.splits = stream(splits.spliterator(), false).collect(toUnmodifiableList()); + this.tableExecuteSplitsInfo = requireNonNull(tableExecuteSplitsInfo, "tableExecuteSplitsInfo is null").map(List::copyOf); } @SuppressWarnings("ObjectEquality") @@ -56,6 +70,12 @@ public boolean isFinished() return offset >= splits.size(); } + @Override + public Optional<List<Object>> getTableExecuteSplitsInfo() + { + return tableExecuteSplitsInfo; + } + @Override public void close() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/TableProcedureExecutionMode.java b/core/trino-spi/src/main/java/io/trino/spi/connector/TableProcedureExecutionMode.java new file mode 100644 index 000000000000..29d4841792b7 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/TableProcedureExecutionMode.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +public final class TableProcedureExecutionMode +{ + private final boolean readsData; + private final boolean supportsFilter; + + public TableProcedureExecutionMode(boolean readsData, boolean supportsFilter) + { + if (!readsData) { + // TODO currently only table procedures which process data are supported + // this is temporary check to be dropped when execution flow will be added for + // table procedures which do not read data + throw new IllegalArgumentException("procedures that do not read data are not supported yet"); + } + + if (!readsData) { + if (supportsFilter) { + throw new IllegalArgumentException("filtering not supported if table data is not processed"); + } + } + this.readsData = readsData; + this.supportsFilter = supportsFilter; + } + + public boolean isReadsData() + { + return readsData; + } + + public boolean supportsFilter() + { + return supportsFilter; + } + + public static TableProcedureExecutionMode coordinatorOnly() + { + return new TableProcedureExecutionMode(false, false); + } + + public static TableProcedureExecutionMode distributedWithFilteringAndRepartitioning() + { + return new TableProcedureExecutionMode(true, true); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/TableProcedureMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/TableProcedureMetadata.java new file mode 100644 index 000000000000..226c2b2ef3a6 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/TableProcedureMetadata.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.trino.spi.session.PropertyMetadata; + +import java.util.List; + +import static io.trino.spi.connector.SchemaUtil.checkNotEmpty; +import static java.util.Objects.requireNonNull; + +public class TableProcedureMetadata +{ + // Name must be uppercase if procedure is to be executed without delimitation via ALTER TABLE ... EXECUTE syntax + private final String name; + private final TableProcedureExecutionMode executionMode; + private final List<PropertyMetadata<?>> properties; + + public TableProcedureMetadata(String name, TableProcedureExecutionMode executionMode, List<PropertyMetadata<?>> properties) + { + this.name = checkNotEmpty(name, "name"); + this.executionMode = requireNonNull(executionMode, "executionMode is null"); + this.properties = List.copyOf(requireNonNull(properties, "properties is null")); + } + + public String getName() + { + return name; + } + + public TableProcedureExecutionMode getExecutionMode() + { + return executionMode; + } + + public List<PropertyMetadata<?>> getProperties() + { + return properties; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java b/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java index 62a2af2a6b42..7ee8ab593075 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/resourcegroups/QueryType.java @@ -23,4 +23,5 @@ public enum QueryType DELETE, ANALYZE, DATA_DEFINITION, + ALTER_TABLE_EXECUTE, } diff --git a/core/trino-spi/src/main/java/io/trino/spi/security/AccessDeniedException.java b/core/trino-spi/src/main/java/io/trino/spi/security/AccessDeniedException.java index 2293eb46d69c..076ef6a83a13 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/security/AccessDeniedException.java +++ b/core/trino-spi/src/main/java/io/trino/spi/security/AccessDeniedException.java @@ -568,6 +568,11 @@ public static void denyExecuteFunction(String functionName) throw new AccessDeniedException(format("Cannot execute function %s", functionName)); } + public static void denyExecuteTableProcedure(String tableName, String procedureName) + { + throw new AccessDeniedException(format("Cannot execute table procedure %s on %s", procedureName, tableName)); + } + private static Object formatExtraInfo(String extraInfo) { if (extraInfo == null || extraInfo.isEmpty()) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java b/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java index 02160550a1d3..9619c0e9af06 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java +++ b/core/trino-spi/src/main/java/io/trino/spi/security/SystemAccessControl.java @@ -46,6 +46,7 @@ import static io.trino.spi.security.AccessDeniedException.denyExecuteFunction; import static io.trino.spi.security.AccessDeniedException.denyExecuteProcedure; import static io.trino.spi.security.AccessDeniedException.denyExecuteQuery; +import static io.trino.spi.security.AccessDeniedException.denyExecuteTableProcedure; import static io.trino.spi.security.AccessDeniedException.denyGrantExecuteFunctionPrivilege; import static io.trino.spi.security.AccessDeniedException.denyGrantRoles; import static io.trino.spi.security.AccessDeniedException.denyGrantSchemaPrivilege; @@ -728,6 +729,16 @@ default void checkCanExecuteFunction(SystemSecurityContext systemSecurityContext denyExecuteFunction(functionName); } + /** + * Check if identity is allowed to execute the specified table procedure on specified table + * + * @throws AccessDeniedException if not allowed + */ + default void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityContext, CatalogSchemaTableName table, String procedure) + { + denyExecuteTableProcedure(table.toString(), procedure); + } + /** * Get a row filter associated with the given table and identity. * <p> diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java index 0265c3dc233e..588dd1f62264 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorAccessControl.java @@ -460,6 +460,14 @@ public void checkCanExecuteProcedure(ConnectorSecurityContext context, SchemaRou } } + @Override + public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, SchemaTableName tableName, String procedure) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.checkCanExecuteTableProcedure(context, tableName, procedure); + } + } + @Override public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index a0c5f51786ff..63daf5333981 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -17,6 +17,7 @@ import io.trino.spi.classloader.ThreadContextClassLoader; import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.AggregationApplicationResult; +import io.trino.spi.connector.BeginTableExecuteResult; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.connector.ColumnHandle; @@ -30,6 +31,7 @@ import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorResolvedIndex; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableExecuteHandle; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableLayoutHandle; @@ -216,6 +218,38 @@ public ConnectorTableHandle getTableHandleForStatisticsCollection(ConnectorSessi } } + @Override + public Optional<ConnectorTableExecuteHandle> getTableHandleForExecute(ConnectorSession session, ConnectorTableHandle tableHandle, String procedureName, Map<String, Object> executeProperties) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getTableHandleForExecute(session, tableHandle, procedureName, executeProperties); + } + } + + @Override + public Optional<ConnectorNewTableLayout> getLayoutForTableExecute(ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getLayoutForTableExecute(session, tableExecuteHandle); + } + } + + @Override + public BeginTableExecuteResult<ConnectorTableExecuteHandle, ConnectorTableHandle> beginTableExecute(ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle, ConnectorTableHandle updatedSourceTableHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.beginTableExecute(session, tableExecuteHandle, updatedSourceTableHandle); + } + } + + @Override + public void finishTableExecute(ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle, Collection<Slice> fragments, List<Object> tableExecuteState) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + delegate.finishTableExecute(session, tableExecuteHandle, fragments, tableExecuteState); + } + } + @Override public Optional<SystemTable> getSystemTable(ConnectorSession session, SchemaTableName tableName) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java index 4898cf8da4f1..88ab689cca07 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorPageSinkProvider.java @@ -19,6 +19,7 @@ import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSinkProvider; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.ConnectorTableExecuteHandle; import io.trino.spi.connector.ConnectorTransactionHandle; import javax.inject.Inject; @@ -53,4 +54,12 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return new ClassLoaderSafeConnectorPageSink(delegate.createPageSink(transactionHandle, session, insertTableHandle), classLoader); } } + + @Override + public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return new ClassLoaderSafeConnectorPageSink(delegate.createPageSink(transactionHandle, session, tableExecuteHandle), classLoader); + } + } } diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorSplitSource.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorSplitSource.java index 11479eb101e5..ca388c69ccfa 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorSplitSource.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorSplitSource.java @@ -19,6 +19,8 @@ import javax.inject.Inject; +import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import static java.util.Objects.requireNonNull; @@ -44,6 +46,14 @@ public CompletableFuture<ConnectorSplitBatch> getNextBatch(ConnectorPartitionHan } } + @Override + public Optional<List<Object>> getTableExecuteSplitsInfo() + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getTableExecuteSplitsInfo(); + } + } + @Override public void close() { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java index 856ccf629315..2eac1e01c161 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllAccessControl.java @@ -295,6 +295,11 @@ public void checkCanExecuteProcedure(ConnectorSecurityContext context, SchemaRou { } + @Override + public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, SchemaTableName tableName, String procedure) + { + } + @Override public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java index d89e3c66cacc..dec2f12c14f3 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/AllowAllSystemAccessControl.java @@ -390,6 +390,11 @@ public void checkCanExecuteFunction(SystemSecurityContext systemSecurityContext, { } + @Override + public void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityContext, CatalogSchemaTableName table, String procedure) + { + } + @Override public Iterable<EventListener> getEventListeners() { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java index f1d8a46f7fc1..b8ce84b512f3 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedAccessControl.java @@ -559,6 +559,11 @@ public void checkCanExecuteProcedure(ConnectorSecurityContext context, SchemaRou { } + @Override + public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, SchemaTableName tableName, String procedure) + { + } + @Override public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java index c4d8cb760ae8..21f0b2de312a 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/FileBasedSystemAccessControl.java @@ -908,6 +908,11 @@ public void checkCanExecuteFunction(SystemSecurityContext systemSecurityContext, { } + @Override + public void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityContext, CatalogSchemaTableName table, String procedure) + { + } + @Override public Iterable<EventListener> getEventListeners() { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java index 680a3cce4d32..c5754a504044 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingConnectorAccessControl.java @@ -361,6 +361,12 @@ public void checkCanExecuteProcedure(ConnectorSecurityContext context, SchemaRou delegate().checkCanExecuteProcedure(context, procedure); } + @Override + public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, SchemaTableName tableName, String procedure) + { + delegate().checkCanExecuteTableProcedure(context, tableName, procedure); + } + @Override public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName) { diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java index 1f2c6a229177..cb9d243b3e85 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/security/ForwardingSystemAccessControl.java @@ -429,6 +429,12 @@ public void checkCanExecuteFunction(SystemSecurityContext systemSecurityContext, delegate().checkCanExecuteFunction(systemSecurityContext, functionName); } + @Override + public void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityContext, CatalogSchemaTableName table, String procedure) + { + delegate().checkCanExecuteTableProcedure(systemSecurityContext, table, procedure); + } + @Override public Iterable<EventListener> getEventListeners() { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java index 09225332e952..ebbbc36993f1 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/LegacyAccessControl.java @@ -366,6 +366,11 @@ public void checkCanExecuteProcedure(ConnectorSecurityContext context, SchemaRou { } + @Override + public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, SchemaTableName tableName, String procedure) + { + } + @Override public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java index 4c834b5e900f..0ad97024e50a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/security/SqlStandardAccessControl.java @@ -69,6 +69,7 @@ import static io.trino.spi.security.AccessDeniedException.denyDropSchema; import static io.trino.spi.security.AccessDeniedException.denyDropTable; import static io.trino.spi.security.AccessDeniedException.denyDropView; +import static io.trino.spi.security.AccessDeniedException.denyExecuteTableProcedure; import static io.trino.spi.security.AccessDeniedException.denyGrantRoles; import static io.trino.spi.security.AccessDeniedException.denyGrantTablePrivilege; import static io.trino.spi.security.AccessDeniedException.denyInsertTable; @@ -537,6 +538,14 @@ public void checkCanExecuteProcedure(ConnectorSecurityContext context, SchemaRou { } + @Override + public void checkCanExecuteTableProcedure(ConnectorSecurityContext context, SchemaTableName tableName, String procedure) + { + if (!isTableOwner(context, tableName)) { + denyExecuteTableProcedure(tableName.toString(), tableName.toString()); + } + } + @Override public Optional<ViewExpression> getRowFilter(ConnectorSecurityContext context, SchemaTableName tableName) {