diff --git a/src/main/scala/is/hail/expr/AST.scala b/src/main/scala/is/hail/expr/AST.scala index 93719037e705..c101ec41633a 100644 --- a/src/main/scala/is/hail/expr/AST.scala +++ b/src/main/scala/is/hail/expr/AST.scala @@ -668,7 +668,7 @@ case class Apply(posn: Position, fn: String, args: Array[AST]) extends AST(posn, for { irArgs <- anyFailAllFail(args.map(_.toIR(agg))) ir <- tryPrimOpConversion(args.map(_.`type`).zip(irArgs)).orElse( - IRFunctionRegistry.lookupFunction(fn, args.map(_.`type`)) + IRFunctionRegistry.lookupConversion(fn, args.map(_.`type`)) .map { irf => irf(irArgs) }) } yield ir } @@ -781,7 +781,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST case _ => for { irs <- anyFailAllFail((lhs +: args).map(_.toIR(agg))) - ir <- IRFunctionRegistry.lookupFunction(method, (lhs +: args).map(_.`type`)) + ir <- IRFunctionRegistry.lookupConversion(method, (lhs +: args).map(_.`type`)) .map { irf => irf(irs) } } yield ir } diff --git a/src/main/scala/is/hail/expr/ir/Children.scala b/src/main/scala/is/hail/expr/ir/Children.scala index 77aa8c256a43..9cc17183d98c 100644 --- a/src/main/scala/is/hail/expr/ir/Children.scala +++ b/src/main/scala/is/hail/expr/ir/Children.scala @@ -76,7 +76,7 @@ object Children { none case Die(message) => none - case ApplyFunction(impl, args) => + case Apply(_, args, _) => args.toIndexedSeq } } diff --git a/src/main/scala/is/hail/expr/ir/Copy.scala b/src/main/scala/is/hail/expr/ir/Copy.scala index 413d12a06eb1..ee7eac163e38 100644 --- a/src/main/scala/is/hail/expr/ir/Copy.scala +++ b/src/main/scala/is/hail/expr/ir/Copy.scala @@ -104,8 +104,8 @@ object Copy { same case Die(message) => same - case ApplyFunction(impl, args) => - ApplyFunction(impl, children) + case Apply(fn, args, impl) => + Apply(fn, children, impl) } } } diff --git a/src/main/scala/is/hail/expr/ir/Emit.scala b/src/main/scala/is/hail/expr/ir/Emit.scala index 438b8ef06bac..c61c0d47d442 100644 --- a/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/src/main/scala/is/hail/expr/ir/Emit.scala @@ -501,7 +501,7 @@ private class Emit( present(fb.getArg[Boolean](i * 2 + 3)) case Die(m) => present(Code._throw(Code.newInstance[RuntimeException, String](m))) - case ApplyFunction(impl, args) => + case Apply(fn, args, impl) => val meth = methods.getOrElseUpdate(impl, { impl.argTypes.foreach(_.clear()) (impl.argTypes, args.map(a => a.typ)).zipped.foreach(_.unify(_)) diff --git a/src/main/scala/is/hail/expr/ir/IR.scala b/src/main/scala/is/hail/expr/ir/IR.scala index 8cc39ea61025..961ff67ba624 100644 --- a/src/main/scala/is/hail/expr/ir/IR.scala +++ b/src/main/scala/is/hail/expr/ir/IR.scala @@ -84,4 +84,4 @@ final case class InMissingness(i: Int) extends IR { val typ: Type = TBoolean() } // FIXME: should be type any final case class Die(message: String) extends IR { val typ = TVoid } -final case class ApplyFunction(implementation: IRFunction, args: Seq[IR]) extends IR { val typ = implementation.returnType } +final case class Apply(function: String, args: Seq[IR], var implementation: IRFunction = null) extends IR { def typ = implementation.returnType } diff --git a/src/main/scala/is/hail/expr/ir/Infer.scala b/src/main/scala/is/hail/expr/ir/Infer.scala index 4465c087e2aa..811219d414b1 100644 --- a/src/main/scala/is/hail/expr/ir/Infer.scala +++ b/src/main/scala/is/hail/expr/ir/Infer.scala @@ -1,5 +1,6 @@ package is.hail.expr.ir +import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.types._ object Infer { @@ -174,9 +175,11 @@ object Infer { assert(typ != null) case InMissingness(i) => case Die(msg) => - case ApplyFunction(impl, args) => + case x@Apply(fn, args, impl) => args.foreach(infer(_)) - assert(args.map(_.typ).zip(impl.argTypes).forall {case (i, j) => j.unify(i)}) + if (impl == null) + x.implementation = IRFunctionRegistry.lookupFunction(fn, args.map(_.typ)).get + assert(args.map(_.typ).zip(x.implementation.argTypes).forall {case (i, j) => j.unify(i)}) } } diff --git a/src/main/scala/is/hail/expr/ir/Recur.scala b/src/main/scala/is/hail/expr/ir/Recur.scala index d93a9a9bf1b5..41af8a885f15 100644 --- a/src/main/scala/is/hail/expr/ir/Recur.scala +++ b/src/main/scala/is/hail/expr/ir/Recur.scala @@ -41,6 +41,6 @@ object Recur { case In(i, typ) => ir case InMissingness(i) => ir case Die(message) => ir - case ApplyFunction(impl, args) => ApplyFunction(impl, args.map(f)) + case Apply(fn, args, impl) => Apply(fn, args.map(f), impl) } } diff --git a/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/src/main/scala/is/hail/expr/ir/functions/Functions.scala index 0ab43c664a7f..71b39d3e3265 100644 --- a/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -7,38 +7,60 @@ import is.hail.utils._ import is.hail.asm4s.coerce import scala.collection.mutable -import scala.reflect.ClassTag object IRFunctionRegistry { - val registry: mutable.Map[String, Seq[(Seq[Type], Seq[IR] => IR)]] = mutable.Map().withDefaultValue(Seq.empty) + val irRegistry: mutable.Map[String, Seq[(Seq[Type], Seq[IR] => IR)]] = mutable.Map().withDefaultValue(Seq.empty) + + val codeRegistry: mutable.Map[String, Seq[(Seq[Type], IRFunction)]] = mutable.Map().withDefaultValue(Seq.empty) def addIRFunction(f: IRFunction) { - val l = registry(f.name) - registry.put(f.name, - l :+ (f.argTypes, { args: Seq[IR] => - ApplyFunction(f, args) - })) + val l = codeRegistry(f.name) + codeRegistry.put(f.name, + l :+ (f.argTypes, f)) } def addIR(name: String, types: Seq[Type], f: Seq[IR] => IR) { - val l = registry(name) - registry.put(name, l :+ ((types, f))) + val l = irRegistry(name) + irRegistry.put(name, l :+ ((types, f))) + } + + def lookupFunction(name: String, args: Seq[Type]): Option[IRFunction] = { + val validF = codeRegistry(name).flatMap { case (ts, f) => + if (ts.length == args.length) { + ts.foreach(_.clear()) + if ((ts, args).zipped.forall(_.unify(_))) + Some(f) + else + None + } else + None + } + + validF match { + case Seq() => None + case Seq(x) => Some(x) + case _ => fatal(s"Multiple IRFunctions found that satisfy $name$args.") + } } - def lookupFunction(name: String, args: Seq[Type]): Option[Seq[IR] => IR] = { + def lookupConversion(name: String, args: Seq[Type]): Option[Seq[IR] => IR] = { assert(args.forall(_ != null)) - val validMethods = registry(name).flatMap { case (ts, f) => + val validIR = irRegistry(name).flatMap { case (ts, f) => if (ts.length == args.length) { ts.foreach(_.clear()) if ((ts, args).zipped.forall(_.unify(_))) Some(f) else None - } else { + } else None - } } + + val validMethods = validIR ++ lookupFunction(name, args).map { f => + { args: Seq[IR] => Apply(name, args, f) } + } + validMethods match { case Seq() => None case Seq(x) => Some(x) diff --git a/src/test/scala/is/hail/expr/ir/CompileSuite.scala b/src/test/scala/is/hail/expr/ir/CompileSuite.scala index 8a76c7130d22..79ad29f764c7 100644 --- a/src/test/scala/is/hail/expr/ir/CompileSuite.scala +++ b/src/test/scala/is/hail/expr/ir/CompileSuite.scala @@ -494,7 +494,7 @@ class CompileSuite { val a2t = TArray(TString()) val a1 = In(0, TArray(TInt32())) val a2 = In(1, TArray(TString())) - val min = IRFunctionRegistry.lookupFunction("min", Seq(TArray(TInt32()))).get + val min = IRFunctionRegistry.lookupConversion("min", Seq(TArray(TInt32()))).get val range = ArrayRange(I32(0), min(Seq(MakeArray(Seq(ArrayLen(a1), ArrayLen(a2))))), I32(1)) val ir = ArrayMap(range, "i", MakeTuple(Seq(ArrayRef(a1, Ref("i")), ArrayRef(a2, Ref("i"))))) val region = Region() diff --git a/src/test/scala/is/hail/expr/ir/FunctionSuite.scala b/src/test/scala/is/hail/expr/ir/FunctionSuite.scala index 5e6302ae4119..f597ee23f9f1 100644 --- a/src/test/scala/is/hail/expr/ir/FunctionSuite.scala +++ b/src/test/scala/is/hail/expr/ir/FunctionSuite.scala @@ -58,10 +58,8 @@ class FunctionSuite { fb.result(Some(new PrintWriter(System.out)))() } - def lookup(meth: String, types: Type*)(irs: IR*): IR = { - val possible = IRFunctionRegistry.registry(meth) - IRFunctionRegistry.lookupFunction(meth, types).get(irs) - } + def lookup(meth: String, types: Type*)(irs: IR*): IR = + IRFunctionRegistry.lookupConversion(meth, types).get(irs) @Test def testCodeFunction() { @@ -110,10 +108,10 @@ class FunctionSuite { @Test def testVariableUnification() { - assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt32(), TInt32())).isDefined) - assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt64(), TInt32())).isEmpty) - assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt64(), TInt64())).isEmpty) - assert(IRFunctionRegistry.lookupFunction("testCodeUnification2", Seq(TArray(TInt32()))).isDefined) + assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt32(), TInt32())).isDefined) + assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt64(), TInt32())).isEmpty) + assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt64(), TInt64())).isEmpty) + assert(IRFunctionRegistry.lookupConversion("testCodeUnification2", Seq(TArray(TInt32()))).isDefined) } @Test