diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashBuildAndJoinBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashBuildAndJoinBenchmark.java index fa8ea47651ba..9ccf72b42014 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashBuildAndJoinBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashBuildAndJoinBenchmark.java @@ -77,7 +77,7 @@ protected List createDrivers(TaskContext taskContext) } // hash build - HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(2, new PlanNodeId("test"), source.getTypes(), ImmutableList.of(0, 1), ImmutableMap.of(), Ints.asList(0), hashChannel, false, Optional.empty(), 1_500_000, 1, new PagesIndex.TestingFactory()); + HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(2, new PlanNodeId("test"), source.getTypes(), ImmutableList.of(0, 1), ImmutableMap.of(), Ints.asList(0), hashChannel, false, Optional.empty(), Optional.empty(), ImmutableList.of(), 1_500_000, 1, new PagesIndex.TestingFactory()); driversBuilder.add(hashBuilder); DriverFactory hashBuildDriverFactory = new DriverFactory(0, true, false, driversBuilder.build(), OptionalInt.empty()); Driver hashBuildDriver = hashBuildDriverFactory.createDriver(taskContext.addPipelineContext(0, true, false).addDriverContext()); diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashBuildBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashBuildBenchmark.java index 42e2f68a0978..ff5173765303 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashBuildBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashBuildBenchmark.java @@ -61,6 +61,8 @@ protected List createDrivers(TaskContext taskContext) Optional.empty(), false, Optional.empty(), + Optional.empty(), + ImmutableList.of(), 1_500_000, 1, new PagesIndex.TestingFactory()); diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashJoinBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashJoinBenchmark.java index dedbf215eda3..80c2b071fd9b 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashJoinBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashJoinBenchmark.java @@ -67,6 +67,8 @@ protected List createDrivers(TaskContext taskContext) Optional.empty(), false, Optional.empty(), + Optional.empty(), + ImmutableList.of(), 1_500_000, 1, new PagesIndex.TestingFactory()); diff --git a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java index ea57481cacb2..f299f048c282 100644 --- a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java +++ b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/BenchmarkInequalityJoin.java @@ -107,6 +107,27 @@ public List benchmarkJoin(Context context) .execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) WHERE t1.val1 < t2.val2"); } + @Benchmark + public List benchmarkJoinWithArithmeticInPredicate(Context context) + { + return context.getQueryRunner() + .execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) AND t1.val1 < t2.val2 + 10"); + } + + @Benchmark + public List benchmarkJoinWithFunctionPredicate(Context context) + { + return context.getQueryRunner() + .execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) AND t1.val1 < sin(t2.val2)"); + } + + @Benchmark + public List benchmarkRangePredicateJoin(Context context) + { + return context.getQueryRunner() + .execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) AND t1.val1 + 1 < t2.val2 AND t2.val2 < t1.val1 + 5 "); + } + public static void main(String[] args) throws RunnerException { diff --git a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java index ef308a3f8a61..4c597bee7c64 100644 --- a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java +++ b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java @@ -42,7 +42,6 @@ public class MemoryLocalQueryRunner { protected final LocalQueryRunner localQueryRunner; - protected final Session session; public MemoryLocalQueryRunner() { @@ -56,8 +55,7 @@ public MemoryLocalQueryRunner(Map properties) .setSchema("default"); properties.forEach(sessionBuilder::setSystemProperty); - session = sessionBuilder.build(); - localQueryRunner = createMemoryLocalQueryRunner(session); + localQueryRunner = createMemoryLocalQueryRunner(sessionBuilder.build()); } public List execute(@Language("SQL") String query) @@ -68,7 +66,7 @@ public List execute(@Language("SQL") String query) SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(new DataSize(1, GIGABYTE)); TaskContext taskContext = new QueryContext(new QueryId("test"), new DataSize(1, GIGABYTE), memoryPool, systemMemoryPool, localQueryRunner.getExecutor(), localQueryRunner.getScheduler(), new DataSize(4, GIGABYTE), spillSpaceTracker) .addTaskContext(new TaskStateMachine(new TaskId("query", 0, 0), localQueryRunner.getExecutor()), - session, + localQueryRunner.getDefaultSession(), false, false); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java index 085d2181e745..23ef7e16c50a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ArrayPositionLinks.java @@ -49,7 +49,7 @@ public int link(int left, int right) @Override public Factory build() { - return filterFunction -> new ArrayPositionLinks(positionLinks); + return searchFunctions -> new ArrayPositionLinks(positionLinks); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java index dcdcae319263..62aa7086b6e9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java @@ -46,6 +46,8 @@ public static class HashBuilderOperatorFactory private final List hashChannels; private final Optional preComputedHashChannel; private final Optional filterFunctionFactory; + private final Optional sortChannel; + private final List searchFunctionFactories; private final PagesIndex.Factory pagesIndexFactory; private final int expectedPositions; @@ -63,13 +65,17 @@ public HashBuilderOperatorFactory( Optional preComputedHashChannel, boolean outer, Optional filterFunctionFactory, + Optional sortChannel, + List searchFunctionFactories, int expectedPositions, int partitionCount, PagesIndex.Factory pagesIndexFactory) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); - + requireNonNull(sortChannel, "sortChannel can not be null"); + requireNonNull(searchFunctionFactories, "searchFunctionFactories is null"); + checkArgument(sortChannel.isPresent() != searchFunctionFactories.isEmpty(), "both or none sortChannel and searchFunctionFactories must be set"); checkArgument(Integer.bitCount(partitionCount) == 1, "partitionCount must be a power of 2"); lookupSourceFactory = new PartitionedLookupSourceFactory( types, @@ -85,6 +91,8 @@ public HashBuilderOperatorFactory( this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null")); this.preComputedHashChannel = requireNonNull(preComputedHashChannel, "preComputedHashChannel is null"); this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null"); + this.sortChannel = sortChannel; + this.searchFunctionFactories = ImmutableList.copyOf(searchFunctionFactories); this.pagesIndexFactory = requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); this.expectedPositions = expectedPositions; @@ -114,6 +122,8 @@ public Operator createOperator(DriverContext driverContext) hashChannels, preComputedHashChannel, filterFunctionFactory, + sortChannel, + searchFunctionFactories, expectedPositions, pagesIndexFactory); @@ -142,6 +152,8 @@ public OperatorFactory duplicate() private final List hashChannels; private final Optional preComputedHashChannel; private final Optional filterFunctionFactory; + private final Optional sortChannel; + private final List searchFunctionFactories; private final PagesIndex index; @@ -156,6 +168,8 @@ public HashBuilderOperator( List hashChannels, Optional preComputedHashChannel, Optional filterFunctionFactory, + Optional sortChannel, + List searchFunctionFactories, int expectedPositions, PagesIndex.Factory pagesIndexFactory) { @@ -164,6 +178,8 @@ public HashBuilderOperator( this.operatorContext = operatorContext; this.partitionIndex = partitionIndex; this.filterFunctionFactory = filterFunctionFactory; + this.sortChannel = sortChannel; + this.searchFunctionFactories = searchFunctionFactories; this.index = pagesIndexFactory.newPagesIndex(lookupSourceFactory.getTypes(), expectedPositions); this.lookupSourceFactory = lookupSourceFactory; @@ -196,7 +212,7 @@ public void finish() } finishing = true; - LookupSourceSupplier partition = index.createLookupSourceSupplier(operatorContext.getSession(), hashChannels, preComputedHashChannel, filterFunctionFactory, Optional.of(outputChannels)); + LookupSourceSupplier partition = index.createLookupSourceSupplier(operatorContext.getSession(), hashChannels, preComputedHashChannel, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.of(outputChannels)); lookupSourceFactory.setPartitionLookupSourceSupplier(partitionIndex, partition); operatorContext.setMemoryReservation(partition.get().getInMemorySizeInBytes()); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java index bfa2f0230b73..1c48cd633004 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java @@ -17,12 +17,8 @@ import javax.annotation.concurrent.NotThreadSafe; -import java.util.Optional; - @NotThreadSafe public interface JoinFilterFunction { boolean filter(int leftAddress, int rightPosition, Page rightPage); - - Optional getSortChannel(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java b/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java index 286da876d8af..ec2c180a0cc0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/JoinHashSupplier.java @@ -16,12 +16,15 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.block.Block; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; +import com.google.common.collect.ImmutableList; import it.unimi.dsi.fastutil.longs.LongArrayList; import java.util.List; import java.util.Optional; import static com.facebook.presto.SystemSessionProperties.isFastInequalityJoin; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class JoinHashSupplier @@ -33,24 +36,28 @@ public class JoinHashSupplier private final List> channels; private final Optional positionLinks; private final Optional filterFunctionFactory; + private final List searchFunctionFactories; public JoinHashSupplier( Session session, PagesHashStrategy pagesHashStrategy, LongArrayList addresses, List> channels, - Optional filterFunctionFactory) + Optional filterFunctionFactory, + Optional sortChannel, + List searchFunctionFactories) { this.session = requireNonNull(session, "session is null"); this.addresses = requireNonNull(addresses, "addresses is null"); this.channels = requireNonNull(channels, "channels is null"); this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null"); + this.searchFunctionFactories = ImmutableList.copyOf(searchFunctionFactories); requireNonNull(pagesHashStrategy, "pagesHashStrategy is null"); PositionLinks.FactoryBuilder positionLinksFactoryBuilder; - if (filterFunctionFactory.isPresent() && - filterFunctionFactory.get().getSortChannel().isPresent() && + if (sortChannel.isPresent() && isFastInequalityJoin(session)) { + checkArgument(filterFunctionFactory.isPresent(), "filterFunctionFactory not set while sortChannel set"); positionLinksFactoryBuilder = SortedPositionLinks.builder( addresses.size(), pagesHashStrategy, @@ -86,6 +93,11 @@ public JoinHash get() return new JoinHash( pagesHash, filterFunction, - positionLinks.map(links -> links.create(filterFunction))); + positionLinks.map(links -> { + List searchFunctions = searchFunctionFactories.stream() + .map(factory -> factory.create(session.toConnectorSession(), addresses, channels)) + .collect(toImmutableList()); + return links.create(searchFunctions); + })); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java index 31d3ee6b6406..010df1609e1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -24,7 +24,6 @@ import com.facebook.presto.sql.gen.JoinCompiler.LookupSourceSupplierFactory; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; import com.facebook.presto.sql.gen.OrderingCompiler; -import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; import io.airlift.slice.Slice; @@ -373,7 +372,7 @@ private PagesIndexOrdering createPagesIndexComparator(List sortChannels public Supplier createLookupSourceSupplier(Session session, List joinChannels) { - return createLookupSourceSupplier(session, joinChannels, Optional.empty(), Optional.empty(), Optional.empty()); + return createLookupSourceSupplier(session, joinChannels, Optional.empty(), Optional.empty(), Optional.empty(), ImmutableList.of()); } public PagesHashStrategy createPagesHashStrategy(List joinChannels, Optional hashChannel) @@ -405,9 +404,11 @@ public LookupSourceSupplier createLookupSourceSupplier( Session session, List joinChannels, Optional hashChannel, - Optional filterFunctionFactory) + Optional filterFunctionFactory, + Optional sortChannel, + List searchFunctionFactories) { - return createLookupSourceSupplier(session, joinChannels, hashChannel, filterFunctionFactory, Optional.empty()); + return createLookupSourceSupplier(session, joinChannels, hashChannel, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.empty()); } public LookupSourceSupplier createLookupSourceSupplier( @@ -415,6 +416,8 @@ public LookupSourceSupplier createLookupSourceSupplier( List joinChannels, Optional hashChannel, Optional filterFunctionFactory, + Optional sortChannel, + List searchFunctionFactories, Optional> outputChannels) { List> channels = ImmutableList.copyOf(this.channels); @@ -424,17 +427,15 @@ public LookupSourceSupplier createLookupSourceSupplier( // OUTER joins into NestedLoopsJoin and remove "type == INNER" condition in LocalExecutionPlanner.visitJoin() try { - Optional sortChannel = Optional.empty(); - if (filterFunctionFactory.isPresent()) { - sortChannel = filterFunctionFactory.get().getSortChannel(); - } LookupSourceSupplierFactory lookupSourceFactory = joinCompiler.compileLookupSourceFactory(types, joinChannels, sortChannel, outputChannels); return lookupSourceFactory.createLookupSourceSupplier( session, valueAddresses, channels, hashChannel, - filterFunctionFactory); + filterFunctionFactory, + sortChannel, + searchFunctionFactories); } catch (Exception e) { log.error(e, "Lookup source compile failed for types=%s error=%s", types, e); @@ -448,14 +449,16 @@ public LookupSourceSupplier createLookupSourceSupplier( channels, joinChannels, hashChannel, - filterFunctionFactory.map(JoinFilterFunctionFactory::getSortChannel).orElse(Optional.empty())); + sortChannel); return new JoinHashSupplier( session, hashStrategy, valueAddresses, channels, - filterFunctionFactory); + filterFunctionFactory, + sortChannel, + searchFunctionFactories); } private List rangeList(int endExclusive) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java index 961ae7164e1d..f91175139287 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/PositionLinks.java @@ -15,7 +15,7 @@ import com.facebook.presto.spi.Page; -import java.util.Optional; +import java.util.List; /** * This class is responsible for iterating over build rows, which have @@ -46,10 +46,6 @@ interface FactoryBuilder */ int link(int left, int right); - /** - * JoinFilterFunction has to be created and supplied for each thread using PositionLinks - * since JoinFilterFunction is not thread safe... - */ Factory build(); /** @@ -66,9 +62,9 @@ default boolean isEmpty() interface Factory { /** - * JoinFilterFunction has to be created and supplied for each thread using PositionLinks + * Separate JoinFilterFunctions have to be created and supplied for each thread using PositionLinks * since JoinFilterFunction is not thread safe... */ - PositionLinks create(Optional joinFilterFunction); + PositionLinks create(List searchFunctions); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java b/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java index ae53063acf53..4ca535183cb4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/SimplePagesHashStrategy.java @@ -17,7 +17,6 @@ import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.type.TypeUtils; import com.google.common.collect.ImmutableList; import org.openjdk.jol.info.ClassLayout; @@ -38,7 +37,7 @@ public class SimplePagesHashStrategy private final List> channels; private final List hashChannels; private final List precomputedHashChannel; - private final Optional sortChannel; + private final Optional sortChannel; public SimplePagesHashStrategy( List types, @@ -46,7 +45,7 @@ public SimplePagesHashStrategy( List> channels, List hashChannels, Optional precomputedHashChannel, - Optional sortChannel) + Optional sortChannel) { this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null")); @@ -245,9 +244,6 @@ private boolean isChannelPositionNull(int channelIndex, int blockIndex, int bloc private int getSortChannel() { - if (!sortChannel.isPresent()) { - throw new UnsupportedOperationException(); - } - return sortChannel.get().getChannel(); + return sortChannel.get(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java b/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java index 287eb1f6f287..2319b1de8d5c 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/SortedPositionLinks.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator; import com.facebook.presto.spi.Page; +import com.google.common.collect.ImmutableList; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -22,7 +23,6 @@ import org.openjdk.jol.info.ClassLayout; import java.util.List; -import java.util.Optional; import static com.facebook.presto.operator.SyntheticAddress.decodePosition; import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; @@ -31,26 +31,10 @@ import static java.util.Objects.requireNonNull; /** - * This class assumes that lessThanFunction is a superset of the whole filtering - * condition used in a join. In other words, we can use SortedPositionLinks - * with following join condition: - *

