Skip to content

Commit

Permalink
MapEntries IR; use in annotate. (hail-is#3176)
Browse files Browse the repository at this point in the history
* wip

* add MapEntries, use in annotate_entries

* fix test

* fix

* fix and disable

* use IR for smol col annotations
  • Loading branch information
Amanda Wang authored and konradjk committed Jun 12, 2018
1 parent 7b8872b commit e11fde4
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 39 deletions.
70 changes: 70 additions & 0 deletions src/main/scala/is/hail/expr/Relational.scala
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,76 @@ case class FilterRows(
}
}

case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR {

def children: IndexedSeq[BaseIR] = Array(child, newEntries)

def copy(newChildren: IndexedSeq[BaseIR]): MapEntries = {
assert(newChildren.length == 2)
MapEntries(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR])
}

val newRow = {
val arrayLength = ArrayLen(GetField(Ref("va"), MatrixType.entriesIdentifier))
val idxEnv = new Env[IR]()
.bind("g", ArrayRef(GetField(Ref("va"), MatrixType.entriesIdentifier), Ref("i")))
.bind("sa", ArrayRef(Ref("sa"), Ref("i")))
val entries = ArrayMap(ArrayRange(I32(0), arrayLength, I32(1)), "i", Subst(newEntries, idxEnv))
InsertFields(Ref("va"), Seq((MatrixType.entriesIdentifier, entries)))
}

val typ: MatrixType = {
Infer(newRow, None, new Env[Type]()
.bind("global", child.typ.globalType)
.bind("va", child.typ.rvRowType)
.bind("sa", TArray(child.typ.colType))
)
child.typ.copy(rvRowType = newRow.typ)
}

def execute(hc: HailContext): MatrixValue = {
val prev = child.execute(hc)

val localGlobalsType = typ.globalType
val localColsType = TArray(typ.colType)
val colValuesBc = prev.colValuesBc
val globalsBc = prev.globals.broadcast

val (rTyp, f) = ir.Compile[Long, Long, Long, Long](
"global", localGlobalsType,
"va", prev.typ.rvRowType,
"sa", localColsType,
newRow)
assert(rTyp == typ.rvRowType)

val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType) { it =>
val rvb = new RegionValueBuilder()
val newRV = RegionValue()
val rowF = f()

it.map { rv =>
val region = rv.region
val oldRow = rv.offset

rvb.set(region)
rvb.start(localGlobalsType)
rvb.addAnnotation(localGlobalsType, globalsBc.value)
val globals = rvb.end()

rvb.start(localColsType)
rvb.addAnnotation(localColsType, colValuesBc.value)
val cols = rvb.end()

val off = rowF(region, globals, false, oldRow, false, cols, false)

newRV.set(region, off)
newRV
}
}
prev.copy(typ = typ, rvd = newRVD)
}
}

