From ed4b6d51641b62310517f1e84272388bd4447531 Mon Sep 17 00:00:00 2001 From: wang Date: Mon, 19 Mar 2018 13:15:34 -0400 Subject: [PATCH] add MapEntries, use in annotate_entries --- src/main/scala/is/hail/expr/Relational.scala | 46 ++++++++-- src/main/scala/is/hail/expr/ir/Compile.scala | 2 +- src/main/scala/is/hail/expr/ir/Infer.scala | 2 +- .../scala/is/hail/variant/MatrixTable.scala | 85 +++++++++++-------- 4 files changed, 92 insertions(+), 43 deletions(-) diff --git a/src/main/scala/is/hail/expr/Relational.scala b/src/main/scala/is/hail/expr/Relational.scala index b1abb46b22b2..828d8691f808 100644 --- a/src/main/scala/is/hail/expr/Relational.scala +++ b/src/main/scala/is/hail/expr/Relational.scala @@ -384,17 +384,25 @@ case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { def children: IndexedSeq[BaseIR] = Array(child, newEntries) def copy(newChildren: IndexedSeq[BaseIR]): MapEntries = { - assert(newChildren.length == 1) + assert(newChildren.length == 2) MapEntries(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) } val newRow = { - val ArrayLen(GetField(Ref("va"), MatrixType.entriesIdentifier)) + val arrayLength = ArrayLen(GetField(Ref("va"), MatrixType.entriesIdentifier)) + val idxEnv = new Env[IR]() + .bind("g", ArrayRef(GetField(Ref("va"), MatrixType.entriesIdentifier), Ref("i"))) + .bind("sa", ArrayRef(Ref("sa"), Ref("i"))) + val entries = ArrayMap(ArrayRange(I32(0), arrayLength, I32(1)), "i", Subst(newEntries, idxEnv)) + InsertFields(Ref("va"), Seq((MatrixType.entriesIdentifier, entries))) } - InsertFields(Ref("va"), Seq((MatrixType.entriesIdentifier, ArrayMap(GetField(Ref("va"), MatrixType.entriesIdentifier), "g", newEntries)))) val typ: MatrixType = { - Infer(newRow, None, new Env[Type]().bind("va", child.typ.rvRowType)) + Infer(newRow, None, new Env[Type]() + .bind("global", child.typ.globalType) + .bind("va", child.typ.rvRowType) + .bind("sa", TArray(child.typ.colType)) + ) child.typ.copy(rvRowType=newRow.typ) } @@ -408,8 +416,36 @@ case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { newRow) assert(rTyp == typ.rvRowType) - prev.rvd.mapPartitionsPreservesPartitioning() + val localGlobalsType = typ.globalType + val localColsType = TArray(typ.colType) + val colValuesBc = prev.colValuesBc + val globalsBc = prev.globals.broadcast + + val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType) { it => + val rvb = new RegionValueBuilder() + val newRV = RegionValue() + val rowF = f() + + it.map { rv => + val region = rv.region + val oldRow = rv.offset + + rvb.set(region) + rvb.start(localGlobalsType) + rvb.addAnnotation(localGlobalsType, globalsBc.value) + val globals = rvb.end() + rvb.start(localColsType) + rvb.addAnnotation(localColsType, colValuesBc.value) + val cols = rvb.end() + + val off = rowF(region, globals, false, oldRow, false, cols, false) + + newRV.set(region, off) + newRV + } + } + prev.copy(typ = typ, rvd = newRVD) } } diff --git a/src/main/scala/is/hail/expr/ir/Compile.scala b/src/main/scala/is/hail/expr/ir/Compile.scala index 29f39a705924..881b3079b240 100644 --- a/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/src/main/scala/is/hail/expr/ir/Compile.scala @@ -55,7 +55,7 @@ object Compile { val env = new Env[IR]() .bind(name0, In(0, typ0)) .bind(name1, In(1, typ1)) - .bind(name1, In(2, typ2)) + .bind(name2, In(2, typ2)) e = Subst(e, env) Infer(e) assert(TypeToIRIntermediateClassTag(e.typ) == classTag[R]) diff --git a/src/main/scala/is/hail/expr/ir/Infer.scala b/src/main/scala/is/hail/expr/ir/Infer.scala index 667af75c0f8c..f0cde8016f96 100644 --- a/src/main/scala/is/hail/expr/ir/Infer.scala +++ b/src/main/scala/is/hail/expr/ir/Infer.scala @@ -151,7 +151,7 @@ object Infer { case x@GetField(o, name, _) => infer(o) val t = coerce[TStruct](o.typ) - assert(t.index(name).nonEmpty) + assert(t.index(name).nonEmpty, s"$name not in $t") x.typ = -t.field(name).typ case GetFieldMissingness(o, name) => infer(o) diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index 0dcd2c33e4b5..f930324546ea 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -4,6 +4,7 @@ import is.hail.annotations._ import is.hail.check.Gen import is.hail.linalg._ import is.hail.expr._ +import is.hail.expr.ir import is.hail.methods._ import is.hail.rvd._ import is.hail.table.{Table, TableSpec} @@ -1578,49 +1579,61 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val globalsBc = globals.broadcast - val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD)) + val asts = Parser.parseAnnotationExprsToAST(expr, ec, Some(Annotation.ENTRY_HEAD)) - val inserterBuilder = new ArrayBuilder[Inserter]() - val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) => - val (s, i) = gsig.structInsert(signature, ids) - inserterBuilder += i - s - } - val inserters = inserterBuilder.result() + val irs = asts.flatMap { case (f, a) => a.toIR().map((f, _)) } - val localNSamples = numCols - val fullRowType = rvRowType - val localColValuesBc = colValuesBc - val localEntriesIndex = entriesIndex + if (irs.length == asts.length) { + val newEntries = ir.InsertFields(ir.Ref("g"), irs) - insertEntries(() => { - val fullRow = new UnsafeRow(fullRowType) - val row = fullRow.deleteField(localEntriesIndex) - (fullRow, row) - })(newEntryType, { case ((fullRow, row), rv, rvb) => - fullRow.set(rv) - val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) + new MatrixTable(hc, MapEntries(ast, newEntries)) + } else { + info("No IR conversion found for annotate_entries. Falling back to AST.") - rvb.startArray(localNSamples) + val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD)) - var i = 0 - while (i < localNSamples) { - val entry = entries(i) - ec.setAll(row, - localColValuesBc.value(i), - entry, - globalsBc.value) + val inserterBuilder = new ArrayBuilder[Inserter]() + val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) => + val (s, i) = gsig.structInsert(signature, ids) + inserterBuilder += i + s + } + val inserters = inserterBuilder.result() - val newEntry = f().zip(inserters) - .foldLeft(entry) { case (ga, (a, inserter)) => - inserter(ga, a) - } - rvb.addAnnotation(newEntryType, newEntry) + val localNSamples = numCols + val fullRowType = rvRowType + val localColValuesBc = colValuesBc + val localEntriesIndex = entriesIndex - i += 1 - } - rvb.endArray() - }) + insertEntries(() => { + val fullRow = new UnsafeRow(fullRowType) + val row = fullRow.deleteField(localEntriesIndex) + (fullRow, row) + })(newEntryType, { case ((fullRow, row), rv, rvb) => + fullRow.set(rv) + val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) + + rvb.startArray(localNSamples) + + var i = 0 + while (i < localNSamples) { + val entry = entries(i) + ec.setAll(row, + localColValuesBc.value(i), + entry, + globalsBc.value) + + val newEntry = f().zip(inserters) + .foldLeft(entry) { case (ga, (a, inserter)) => + inserter(ga, a) + } + rvb.addAnnotation(newEntryType, newEntry) + + i += 1 + } + rvb.endArray() + }) + } } def filterCols(p: (Annotation, Int) => Boolean): MatrixTable = {