From e11fde4399b56c29cf6a4a4c54d2fd4c82e1e047 Mon Sep 17 00:00:00 2001 From: Amanda Wang Date: Tue, 20 Mar 2018 17:37:32 -0400 Subject: [PATCH] MapEntries IR; use in annotate. (#3176) * wip * add MapEntries, use in annotate_entries * fix test * fix * fix and disable * use IR for smol col annotations --- src/main/scala/is/hail/expr/Relational.scala | 70 +++++++++++++++ src/main/scala/is/hail/expr/ir/Compile.scala | 24 ++++++ src/main/scala/is/hail/expr/ir/Infer.scala | 2 +- .../scala/is/hail/variant/MatrixTable.scala | 85 +++++++++++-------- .../scala/is/hail/io/ExportVCFSuite.scala | 4 +- 5 files changed, 146 insertions(+), 39 deletions(-) diff --git a/src/main/scala/is/hail/expr/Relational.scala b/src/main/scala/is/hail/expr/Relational.scala index 9bc86a4fdd21..a77847f01aa3 100644 --- a/src/main/scala/is/hail/expr/Relational.scala +++ b/src/main/scala/is/hail/expr/Relational.scala @@ -379,6 +379,76 @@ case class FilterRows( } } +case class MapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { + + def children: IndexedSeq[BaseIR] = Array(child, newEntries) + + def copy(newChildren: IndexedSeq[BaseIR]): MapEntries = { + assert(newChildren.length == 2) + MapEntries(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) + } + + val newRow = { + val arrayLength = ArrayLen(GetField(Ref("va"), MatrixType.entriesIdentifier)) + val idxEnv = new Env[IR]() + .bind("g", ArrayRef(GetField(Ref("va"), MatrixType.entriesIdentifier), Ref("i"))) + .bind("sa", ArrayRef(Ref("sa"), Ref("i"))) + val entries = ArrayMap(ArrayRange(I32(0), arrayLength, I32(1)), "i", Subst(newEntries, idxEnv)) + InsertFields(Ref("va"), Seq((MatrixType.entriesIdentifier, entries))) + } + + val typ: MatrixType = { + Infer(newRow, None, new Env[Type]() + .bind("global", child.typ.globalType) + .bind("va", child.typ.rvRowType) + .bind("sa", TArray(child.typ.colType)) + ) + child.typ.copy(rvRowType = newRow.typ) + } + + def execute(hc: HailContext): MatrixValue = { + val prev = child.execute(hc) + + val localGlobalsType = typ.globalType + val localColsType = TArray(typ.colType) + val colValuesBc = prev.colValuesBc + val globalsBc = prev.globals.broadcast + + val (rTyp, f) = ir.Compile[Long, Long, Long, Long]( + "global", localGlobalsType, + "va", prev.typ.rvRowType, + "sa", localColsType, + newRow) + assert(rTyp == typ.rvRowType) + + val newRVD = prev.rvd.mapPartitionsPreservesPartitioning(typ.orvdType) { it => + val rvb = new RegionValueBuilder() + val newRV = RegionValue() + val rowF = f() + + it.map { rv => + val region = rv.region + val oldRow = rv.offset + + rvb.set(region) + rvb.start(localGlobalsType) + rvb.addAnnotation(localGlobalsType, globalsBc.value) + val globals = rvb.end() + + rvb.start(localColsType) + rvb.addAnnotation(localColsType, colValuesBc.value) + val cols = rvb.end() + + val off = rowF(region, globals, false, oldRow, false, cols, false) + + newRV.set(region, off) + newRV + } + } + prev.copy(typ = typ, rvd = newRVD) + } +} + case class TableValue(typ: TableType, globals: BroadcastValue, rvd: RVD) { def rdd: RDD[Row] = { val localRowType = typ.rowType diff --git a/src/main/scala/is/hail/expr/ir/Compile.scala b/src/main/scala/is/hail/expr/ir/Compile.scala index 673a7429758e..0fd1405166a9 100644 --- a/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/src/main/scala/is/hail/expr/ir/Compile.scala @@ -45,6 +45,30 @@ object Compile { apply[AsmFunction5[Region, T0, Boolean, T1, Boolean, R], R](Seq((name0, typ0, classTag[T0]), (name1, typ1, classTag[T1])), body) } + def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, T2: TypeInfo : ClassTag, R: TypeInfo : ClassTag]( + name0: String, + typ0: Type, + name1: String, + typ1: Type, + name2: String, + typ2: Type, + body: IR): (Type, () => AsmFunction7[Region, T0, Boolean, T1, Boolean, T2, Boolean, R]) = { + assert(TypeToIRIntermediateClassTag(typ0) == classTag[T0]) + assert(TypeToIRIntermediateClassTag(typ1) == classTag[T1]) + assert(TypeToIRIntermediateClassTag(typ2) == classTag[T2]) + val fb = FunctionBuilder.functionBuilder[Region, T0, Boolean, T1, Boolean, T2, Boolean, R] + var e = body + val env = new Env[IR]() + .bind(name0, In(0, typ0)) + .bind(name1, In(1, typ1)) + .bind(name2, In(2, typ2)) + e = Subst(e, env) + Infer(e) + assert(TypeToIRIntermediateClassTag(e.typ) == classTag[R]) + Emit(e, fb) + (e.typ, fb.result()) + } + def apply[T0: TypeInfo : ClassTag, T1: TypeInfo : ClassTag, T2: TypeInfo : ClassTag, T3: TypeInfo : ClassTag, T4: TypeInfo : ClassTag, T5: TypeInfo : ClassTag, R: TypeInfo : ClassTag]( diff --git a/src/main/scala/is/hail/expr/ir/Infer.scala b/src/main/scala/is/hail/expr/ir/Infer.scala index ed248a1f418a..4465c087e2aa 100644 --- a/src/main/scala/is/hail/expr/ir/Infer.scala +++ b/src/main/scala/is/hail/expr/ir/Infer.scala @@ -156,7 +156,7 @@ object Infer { case x@GetField(o, name, _) => infer(o) val t = coerce[TStruct](o.typ) - assert(t.index(name).nonEmpty) + assert(t.index(name).nonEmpty, s"$name not in $t") x.typ = -t.field(name).typ case GetFieldMissingness(o, name) => infer(o) diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index 56d894198a2d..479d2d3361bc 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -4,6 +4,7 @@ import is.hail.annotations._ import is.hail.check.Gen import is.hail.linalg._ import is.hail.expr._ +import is.hail.expr.ir import is.hail.methods._ import is.hail.rvd._ import is.hail.table.{Table, TableSpec} @@ -1578,49 +1579,61 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val globalsBc = globals.broadcast - val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD)) + val asts = Parser.parseAnnotationExprsToAST(expr, ec, Some(Annotation.ENTRY_HEAD)) - val inserterBuilder = new ArrayBuilder[Inserter]() - val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) => - val (s, i) = gsig.structInsert(signature, ids) - inserterBuilder += i - s - } - val inserters = inserterBuilder.result() + val irs = asts.flatMap { case (f, a) => a.toIR().map((f, _)) } - val localNSamples = numCols - val fullRowType = rvRowType - val localColValuesBc = colValuesBc - val localEntriesIndex = entriesIndex + val colValuesIsSmall = colType.size == 1 && colType.types.head.isOfType(TString()) + if (irs.length == asts.length && colValuesIsSmall) { + val newEntries = ir.InsertFields(ir.Ref("g"), irs) - insertEntries(() => { - val fullRow = new UnsafeRow(fullRowType) - val row = fullRow.deleteField(localEntriesIndex) - (fullRow, row) - })(newEntryType, { case ((fullRow, row), rv, rvb) => - fullRow.set(rv) - val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) + new MatrixTable(hc, MapEntries(ast, newEntries)) + } else { - rvb.startArray(localNSamples) + val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD)) - var i = 0 - while (i < localNSamples) { - val entry = entries(i) - ec.setAll(row, - localColValuesBc.value(i), - entry, - globalsBc.value) + val inserterBuilder = new ArrayBuilder[Inserter]() + val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) => + val (s, i) = gsig.structInsert(signature, ids) + inserterBuilder += i + s + } + val inserters = inserterBuilder.result() - val newEntry = f().zip(inserters) - .foldLeft(entry) { case (ga, (a, inserter)) => - inserter(ga, a) - } - rvb.addAnnotation(newEntryType, newEntry) + val localNSamples = numCols + val fullRowType = rvRowType + val localColValuesBc = colValuesBc + val localEntriesIndex = entriesIndex - i += 1 - } - rvb.endArray() - }) + insertEntries(() => { + val fullRow = new UnsafeRow(fullRowType) + val row = fullRow.deleteField(localEntriesIndex) + (fullRow, row) + })(newEntryType, { case ((fullRow, row), rv, rvb) => + fullRow.set(rv) + val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) + + rvb.startArray(localNSamples) + + var i = 0 + while (i < localNSamples) { + val entry = entries(i) + ec.setAll(row, + localColValuesBc.value(i), + entry, + globalsBc.value) + + val newEntry = f().zip(inserters) + .foldLeft(entry) { case (ga, (a, inserter)) => + inserter(ga, a) + } + rvb.addAnnotation(newEntryType, newEntry) + + i += 1 + } + rvb.endArray() + }) + } } def filterCols(p: (Annotation, Int) => Boolean): MatrixTable = { diff --git a/src/test/scala/is/hail/io/ExportVCFSuite.scala b/src/test/scala/is/hail/io/ExportVCFSuite.scala index b1da4c2235df..0cd0ad9bac33 100644 --- a/src/test/scala/is/hail/io/ExportVCFSuite.scala +++ b/src/test/scala/is/hail/io/ExportVCFSuite.scala @@ -210,13 +210,13 @@ class ExportVCFSuite extends SparkSuite { TestUtils.interceptFatal("Invalid type for format field 'BOOL'. Found 'bool'.") { ExportVCF(vds - .annotateEntriesExpr("g = {BOOL: true}"), + .annotateEntriesExpr("g.BOOL = true"), out) } TestUtils.interceptFatal("Invalid type for format field 'AA'.") { ExportVCF(vds - .annotateEntriesExpr("g = {AA: [[0]]}"), + .annotateEntriesExpr("g.AA = [[0]]"), out) } }