Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MapEntries IR; use in annotate. #3176

Merged
merged 6 commits into from
Mar 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/main/scala/is/hail/expr/Relational.scala
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,76 @@ case class FilterRows(
}
}

case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR {

def children: IndexedSeq[BaseIR] = Array(child, newEntries)

def copy(newChildren: IndexedSeq[BaseIR]): MapEntries = {
assert(newChildren.length == 2)
MapEntries(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR])
}

val newRow = {
val arrayLength = ArrayLen(GetField(Ref("va"), MatrixType.entriesIdentifier))
val idxEnv = new Env[IR]()
.bind("g", ArrayRef(GetField(Ref("va"), MatrixType.entriesIdentifier), Ref("i")))
.bind("sa", ArrayRef(Ref("sa"), Ref("i")))
val entries = ArrayMap(ArrayRange(I32(0), arrayLength, I32(1)), "i", Subst(newEntries, idxEnv))
InsertFields(Ref("va"), Seq((MatrixType.entriesIdentifier, entries)))
}

val typ: MatrixType = {
Infer(newRow, None, new Env[Type]()
.bind("global", child.typ.globalType)
.bind("va", child.typ.rvRowType)
.bind("sa", TArray(child.typ.colType))
)
child.typ.copy(rvRowType = newRow.typ)
}

def execute(hc: HailContext): MatrixValue = {
val prev = child.execute(hc)

val localGlobalsType = typ.globalType
val localColsType = TArray(typ.colType)
val colValuesBc = prev.colValuesBc
val globalsBc = prev.globals.broadcast

val (rTyp, f) = ir.Compile[Long, Long, Long, Long](
"global", localGlobalsType,
"va", prev.typ.rvRowType,
"sa", localColsType,
newRow)
assert(rTyp == typ.rvRowType)

val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType) { it =>
val rvb = new RegionValueBuilder()
val newRV = RegionValue()
val rowF = f()

it.map { rv =>
val region = rv.region
val oldRow = rv.offset

rvb.set(region)
rvb.start(localGlobalsType)
rvb.addAnnotation(localGlobalsType, globalsBc.value)
val globals = rvb.end()

rvb.start(localColsType)
rvb.addAnnotation(localColsType, colValuesBc.value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer addUnsafeArray here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and :( this is going to be so slow, the off-heap changes can't come fast enough!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a bit of trouble reasoning about this. rv.region is usually the same region for every record, right? Is this going to mean that the sample annotations are added to the region once for every record in a partition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:( yes. But we can't get around that right now since we're not creating the original region.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might be untenable, even in the short term. Let's discuss at checkin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re: addUnsafeArray---I kind of prefer the consistency of addAnnotation. It calls addUnsafeArray in the UnsafeIndexedSeq case anyways.

val cols = rvb.end()

val off = rowF(region, globals, false, oldRow, false, cols, false)

newRV.set(region, off)
newRV
}
}
prev.copy(typ = typ, rvd = newRVD)
}
}