- * {@code filterFunction_1(...) AND filterFunction_2(....) AND ... AND filterFunction_n(...)} - *

- * by passing any of the filterFunction_i to the SortedPositionLinks. We could not - * do that for join condition like: - *

- * {@code filterFunction_1(...) OR filterFunction_2(....) OR ... OR filterFunction_n(...)} - *

- * To use lessThanFunction in this class, it must be an expression in form of: - *

- * {@code f(probeColumn1, probeColumn2, ..., probeColumnN) COMPARE g(buildColumn1, ..., buildColumnN)} - *

- * where {@code COMPARE} is one of: {@code < <= > >=} - *

- * That allows us to define an order of the elements in positionLinks (this defining which - * element is smaller) using {@code g(...)} function and to perform a binary search using - * {@code f(probePosition)} value. + * Maintains position links in sorted order by build side expression. + * Then iteration over position links uses set of @{code searchFunctions} which needs to be compatible + * with expression used for sorting. + * The binary search is used to quickly skip positions which would not match filter function from join condition. */ public final class SortedPositionLinks implements PositionLinks @@ -147,13 +131,10 @@ public Factory build() } } - return lessThanFunction -> { - checkState(lessThanFunction.isPresent(), "Using SortedPositionLinks without lessThanFunction"); - return new SortedPositionLinks( - arrayPositionLinksFactoryBuilder.build().create(Optional.empty()), - sortedPositionLinks, - lessThanFunction.get()); - }; + return searchFunctions -> new SortedPositionLinks( + arrayPositionLinksFactoryBuilder.build().create(ImmutableList.of()), + sortedPositionLinks, + searchFunctions); } @Override @@ -165,15 +146,17 @@ public int size() private final PositionLinks positionLinks; private final int[][] sortedPositionLinks; - private final JoinFilterFunction lessThanFunction; private final long sizeInBytes; + private final JoinFilterFunction[] searchFunctions; - private SortedPositionLinks(PositionLinks positionLinks, int[][] sortedPositionLinks, JoinFilterFunction lessThanFunction) + private SortedPositionLinks(PositionLinks positionLinks, int[][] sortedPositionLinks, List searchFunctions) { this.positionLinks = requireNonNull(positionLinks, "positionLinks is null"); this.sortedPositionLinks = requireNonNull(sortedPositionLinks, "sortedPositionLinks is null"); - this.lessThanFunction = requireNonNull(lessThanFunction, "lessThanFunction is null"); this.sizeInBytes = INSTANCE_SIZE + positionLinks.getSizeInBytes() + sizeOfPositionLinks(sortedPositionLinks); + requireNonNull(searchFunctions, "searchFunctions is null"); + checkState(!searchFunctions.isEmpty(), "Using sortedPositionLinks with no search functions"); + this.searchFunctions = searchFunctions.stream().toArray(JoinFilterFunction[]::new); } private long sizeOfPositionLinks(int[][] sortedPositionLinks) @@ -197,43 +180,67 @@ public int next(int position, int probePosition, Page allProbeChannelsPage) if (nextPosition < 0) { return -1; } - // break a position links chain if next position should be filtered out - if (applyLessThanFunction(nextPosition, probePosition, allProbeChannelsPage)) { - return nextPosition; + if (!applyAllSearchFunctions(nextPosition, probePosition, allProbeChannelsPage)) { + // break a position links chain if next position should be filtered out + return -1; } - return -1; + return nextPosition; } @Override public int start(int startingPosition, int probePosition, Page allProbeChannelsPage) { - // check if filtering function to startingPosition - if (applyLessThanFunction(startingPosition, probePosition, allProbeChannelsPage)) { + if (applyAllSearchFunctions(startingPosition, probePosition, allProbeChannelsPage)) { return startingPosition; } - - if (sortedPositionLinks[startingPosition] == null) { + int[] links = sortedPositionLinks[startingPosition]; + if (links == null) { return -1; } + int currentStartOffset = 0; + for (JoinFilterFunction searchFunction : searchFunctions) { + currentStartOffset = findStartPositionForFunction(searchFunction, links, currentStartOffset, probePosition, allProbeChannelsPage); + // return as soon as a mismatch is found, since we are handling only AND predicates (conjuncts) + if (currentStartOffset == -1) { + return -1; + } + } + return links[currentStartOffset]; + } - int left = 0; - int right = sortedPositionLinks[startingPosition].length - 1; + private boolean applyAllSearchFunctions(int buildPosition, int probePosition, Page allProbeChannelsPage) + { + for (JoinFilterFunction searchFunction : searchFunctions) { + if (!applySearchFunction(searchFunction, buildPosition, probePosition, allProbeChannelsPage)) { + return false; + } + } + return true; + } - // do a binary search for the first position for which filter function applies - int offset = lowerBound(startingPosition, left, right, probePosition, allProbeChannelsPage); - if (offset < 0) { - return -1; + private int findStartPositionForFunction(JoinFilterFunction searchFunction, int[] links, int startOffset, int probePosition, Page allProbeChannelsPage) + { + if (applySearchFunction(searchFunction, links, startOffset, probePosition, allProbeChannelsPage)) { + // MAJOR HACK: if searchFunction is of shape `f(probe) > build_symbol` it is not fit for binary search below, + // but it does not imply extra constraints on start position; so we just ignore it. + // It does not break logic for `f(probe) < build_symbol` as the binary search below would return same value. + + // todo: Explicitly handle less-than and greater-than functions separately. + return startOffset; } - if (!applyLessThanFunction(startingPosition, offset, probePosition, allProbeChannelsPage)) { + + // do a binary search for the first position for which filter function applies + int offset = lowerBound(searchFunction, links, startOffset, links.length - 1, probePosition, allProbeChannelsPage); + if (!applySearchFunction(searchFunction, links, offset, probePosition, allProbeChannelsPage)) { return -1; } - return sortedPositionLinks[startingPosition][offset]; + return offset; } /** * Find the first element in position links that is NOT smaller than probePosition */ - private int lowerBound(int startingPosition, int first, int last, int probePosition, Page allProbeChannelsPage) + private int lowerBound(JoinFilterFunction searchFunction, int[] links, int first, int last, int probePosition, Page allProbeChannelsPage) { int middle; int step; @@ -241,7 +248,7 @@ private int lowerBound(int startingPosition, int first, int last, int probePosit while (count > 0) { step = count / 2; middle = first + step; - if (!applyLessThanFunction(startingPosition, middle, probePosition, allProbeChannelsPage)) { + if (!applySearchFunction(searchFunction, links, middle, probePosition, allProbeChannelsPage)) { first = ++middle; count -= step + 1; } @@ -258,14 +265,14 @@ public long getSizeInBytes() return sizeInBytes; } - private boolean applyLessThanFunction(int leftPosition, int leftOffset, int rightPosition, Page rightPage) + private boolean applySearchFunction(JoinFilterFunction searchFunction, int[] links, int linkOffset, int probePosition, Page allProbeChannelsPage) { - return applyLessThanFunction(sortedPositionLinks[leftPosition][leftOffset], rightPosition, rightPage); + return applySearchFunction(searchFunction, links[linkOffset], probePosition, allProbeChannelsPage); } - private boolean applyLessThanFunction(long leftPosition, int rightPosition, Page rightPage) + private boolean applySearchFunction(JoinFilterFunction searchFunction, long buildPosition, int probePosition, Page allProbeChannelsPage) { - return lessThanFunction.filter((int) leftPosition, rightPosition, rightPage); + return searchFunction.filter((int) buildPosition, probePosition, allProbeChannelsPage); } private static class PositionComparator diff --git a/presto-main/src/main/java/com/facebook/presto/operator/StandardJoinFilterFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/StandardJoinFilterFunction.java index f4cbf407c4c3..6fddc032671e 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/StandardJoinFilterFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/StandardJoinFilterFunction.java @@ -19,7 +19,6 @@ import it.unimi.dsi.fastutil.longs.LongArrayList; import java.util.List; -import java.util.Optional; import static com.facebook.presto.operator.SyntheticAddress.decodePosition; import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; @@ -33,13 +32,11 @@ public class StandardJoinFilterFunction private final InternalJoinFilterFunction filterFunction; private final LongArrayList addresses; private final List pages; - private final Optional sortChannel; - public StandardJoinFilterFunction(InternalJoinFilterFunction filterFunction, LongArrayList addresses, List> channels, Optional sortChannel) + public StandardJoinFilterFunction(InternalJoinFilterFunction filterFunction, LongArrayList addresses, List> channels) { this.filterFunction = requireNonNull(filterFunction, "filterFunction can not be null"); this.addresses = requireNonNull(addresses, "addresses is null"); - this.sortChannel = requireNonNull(sortChannel, "sortChannel is null"); requireNonNull(channels, "channels can not be null"); ImmutableList.Builder pagesBuilder = ImmutableList.builder(); @@ -66,12 +63,6 @@ public boolean filter(int leftAddress, int rightPosition, Page rightPage) return filterFunction.filter(blockPosition, getLeftBlocks(blockIndex), rightPosition, rightPage.getBlocks()); } - @Override - public Optional getSortChannel() - { - return sortChannel; - } - private Block[] getLeftBlocks(int leftBlockIndex) { if (pages.isEmpty()) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/index/IndexSnapshotBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/index/IndexSnapshotBuilder.java index 125274bb43c6..55d43c2d298f 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/index/IndexSnapshotBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/index/IndexSnapshotBuilder.java @@ -125,7 +125,7 @@ public IndexSnapshot createIndexSnapshot(UnloadedIndexKeyRecordSet indexKeysReco } pages.clear(); - LookupSource lookupSource = outputPagesIndex.createLookupSourceSupplier(session, keyOutputChannels, keyOutputHashChannel, Optional.empty()).get(); + LookupSource lookupSource = outputPagesIndex.createLookupSourceSupplier(session, keyOutputChannels, keyOutputHashChannel, Optional.empty(), Optional.empty(), ImmutableList.of()).get(); // Build a page containing the keys that produced no output rows, so in future requests can skip these keys PageBuilder missingKeysPageBuilder = new PageBuilder(missingKeysIndex.getTypes()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java index 556c17ddca29..47928ff9e55e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java @@ -39,7 +39,6 @@ import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory; -import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; @@ -109,7 +108,7 @@ public Class load(CacheKey key) } }); - public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels, Optional sortChannel) + public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels, Optional sortChannel) { return compileLookupSourceFactory(types, joinChannels, sortChannel, Optional.empty()); } @@ -128,7 +127,7 @@ public CacheStatsMBean getHashStrategiesStats() return new CacheStatsMBean(hashStrategies); } - public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels, Optional sortChannel, Optional> outputChannels) + public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels, Optional sortChannel, Optional> outputChannels) { try { return lookupSourceFactories.get(new CacheKey( @@ -172,7 +171,7 @@ private List rangeList(int endExclusive) .collect(toImmutableList()); } - private LookupSourceSupplierFactory internalCompileLookupSourceFactory(List types, List outputChannels, List joinChannels, Optional sortChannel) + private LookupSourceSupplierFactory internalCompileLookupSourceFactory(List types, List outputChannels, List joinChannels, Optional sortChannel) { Class pagesHashStrategyClass = internalCompileHashStrategy(types, outputChannels, joinChannels, sortChannel); @@ -201,7 +200,7 @@ private static FieldDefinition generateInstanceSize(ClassDefinition definition) return instanceSize; } - private Class internalCompileHashStrategy(List types, List outputChannels, List joinChannels, Optional sortChannel) + private Class internalCompileHashStrategy(List types, List outputChannels, List joinChannels, Optional sortChannel) { CallSiteBinder callSiteBinder = new CallSiteBinder(); @@ -712,7 +711,7 @@ private static void generateCompareSortChannelPositionsMethod( CallSiteBinder callSiteBinder, List types, List channelFields, - Optional sortChannel) + Optional sortChannel) { Parameter leftBlockIndex = arg("leftBlockIndex", int.class); Parameter leftBlockPosition = arg("leftBlockPosition", int.class); @@ -736,7 +735,7 @@ private static void generateCompareSortChannelPositionsMethod( Variable thisVariable = compareMethod.getThis(); - int index = sortChannel.get().getChannel(); + int index = sortChannel.get(); BytecodeExpression type = constantType(callSiteBinder, types.get(index)); BytecodeExpression leftBlock = thisVariable @@ -759,7 +758,7 @@ private static void generateCompareSortChannelPositionsMethod( private static void generateIsSortChannelPositionNull( ClassDefinition classDefinition, List channelFields, - Optional sortChannel) + Optional sortChannel) { Parameter blockIndex = arg("blockIndex", int.class); Parameter blockPosition = arg("blockPosition", int.class); @@ -779,7 +778,7 @@ private static void generateIsSortChannelPositionNull( Variable thisVariable = isSortChannelPositionNullMethod.getThis(); - int index = sortChannel.get().getChannel(); + int index = sortChannel.get(); BytecodeExpression block = thisVariable .getField(channelFields.get(index)) @@ -835,7 +834,7 @@ public LookupSourceSupplierFactory(Class joinHas { this.pagesHashStrategyFactory = pagesHashStrategyFactory; try { - constructor = joinHashSupplierClass.getConstructor(Session.class, PagesHashStrategy.class, LongArrayList.class, List.class, Optional.class); + constructor = joinHashSupplierClass.getConstructor(Session.class, PagesHashStrategy.class, LongArrayList.class, List.class, Optional.class, Optional.class, List.class); } catch (NoSuchMethodException e) { throw Throwables.propagate(e); @@ -847,11 +846,13 @@ public LookupSourceSupplier createLookupSourceSupplier( LongArrayList addresses, List> channels, Optional hashChannel, - Optional filterFunctionFactory) + Optional filterFunctionFactory, + Optional sortChannel, + List searchFunctionFactories) { PagesHashStrategy pagesHashStrategy = pagesHashStrategyFactory.createPagesHashStrategy(channels, hashChannel); try { - return constructor.newInstance(session, pagesHashStrategy, addresses, channels, filterFunctionFactory); + return constructor.newInstance(session, pagesHashStrategy, addresses, channels, filterFunctionFactory, sortChannel, searchFunctionFactories); } catch (Exception e) { throw Throwables.propagate(e); @@ -889,9 +890,9 @@ private static final class CacheKey private final List types; private final List outputChannels; private final List joinChannels; - private final Optional sortChannel; + private final Optional sortChannel; - private CacheKey(List types, List outputChannels, List joinChannels, Optional sortChannel) + private CacheKey(List types, List outputChannels, List joinChannels, Optional sortChannel) { this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null")); @@ -914,7 +915,7 @@ private List getJoinChannels() return joinChannels; } - public Optional getSortChannel() + private Optional getSortChannel() { return sortChannel; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java index ecff33aba446..a89cfc8dcb1d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java @@ -30,7 +30,6 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.block.Block; import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda; -import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.LambdaDefinitionExpression; import com.facebook.presto.sql.relational.RowExpression; @@ -54,7 +53,6 @@ import java.lang.reflect.Constructor; import java.util.List; import java.util.Objects; -import java.util.Optional; import java.util.Set; import static com.facebook.presto.bytecode.Access.FINAL; @@ -93,7 +91,7 @@ public JoinFilterFunctionCompiler(Metadata metadata) public JoinFilterFunctionFactory load(JoinFilterCacheKey key) throws Exception { - return internalCompileFilterFunctionFactory(key.getFilter(), key.getLeftBlocksSize(), key.getSortChannel()); + return internalCompileFilterFunctionFactory(key.getFilter(), key.getLeftBlocksSize()); } }); @@ -104,15 +102,15 @@ public CacheStatsMBean getJoinFilterFunctionFactoryStats() return new CacheStatsMBean(joinFilterFunctionFactories); } - public JoinFilterFunctionFactory compileJoinFilterFunction(RowExpression filter, int leftBlocksSize, Optional sortChannel) + public JoinFilterFunctionFactory compileJoinFilterFunction(RowExpression filter, int leftBlocksSize) { - return joinFilterFunctionFactories.getUnchecked(new JoinFilterCacheKey(filter, leftBlocksSize, sortChannel)); + return joinFilterFunctionFactories.getUnchecked(new JoinFilterCacheKey(filter, leftBlocksSize)); } - private JoinFilterFunctionFactory internalCompileFilterFunctionFactory(RowExpression filterExpression, int leftBlocksSize, Optional sortChannel) + private JoinFilterFunctionFactory internalCompileFilterFunctionFactory(RowExpression filterExpression, int leftBlocksSize) { Class internalJoinFilterFunction = compileInternalJoinFilterFunction(filterExpression, leftBlocksSize); - return new IsolatedJoinFilterFunctionFactory(internalJoinFilterFunction, sortChannel); + return new IsolatedJoinFilterFunctionFactory(internalJoinFilterFunction); } private Class compileInternalJoinFilterFunction(RowExpression filterExpression, int leftBlocksSize) @@ -311,11 +309,6 @@ private static void generateToString(ClassDefinition classDefinition, CallSiteBi public interface JoinFilterFunctionFactory { JoinFilterFunction create(ConnectorSession session, LongArrayList addresses, List> channels); - - default Optional getSortChannel() - { - return Optional.empty(); - } } private static RowExpressionVisitor fieldReferenceCompiler( @@ -336,13 +329,11 @@ private static final class JoinFilterCacheKey { private final RowExpression filter; private final int leftBlocksSize; - private final Optional sortChannel; - public JoinFilterCacheKey(RowExpression filter, int leftBlocksSize, Optional sortChannel) + public JoinFilterCacheKey(RowExpression filter, int leftBlocksSize) { this.filter = requireNonNull(filter, "filter can not be null"); this.leftBlocksSize = leftBlocksSize; - this.sortChannel = requireNonNull(sortChannel, "sortChannel can not be null"); } public RowExpression getFilter() @@ -355,11 +346,6 @@ public int getLeftBlocksSize() return leftBlocksSize; } - public Optional getSortChannel() - { - return sortChannel; - } - @Override public boolean equals(Object o) { @@ -377,7 +363,7 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(filter, leftBlocksSize); + return Objects.hash(leftBlocksSize, filter); } @Override @@ -395,11 +381,9 @@ private static class IsolatedJoinFilterFunctionFactory { private final Constructor internalJoinFilterFunctionConstructor; private final Constructor isolatedJoinFilterFunctionConstructor; - private final Optional sortChannel; - public IsolatedJoinFilterFunctionFactory(Class internalJoinFilterFunction, Optional sortChannel) + public IsolatedJoinFilterFunctionFactory(Class internalJoinFilterFunction) { - this.sortChannel = sortChannel; try { internalJoinFilterFunctionConstructor = internalJoinFilterFunction .getConstructor(ConnectorSession.class); @@ -408,7 +392,7 @@ public IsolatedJoinFilterFunctionFactory(Class getSortChannel() - { - return sortChannel; - } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index b9b8698505a0..37ff87a38b86 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -105,7 +105,6 @@ import com.facebook.presto.sql.gen.PageFunctionCompiler; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; -import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; @@ -148,6 +147,7 @@ import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.HashMultimap; @@ -1586,12 +1586,32 @@ private LookupSourceFactory createLookupSourceFactory( Optional filterFunctionFactory = node.getFilter() .map(filterExpression -> compileJoinFilterFunction( filterExpression, - node.getSortExpression(), probeLayout, buildSource.getLayout(), context.getTypes(), context.getSession())); + Optional sortExpressionContext = node.getSortExpressionContext(); + + Optional sortChannel = sortExpressionContext + .map(SortExpressionContext::getSortExpression) + .map(sortExpression -> sortExpressionAsSortChannel( + sortExpression, + probeLayout, + buildSource.getLayout())); + + List searchFunctionFactories = sortExpressionContext + .map(SortExpressionContext::getSearchExpressions) + .map(searchExpressions -> searchExpressions.stream() + .map(searchExpression -> compileJoinFilterFunction( + searchExpression, + probeLayout, + buildSource.getLayout(), + context.getTypes(), + context.getSession())) + .collect(toImmutableList())) + .orElse(ImmutableList.of()); + HashBuilderOperatorFactory hashBuilderOperatorFactory = new HashBuilderOperatorFactory( buildContext.getNextOperatorId(), node.getId(), @@ -1602,6 +1622,8 @@ private LookupSourceFactory createLookupSourceFactory( buildHashChannel, node.getType() == RIGHT || node.getType() == FULL, filterFunctionFactory, + sortChannel, + searchFunctionFactories, 10_000, buildContext.getDriverInstanceCount().orElse(1), pagesIndexFactory); @@ -1620,7 +1642,6 @@ private LookupSourceFactory createLookupSourceFactory( private JoinFilterFunctionFactory compileJoinFilterFunction( Expression filterExpression, - Optional sortExpression, Map probeLayout, Map buildLayout, Map types, @@ -1632,11 +1653,6 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( .collect(toImmutableMap(Map.Entry::getValue, entry -> types.get(entry.getKey()))); Expression rewrittenFilter = new SymbolToInputRewriter(joinSourcesLayout).rewrite(filterExpression); - Optional rewrittenSortExpression = sortExpression.map( - expression -> new SymbolToInputRewriter(buildLayout).rewrite(expression)); - - Optional sortChannel = rewrittenSortExpression.map(SortExpression::fromExpression); - Map, Type> expressionTypes = getExpressionTypesFromInput( session, metadata, @@ -1646,7 +1662,18 @@ private JoinFilterFunctionFactory compileJoinFilterFunction( emptyList() /* parameters have already been replaced */); RowExpression translatedFilter = toRowExpression(rewrittenFilter, expressionTypes); - return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size(), sortChannel); + return joinFilterFunctionCompiler.compileJoinFilterFunction(translatedFilter, buildLayout.size()); + } + + private int sortExpressionAsSortChannel( + Expression sortExpression, + Map probeLayout, + Map buildLayout) + { + Map joinSourcesLayout = createJoinSourcesLayout(buildLayout, probeLayout); + Expression rewrittenSortExpression = new SymbolToInputRewriter(joinSourcesLayout).rewrite(sortExpression); + checkArgument(rewrittenSortExpression instanceof FieldReference, "Unsupported expression type [%s]", rewrittenSortExpression); + return ((FieldReference) rewrittenSortExpression).getFieldIndex(); } private OperatorFactory createLookupJoin( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionContext.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionContext.java new file mode 100644 index 000000000000..1b07e900d344 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionContext.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class SortExpressionContext +{ + private final Expression sortExpression; + private final List searchExpressions; + + public SortExpressionContext(Expression sortExpression, List searchExpressions) + { + this.sortExpression = requireNonNull(sortExpression, "sortExpression can not be null"); + this.searchExpressions = ImmutableList.copyOf(searchExpressions); + } + + public Expression getSortExpression() + { + return sortExpression; + } + + public List getSearchExpressions() + { + return searchExpressions; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SortExpressionContext that = (SortExpressionContext) o; + return Objects.equals(sortExpression, that.sortExpression) && + Objects.equals(searchExpressions, that.searchExpressions); + } + + @Override + public int hashCode() + { + return Objects.hash(sortExpression, searchExpressions); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("sortExpression", sortExpression) + .add("searchExpressions", searchExpressions) + .toString(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java index 95a76ce5f581..6933ef6969f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SortExpressionExtractor.java @@ -13,51 +13,100 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.tree.AstVisitor; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; -import java.util.Objects; +import java.util.List; import java.util.Optional; import java.util.Set; -import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Collections.singletonList; +import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; /** - * Currently this class handles only simple expressions like: + * Extracts sort expression to be used for creating {@link com.facebook.presto.operator.SortedPositionLinks} from join filter expression. + * Currently this class can extract sort and search expressions from filter function conjuncts of shape: *

