Skip to content

Commit

Permalink
[compiler] Iterative DistinctlyKeyed Analysis (#12696)
Browse files Browse the repository at this point in the history
* [compiler] Iterative `DistinctlyKeyed` Analysis
Use iterative tree traversals to prevent exceeding stack size for large IRs.

* traverse all ir nodes
  • Loading branch information
ehigham authored Feb 15, 2023
1 parent 0f85414 commit 1fe6f2a
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 55 deletions.
94 changes: 39 additions & 55 deletions hail/src/main/scala/is/hail/expr/ir/DistinctlyKeyed.scala
Original file line number Diff line number Diff line change
@@ -1,71 +1,55 @@
package is.hail.expr.ir

object DistinctlyKeyed {

def apply(node: BaseIR): DistinctKeyedAnalysis = {
val distinctMemo = Memo.empty[Unit]
analyze(node, distinctMemo)
DistinctKeyedAnalysis(distinctMemo)
}
private def analyze(node:BaseIR, memo: Memo[Unit]): Unit = {
def basicChildrenCheck(children: IndexedSeq[BaseIR]): Unit = {
children.foreach(child => analyze(child, memo))
if (children.forall(child => memo.contains(child)))
memo.bind(node, ())
}
node match {
case TableLiteral(_,_,_,_) =>
case x@TableRead(_, _, _) =>
if(x.isDistinctlyKeyed)
memo.bind(node, ())
case TableParallelize(_,_) =>
case TableKeyBy(child, keys, _) =>
analyze(child, memo)
if (child.typ.key.forall(cKey => keys.contains(cKey)) && memo.contains(child))
memo.bind(node, ())
case TableRange(_,_) => memo.bind(node, ())
case TableFilter(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableHead(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableTail(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableRepartition(child, _, _) => basicChildrenCheck(IndexedSeq(child))
case TableJoin(left, right, _, _) => basicChildrenCheck(IndexedSeq(left, right))
case TableIntervalJoin(left, right, _, _) => basicChildrenCheck(IndexedSeq(left, right))
case TableMultiWayZipJoin(children, _, _) => basicChildrenCheck(children)
case TableLeftJoinRightDistinct(left, right, _) => basicChildrenCheck(IndexedSeq(left, right))
case TableMapPartitions(child, _, _, _) => analyze(child, memo)
case TableMapRows(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableMapGlobals(child, _) => basicChildrenCheck(IndexedSeq(child))
case TableExplode(child, _) => analyze(child, memo)
case TableUnion(children) => children.foreach(child => analyze(child, memo))
case TableDistinct(child) =>
memo.bind(node, ())
analyze(child, memo)
case TableKeyByAndAggregate(child, _, _, _, _) =>
memo.bind(node, ())
analyze(child, memo)
case TableAggregateByKey(child, _) =>
memo.bind(node, ())
analyze(child, memo)
case TableOrderBy(child, _) => analyze(child, memo)
case TableRename(child, _, _) => basicChildrenCheck(IndexedSeq(child))
case TableFilterIntervals(child, _, _) => basicChildrenCheck(IndexedSeq(child))
case TableToTableApply(child, _) => analyze(child, memo)
case BlockMatrixToTableApply(_, _, _) =>
case BlockMatrixToTable(_) =>
case RelationalLetTable(_, _, body) => basicChildrenCheck(IndexedSeq(body))
val memo = Memo.empty[Unit]
IRTraversal.postOrder(node).foreach {
case t: TableRead =>
memo.bindIf(t.isDistinctlyKeyed, t, ())

case t@TableKeyBy(child, keys, _) =>
memo.bindIf(child.typ.key.forall(keys.contains) && memo.contains(child), t, ())

case t@(_: TableFilter
| _: TableIntervalJoin
| _: TableLeftJoinRightDistinct
| _: TableMapRows
| _: TableMapGlobals) =>
memo.bindIf(memo.contains(t.children.head), t, ())

case t@RelationalLetTable(_, _, body) =>
memo.bindIf(memo.contains(body), t, ())

case t@(_: TableHead
| _: TableTail
| _: TableRepartition
| _: TableJoin
| _: TableMultiWayZipJoin
| _: TableRename
| _: TableFilterIntervals) =>
memo.bindIf(t.children.forall(memo.contains), t, ())

case t@(_: TableRange
| _: TableDistinct
| _: TableKeyByAndAggregate
| _: TableAggregateByKey) =>
memo.bind(t, ())

case _: MatrixIR =>
throw new IllegalArgumentException("MatrixIR should be lowered when it reaches distinct analysis")
case _: BlockMatrixIR =>
case ir: IR =>
ir.children.foreach(child => analyze(child, memo))
}

case _ =>
memo
}
DistinctKeyedAnalysis(memo)
}

}

case class DistinctKeyedAnalysis(distinctMemo: Memo[Unit]) {
def contains(tableIR: BaseIR): Boolean = {
distinctMemo.contains(tableIR)
}

override def toString: String = distinctMemo.toString
}
16 changes: 16 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/IRTraversal.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package is.hail.expr.ir

import is.hail.utils.TreeTraversal

object IRTraversal {

def postOrder: BaseIR => Iterator[BaseIR] =
TreeTraversal.postOrder(_.children.iterator)

def preOrder: BaseIR => Iterator[BaseIR] =
TreeTraversal.preOrder(_.children.iterator)

def levelOrder: BaseIR => Iterator[BaseIR] =
TreeTraversal.levelOrder(_.children.iterator)

}
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/RefEquality.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class Memo[T] private(val m: mutable.HashMap[RefEquality[BaseIR], T]) {
this
}

def bindIf(test: Boolean, ir: BaseIR, t: => T): Memo[T] =
if (test) bind(ir, t) else this

def contains(ir: BaseIR): Boolean = contains(RefEquality(ir))
def contains(ir: RefEquality[BaseIR]): Boolean = m.contains(ir)

Expand Down
67 changes: 67 additions & 0 deletions hail/src/main/scala/is/hail/utils/TreeTraversal.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package is.hail.utils

import scala.collection.mutable

// "Lightweight" (less-safe) implementations of tree traversal algorithms
// inspired by those in Guava
object TreeTraversal {

def postOrder[A](adj: A => Iterator[A])(root: A): Iterator[A] =
new Iterator[A] {
// Java (and Scala) iterators mutate on `next()` so it's convenient
// to hold on to a node and its children as we visit the node after
// its children.
private var stack = List((root, adj(root)))

override def hasNext: Boolean = stack.nonEmpty

override def next(): A = {
while (stack.head._2.hasNext) {
val node = stack.head._2.next()
stack = (node, adj(node)) :: stack
}

val (node, _) = stack.head
stack = stack.tail
node
}
}

def preOrder[A](adj: A => Iterator[A])(root: A): Iterator[A] =
new Iterator[A] {
private var stack = List(Iterator.single(root))

override def hasNext: Boolean = stack.nonEmpty

override def next(): A = {
val top = stack.head.next()
if (!stack.head.hasNext)
stack = stack.tail

val children = adj(top)
if (children.hasNext)
stack = children :: stack

top
}
}

def levelOrder[A](adj: A => Iterator[A])(root: A): Iterator[A] =
new Iterator[A] {
private val queue = mutable.Queue(Iterator.single(root))

override def hasNext: Boolean = queue.nonEmpty

override def next(): A = {
val top = queue.front.next()
if (!queue.front.hasNext)
queue.dequeue()

val children = adj(top)
if (children.hasNext)
queue.enqueue(children)

top
}
}
}
31 changes: 31 additions & 0 deletions hail/src/test/scala/is/hail/utils/TreeTraversalSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package is.hail.utils

import org.testng.Assert
import org.testng.annotations.Test
class TreeTraversalSuite {

def binaryTree(i: Int): Iterator[Int] =
(1 to 2).map(2 * i + _).iterator.filter(_ < 7)

@Test def testPostOrder =
Assert.assertEquals(
TreeTraversal.postOrder(binaryTree)(0).toArray,
Array(3, 4, 1, 5, 6, 2, 0),
""
)

@Test def testPreOrder =
Assert.assertEquals(
TreeTraversal.preOrder(binaryTree)(0).toArray,
Array(0, 1, 3, 4, 2, 5, 6),
""
)

@Test def levelOrder =
Assert.assertEquals(
TreeTraversal.levelOrder(binaryTree)(0).toArray,
(0 to 6).toArray,
""
)

}

0 comments on commit 1fe6f2a

Please sign in to comment.