case class TableValue(typ: TableType, globals: BroadcastValue, rvd: RVD) {
def rdd: RDD[Row] = {
val localRowType = typ.rowType
Expand Down
24 changes: 24 additions & 0 deletions src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ object Compile {
apply[AsmFunction5[Region, T0, Boolean, T1, Boolean, R], R](Seq((name0, typ0, classTag[T0]), (name1, typ1, classTag[T1])), body)
}

def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, T2: TypeInfo : ClassTag, R: TypeInfo : ClassTag](
name0: String,
typ0: Type,
name1: String,
typ1: Type,
name2: String,
typ2: Type,
body: IR): (Type, () => AsmFunction7[Region, T0, Boolean, T1, Boolean, T2, Boolean, R]) = {
assert(TypeToIRIntermediateClassTag(typ0) == classTag[T0])
assert(TypeToIRIntermediateClassTag(typ1) == classTag[T1])
assert(TypeToIRIntermediateClassTag(typ2) == classTag[T2])
val fb = FunctionBuilder.functionBuilder[Region, T0, Boolean, T1, Boolean, T2, Boolean, R]
var e = body
val env = new Env[IR]()
.bind(name0, In(0, typ0))
.bind(name1, In(1, typ1))
.bind(name2, In(2, typ2))
e = Subst(e, env)
Infer(e)
assert(TypeToIRIntermediateClassTag(e.typ) == classTag[R])
Emit(e, fb)
(e.typ, fb.result())
}

def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, T2: TypeInfo : ClassTag,
T3: TypeInfo : ClassTag, T4: TypeInfo : ClassTag, T5: TypeInfo : ClassTag,
R: TypeInfo : ClassTag](
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/is/hail/expr/ir/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ object Infer {
case x@GetField(o, name, _) =>
infer(o)
val t = coerce[TStruct](o.typ)
assert(t.index(name).nonEmpty)
assert(t.index(name).nonEmpty, s"$name not in $t")
x.typ = -t.field(name).typ
case GetFieldMissingness(o, name) =>
infer(o)
Expand Down
85 changes: 49 additions & 36 deletions src/main/scala/is/hail/variant/MatrixTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.annotations._
import is.hail.check.Gen
import is.hail.linalg._
import is.hail.expr._
import is.hail.expr.ir
import is.hail.methods._
import is.hail.rvd._
import is.hail.table.{Table, TableSpec}
Expand Down Expand Up @@ -1578,49 +1579,61 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {

val globalsBc = globals.broadcast

val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD))
val asts = Parser.parseAnnotationExprsToAST(expr, ec, Some(Annotation.ENTRY_HEAD))

val inserterBuilder = new ArrayBuilder[Inserter]()
val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) =>
val (s, i) = gsig.structInsert(signature, ids)
inserterBuilder += i
s
}
val inserters = inserterBuilder.result()
val irs = asts.flatMap { case (f, a) => a.toIR().map((f, _)) }

val localNSamples = numCols
val fullRowType = rvRowType
val localColValuesBc = colValuesBc
val localEntriesIndex = entriesIndex
val colValuesIsSmall = colType.size == 1 && colType.types.head.isOfType(TString())
if (irs.length == asts.length && colValuesIsSmall) {
val newEntries = ir.InsertFields(ir.Ref("g"), irs)

insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
(fullRow, row)
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)
new MatrixTable(hc, MapEntries(ast, newEntries))
} else {

rvb.startArray(localNSamples)
val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD))

var i = 0
while (i < localNSamples) {
val entry = entries(i)
ec.setAll(row,
localColValuesBc.value(i),
entry,
globalsBc.value)
val inserterBuilder = new ArrayBuilder[Inserter]()
val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) =>
val (s, i) = gsig.structInsert(signature, ids)
inserterBuilder += i
s
}
val inserters = inserterBuilder.result()

val newEntry = f().zip(inserters)
.foldLeft(entry) { case (ga, (a, inserter)) =>
inserter(ga, a)
}
rvb.addAnnotation(newEntryType, newEntry)
val localNSamples = numCols
val fullRowType = rvRowType
val localColValuesBc = colValuesBc
val localEntriesIndex = entriesIndex

i += 1
}
rvb.endArray()
})
insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
(fullRow, row)
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)

rvb.startArray(localNSamples)

var i = 0
while (i < localNSamples) {
val entry = entries(i)
ec.setAll(row,
localColValuesBc.value(i),
entry,
globalsBc.value)

val newEntry = f().zip(inserters)
.foldLeft(entry) { case (ga, (a, inserter)) =>
inserter(ga, a)
}
rvb.addAnnotation(newEntryType, newEntry)

i += 1
}
rvb.endArray()
})
}
}

def filterCols(p: (Annotation, Int) => Boolean): MatrixTable = {
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/is/hail/io/ExportVCFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ class ExportVCFSuite extends SparkSuite {

TestUtils.interceptFatal("Invalid type for format field 'BOOL'. Found 'bool'.") {
ExportVCF(vds
.annotateEntriesExpr("g = {BOOL: true}"),
.annotateEntriesExpr("g.BOOL = true"),
out)
}

TestUtils.interceptFatal("Invalid type for format field 'AA'.") {
ExportVCF(vds
.annotateEntriesExpr("g = {AA: [[0]]}"),
.annotateEntriesExpr("g.AA = [[0]]"),
out)
}
}
Expand Down