Skip to content

Commit

Permalink
add MapEntries, use in annotate_entries
Browse files Browse the repository at this point in the history
  • Loading branch information
wang committed Mar 19, 2018
1 parent cce0791 commit ed4b6d5
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 43 deletions.
46 changes: 41 additions & 5 deletions src/main/scala/is/hail/expr/Relational.scala
Original file line number Diff line number Diff line change
Expand Up @@ -384,17 +384,25 @@ case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR {
def children: IndexedSeq[BaseIR] = Array(child, newEntries)

def copy(newChildren: IndexedSeq[BaseIR]): MapEntries = {
assert(newChildren.length == 1)
assert(newChildren.length == 2)
MapEntries(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR])
}

val newRow = {
val ArrayLen(GetField(Ref("va"), MatrixType.entriesIdentifier))
val arrayLength = ArrayLen(GetField(Ref("va"), MatrixType.entriesIdentifier))
val idxEnv = new Env[IR]()
.bind("g", ArrayRef(GetField(Ref("va"), MatrixType.entriesIdentifier), Ref("i")))
.bind("sa", ArrayRef(Ref("sa"), Ref("i")))
val entries = ArrayMap(ArrayRange(I32(0), arrayLength, I32(1)), "i", Subst(newEntries, idxEnv))
InsertFields(Ref("va"), Seq((MatrixType.entriesIdentifier, entries)))
}
InsertFields(Ref("va"), Seq((MatrixType.entriesIdentifier, ArrayMap(GetField(Ref("va"), MatrixType.entriesIdentifier), "g", newEntries))))

val typ: MatrixType = {
Infer(newRow, None, new Env[Type]().bind("va", child.typ.rvRowType))
Infer(newRow, None, new Env[Type]()
.bind("global", child.typ.globalType)
.bind("va", child.typ.rvRowType)
.bind("sa", TArray(child.typ.colType))
)
child.typ.copy(rvRowType=newRow.typ)
}

Expand All @@ -408,8 +416,36 @@ case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR {
newRow)
assert(rTyp == typ.rvRowType)

prev.rvd.mapPartitionsPreservesPartitioning()
val localGlobalsType = typ.globalType
val localColsType = TArray(typ.colType)
val colValuesBc = prev.colValuesBc
val globalsBc = prev.globals.broadcast

val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType) { it =>
val rvb = new RegionValueBuilder()
val newRV = RegionValue()
val rowF = f()

it.map { rv =>
val region = rv.region
val oldRow = rv.offset

rvb.set(region)
rvb.start(localGlobalsType)
rvb.addAnnotation(localGlobalsType, globalsBc.value)
val globals = rvb.end()

rvb.start(localColsType)
rvb.addAnnotation(localColsType, colValuesBc.value)
val cols = rvb.end()

val off = rowF(region, globals, false, oldRow, false, cols, false)

newRV.set(region, off)
newRV
}
}
prev.copy(typ = typ, rvd = newRVD)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object Compile {
val env = new Env[IR]()
.bind(name0, In(0, typ0))
.bind(name1, In(1, typ1))
.bind(name1, In(2, typ2))
.bind(name2, In(2, typ2))
e = Subst(e, env)
Infer(e)
assert(TypeToIRIntermediateClassTag(e.typ) == classTag[R])
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ object Infer {
case x@GetField(o, name, _) =>
infer(o)
val t = coerce[TStruct](o.typ)
assert(t.index(name).nonEmpty)
assert(t.index(name).nonEmpty, s"$name not in $t")
x.typ = -t.field(name).typ
case GetFieldMissingness(o, name) =>
infer(o)
Expand Down
85 changes: 49 additions & 36 deletions src/main/scala/is/hail/variant/MatrixTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.annotations._
import is.hail.check.Gen
import is.hail.linalg._
import is.hail.expr._
import is.hail.expr.ir
import is.hail.methods._
import is.hail.rvd._
import is.hail.table.{Table, TableSpec}
Expand Down Expand Up @@ -1578,49 +1579,61 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {

val globalsBc = globals.broadcast

val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD))
val asts = Parser.parseAnnotationExprsToAST(expr, ec, Some(Annotation.ENTRY_HEAD))

val inserterBuilder = new ArrayBuilder[Inserter]()
val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) =>
val (s, i) = gsig.structInsert(signature, ids)
inserterBuilder += i
s
}
val inserters = inserterBuilder.result()
val irs = asts.flatMap { case (f, a) => a.toIR().map((f, _)) }

val localNSamples = numCols
val fullRowType = rvRowType
val localColValuesBc = colValuesBc
val localEntriesIndex = entriesIndex
if (irs.length == asts.length) {
val newEntries = ir.InsertFields(ir.Ref("g"), irs)

insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
(fullRow, row)
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)
new MatrixTable(hc, MapEntries(ast, newEntries))
} else {
info("No IR conversion found for annotate_entries. Falling back to AST.")

rvb.startArray(localNSamples)
val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD))

var i = 0
while (i < localNSamples) {
val entry = entries(i)
ec.setAll(row,
localColValuesBc.value(i),
entry,
globalsBc.value)
val inserterBuilder = new ArrayBuilder[Inserter]()
val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) =>
val (s, i) = gsig.structInsert(signature, ids)
inserterBuilder += i
s
}
val inserters = inserterBuilder.result()

val newEntry = f().zip(inserters)
.foldLeft(entry) { case (ga, (a, inserter)) =>
inserter(ga, a)
}
rvb.addAnnotation(newEntryType, newEntry)
val localNSamples = numCols
val fullRowType = rvRowType
val localColValuesBc = colValuesBc
val localEntriesIndex = entriesIndex

i += 1
}
rvb.endArray()
})
insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
(fullRow, row)
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)

rvb.startArray(localNSamples)

var i = 0
while (i < localNSamples) {
val entry = entries(i)
ec.setAll(row,
localColValuesBc.value(i),
entry,
globalsBc.value)

val newEntry = f().zip(inserters)
.foldLeft(entry) { case (ga, (a, inserter)) =>
inserter(ga, a)
}
rvb.addAnnotation(newEntryType, newEntry)

i += 1
}
rvb.endArray()
})
}
}

def filterCols(p: (Annotation, Int) => Boolean): MatrixTable = {
Expand Down

0 comments on commit ed4b6d5

Please sign in to comment.