Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend fast inequality join #8614

Merged
merged 20 commits into from
Sep 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
e07fffc
Use proper session in MemoryLocalQueryRunner
losipiuk Sep 11, 2017
0d34272
Remove unused field from inequality join optimization
Sep 11, 2017
6db0b36
Simplify javadoc for SortedPositionLinks and SortExpressionExtractor
losipiuk Sep 11, 2017
2a10ce1
Update JoinFilterCacheKey equals and hashCode
Sep 11, 2017
db2e844
Add tests for non-equi join condition optimization
Sep 11, 2017
16dab33
Rename SortExpression to RowSortExpressionContext
losipiuk Sep 11, 2017
b1911bd
Move RowSortExpressionContext to top level
losipiuk Sep 11, 2017
a50fb91
Add SortExpressionContext
losipiuk Sep 11, 2017
adb2cc6
Explicitly pass sortChannel where needed and drop RowSortExpressionCo…
losipiuk Sep 11, 2017
d6858e8
Use explicit searchExpression in SortExpressionContext
losipiuk Sep 11, 2017
8a2a194
Rename lessThanFunction to searchFunction in SortedPositionLinks
losipiuk Sep 11, 2017
081f6c4
Refactor unit tests to use expression utility
Sep 11, 2017
2719246
Extend non-equi join optimization to support range predicates
Sep 11, 2017
6a64e5d
Remove extranous Javadoc
losipiuk Sep 11, 2017
26b7c9a
Add comment to TestPositionLinks
losipiuk Sep 11, 2017
e567484
Replace anonymous classess with lambdas
losipiuk Sep 11, 2017
8cc7cd6
Use TEST_PAGE instead rightPage in TestPositionLinks
losipiuk Sep 11, 2017
34c51de
Extend inequality testcases in TestPositionLinks
losipiuk Sep 11, 2017
0f74747
Extract sort expressions from complex join filters
losipiuk Sep 11, 2017
318332a
Allow some conjuncts to be nondeterministic in SortExpressionExtractor
losipiuk Sep 11, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ protected List<Driver> createDrivers(TaskContext taskContext)
}

// hash build
HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(2, new PlanNodeId("test"), source.getTypes(), ImmutableList.of(0, 1), ImmutableMap.of(), Ints.asList(0), hashChannel, false, Optional.empty(), 1_500_000, 1, new PagesIndex.TestingFactory());
HashBuilderOperatorFactory hashBuilder = new HashBuilderOperatorFactory(2, new PlanNodeId("test"), source.getTypes(), ImmutableList.of(0, 1), ImmutableMap.of(), Ints.asList(0), hashChannel, false, Optional.empty(), Optional.empty(), ImmutableList.of(), 1_500_000, 1, new PagesIndex.TestingFactory());
driversBuilder.add(hashBuilder);
DriverFactory hashBuildDriverFactory = new DriverFactory(0, true, false, driversBuilder.build(), OptionalInt.empty());
Driver hashBuildDriver = hashBuildDriverFactory.createDriver(taskContext.addPipelineContext(0, true, false).addDriverContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ protected List<Driver> createDrivers(TaskContext taskContext)
Optional.empty(),
false,
Optional.empty(),
Optional.empty(),
ImmutableList.of(),
1_500_000,
1,
new PagesIndex.TestingFactory());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ protected List<Driver> createDrivers(TaskContext taskContext)
Optional.empty(),
false,
Optional.empty(),
Optional.empty(),
ImmutableList.of(),
1_500_000,
1,
new PagesIndex.TestingFactory());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,27 @@ public List<Page> benchmarkJoin(Context context)
.execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) WHERE t1.val1 < t2.val2");
}

@Benchmark
public List<Page> benchmarkJoinWithArithmeticInPredicate(Context context)
{
return context.getQueryRunner()
.execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) AND t1.val1 < t2.val2 + 10");
}

@Benchmark
public List<Page> benchmarkJoinWithFunctionPredicate(Context context)
{
return context.getQueryRunner()
.execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) AND t1.val1 < sin(t2.val2)");
}

@Benchmark
public List<Page> benchmarkRangePredicateJoin(Context context)
{
return context.getQueryRunner()
.execute("SELECT count(*) FROM t1 JOIN t2 on (t1.bucket = t2.bucket) AND t1.val1 + 1 < t2.val2 AND t2.val2 < t1.val1 + 5 ");
}

