Skip to content

Commit

Permalink
First draft of partially aggregated and code generated count distinct…
Browse files Browse the repository at this point in the history
… / max
  • Loading branch information
marmbrus committed Aug 18, 2014
1 parent 73ab7f1 commit 213ada8
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

protected val exprArray = expressions.toArray
// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null

def apply(input: Row): Row = {
val outputArray = new Array[Any](exprArray.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.util.collection.OpenHashSet

abstract class AggregateExpression extends Expression {
self: Product =>
Expand Down Expand Up @@ -161,13 +162,96 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def newInstance() = new CountFunction(child, this)
}

case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
def this() = this(null)

override def children = expressions
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = LongType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
override def newInstance() = new CountDistinctFunction(expressions, this)

override def asPartial = {
val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
SplitEvaluation(
CombineSetsAndCount(partialSet.toAttribute),
partialSet :: Nil)
}
}

case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
def this() = this(null)

override def children = expressions
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = ArrayType(expressions.head.dataType)
override def toString = s"AddToHashSet(${expressions.mkString(",")})"
override def newInstance() = new CollectHashSetFunction(expressions, this)
}

case class CollectHashSetFunction(
@transient expr: Seq[Expression],
@transient base: AggregateExpression)
extends MergableAggregateFunction {

def this() = this(null, null) // Required for serialization.

val seen = new OpenHashSet[Any]()

@transient
val distinctValue = new InterpretedProjection(expr)

override def merge(other: MergableAggregateFunction): MergableAggregateFunction = {
val otherSetIterator = other.asInstanceOf[CountDistinctFunction].seen.iterator
while(otherSetIterator.hasNext) {
seen.add(otherSetIterator.next())
}
this
}

override def update(input: Row): Unit = {
val evaluatedExpr = distinctValue(input)
if (!evaluatedExpr.anyNull) {
seen.add(evaluatedExpr)
}
}

override def eval(input: Row): Any = {
seen
}
}

case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
def this() = this(null)

override def children = inputSet :: Nil
override def references = inputSet.references
override def nullable = false
override def dataType = LongType
override def toString = s"CombineAndCount($inputSet)"
override def newInstance() = new CombineSetsAndCountFunction(inputSet, this)
}

case class CombineSetsAndCountFunction(
@transient inputSet: Expression,
@transient base: AggregateExpression)
extends AggregateFunction {

def this() = this(null, null) // Required for serialization.

val seen = new OpenHashSet[Any]()

override def update(input: Row): Unit = {
val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
val inputIterator = inputSetEval.iterator
while (inputIterator.hasNext) {
seen.add(inputIterator.next)
}
}

override def eval(input: Row): Any = seen.size.toLong
}

case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
Expand Down Expand Up @@ -379,17 +463,22 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
}

case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression)
extends AggregateFunction {
case class CountDistinctFunction(
@transient expr: Seq[Expression],
@transient base: AggregateExpression)
extends MergableAggregateFunction {

def this() = this(null, null) // Required for serialization.

val seen = new scala.collection.mutable.HashSet[Any]()
val seen = new OpenHashSet[Any]()

@transient
val distinctValue = new InterpretedProjection(expr)

override def update(input: Row): Unit = {
val evaluatedExpr = expr.map(_.eval(input))
if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) {
seen += evaluatedExpr
val evaluatedExpr = distinctValue(input)
if (!evaluatedExpr.anyNull) {
seen.add(evaluatedExpr)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,17 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet

override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
}

case class MaxOf(left: Expression, right: Expression) extends Expression {
type EvaluatedType = Any

override def nullable = left.nullable && right.nullable

override def children = left :: right :: Nil

override def references = (left.flatMap(_.references) ++ right.flatMap(_.references)).toSet

override def dataType = left.dataType

override def eval(input: Row): Any = ???
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._

class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]

/**
* A base class for generators of byte code to perform expression evaluation. Includes a set of
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
Expand Down Expand Up @@ -71,7 +74,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
* fundamental difference is that a ConcurrentMap persists all elements that are added to it until
* they are explicitly removed. A Cache on the other hand is generally configured to evict entries
* automatically, in order to constrain its memory footprint
* automatically, in order to constrain its memory footprint. Note that this cache does not use
* weak keys/values and thus does not respond to memory pressure.
*/
protected val cache = CacheBuilder.newBuilder()
.maximumSize(1000)
Expand Down Expand Up @@ -398,6 +402,75 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
$primitiveTerm = ${falseEval.primitiveTerm}
}
""".children

case NewSet(elementType) =>
q"""
val $nullTerm = false
val $primitiveTerm = new ${hashSetForType(elementType)}()
""".children

case AddItemToSet(item, set) =>
val itemEval = expressionEvaluator(item)
val setEval = expressionEvaluator(set)

val ArrayType(elementType, _) = set.dataType

itemEval.code ++ setEval.code ++
q"""
if (!${itemEval.nullTerm}) {
${setEval.primitiveTerm}
.asInstanceOf[${hashSetForType(elementType)}]
.add(${itemEval.primitiveTerm})
}

val $nullTerm = false
val $primitiveTerm = ${setEval.primitiveTerm}
""".children

case CombineSets(left, right) =>
val leftEval = expressionEvaluator(left)
val rightEval = expressionEvaluator(right)

val ArrayType(elementType, _) = left.dataType

leftEval.code ++ rightEval.code ++
q"""
val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
val iterator = rightSet.iterator
while (iterator.hasNext) {
leftSet.add(iterator.next())
}

val $nullTerm = false
val $primitiveTerm = leftSet
""".children

case MaxOf(e1, e2) =>
val eval1 = expressionEvaluator(e1)
val eval2 = expressionEvaluator(e2)

eval1.code ++ eval2.code ++
q"""
var $nullTerm = false
var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}

if (${eval1.nullTerm}) {
$nullTerm = ${eval2.nullTerm}
$primitiveTerm = ${eval2.primitiveTerm}
} else if (${eval2.nullTerm}) {
$nullTerm = ${eval1.nullTerm}
$primitiveTerm = ${eval1.primitiveTerm}
} else {
$nullTerm = false
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
$primitiveTerm = ${eval1.primitiveTerm}
} else {
$primitiveTerm = ${eval2.primitiveTerm}
}
}
""".children

}

// If there was no match in the partial function above, we fall back on calling the interpreted
Expand Down Expand Up @@ -437,6 +510,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")

protected def hashSetForType(dt: DataType) = dt match {
case IntegerType => typeOf[IntegerHashSet]
case LongType => typeOf[LongHashSet]
}

protected def primitiveForType(dt: DataType) = dt match {
case IntegerType => "Int"
case LongType => "Long"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
..${evaluatedExpression.code}
if(${evaluatedExpression.nullTerm})
setNullAt($iLit)
else
else {
nullBits($iLit) = false
$elementName = ${evaluatedExpression.primitiveTerm}
}
}
""".children : Seq[Tree]
}
Expand Down Expand Up @@ -106,9 +108,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if(value == null) {
setNullAt(i)
} else {
nullBits(i) = false
$elementName = value.asInstanceOf[${termForType(e.dataType)}]
return
}
return
}"""
}
q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
Expand Down Expand Up @@ -137,7 +140,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
q"if(i == $i) { $elementName = value; return }" :: Nil
q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
case _ => Nil
}

Expand Down
Loading

0 comments on commit 213ada8

Please sign in to comment.