Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IRFunctionRegistry changes #3202

Merged
merged 3 commits into from
Mar 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/main/scala/is/hail/expr/AST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ case class Apply(posn: Position, fn: String, args: Array[AST]) extends AST(posn,
for {
irArgs <- anyFailAllFail(args.map(_.toIR(agg)))
ir <- tryPrimOpConversion(args.map(_.`type`).zip(irArgs)).orElse(
IRFunctionRegistry.lookupFunction(fn, args.map(_.`type`))
IRFunctionRegistry.lookupConversion(fn, args.map(_.`type`))
.map { irf => irf(irArgs) })
} yield ir
}
Expand Down Expand Up @@ -781,7 +781,7 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
case _ =>
for {
irs <- anyFailAllFail((lhs +: args).map(_.toIR(agg)))
ir <- IRFunctionRegistry.lookupFunction(method, (lhs +: args).map(_.`type`))
ir <- IRFunctionRegistry.lookupConversion(method, (lhs +: args).map(_.`type`))
.map { irf => irf(irs) }
} yield ir
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Children.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object Children {
none
case Die(message) =>
none
case ApplyFunction(impl, args) =>
case Apply(_, args, _) =>
args.toIndexedSeq
}
}
4 changes: 2 additions & 2 deletions src/main/scala/is/hail/expr/ir/Copy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ object Copy {
same
case Die(message) =>
same
case ApplyFunction(impl, args) =>
ApplyFunction(impl, children)
case Apply(fn, args, impl) =>
Apply(fn, children, impl)
}
}
}
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ private class Emit(
present(fb.getArg[Boolean](i * 2 + 3))
case Die(m) =>
present(Code._throw(Code.newInstance[RuntimeException, String](m)))
case ApplyFunction(impl, args) =>
case Apply(fn, args, impl) =>
val meth = methods.getOrElseUpdate(impl, {
impl.argTypes.foreach(_.clear())
(impl.argTypes, args.map(a => a.typ)).zipped.foreach(_.unify(_))
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,4 @@ final case class InMissingness(i: Int) extends IR { val typ: Type = TBoolean() }
// FIXME: should be type any
final case class Die(message: String) extends IR { val typ = TVoid }

final case class ApplyFunction(implementation: IRFunction, args: Seq[IR]) extends IR { val typ = implementation.returnType }
final case class Apply(function: String, args: Seq[IR], var implementation: IRFunction = null) extends IR { def typ = implementation.returnType }
7 changes: 5 additions & 2 deletions src/main/scala/is/hail/expr/ir/Infer.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package is.hail.expr.ir

import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.types._

object Infer {
Expand Down Expand Up @@ -174,9 +175,11 @@ object Infer {
assert(typ != null)
case InMissingness(i) =>
case Die(msg) =>
case ApplyFunction(impl, args) =>
case x@Apply(fn, args, impl) =>
args.foreach(infer(_))
assert(args.map(_.typ).zip(impl.argTypes).forall {case (i, j) => j.unify(i)})
if (impl == null)
x.implementation = IRFunctionRegistry.lookupFunction(fn, args.map(_.typ)).get
assert(args.map(_.typ).zip(x.implementation.argTypes).forall {case (i, j) => j.unify(i)})
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Recur.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ object Recur {
case In(i, typ) => ir
case InMissingness(i) => ir
case Die(message) => ir
case ApplyFunction(impl, args) => ApplyFunction(impl, args.map(f))
case Apply(fn, args, impl) => Apply(fn, args.map(f), impl)
}
}
48 changes: 35 additions & 13 deletions src/main/scala/is/hail/expr/ir/functions/Functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,60 @@ import is.hail.utils._
import is.hail.asm4s.coerce

import scala.collection.mutable
import scala.reflect.ClassTag

object IRFunctionRegistry {

val registry: mutable.Map[String, Seq[(Seq[Type], Seq[IR] => IR)]] = mutable.Map().withDefaultValue(Seq.empty)
val irRegistry: mutable.Map[String, Seq[(Seq[Type], Seq[IR] => IR)]] = mutable.Map().withDefaultValue(Seq.empty)

val codeRegistry: mutable.Map[String, Seq[(Seq[Type], IRFunction)]] = mutable.Map().withDefaultValue(Seq.empty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@catoverdrive Sorry I'm late to the game here. We should use mutable.MultiMap when we have Map[K, Set[V]]. Unless there's a compelling case for Seq here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That also makes addIRFunction and addIR just calls to irRegistry.addBinding and codeRegistry.addBinding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

womp, I didn't know this was a thing. I'll change.


def addIRFunction(f: IRFunction) {
val l = registry(f.name)
registry.put(f.name,
l :+ (f.argTypes, { args: Seq[IR] =>
ApplyFunction(f, args)
}))
val l = codeRegistry(f.name)
codeRegistry.put(f.name,
l :+ (f.argTypes, f))
}

def addIR(name: String, types: Seq[Type], f: Seq[IR] => IR) {
val l = registry(name)
registry.put(name, l :+ ((types, f)))
val l = irRegistry(name)
irRegistry.put(name, l :+ ((types, f)))
}

def lookupFunction(name: String, args: Seq[Type]): Option[IRFunction] = {
val validF = codeRegistry(name).flatMap { case (ts, f) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use any fail all fail here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the things (except the return) are Options? I don't see how it would work here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, got it. nevermind.

if (ts.length == args.length) {
ts.foreach(_.clear())
if ((ts, args).zipped.forall(_.unify(_)))
Some(f)
else
None
} else
None
}

validF match {
case Seq() => None
case Seq(x) => Some(x)
case _ => fatal(s"Multiple IRFunctions found that satisfy $name$args.")
}
}

def lookupFunction(name: String, args: Seq[Type]): Option[Seq[IR] => IR] = {
def lookupConversion(name: String, args: Seq[Type]): Option[Seq[IR] => IR] = {
Copy link
Contributor

@danking danking Mar 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lookupConversion and lookupFunction are suspiciously similar. I'm not sure I see the compelling reason to keep them separate.

Would the code look simpler if we had one MultiMap[String, IRFun] where IRFun was something like:

trait IRFun {
  def args: Seq[Type]
  def ret: Type
  def emit(
    emitter: IR => (Code[Unit], Code[Boolean], Code[_]),
    mb: MethodBuilder,
    args: IR*
  ): (Code[Unit], Code[Boolean], Code[_])
}

class IRMacro(
  val args: Seq[Type],
  val ret: Type,
  val impl: Array[IR] => IR
) extends IRFun {
  def emit(
    emitter: IR => (Code[Unit], Code[Boolean], Code[_]),
    mb: MethodBuilder,
    args: IR*
  ): (Code[Unit], Code[Boolean], Code[_]) = emitter(impl(args.toArray))
}

class JVMFun(
  val args: Seq[Type],
  retType: Type,
  impl: (MethodBuilder, Array[Code[_]]) => Code[_]
) extends IRFun {
  def emit(
    emitter: IR => (Code[Unit], Code[Boolean], Code[_]),
    mb: MethodBuilder,
    args: IR*
  ): (Code[Unit], Code[Boolean], Code[_]) = {
    val (setup, ms, vs) = args.map(emitter).unzip
    (Code(setup:_*), ms.fold(_ || _), impl(mb, vs))
  }
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe let JVMFun also control missingness?

class MissingAwareJVMFun(
  val args: Seq[Type],
  retType: Type,
  impl: (MethodBuilder, Array[Code[_]]) => (Code[Unit], Code[Boolean], Code[_])
) extends IRFun {
  def emit(
    emitter: IR => (Code[Unit], Code[Boolean], Code[_]),
    mb: MethodBuilder,
    args: IR*
  ): (Code[Unit], Code[Boolean], Code[_]) = {
    val (setups, ms, vs) = args.map(emitter).unzip
    val (setup, m, v) = impl(mb, vs)
    (Code(setups ++ setup:_*), (ms :+ m).fold(_ || _), v)
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really want to entangle them because of two resasons:

  • Things that can be translated directly (Seq[IR] => IR) are maybe things that we don't want to keep in the function registry once we can directly pass AST nodes from Python. I kind of want all of them to go away.
  • Even if we do want to keep them in the function registry, I want to be able to look up the non-IR conversions post optimization, when we go to compile the code. Once we're no longer translating from AST, we don't need to store the implementation as part of the IR node and can look it up after we're done optimizing the overall IR. I feel like the things that are implemented in terms of existing IR nodes should just go directly into the IR tree and optimized away with the rest of it. (this is kind of what currently happens; Infer on Apply nodes does a lookup on the IRFunction registry to be able to properly infer the return type)

assert(args.forall(_ != null))
val validMethods = registry(name).flatMap { case (ts, f) =>
val validIR = irRegistry(name).flatMap { case (ts, f) =>
if (ts.length == args.length) {
ts.foreach(_.clear())
if ((ts, args).zipped.forall(_.unify(_)))
Some(f)
else
None
} else {
} else
None
}
}

val validMethods = validIR ++ lookupFunction(name, args).map { f =>
{ args: Seq[IR] => Apply(name, args, f) }
}

validMethods match {
case Seq() => None
case Seq(x) => Some(x)
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/is/hail/expr/ir/CompileSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ class CompileSuite {
val a2t = TArray(TString())
val a1 = In(0, TArray(TInt32()))
val a2 = In(1, TArray(TString()))
val min = IRFunctionRegistry.lookupFunction("min", Seq(TArray(TInt32()))).get
val min = IRFunctionRegistry.lookupConversion("min", Seq(TArray(TInt32()))).get
val range = ArrayRange(I32(0), min(Seq(MakeArray(Seq(ArrayLen(a1), ArrayLen(a2))))), I32(1))
val ir = ArrayMap(range, "i", MakeTuple(Seq(ArrayRef(a1, Ref("i")), ArrayRef(a2, Ref("i")))))
val region = Region()
Expand Down
14 changes: 6 additions & 8 deletions src/test/scala/is/hail/expr/ir/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ class FunctionSuite {
fb.result(Some(new PrintWriter(System.out)))()
}

def lookup(meth: String, types: Type*)(irs: IR*): IR = {
val possible = IRFunctionRegistry.registry(meth)
IRFunctionRegistry.lookupFunction(meth, types).get(irs)
}
def lookup(meth: String, types: Type*)(irs: IR*): IR =
IRFunctionRegistry.lookupConversion(meth, types).get(irs)

@Test
def testCodeFunction() {
Expand Down Expand Up @@ -110,10 +108,10 @@ class FunctionSuite {

@Test
def testVariableUnification() {
assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt32(), TInt32())).isDefined)
assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt64(), TInt32())).isEmpty)
assert(IRFunctionRegistry.lookupFunction("testCodeUnification", Seq(TInt64(), TInt64())).isEmpty)
assert(IRFunctionRegistry.lookupFunction("testCodeUnification2", Seq(TArray(TInt32()))).isDefined)
assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt32(), TInt32())).isDefined)
assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt64(), TInt32())).isEmpty)
assert(IRFunctionRegistry.lookupConversion("testCodeUnification", Seq(TInt64(), TInt64())).isEmpty)
assert(IRFunctionRegistry.lookupConversion("testCodeUnification2", Seq(TArray(TInt32()))).isDefined)
}

@Test
Expand Down