public static void main(String[] args)
throws RunnerException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
public class MemoryLocalQueryRunner
{
protected final LocalQueryRunner localQueryRunner;
protected final Session session;

public MemoryLocalQueryRunner()
{
Expand All @@ -56,8 +55,7 @@ public MemoryLocalQueryRunner(Map<String, String> properties)
.setSchema("default");
properties.forEach(sessionBuilder::setSystemProperty);

session = sessionBuilder.build();
localQueryRunner = createMemoryLocalQueryRunner(session);
localQueryRunner = createMemoryLocalQueryRunner(sessionBuilder.build());
}

public List<Page> execute(@Language("SQL") String query)
Expand All @@ -68,7 +66,7 @@ public List<Page> execute(@Language("SQL") String query)
SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(new DataSize(1, GIGABYTE));
TaskContext taskContext = new QueryContext(new QueryId("test"), new DataSize(1, GIGABYTE), memoryPool, systemMemoryPool, localQueryRunner.getExecutor(), localQueryRunner.getScheduler(), new DataSize(4, GIGABYTE), spillSpaceTracker)
.addTaskContext(new TaskStateMachine(new TaskId("query", 0, 0), localQueryRunner.getExecutor()),
session,
localQueryRunner.getDefaultSession(),
false,
false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public int link(int left, int right)
@Override
public Factory build()
{
return filterFunction -> new ArrayPositionLinks(positionLinks);
return searchFunctions -> new ArrayPositionLinks(positionLinks);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

belongs to "Refactor: rename lessThanFunction to searchFunction in SortedPosition…" commit

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public static class HashBuilderOperatorFactory
private final List<Integer> hashChannels;
private final Optional<Integer> preComputedHashChannel;
private final Optional<JoinFilterFunctionFactory> filterFunctionFactory;
private final Optional<Integer> sortChannel;
private final List<JoinFilterFunctionFactory> searchFunctionFactories;
private final PagesIndex.Factory pagesIndexFactory;

private final int expectedPositions;
Expand All @@ -63,13 +65,17 @@ public HashBuilderOperatorFactory(
Optional<Integer> preComputedHashChannel,
boolean outer,
Optional<JoinFilterFunctionFactory> filterFunctionFactory,
Optional<Integer> sortChannel,
List<JoinFilterFunctionFactory> searchFunctionFactories,
int expectedPositions,
int partitionCount,
PagesIndex.Factory pagesIndexFactory)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");

requireNonNull(sortChannel, "sortChannel can not be null");
requireNonNull(searchFunctionFactories, "searchFunctionFactories is null");
checkArgument(sortChannel.isPresent() != searchFunctionFactories.isEmpty(), "both or none sortChannel and searchFunctionFactories must be set");
checkArgument(Integer.bitCount(partitionCount) == 1, "partitionCount must be a power of 2");
lookupSourceFactory = new PartitionedLookupSourceFactory(
types,
Expand All @@ -85,6 +91,8 @@ public HashBuilderOperatorFactory(
this.hashChannels = ImmutableList.copyOf(requireNonNull(hashChannels, "hashChannels is null"));
this.preComputedHashChannel = requireNonNull(preComputedHashChannel, "preComputedHashChannel is null");
this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null");
this.sortChannel = sortChannel;
this.searchFunctionFactories = ImmutableList.copyOf(searchFunctionFactories);
this.pagesIndexFactory = requireNonNull(pagesIndexFactory, "pagesIndexFactory is null");

this.expectedPositions = expectedPositions;
Expand Down Expand Up @@ -114,6 +122,8 @@ public Operator createOperator(DriverContext driverContext)
hashChannels,
preComputedHashChannel,
filterFunctionFactory,
sortChannel,
searchFunctionFactories,
expectedPositions,
pagesIndexFactory);

Expand Down Expand Up @@ -142,6 +152,8 @@ public OperatorFactory duplicate()
private final List<Integer> hashChannels;
private final Optional<Integer> preComputedHashChannel;
private final Optional<JoinFilterFunctionFactory> filterFunctionFactory;
private final Optional<Integer> sortChannel;
private final List<JoinFilterFunctionFactory> searchFunctionFactories;

private final PagesIndex index;

Expand All @@ -156,6 +168,8 @@ public HashBuilderOperator(
List<Integer> hashChannels,
Optional<Integer> preComputedHashChannel,
Optional<JoinFilterFunctionFactory> filterFunctionFactory,
Optional<Integer> sortChannel,
List<JoinFilterFunctionFactory> searchFunctionFactories,
int expectedPositions,
PagesIndex.Factory pagesIndexFactory)
{
Expand All @@ -164,6 +178,8 @@ public HashBuilderOperator(
this.operatorContext = operatorContext;
this.partitionIndex = partitionIndex;
this.filterFunctionFactory = filterFunctionFactory;
this.sortChannel = sortChannel;
this.searchFunctionFactories = searchFunctionFactories;

this.index = pagesIndexFactory.newPagesIndex(lookupSourceFactory.getTypes(), expectedPositions);
this.lookupSourceFactory = lookupSourceFactory;
Expand Down Expand Up @@ -196,7 +212,7 @@ public void finish()
}
finishing = true;

LookupSourceSupplier partition = index.createLookupSourceSupplier(operatorContext.getSession(), hashChannels, preComputedHashChannel, filterFunctionFactory, Optional.of(outputChannels));
LookupSourceSupplier partition = index.createLookupSourceSupplier(operatorContext.getSession(), hashChannels, preComputedHashChannel, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.of(outputChannels));
lookupSourceFactory.setPartitionLookupSourceSupplier(partitionIndex, partition);

operatorContext.setMemoryReservation(partition.get().getInMemorySizeInBytes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@

import javax.annotation.concurrent.NotThreadSafe;

import java.util.Optional;

@NotThreadSafe
public interface JoinFilterFunction
{
boolean filter(int leftAddress, int rightPosition, Page rightPage);

Optional<Integer> getSortChannel();
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
import com.facebook.presto.Session;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory;
import com.google.common.collect.ImmutableList;
import it.unimi.dsi.fastutil.longs.LongArrayList;

import java.util.List;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.isFastInequalityJoin;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

public class JoinHashSupplier
Expand All @@ -33,24 +36,28 @@ public class JoinHashSupplier
private final List<List<Block>> channels;
private final Optional<PositionLinks.Factory> positionLinks;
private final Optional<JoinFilterFunctionFactory> filterFunctionFactory;
private final List<JoinFilterFunctionFactory> searchFunctionFactories;

public JoinHashSupplier(
Session session,
PagesHashStrategy pagesHashStrategy,
LongArrayList addresses,
List<List<Block>> channels,
Optional<JoinFilterFunctionFactory> filterFunctionFactory)
Optional<JoinFilterFunctionFactory> filterFunctionFactory,
Optional<Integer> sortChannel,
List<JoinFilterFunctionFactory> searchFunctionFactories)
{
this.session = requireNonNull(session, "session is null");
this.addresses = requireNonNull(addresses, "addresses is null");
this.channels = requireNonNull(channels, "channels is null");
this.filterFunctionFactory = requireNonNull(filterFunctionFactory, "filterFunctionFactory is null");
this.searchFunctionFactories = ImmutableList.copyOf(searchFunctionFactories);
requireNonNull(pagesHashStrategy, "pagesHashStrategy is null");

PositionLinks.FactoryBuilder positionLinksFactoryBuilder;
if (filterFunctionFactory.isPresent() &&
filterFunctionFactory.get().getSortChannel().isPresent() &&
if (sortChannel.isPresent() &&
isFastInequalityJoin(session)) {
checkArgument(filterFunctionFactory.isPresent(), "filterFunctionFactory not set while sortChannel set");
positionLinksFactoryBuilder = SortedPositionLinks.builder(
addresses.size(),
pagesHashStrategy,
Expand Down Expand Up @@ -86,6 +93,11 @@ public JoinHash get()
return new JoinHash(
pagesHash,
filterFunction,
positionLinks.map(links -> links.create(filterFunction)));
positionLinks.map(links -> {
List<JoinFilterFunction> searchFunctions = searchFunctionFactories.stream()
.map(factory -> factory.create(session.toConnectorSession(), addresses, channels))
.collect(toImmutableList());
return links.create(searchFunctions);
}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import com.facebook.presto.sql.gen.JoinCompiler.LookupSourceSupplierFactory;
import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler.JoinFilterFunctionFactory;
import com.facebook.presto.sql.gen.OrderingCompiler;
import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression;
import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
Expand Down Expand Up @@ -373,7 +372,7 @@ private PagesIndexOrdering createPagesIndexComparator(List<Integer> sortChannels

public Supplier<LookupSource> createLookupSourceSupplier(Session session, List<Integer> joinChannels)
{
return createLookupSourceSupplier(session, joinChannels, Optional.empty(), Optional.empty(), Optional.empty());
return createLookupSourceSupplier(session, joinChannels, Optional.empty(), Optional.empty(), Optional.empty(), ImmutableList.of());
}

public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Optional<Integer> hashChannel)
Expand Down Expand Up @@ -405,16 +404,20 @@ public LookupSourceSupplier createLookupSourceSupplier(
Session session,
List<Integer> joinChannels,
Optional<Integer> hashChannel,
Optional<JoinFilterFunctionFactory> filterFunctionFactory)
Optional<JoinFilterFunctionFactory> filterFunctionFactory,
Optional<Integer> sortChannel,
List<JoinFilterFunctionFactory> searchFunctionFactories)
{
return createLookupSourceSupplier(session, joinChannels, hashChannel, filterFunctionFactory, Optional.empty());
return createLookupSourceSupplier(session, joinChannels, hashChannel, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.empty());
}

public LookupSourceSupplier createLookupSourceSupplier(
Session session,
List<Integer> joinChannels,
Optional<Integer> hashChannel,
Optional<JoinFilterFunctionFactory> filterFunctionFactory,
Optional<Integer> sortChannel,
List<JoinFilterFunctionFactory> searchFunctionFactories,
Optional<List<Integer>> outputChannels)
{
List<List<Block>> channels = ImmutableList.copyOf(this.channels);
Expand All @@ -424,17 +427,15 @@ public LookupSourceSupplier createLookupSourceSupplier(
// OUTER joins into NestedLoopsJoin and remove "type == INNER" condition in LocalExecutionPlanner.visitJoin()

try {
Optional<SortExpression> sortChannel = Optional.empty();
if (filterFunctionFactory.isPresent()) {
sortChannel = filterFunctionFactory.get().getSortChannel();
}
LookupSourceSupplierFactory lookupSourceFactory = joinCompiler.compileLookupSourceFactory(types, joinChannels, sortChannel, outputChannels);
return lookupSourceFactory.createLookupSourceSupplier(
session,
valueAddresses,
channels,
hashChannel,
filterFunctionFactory);
filterFunctionFactory,
sortChannel,
searchFunctionFactories);
}
catch (Exception e) {
log.error(e, "Lookup source compile failed for types=%s error=%s", types, e);
Expand All @@ -448,14 +449,16 @@ public LookupSourceSupplier createLookupSourceSupplier(
channels,
joinChannels,
hashChannel,
filterFunctionFactory.map(JoinFilterFunctionFactory::getSortChannel).orElse(Optional.empty()));
sortChannel);

return new JoinHashSupplier(
session,
hashStrategy,
valueAddresses,
channels,
filterFunctionFactory);
filterFunctionFactory,
sortChannel,
searchFunctionFactories);
}

private List<Integer> rangeList(int endExclusive)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import com.facebook.presto.spi.Page;

import java.util.Optional;
import java.util.List;

/**
* This class is responsible for iterating over build rows, which have
Expand Down Expand Up @@ -46,10 +46,6 @@ interface FactoryBuilder
*/
int link(int left, int right);

/**
* JoinFilterFunction has to be created and supplied for each thread using PositionLinks
* since JoinFilterFunction is not thread safe...
*/
Factory build();

/**
Expand All @@ -66,9 +62,9 @@ default boolean isEmpty()
interface Factory
{
/**
* JoinFilterFunction has to be created and supplied for each thread using PositionLinks
* Separate JoinFilterFunctions have to be created and supplied for each thread using PositionLinks
* since JoinFilterFunction is not thread safe...
*/
PositionLinks create(Optional<JoinFilterFunction> joinFilterFunction);
PositionLinks create(List<JoinFilterFunction> searchFunctions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.planner.SortExpressionExtractor.SortExpression;
import com.facebook.presto.type.TypeUtils;
import com.google.common.collect.ImmutableList;
import org.openjdk.jol.info.ClassLayout;
Expand All @@ -38,15 +37,15 @@ public class SimplePagesHashStrategy
private final List<List<Block>> channels;
private final List<Integer> hashChannels;
private final List<Block> precomputedHashChannel;
private final Optional<SortExpression> sortChannel;
private final Optional<Integer> sortChannel;

public SimplePagesHashStrategy(
List<Type> types,
List<Integer> outputChannels,
List<List<Block>> channels,
List<Integer> hashChannels,
Optional<Integer> precomputedHashChannel,
Optional<SortExpression> sortChannel)
Optional<Integer> sortChannel)
{
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.outputChannels = ImmutableList.copyOf(requireNonNull(outputChannels, "outputChannels is null"));
Expand Down Expand Up @@ -245,9 +244,6 @@ private boolean isChannelPositionNull(int channelIndex, int blockIndex, int bloc

private int getSortChannel()
{
if (!sortChannel.isPresent()) {
throw new UnsupportedOperationException();
}
return sortChannel.get().getChannel();
return sortChannel.get();
}
}
Loading