- * A.a < B.x + * {@code A.a < f(B.x, B.y, B.z)} or {@code f(B.x, B.y, B.z) < A.a} *

- * It could be extended to handle any expressions like: - *

- * A.a * sin(A.b) / log(B.x) < cos(B.z) - *

- * by transforming it to: - *

- * f(A.a, A.b) < g(B.x, B.z) - *

- * Where f(...) and g(...) would be some functions/expressions. That - * would allow us to perform binary search on arbitrary complex expressions - * by sorting position links according to the result of f(...) function. + * where {@code a} is the build side symbol reference and {@code x,y,z} are probe + * side symbol references. Any of inequality operators ({@code <,<=,>,>=}) can be used. + * Same build side symbol need to be used in all conjuncts. */ public final class SortExpressionExtractor { + /* TODO: + This class could be extended to handle any expressions like: + A.a * sin(A.b) / log(B.x) < cos(B.z) + by transforming it to: + f(A.a, A.b) < g(B.x, B.z) + Where f(...) and g(...) would be some functions/expressions. That + would allow us to perform binary search on arbitrary complex expressions + by sorting position links according to the result of f(...) function. + */ private SortExpressionExtractor() {} - public static Optional extractSortExpression(Set buildSymbols, Expression filter) + public static Optional extractSortExpression(Set buildSymbols, Expression filter) + { + List filterConjuncts = ExpressionUtils.extractConjuncts(filter); + SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols); + + List sortExpressionCandidates = filterConjuncts.stream() + .filter(DeterminismEvaluator::isDeterministic) + .map(visitor::process) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toMap(SortExpressionContext::getSortExpression, identity(), SortExpressionExtractor::merge)) + .values() + .stream() + .collect(toImmutableList()); + + // For now heuristically pick sort expression which has most search expressions assigned to it. + // TODO: make it cost based decision based on symbol statistics + return sortExpressionCandidates.stream() + .sorted(comparing(context -> -1 * context.getSearchExpressions().size())) + .findFirst(); + } + + private static SortExpressionContext merge(SortExpressionContext left, SortExpressionContext right) + { + checkArgument(left.getSortExpression().equals(right.getSortExpression())); + ImmutableList.Builder searchExpressions = ImmutableList.builder(); + searchExpressions.addAll(left.getSearchExpressions()); + searchExpressions.addAll(right.getSearchExpressions()); + return new SortExpressionContext(left.getSortExpression(), searchExpressions.build()); + } + + private static class SortExpressionVisitor + extends AstVisitor, Void> { - if (!DeterminismEvaluator.isDeterministic(filter)) { + private final Set buildSymbols; + + public SortExpressionVisitor(Set buildSymbols) + { + this.buildSymbols = buildSymbols; + } + + @Override + protected Optional visitExpression(Expression expression, Void context) + { return Optional.empty(); } - if (filter instanceof ComparisonExpression) { - ComparisonExpression comparison = (ComparisonExpression) filter; + @Override + protected Optional visitComparisonExpression(ComparisonExpression comparison, Void context) + { switch (comparison.getType()) { case GREATER_THAN: case GREATER_THAN_OR_EQUAL: @@ -70,19 +119,18 @@ public static Optional extractSortExpression(Set buildSymbol hasBuildReferencesOnOtherSide = hasBuildSymbolReference(buildSymbols, comparison.getRight()); } if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) { - return sortChannel.map(symbolReference -> (Expression) symbolReference); + return sortChannel.map(symbolReference -> new SortExpressionContext(symbolReference, singletonList(comparison))); } return Optional.empty(); default: return Optional.empty(); } } - - return Optional.empty(); } private static Optional asBuildSymbolReference(Set buildLayout, Expression expression) { + // Currently only we support only symbol as sort expression on build side if (expression instanceof SymbolReference) { SymbolReference symbolReference = (SymbolReference) expression; if (buildLayout.contains(new Symbol(symbolReference.getName()))) { @@ -126,51 +174,4 @@ protected Boolean visitSymbolReference(SymbolReference symbolReference, Void con return buildSymbols.contains(symbolReference.getName()); } } - - public static class SortExpression - { - private final int channel; - - public SortExpression(int channel) - { - this.channel = channel; - } - - public int getChannel() - { - return channel; - } - - @Override - public boolean equals(Object obj) - { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - SortExpression other = (SortExpression) obj; - return Objects.equals(this.channel, other.channel); - } - - @Override - public int hashCode() - { - return Objects.hash(channel); - } - - public String toString() - { - return toStringHelper(this) - .add("channel", channel) - .toString(); - } - - public static SortExpression fromExpression(Expression expression) - { - checkState(expression instanceof FieldReference, "Unsupported expression type [%s]", expression); - return new SortExpression(((FieldReference) expression).getFieldIndex()); - } - } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java index 13916084b047..2e1b8a52ed25 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.sql.planner.SortExpressionContext; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.ComparisonExpressionType; @@ -176,9 +177,10 @@ public Optional getFilter() return filter; } - public Optional getSortExpression() + public Optional getSortExpressionContext() { - return filter.map(filter -> extractSortExpression(ImmutableSet.copyOf(right.getOutputSymbols()), filter).orElse(null)); + return filter + .flatMap(filter -> extractSortExpression(ImmutableSet.copyOf(right.getOutputSymbols()), filter)); } @JsonProperty("leftHashSymbol") diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 9e9e9b94d664..79072e50a2e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -554,7 +554,7 @@ public Void visitJoin(JoinNode node, Integer indent) formatOutputs(node.getOutputSymbols())); } - node.getSortExpression().ifPresent(expression -> print(indent + 2, "SortExpression[%s]", expression)); + node.getSortExpressionContext().ifPresent(context -> print(indent + 2, "SortExpression[%s]", context.getSortExpression())); printCost(indent + 2, node); printStats(indent + 2, node.getId()); node.getLeft().accept(this, indent + 1); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java index 237b3ea66955..8cdcaa904d98 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java @@ -286,6 +286,8 @@ private LookupSourceFactory benchmarkBuildHash(BuildContext buildContext, List getHashChannels(RowPagesBuilder probe, RowPagesBuil private static LookupSourceFactory buildHash(boolean parallelBuild, TaskContext taskContext, List hashChannels, RowPagesBuilder buildPages, Optional filterFunction) { Optional filterFunctionFactory = filterFunction - .map(function -> (session, addresses, channels) -> new StandardJoinFilterFunction(function, addresses, channels, Optional.empty())); + .map(function -> (session, addresses, channels) -> new StandardJoinFilterFunction(function, addresses, channels)); int partitionCount = parallelBuild ? PARTITION_COUNT : 1; LocalExchange localExchange = new LocalExchange(FIXED_HASH_DISTRIBUTION, partitionCount, buildPages.getTypes(), hashChannels, buildPages.getHashChannel()); @@ -793,6 +793,8 @@ private static LookupSourceFactory buildHash(boolean parallelBuild, TaskContext buildPages.getHashChannel(), false, filterFunctionFactory, + Optional.empty(), + ImmutableList.of(), 100, partitionCount, new PagesIndex.TestingFactory()); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java b/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java index 22160774b0e2..3b39855020ac 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestPositionLinks.java @@ -15,7 +15,6 @@ import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.spi.Page; -import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression; import com.google.common.collect.ImmutableList; import it.unimi.dsi.fastutil.longs.LongArrayList; import org.testng.annotations.Test; @@ -43,7 +42,7 @@ public void testArrayPositionLinks() assertEquals(factoryBuilder.link(11, 10), 11); assertEquals(factoryBuilder.link(12, 11), 12); - PositionLinks positionLinks = factoryBuilder.build().create(Optional.empty()); + PositionLinks positionLinks = factoryBuilder.build().create(ImmutableList.of()); assertEquals(positionLinks.start(3, 0, TEST_PAGE), 3); assertEquals(positionLinks.next(3, 0, TEST_PAGE), 2); @@ -61,54 +60,177 @@ public void testArrayPositionLinks() @Test public void testSortedPositionLinks() { - JoinFilterFunction filterFunction = new JoinFilterFunction() - { - @Override - public boolean filter(int leftAddress, int rightPosition, Page rightPage) - { - return BIGINT.getLong(rightPage.getBlock(0), leftAddress) > 4; - } - - @Override - public Optional getSortChannel() - { - throw new UnsupportedOperationException(); - } - }; + JoinFilterFunction filterFunction = (leftAddress, rightPosition, rightPage) -> + BIGINT.getLong(TEST_PAGE.getBlock(0), leftAddress) > 4; PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); - PositionLinks positionLinks = factoryBuilder.build().create(Optional.of(filterFunction)); + PositionLinks positionLinks = factoryBuilder.build().create(ImmutableList.of(filterFunction)); assertEquals(positionLinks.start(0, 0, TEST_PAGE), 5); assertEquals(positionLinks.next(5, 0, TEST_PAGE), 6); assertEquals(positionLinks.next(6, 0, TEST_PAGE), -1); + assertEquals(positionLinks.start(7, 0, TEST_PAGE), 7); + assertEquals(positionLinks.next(7, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(8, 0, TEST_PAGE), 8); + assertEquals(positionLinks.next(8, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(9, 0, TEST_PAGE), 9); + assertEquals(positionLinks.next(9, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(10, 0, TEST_PAGE), 10); + assertEquals(positionLinks.next(10, 0, TEST_PAGE), 11); + assertEquals(positionLinks.next(11, 0, TEST_PAGE), 12); + assertEquals(positionLinks.next(12, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(13, 0, TEST_PAGE), 13); + assertEquals(positionLinks.next(13, 0, TEST_PAGE), -1); + } + + @Test + public void testSortedPositionLinksAllMatch() + { + JoinFilterFunction filterFunction = (leftAddress, rightPosition, rightPage) -> + BIGINT.getLong(rightPage.getBlock(0), leftAddress) >= 0; + + PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); + PositionLinks positionLinks = factoryBuilder.build().create(ImmutableList.of(filterFunction)); + + assertEquals(positionLinks.start(0, 0, TEST_PAGE), 0); + assertEquals(positionLinks.next(0, 0, TEST_PAGE), 1); + assertEquals(positionLinks.next(1, 0, TEST_PAGE), 2); + assertEquals(positionLinks.next(2, 0, TEST_PAGE), 3); + assertEquals(positionLinks.next(3, 0, TEST_PAGE), 4); + assertEquals(positionLinks.next(4, 0, TEST_PAGE), 5); + assertEquals(positionLinks.next(5, 0, TEST_PAGE), 6); + assertEquals(positionLinks.next(6, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(7, 0, TEST_PAGE), 7); + assertEquals(positionLinks.next(7, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(8, 0, TEST_PAGE), 8); + assertEquals(positionLinks.next(8, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(9, 0, TEST_PAGE), 9); + assertEquals(positionLinks.next(9, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(10, 0, TEST_PAGE), 10); + assertEquals(positionLinks.next(10, 0, TEST_PAGE), 11); + assertEquals(positionLinks.next(11, 0, TEST_PAGE), 12); + assertEquals(positionLinks.next(12, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(13, 0, TEST_PAGE), 13); + assertEquals(positionLinks.next(13, 0, TEST_PAGE), -1); + } + + @Test + public void testSortedPositionLinksForRangePredicates() + { + JoinFilterFunction filterFunctionOne = (leftAddress, rightPosition, rightPage) -> BIGINT.getLong(TEST_PAGE.getBlock(0), leftAddress) > 4; + + JoinFilterFunction filterFunctionTwo = (leftAddress, rightPosition, rightPage) -> BIGINT.getLong(TEST_PAGE.getBlock(0), leftAddress) <= 11; + + PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); + PositionLinks positionLinks = factoryBuilder.build().create(ImmutableList.of(filterFunctionOne, filterFunctionTwo)); + + assertEquals(positionLinks.start(0, 0, TEST_PAGE), 5); + assertEquals(positionLinks.next(4, 0, TEST_PAGE), 5); + assertEquals(positionLinks.next(5, 0, TEST_PAGE), 6); + assertEquals(positionLinks.next(6, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(7, 0, TEST_PAGE), 7); + assertEquals(positionLinks.next(7, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(8, 0, TEST_PAGE), 8); + assertEquals(positionLinks.next(8, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(9, 0, TEST_PAGE), 9); + assertEquals(positionLinks.next(9, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(10, 0, TEST_PAGE), 10); + assertEquals(positionLinks.next(10, 0, TEST_PAGE), 11); + assertEquals(positionLinks.next(11, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(13, 0, TEST_PAGE), -1); + } + + @Test + public void testSortedPositionLinksForRangePredicatesPrefixMatch() + { + JoinFilterFunction filterFunctionOne = (leftAddress, rightPosition, rightPage) -> BIGINT.getLong(rightPage.getBlock(0), leftAddress) >= 0; + + JoinFilterFunction filterFunctionTwo = (leftAddress, rightPosition, rightPage) -> BIGINT.getLong(rightPage.getBlock(0), leftAddress) <= 11; + + PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); + PositionLinks positionLinks = factoryBuilder.build().create(ImmutableList.of(filterFunctionOne, filterFunctionTwo)); + + assertEquals(positionLinks.start(0, 0, TEST_PAGE), 0); + assertEquals(positionLinks.next(0, 0, TEST_PAGE), 1); + assertEquals(positionLinks.next(1, 0, TEST_PAGE), 2); + assertEquals(positionLinks.next(2, 0, TEST_PAGE), 3); + assertEquals(positionLinks.next(3, 0, TEST_PAGE), 4); + assertEquals(positionLinks.next(4, 0, TEST_PAGE), 5); + assertEquals(positionLinks.next(5, 0, TEST_PAGE), 6); + assertEquals(positionLinks.next(6, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(7, 0, TEST_PAGE), 7); + assertEquals(positionLinks.next(7, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(8, 0, TEST_PAGE), 8); + assertEquals(positionLinks.next(8, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(9, 0, TEST_PAGE), 9); + assertEquals(positionLinks.next(9, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(10, 0, TEST_PAGE), 10); + assertEquals(positionLinks.next(10, 0, TEST_PAGE), 11); + assertEquals(positionLinks.next(11, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(13, 0, TEST_PAGE), -1); + } + + @Test + public void testSortedPositionLinksForRangePredicatesSuffixMatch() + { + JoinFilterFunction filterFunctionOne = (leftAddress, rightPosition, rightPage) -> BIGINT.getLong(rightPage.getBlock(0), leftAddress) > 4; + + JoinFilterFunction filterFunctionTwo = (leftAddress, rightPosition, rightPage) -> BIGINT.getLong(rightPage.getBlock(0), leftAddress) < 100; + + PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); + PositionLinks positionLinks = factoryBuilder.build().create(ImmutableList.of(filterFunctionOne, filterFunctionTwo)); + + assertEquals(positionLinks.start(0, 0, TEST_PAGE), 5); + assertEquals(positionLinks.next(4, 0, TEST_PAGE), 5); + assertEquals(positionLinks.next(5, 0, TEST_PAGE), 6); + assertEquals(positionLinks.next(6, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(7, 0, TEST_PAGE), 7); + assertEquals(positionLinks.next(7, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(8, 0, TEST_PAGE), 8); + assertEquals(positionLinks.next(8, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(9, 0, TEST_PAGE), 9); + assertEquals(positionLinks.next(9, 0, TEST_PAGE), -1); + assertEquals(positionLinks.start(10, 0, TEST_PAGE), 10); assertEquals(positionLinks.next(10, 0, TEST_PAGE), 11); assertEquals(positionLinks.next(11, 0, TEST_PAGE), 12); assertEquals(positionLinks.next(12, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(13, 0, TEST_PAGE), 13); + assertEquals(positionLinks.next(13, 0, TEST_PAGE), -1); } @Test public void testReverseSortedPositionLinks() { - JoinFilterFunction filterFunction = new JoinFilterFunction() - { - @Override - public boolean filter(int leftAddress, int rightPosition, Page rightPage) - { - return BIGINT.getLong(rightPage.getBlock(0), leftAddress) < 4; - } - - @Override - public Optional getSortChannel() - { - throw new UnsupportedOperationException(); - } - }; + JoinFilterFunction filterFunction = (leftAddress, rightPosition, rightPage) -> + BIGINT.getLong(TEST_PAGE.getBlock(0), leftAddress) < 4; PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); - PositionLinks positionLinks = factoryBuilder.build().create(Optional.of(filterFunction)); + PositionLinks positionLinks = factoryBuilder.build().create(ImmutableList.of(filterFunction)); assertEquals(positionLinks.start(0, 0, TEST_PAGE), 0); assertEquals(positionLinks.next(0, 0, TEST_PAGE), 1); @@ -119,6 +241,41 @@ public Optional getSortChannel() assertEquals(positionLinks.start(10, 0, TEST_PAGE), -1); } + @Test + public void testReverseSortedPositionLinksAllMatch() + { + JoinFilterFunction filterFunction = (leftAddress, rightPosition, rightPage) -> + BIGINT.getLong(rightPage.getBlock(0), leftAddress) < 13; + + PositionLinks.FactoryBuilder factoryBuilder = buildSortedPositionLinks(); + PositionLinks positionLinks = factoryBuilder.build().create(ImmutableList.of(filterFunction)); + + assertEquals(positionLinks.start(0, 0, TEST_PAGE), 0); + assertEquals(positionLinks.next(0, 0, TEST_PAGE), 1); + assertEquals(positionLinks.next(1, 0, TEST_PAGE), 2); + assertEquals(positionLinks.next(2, 0, TEST_PAGE), 3); + assertEquals(positionLinks.next(3, 0, TEST_PAGE), 4); + assertEquals(positionLinks.next(4, 0, TEST_PAGE), 5); + assertEquals(positionLinks.next(5, 0, TEST_PAGE), 6); + assertEquals(positionLinks.next(6, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(7, 0, TEST_PAGE), 7); + assertEquals(positionLinks.next(7, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(8, 0, TEST_PAGE), 8); + assertEquals(positionLinks.next(8, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(9, 0, TEST_PAGE), 9); + assertEquals(positionLinks.next(9, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(10, 0, TEST_PAGE), 10); + assertEquals(positionLinks.next(10, 0, TEST_PAGE), 11); + assertEquals(positionLinks.next(11, 0, TEST_PAGE), 12); + assertEquals(positionLinks.next(12, 0, TEST_PAGE), -1); + + assertEquals(positionLinks.start(13, 0, TEST_PAGE), -1); + } + private static PositionLinks.FactoryBuilder buildSortedPositionLinks() { SortedPositionLinks.FactoryBuilder builder = SortedPositionLinks.builder( @@ -126,6 +283,13 @@ private static PositionLinks.FactoryBuilder buildSortedPositionLinks() pagesHashStrategy(), addresses()); + /* + * Built sorted positions links + * + * [0] -> [1,2,3,4,5,6] + * [10] -> [11,12] + */ + assertEquals(builder.link(4, 5), 4); assertEquals(builder.link(6, 4), 4); assertEquals(builder.link(2, 4), 2); @@ -147,7 +311,7 @@ private static PagesHashStrategy pagesHashStrategy() ImmutableList.of(ImmutableList.of(TEST_PAGE.getBlock(0))), ImmutableList.of(), Optional.empty(), - Optional.of(new SortExpression(0))); + Optional.of(0)); } private static LongArrayList addresses() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinProbeCompiler.java b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinProbeCompiler.java index f65a11c50e95..ded2dfddac65 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinProbeCompiler.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/gen/TestJoinProbeCompiler.java @@ -127,7 +127,9 @@ public void testSingleChannel(boolean hashEnabled) addresses, channels, hashChannel, - Optional.empty()) + Optional.empty(), + Optional.empty(), + ImmutableList.of()) .get(); JoinProbeCompiler joinProbeCompiler = new JoinProbeCompiler(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java index e5ee91cf90db..5fbacc55b0e4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestSortExpressionExtractor.java @@ -13,20 +13,20 @@ */ package com.facebook.presto.sql.planner; -import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; -import com.facebook.presto.sql.tree.ComparisonExpression; -import com.facebook.presto.sql.tree.ComparisonExpressionType; +import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.SymbolReference; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; +import java.util.Arrays; +import java.util.List; import java.util.Optional; import java.util.Set; +import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; +import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; +import static com.google.common.collect.ImmutableList.toImmutableList; import static org.testng.Assert.assertEquals; public class TestSortExpressionExtractor @@ -36,63 +36,81 @@ public class TestSortExpressionExtractor @Test public void testGetSortExpression() { - assertGetSortExpression( - new ComparisonExpression( - ComparisonExpressionType.GREATER_THAN, - new SymbolReference("p1"), - new SymbolReference("b1")), - "b1"); - - assertGetSortExpression( - new ComparisonExpression( - ComparisonExpressionType.LESS_THAN_OR_EQUAL, - new SymbolReference("b2"), - new SymbolReference("p1")), - "b2"); - - assertGetSortExpression( - new ComparisonExpression( - ComparisonExpressionType.GREATER_THAN, - new SymbolReference("b2"), - new SymbolReference("p1")), - "b2"); - - assertGetSortExpression( - new ComparisonExpression( - ComparisonExpressionType.GREATER_THAN, - new SymbolReference("b2"), - new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new SymbolReference("p1")))), - "b2"); - - assertGetSortExpression( - new ComparisonExpression( - ComparisonExpressionType.GREATER_THAN, - new SymbolReference("b2"), - new FunctionCall(QualifiedName.of("random"), ImmutableList.of(new SymbolReference("p1"))))); - - assertGetSortExpression( - new ComparisonExpression( - ComparisonExpressionType.GREATER_THAN, - new SymbolReference("b1"), - new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.ADD, new SymbolReference("b2"), new SymbolReference("p1")))); - - assertGetSortExpression( - new ComparisonExpression( - ComparisonExpressionType.GREATER_THAN, - new FunctionCall(QualifiedName.of("sin"), ImmutableList.of(new SymbolReference("b1"))), - new SymbolReference("p1"))); + assertGetSortExpression("p1 > b1", "b1"); + + assertGetSortExpression("b2 <= p1", "b2"); + + assertGetSortExpression("b2 > p1", "b2"); + + assertGetSortExpression("b2 > sin(p1)", "b2"); + + assertNoSortExpression("b2 > random(p1)"); + + assertGetSortExpression("b2 > random(p1) AND b2 > p1", "b2", "b2 > p1"); + + assertGetSortExpression("b2 > random(p1) AND b1 > p1", "b1", "b1 > p1"); + + assertNoSortExpression("b1 > p1 + b2"); + + assertNoSortExpression("sin(b1) > p1"); + + assertNoSortExpression("b1 <= p1 OR b2 <= p1"); + + assertNoSortExpression("sin(b2) > p1 AND (b2 <= p1 OR b2 <= p1 + 10)"); + + assertGetSortExpression("sin(b2) > p1 AND (b2 <= p1 AND b2 <= p1 + 10)", "b2", "b2 <= p1", "b2 <= p1 + 10"); + + assertGetSortExpression("b1 > p1 AND b1 <= p1", "b1"); + + assertGetSortExpression("b1 > p1 AND b1 <= p1 AND b2 > p1", "b1", "b1 > p1", "b1 <= p1"); + + assertGetSortExpression("b1 > p1 AND b1 <= p1 AND b2 > p1 AND b2 < p1 + 10 AND b2 > p2", "b2", "b2 > p1", "b2 < p1 + 10", "b2 > p2"); + } + + private Expression expression(String sql) + { + return rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql)); + } + + private void assertNoSortExpression(String expression) + { + assertNoSortExpression(expression(expression)); } - private static void assertGetSortExpression(Expression expression) + private void assertNoSortExpression(Expression expression) { - Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); + Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); assertEquals(actual, Optional.empty()); } - private static void assertGetSortExpression(Expression expression, String expectedSymbol) + private void assertGetSortExpression(String expression, String expectedSymbol) + { + assertGetSortExpression(expression(expression), expectedSymbol); + } + + private void assertGetSortExpression(Expression expression, String expectedSymbol) + { + // for now we expect that search expressions contain all the conjuncts from filterExpression as more complex cases are not supported yet. + assertGetSortExpression(expression, expectedSymbol, extractConjuncts(expression)); + } + + private void assertGetSortExpression(String expression, String expectedSymbol, String... searchExpressions) + { + assertGetSortExpression(expression(expression), expectedSymbol, searchExpressions); + } + + private void assertGetSortExpression(Expression expression, String expectedSymbol, String... searchExpressions) + { + List searchExpressionList = Arrays.stream(searchExpressions) + .map(this::expression) + .collect(toImmutableList()); + assertGetSortExpression(expression, expectedSymbol, searchExpressionList); + } + + private static void assertGetSortExpression(Expression expression, String expectedSymbol, List searchExpressions) { - Optional expected = Optional.of(new SymbolReference(expectedSymbol)); - Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); + Optional expected = Optional.of(new SortExpressionContext(new SymbolReference(expectedSymbol), searchExpressions)); + Optional actual = SortExpressionExtractor.extractSortExpression(BUILD_SYMBOLS, expression); assertEquals(actual, expected); } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 28961c2e46b0..0ab70b674e88 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -2312,6 +2312,20 @@ public void testJoinWithLessThanInJoinClause() "VALUES -1"); // test with only null value in build side assertQuery("SELECT b FROM nation n, (VALUES (0, NULL)) t(a, b) WHERE n.regionkey - 100 < t.b AND n.nationkey = t.a", "SELECT 1 WHERE FALSE"); + // test with function predicate in ON clause + assertQuery("SELECT n.nationkey, r.regionkey FROM nation n JOIN region r ON n.regionkey = r.regionkey AND length(n.name) < length(substr(r.name, 5))"); + + assertQuery("SELECT * FROM " + + "(VALUES (1,1),(2,1)) t1(a,b), " + + "(VALUES (1,1),(1,2),(2,1)) t2(x,y) " + + "WHERE a=x and b<=y", + "VALUES (1,1,1,1), (1,1,1,2), (2,1,2,1)"); + + assertQuery("SELECT * FROM " + + "(VALUES (1,1),(2,1)) t1(a,b), " + + "(VALUES (1,1),(1,2),(2,1)) t2(x,y) " + + "WHERE a=x and b t.b AND n.nationkey = t.a", "SELECT 1 WHERE FALSE"); + /// test with function predicate in ON clause + assertQuery("SELECT n.nationkey, r.regionkey FROM nation n JOIN region r ON n.regionkey = r.regionkey AND length(n.name) > length(substr(r.name, 5))"); + + assertQuery("SELECT * FROM " + + "(VALUES (1,1),(2,1)) t1(a,b), " + + "(VALUES (1,1),(1,2),(2,1)) t2(x,y) " + + "WHERE a=x and b>=y", + "VALUES (1,1,1,1), (2,1,2,1)"); + + assertQuery("SELECT * FROM " + + "(VALUES (1,1),(2,1)) t1(a,b), " + + "(VALUES (1,1),(1,2),(2,1)) t2(x,y) " + + "WHERE a=x and b>y", + "SELECT 1 WHERE FALSE"); + } + + @Test + public void testJoinWithRangePredicatesinJoinClause() + { + assertQuery("SELECT COUNT(*) " + + "FROM (SELECT * FROM lineitem WHERE orderkey % 16 = 0 AND partkey % 2 = 0) lineitem " + + "JOIN (SELECT * FROM orders WHERE orderkey % 16 = 0 AND custkey % 2 = 0) orders " + + "ON lineitem.orderkey % 8 = orders.orderkey % 8 AND lineitem.linenumber % 2 = 0 " + + "AND orders.custkey % 8 < 7 AND lineitem.suppkey % 10 < orders.custkey % 7 AND lineitem.suppkey % 7 > orders.custkey % 7"); + + assertQuery("SELECT COUNT(*) " + + "FROM (SELECT * FROM lineitem WHERE orderkey % 16 = 0 AND partkey % 2 = 0) lineitem " + + "JOIN (SELECT * FROM orders WHERE orderkey % 16 = 0 AND custkey % 2 = 0) orders " + + "ON lineitem.orderkey % 8 = orders.orderkey % 8 AND lineitem.linenumber % 2 = 0 " + + "AND orders.custkey % 8 < lineitem.linenumber % 2 AND lineitem.suppkey % 10 < orders.custkey % 7 AND lineitem.suppkey % 7 > orders.custkey % 7"); + } + + @Test + public void testJoinWithMultipleLessThanPredicatesDifferentOrders() + { + // test that fast inequality join is not sensitive to order of search conjuncts. + assertQuery("SELECT count(*) FROM lineitem l JOIN nation n ON l.suppkey % 5 = n.nationkey % 5 AND l.partkey % 3 < n.regionkey AND l.partkey % 3 + 1 < n.regionkey AND l.partkey % 3 + 2 < n.regionkey"); + assertQuery("SELECT count(*) FROM lineitem l JOIN nation n ON l.suppkey % 5 = n.nationkey % 5 AND l.partkey % 3 + 2 < n.regionkey AND l.partkey % 3 + 1 < n.regionkey AND l.partkey % 3 < n.regionkey"); + assertQuery("SELECT count(*) FROM lineitem l JOIN nation n ON l.suppkey % 5 = n.nationkey % 5 AND l.partkey % 3 > n.regionkey AND l.partkey % 3 + 1 > n.regionkey AND l.partkey % 3 + 2 > n.regionkey"); + assertQuery("SELECT count(*) FROM lineitem l JOIN nation n ON l.suppkey % 5 = n.nationkey % 5 AND l.partkey % 3 + 2 > n.regionkey AND l.partkey % 3 + 1 > n.regionkey AND l.partkey % 3 > n.regionkey"); } @Test @@ -2340,6 +2394,12 @@ public void testJoinWithLessThanOnDatesInJoinClause() assertQuery( "SELECT o.orderkey, o.orderdate, l.shipdate FROM orders o JOIN lineitem l ON l.orderkey = o.orderkey AND l.shipdate < o.orderdate + INTERVAL '10' DAY", "SELECT o.orderkey, o.orderdate, l.shipdate FROM orders o JOIN lineitem l ON l.orderkey = o.orderkey AND l.shipdate < DATEADD('DAY', 10, o.orderdate)"); + assertQuery( + "SELECT o.orderkey, o.orderdate, l.shipdate FROM lineitem l JOIN orders o ON l.orderkey = o.orderkey AND l.shipdate < DATE_ADD('DAY', 10, o.orderdate)", + "SELECT o.orderkey, o.orderdate, l.shipdate FROM orders o JOIN lineitem l ON l.orderkey = o.orderkey AND l.shipdate < DATEADD('DAY', 10, o.orderdate)"); + assertQuery( + "SELECT o.orderkey, o.orderdate, l.shipdate FROM orders o JOIN lineitem l ON o.orderkey=l.orderkey AND o.orderdate + INTERVAL '2' DAY <= l.shipdate AND l.shipdate < o.orderdate + INTERVAL '7' DAY", + "SELECT o.orderkey, o.orderdate, l.shipdate FROM orders o JOIN lineitem l ON o.orderkey=l.orderkey AND DATEADD('DAY', 2, o.orderdate) <= l.shipdate AND l.shipdate < DATEADD('DAY', 7, o.orderdate)"); } @Test