Skip to content

Commit

Permalink
IRFunctionRegistry changes (#3202)
Browse files Browse the repository at this point in the history
* separate IR conversions from IRFunction conversions in IRFunctionRegistry

* change ApplyFunction(impl, args) => Apply(fn, args, impl)

* cleanup
  • Loading branch information
Amanda Wang authored and tpoterba committed Mar 21, 2018
1 parent c1cc9ed commit 127a3de
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/is/hail/expr/AST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ case class Apply(posn: Position, fn: String, args: Array[AST]) extends AST(posn,
for {
irArgs <- anyFailAllFail(args.map(_.toIR(agg)))
ir <- tryPrimOpConversion(args.map(_.`type`).zip(irArgs)).orElse(
IRFunctionRegistry.lookupFunction(fn, args.map(_.`type`))
IRFunctionRegistry.lookupConversion(fn, args.map(_.`type`))
.map { irf => irf(irArgs) })
} yield ir
}
Expand Down Expand Up @@ -781,7 +781,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
case _ =>
for {
irs <- anyFailAllFail((lhs +: args).map(_.toIR(agg)))
ir <- IRFunctionRegistry.lookupFunction(method, (lhs +: args).map(_.`type`))
ir <- IRFunctionRegistry.lookupConversion(method, (lhs +: args).map(_.`type`))
.map { irf => irf(irs) }
} yield ir
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Children.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object Children {
none
case Die(message) =>
none
case ApplyFunction(impl, args) =>
case Apply(_, args, _) =>
args.toIndexedSeq
}
}
4 changes: 2 additions & 2 deletions src/main/scala/is/hail/expr/ir/Copy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ object Copy {
same
case Die(message) =>
same
case ApplyFunction(impl, args) =>
ApplyFunction(impl, children)
case Apply(fn, args, impl) =>
Apply(fn, children, impl)
}
}
}
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ private class Emit(
present(fb.getArg[Boolean](i * 2 + 3))
case Die(m) =>
present(Code._throw(Code.newInstance[RuntimeException, String](m)))
case ApplyFunction(impl, args) =>
case Apply(fn, args, impl) =>
val meth = methods.getOrElseUpdate(impl, {
impl.argTypes.foreach(_.clear())
(impl.argTypes, args.map(a => a.typ)).zipped.foreach(_.unify(_))
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,4 @@ final case class InMissingness(i: Int) extends IR { val typ: Type = TBoolean() }
// FIXME: should be type any
final case class Die(message: String) extends IR { val typ = TVoid }

final case class ApplyFunction(implementation: IRFunction, args: Seq[IR]) extends IR { val typ = implementation.returnType }
final case class Apply(function: String, args: Seq[IR], var implementation: IRFunction = null) extends IR { def typ = implementation.returnType }
7 changes: 5 additions & 2 deletions src/main/scala/is/hail/expr/ir/Infer.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package is.hail.expr.ir

import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.types._

object Infer {
Expand Down Expand Up @@ -174,9 +175,11 @@ object Infer {
assert(typ != null)
case InMissingness(i) =>
case Die(msg) =>
case ApplyFunction(impl, args) =>
case x@Apply(fn, args, impl) =>
args.foreach(infer(_))
assert(args.map(_.typ).zip(impl.argTypes).forall {case (i, j) => j.unify(i)})
if (impl == null)
x.implementation = IRFunctionRegistry.lookupFunction(fn, args.map(_.typ)).get
assert(args.map(_.typ).zip(x.implementation.argTypes).forall {case (i, j) => j.unify(i)})
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Recur.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ object Recur {
case In(i, typ) => ir
case InMissingness(i) => ir
case Die(message) => ir
case ApplyFunction(impl, args) => ApplyFunction(impl, args.map(f))
case Apply(fn, args, impl) => Apply(fn, args.map(f), impl)
}
}
48 changes: 35 additions & 13 deletions src/main/scala/is/hail/expr/ir/functions/Functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,60 @@ import is.hail.utils._
import is.hail.asm4s.coerce

import scala.collection.mutable
import scala.reflect.ClassTag

object IRFunctionRegistry {

val registry: mutable.Map[String, Seq[(Seq[Type], Seq[IR] => IR)]] = mutable.Map().withDefaultValue(Seq.empty)
val irRegistry: mutable.Map[String, Seq[(Seq[Type], Seq[IR] => IR)]] = mutable.Map().withDefaultValue(Seq.empty)

val codeRegistry: mutable.Map[String, Seq[(Seq[Type], IRFunction)]] = mutable.Map().withDefaultValue(Seq.empty)

def addIRFunction(f: IRFunction) {
val l = registry(f.name)
registry.put(f.name,
l :+ (f.argTypes, { args: Seq[IR] =>
ApplyFunction(f, args)
}))
val l = codeRegistry(f.name)
codeRegistry.put(f.name,
l :+ (f.argTypes, f))
}

def addIR(name: String, types: Seq[Type], f: Seq[IR] => IR) {
val l = registry(name)
registry.put(name, l :+ ((types, f)))
val l = irRegistry(name)
irRegistry.put(name, l :+ ((types, f)))
}

def lookupFunction(name: String, args: Seq[Type]): Option[IRFunction] = {
val validF = codeRegistry(name).flatMap { case (ts, f) =>
if (ts.length == args.length) {
ts.foreach(_.clear())
if ((ts, args).zipped.forall(_.unify(_)))
Some(f)
else
None
} else
None
}

validF match {
case Seq() => None
case Seq(x) => Some(x)
case _ => fatal(s"Multiple IRFunctions found that satisfy $name$args.")
}
}

def lookupFunction(name: String, args: Seq[Type]): Option[Seq[IR] => IR] = {
def lookupConversion(name: String, args: Seq[Type]): Option[Seq[IR] => IR] = {
assert(args.forall(_ != null))
val validMethods = registry(name).flatMap { case (ts, f) =>
val validIR = irRegistry(name).flatMap { case (ts, f) =>
if (ts.length == args.length) {
ts.foreach(_.clear())
if ((ts, args).zipped.forall(_.unify(_)))
Some(f)
else
None
} else {
} else
None
}
}

val validMethods = validIR ++ lookupFunction(name, args).map { f =>
{ args: Seq[IR] => Apply(name, args, f) }
}

validMethods match {
case Seq() => None
case Seq(x) => Some(x)
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/is/hail/expr/ir/CompileSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ class CompileSuite {
val a2t = TArray(TString())
val a1 = In(0, TArray(TInt32()))
val a2 = In(1, TArray(TString()))
val min = IRFunctionRegistry.lookupFunction("min", Seq(TArray(TInt32()))).get
val min = IRFunctionRegistry.lookupConversion("min", Seq(TArray(TInt32()))).get
val range = ArrayRange(I32(0), min(Seq(MakeArray(Seq(ArrayLen(a1), ArrayLen(a2))))), I32(1))
val ir = ArrayMap(range, "i", MakeTuple(Seq(ArrayRef(a1, Ref("i")), ArrayRef(a2, Ref("i")))))
val region = Region()
Expand Down
14 changes: 6 additions & 8 deletions src/test/scala/is/hail/expr/ir/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ class FunctionSuite {
fb.result(Some(new PrintWriter(System.out)))()
}

def lookup(meth: String, types: Type*)(irs: IR*): IR = {
val possible = IRFunctionRegistry.registry(meth)
IRFunctionRegistry.lookupFunction(meth, types).get(irs)
}
def lookup(meth: String, types: Type*)(irs: IR*): IR =
IRFunctionRegistry.lookupConversion(meth, types).get(irs)

@Test
def testCodeFunction() {
Expand Down Expand Up @@ -110,10 +108,10 @@ class FunctionSuite {

@Test
def testVariableUnification() {
assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt32(), TInt32())).isDefined)
assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt64(), TInt32())).isEmpty)
assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt64(), TInt64())).isEmpty)
assert(IRFunctionRegistry.lookupFunction("testCodeUnification2", Seq(TArray(TInt32()))).isDefined)
assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt32(), TInt32())).isDefined)
assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt64(), TInt32())).isEmpty)
assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt64(), TInt64())).isEmpty)
assert(IRFunctionRegistry.lookupConversion("testCodeUnification2", Seq(TArray(TInt32()))).isDefined)
}

@Test
Expand Down

0 comments on commit 127a3de

Please sign in to comment.