case class TableValue(typ: TableType, globals: BroadcastValue, rvd: RVD) {
def rdd: RDD[Row] = {
val localRowType = typ.rowType
Expand Down
24 changes: 24 additions & 0 deletions src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ object Compile {
apply[AsmFunction5[Region, T0, Boolean, T1, Boolean, R], R](Seq((name0, typ0, classTag[T0]), (name1, typ1, classTag[T1])), body)
}

def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, T2: TypeInfo : ClassTag, R: TypeInfo : ClassTag](
name0: String,
typ0: Type,
name1: String,
typ1: Type,
name2: String,
typ2: Type,
body: IR): (Type, () => AsmFunction7[Region, T0, Boolean, T1, Boolean, T2, Boolean, R]) = {
assert(TypeToIRIntermediateClassTag(typ0) == classTag[T0])
assert(TypeToIRIntermediateClassTag(typ1) == classTag[T1])
assert(TypeToIRIntermediateClassTag(typ2) == classTag[T2])
val fb = FunctionBuilder.functionBuilder[Region, T0, Boolean, T1, Boolean, T2, Boolean, R]
var e = body
val env = new Env[IR]()
.bind(name0, In(0, typ0))
.bind(name1, In(1, typ1))
.bind(name2, In(2, typ2))
e = Subst(e, env)
Infer(e)
assert(TypeToIRIntermediateClassTag(e.typ) == classTag[R])
Emit(e, fb)
(e.typ, fb.result())
}

def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, T2: TypeInfo : ClassTag,
T3: TypeInfo : ClassTag, T4: TypeInfo : ClassTag, T5: TypeInfo : ClassTag,
R: TypeInfo : ClassTag](
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ object Infer {
case x@GetField(o, name, _) =>
infer(o)
val t = coerce[TStruct](o.typ)
assert(t.index(name).nonEmpty)
assert(t.index(name).nonEmpty, s"$name not in $t")
x.typ = -t.field(name).typ
case GetFieldMissingness(o, name) =>
infer(o)
Expand Down
85 changes: 49 additions & 36 deletions src/main/scala/is/hail/variant/MatrixTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.annotations._
import is.hail.check.Gen
import is.hail.linalg._
import is.hail.expr._
import is.hail.expr.ir
import is.hail.methods._
import is.hail.rvd._
import is.hail.table.{Table, TableSpec}
Expand Down Expand Up @@ -1578,49 +1579,61 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {

val globalsBc = globals.broadcast

val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD))
val asts = Parser.parseAnnotationExprsToAST(expr, ec, Some(Annotation.ENTRY_HEAD))

val inserterBuilder = new ArrayBuilder[Inserter]()
val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) =>
val (s, i) = gsig.structInsert(signature, ids)
inserterBuilder += i
s
}
val inserters = inserterBuilder.result()
val irs = asts.flatMap { case (f, a) => a.toIR().map((f, _)) }

val localNSamples = numCols
val fullRowType = rvRowType
val localColValuesBc = colValuesBc
val localEntriesIndex = entriesIndex
val colValuesIsSmall = colType.size == 1 && colType.types.head.isOfType(TString())
if (irs.length == asts.length && colValuesIsSmall) {
val newEntries = ir.InsertFields(ir.Ref("g"), irs)

insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
(fullRow, row)
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)
new MatrixTable(hc, MapEntries(ast, newEntries))
} else {

rvb.startArray(localNSamples)
val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD))

var i = 0
while (i < localNSamples) {
val entry = entries(i)
ec.setAll(row,
localColValuesBc.value(i),
entry,
globalsBc.value)
val inserterBuilder = new ArrayBuilder[Inserter]()
val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) =>
val (s, i) = gsig.structInsert(signature, ids)
inserterBuilder += i
s
}
val inserters = inserterBuilder.result()

val newEntry = f().zip(inserters)
.foldLeft(entry) { case (ga, (a, inserter)) =>
inserter(ga, a)
}
rvb.addAnnotation(newEntryType, newEntry)
val localNSamples = numCols
val fullRowType = rvRowType
val localColValuesBc = colValuesBc
val localEntriesIndex = entriesIndex

i += 1
}
rvb.endArray()
})
insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
(fullRow, row)
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)

rvb.startArray(localNSamples)

var i = 0
while (i < localNSamples) {
val entry = entries(i)
ec.setAll(row,
localColValuesBc.value(i),
entry,
globalsBc.value)

val newEntry = f().zip(inserters)
.foldLeft(entry) { case (ga, (a, inserter)) =>
inserter(ga, a)
}
rvb.addAnnotation(newEntryType, newEntry)

i += 1
}
rvb.endArray()
})
}
}

def filterCols(p: (Annotation, Int) => Boolean): MatrixTable = {
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/is/hail/io/ExportVCFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ class ExportVCFSuite extends SparkSuite {

TestUtils.interceptFatal("Invalid type for format field 'BOOL'. Found 'bool'.") {
ExportVCF(vds
.annotateEntriesExpr("g = {BOOL: true}"),
.annotateEntriesExpr("g.BOOL = true"),
out)
}

TestUtils.interceptFatal("Invalid type for format field 'AA'.") {
ExportVCF(vds
.annotateEntriesExpr("g = {AA: [[0]]}"),
.annotateEntriesExpr("g.AA = [[0]]"),
out)
}
}
Expand Down

0 comments on commit e11fde4

Please sign in to comment.