From 2bd08ee4bd85ea65e6777c3cdcb4187e2ff89307 Mon Sep 17 00:00:00 2001 From: wang Date: Mon, 19 Mar 2018 19:35:02 -0400 Subject: [PATCH 1/9] rewrite MatrixTable.selectEntries to take single expr --- .../scala/is/hail/variant/MatrixTable.scala | 87 ++++++++----------- .../scala/is/hail/io/HardCallsSuite.scala | 2 +- .../scala/is/hail/variant/vsm/VSMSuite.scala | 2 +- 3 files changed, 38 insertions(+), 53 deletions(-) diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index 479d2d3361b..57748d19b9f 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -1337,60 +1337,45 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { rvd = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType)(mapPartitionsF)) } - def selectEntries(selectExprs: java.util.ArrayList[String]): MatrixTable = selectEntries(selectExprs.asScala.toArray: _*) - - def selectEntries(exprs: String*): MatrixTable = { + def selectEntries(expr: String): MatrixTable = { val ec = entryEC - val globalsBc = globals.broadcast - - val (paths, types, f) = Parser.parseSelectExprs(exprs.toArray, ec) - val topLevelFields = mutable.Set.empty[String] - val finalNames = paths.map { - // assignment - case Left(name) => name - case Right(path) => - assert(path.head == Annotation.ENTRY_HEAD) - path match { - case List(Annotation.ENTRY_HEAD, name) => topLevelFields += name - } - path.last + val entryAST = Parser.parseToAST(expr, ec) + assert(entryAST.`type`.isInstanceOf[TStruct]) + + entryAST.toIR() match { + case Some(ir) => + new MatrixTable(hc, MapEntries(ast, ir)) + case None => + val (t, f) = Parser.parseExpr(expr, ec) + val newEntryType = t.asInstanceOf[TStruct] + val globalsBc = globals.broadcast + val fullRowType = rvRowType + val localEntriesIndex = entriesIndex + val localNCols = numCols + val localColValuesBc = colValuesBc + + insertEntries(() => { + val fullRow = new UnsafeRow(fullRowType) + val row = fullRow.deleteField(localEntriesIndex) + ec.set(0, globalsBc.value) + ec.set(1, row) + fullRow -> row + })(newEntryType, { case ((fullRow, row), rv, rvb) => + fullRow.set(rv) + val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) + rvb.startArray(localNCols) + var i = 0 + while (i < localNCols) { + val entry = entries(i) + ec.set(2, localColValuesBc.value(i)) + ec.set(3, entry) + rvb.addAnnotation(newEntryType, f()) + i += 1 + } + rvb.endArray() + }) } - assert(finalNames.areDistinct()) - - val newEntryType = TStruct(finalNames.zip(types): _*) - val fullRowType = rvRowType - val localEntriesIndex = entriesIndex - val localNCols = numCols - val localColValuesBc = colValuesBc - - insertEntries(() => { - val fullRow = new UnsafeRow(fullRowType) - val row = fullRow.deleteField(localEntriesIndex) - ec.set(0, globalsBc.value) - ec.set(1, row) - fullRow -> row - })(newEntryType, { case ((fullRow, row), rv, rvb) => - fullRow.set(rv) - val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) - rvb.startArray(localNCols) - var i = 0 - while (i < localNCols) { - val entry = entries(i) - ec.set(2, localColValuesBc.value(i)) - ec.set(3, entry) - val results = f() - var j = 0 - rvb.startStruct() - while (j < types.length) { - rvb.addAnnotation(types(j), results(j)) - j += 1 - } - rvb.endStruct() - i += 1 - } - rvb.endArray() - }) } diff --git a/src/test/scala/is/hail/io/HardCallsSuite.scala b/src/test/scala/is/hail/io/HardCallsSuite.scala index 824b0afceaa..6a272a6c491 100644 --- a/src/test/scala/is/hail/io/HardCallsSuite.scala +++ b/src/test/scala/is/hail/io/HardCallsSuite.scala @@ -9,7 +9,7 @@ import org.testng.annotations.Test class HardCallsSuite extends SparkSuite { @Test def test() { val p = forAll(MatrixTable.gen(hc, VSMSubgen.random)) { vds => - val hard = vds.selectEntries("g.GT") + val hard = vds.selectEntries("{GT: g.GT}") assert(hard.queryEntries("AGG.map(g => g.GT).counter()") == vds.queryEntries("AGG.map(g => g.GT).counter()")) diff --git a/src/test/scala/is/hail/variant/vsm/VSMSuite.scala b/src/test/scala/is/hail/variant/vsm/VSMSuite.scala index f5fc0f54c0e..a6679d577cb 100644 --- a/src/test/scala/is/hail/variant/vsm/VSMSuite.scala +++ b/src/test/scala/is/hail/variant/vsm/VSMSuite.scala @@ -154,7 +154,7 @@ class VSMSuite extends SparkSuite { .indexRows("rowIdx") .indexCols("colIdx") - mt.selectEntries("x = (g.GT.nNonRefAlleles().toInt64 + va.rowIdx + sa.colIdx.toInt64 + 1L).toFloat64") + mt.selectEntries("{x: (g.GT.nNonRefAlleles().toInt64 + va.rowIdx + sa.colIdx.toInt64 + 1L).toFloat64}") .writeBlockMatrix(dirname, "x", blockSize) val data = mt.entriesTable() From 5bbb3b53e1b8029da9b1f354e8125b7b794d5152 Mon Sep 17 00:00:00 2001 From: wang Date: Mon, 19 Mar 2018 20:02:48 -0400 Subject: [PATCH 2/9] remove annotateEntriesExpr and dropEntries from scala MatrixTable --- .../scala/is/hail/variant/MatrixTable.scala | 114 +----------------- .../scala/is/hail/io/ExportVCFSuite.scala | 15 +-- .../scala/is/hail/utils/RichMatrixTable.scala | 12 ++ 3 files changed, 21 insertions(+), 120 deletions(-) diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index 57748d19b9f..a06c8a8980a 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -1338,6 +1338,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } def selectEntries(expr: String): MatrixTable = { + println(expr) val ec = entryEC val entryAST = Parser.parseToAST(expr, ec) @@ -1378,52 +1379,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } } - - def dropEntries(fields: java.util.ArrayList[String]): MatrixTable = dropEntries(fields.asScala.toArray: _*) - - def dropEntries(fields: String*): MatrixTable = { - if (fields.isEmpty) - return this - assert(fields.areDistinct()) - val dropSet = fields.toSet - val allEntryFields = entryType.fieldNames.toSet - assert(fields.forall(allEntryFields.contains)) - - val keepIndices = entryType.fields - .filter(f => !dropSet.contains(f.name)) - .map(f => f.index) - .toArray - val newEntryType = TStruct(keepIndices.map(entryType.fields(_)).map(f => f.name -> f.typ): _*) - - val fullRowType = rvRowType - val localEntriesIndex = entriesIndex - val localEntriesType = matrixType.entryArrayType - val localEntryType = entryType - val localNCols = numCols - // FIXME: replace with physical type - insertEntries(noOp)(newEntryType, { case (_, rv, rvb) => - val entriesOffset = fullRowType.loadField(rv, localEntriesIndex) - rvb.startArray(localNCols) - var i = 0 - while (i < localNCols) { - if (localEntriesType.isElementMissing(rv.region, entriesOffset, i)) - rvb.setMissing() - else { - val eltOffset = localEntriesType.loadElement(rv.region, entriesOffset, localNCols, i) - rvb.startStruct() - var j = 0 - while (j < keepIndices.length) { - rvb.addField(localEntryType, rv.region, eltOffset, keepIndices(j)) - j += 1 - } - rvb.endStruct() - } - i += 1 - } - rvb.endArray() - }) - } - def nPartitions: Int = rvd.partitions.length def annotateRowsVDS(right: MatrixTable, root: String): MatrixTable = { @@ -1554,73 +1509,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { new Table(hc, TableLiteral(TableValue(TableType(rvRowType.rename(m), rowKey, globalType), globals, rvd))) } - def annotateEntriesExpr(expr: String): MatrixTable = { - val symTab = Map( - "va" -> (0, rowType), - "sa" -> (1, colType), - "g" -> (2, entryType), - "global" -> (3, globalType)) - val ec = EvalContext(symTab) - - val globalsBc = globals.broadcast - - val asts = Parser.parseAnnotationExprsToAST(expr, ec, Some(Annotation.ENTRY_HEAD)) - - val irs = asts.flatMap { case (f, a) => a.toIR().map((f, _)) } - - val colValuesIsSmall = colType.size == 1 && colType.types.head.isOfType(TString()) - if (irs.length == asts.length && colValuesIsSmall) { - val newEntries = ir.InsertFields(ir.Ref("g"), irs) - - new MatrixTable(hc, MapEntries(ast, newEntries)) - } else { - - val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD)) - - val inserterBuilder = new ArrayBuilder[Inserter]() - val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) => - val (s, i) = gsig.structInsert(signature, ids) - inserterBuilder += i - s - } - val inserters = inserterBuilder.result() - - val localNSamples = numCols - val fullRowType = rvRowType - val localColValuesBc = colValuesBc - val localEntriesIndex = entriesIndex - - insertEntries(() => { - val fullRow = new UnsafeRow(fullRowType) - val row = fullRow.deleteField(localEntriesIndex) - (fullRow, row) - })(newEntryType, { case ((fullRow, row), rv, rvb) => - fullRow.set(rv) - val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) - - rvb.startArray(localNSamples) - - var i = 0 - while (i < localNSamples) { - val entry = entries(i) - ec.setAll(row, - localColValuesBc.value(i), - entry, - globalsBc.value) - - val newEntry = f().zip(inserters) - .foldLeft(entry) { case (ga, (a, inserter)) => - inserter(ga, a) - } - rvb.addAnnotation(newEntryType, newEntry) - - i += 1 - } - rvb.endArray() - }) - } - } - def filterCols(p: (Annotation, Int) => Boolean): MatrixTable = { copyAST(ast = MatrixLiteral(matrixType, value.filterCols(p))) } diff --git a/src/test/scala/is/hail/io/ExportVCFSuite.scala b/src/test/scala/is/hail/io/ExportVCFSuite.scala index 0cd0ad9bac3..b906afff185 100644 --- a/src/test/scala/is/hail/io/ExportVCFSuite.scala +++ b/src/test/scala/is/hail/io/ExportVCFSuite.scala @@ -9,6 +9,7 @@ import is.hail.io.vcf.ExportVCF import is.hail.utils._ import is.hail.variant.{MatrixTable, VSMSubgen, Variant} import org.testng.annotations.Test +import is.hail.testUtils._ import scala.io.Source import scala.language.postfixOps @@ -210,13 +211,13 @@ class ExportVCFSuite extends SparkSuite { TestUtils.interceptFatal("Invalid type for format field 'BOOL'. Found 'bool'.") { ExportVCF(vds - .annotateEntriesExpr("g.BOOL = true"), + .annotateEntriesExpr(("BOOL","true")), out) } TestUtils.interceptFatal("Invalid type for format field 'AA'.") { ExportVCF(vds - .annotateEntriesExpr("g.AA = [[0]]"), + .annotateEntriesExpr(("AA", "[[0]]")), out) } } @@ -281,15 +282,15 @@ class ExportVCFSuite extends SparkSuite { val callArrayFields = schema.fields.filter(fd => fd.typ == TArray(TCall())).map(_.name) val callSetFields = schema.fields.filter(fd => fd.typ == TSet(TCall())).map(_.name) - val callAnnots = callFields.map(name => s"g.$name = let c = g.$name in " + - s"if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c") + val callAnnots = callFields.map(name => (name, s"let c = g.$name in " + + s"if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c")) - val callContainerAnnots = (callArrayFields ++ callSetFields).map(name => s"g.$name = " + - s"g.$name.map(c => if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c)") + val callContainerAnnots = (callArrayFields ++ callSetFields).map(name => (name, + s"g.$name.map(c => if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c)")) val annots = callAnnots ++ callContainerAnnots - val vsmAnn = if (annots.nonEmpty) vsm.annotateEntriesExpr(annots.mkString(",")) else vsm + val vsmAnn = if (annots.nonEmpty) vsm.annotateEntriesExpr(annots: _*) else vsm hadoopConf.delete(out, recursive = true) ExportVCF(vsmAnn, out) diff --git a/src/test/scala/is/hail/utils/RichMatrixTable.scala b/src/test/scala/is/hail/utils/RichMatrixTable.scala index 7ee3035e448..3ebc34af3c7 100644 --- a/src/test/scala/is/hail/utils/RichMatrixTable.scala +++ b/src/test/scala/is/hail/utils/RichMatrixTable.scala @@ -43,6 +43,18 @@ class RichMatrixTable(vsm: MatrixTable) { vsm.annotateCols(t, i) { case (_, i) => annotation(sampleIds(i)) } } + def annotateEntriesExpr(exprs: (String, String)*): MatrixTable = { + val exprMap = exprs.toMap + val expr = (vsm.matrixType.entryType.fieldNames ++ exprMap.keys.filter(!vsm.matrixType.entryType.fieldNames.contains(_))).map { n => + if (exprMap.keySet.contains(n)) + s"$n: ${ exprMap(n) }" + else + s"$n: g.`$n`" + }.mkString(", ") + + vsm.selectEntries(s"{$expr}") + } + def querySA(code: String): (Type, Querier) = { val st = Map(Annotation.COL_HEAD -> (0, vsm.colType)) val ec = EvalContext(st) From 76fbaeaea690a471dbce45ca6f225c40c74e9408 Mon Sep 17 00:00:00 2001 From: wang Date: Mon, 19 Mar 2018 23:11:28 -0400 Subject: [PATCH 3/9] python wip --- python/hail/matrixtable.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/python/hail/matrixtable.py b/python/hail/matrixtable.py index 19e0b0a7846..d29d200d6b5 100644 --- a/python/hail/matrixtable.py +++ b/python/hail/matrixtable.py @@ -940,16 +940,14 @@ def annotate_entries(self, **named_exprs: NamedExprs) -> 'MatrixTable': :class:`.MatrixTable` Matrix table with new row-and-column-indexed field(s). """ - exprs = [] + + named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} - base, cleanup = self._process_joins(*named_exprs.values()) for k, v in named_exprs.items(): - analyze('MatrixTable.annotate_entries', v, self._entry_indices) - exprs.append('g.{k} = {v}'.format(k=escape_id(k), v=v._ast.to_hql())) check_collisions(self._fields, k, self._entry_indices) - m = MatrixTable(base._jvds.annotateEntriesExpr(",\n".join(exprs))) - return cleanup(m) + + return self._select_entries("MatrixTable.annotate_entries", self.entries.annotate(**named_exprs)) def select_globals(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'MatrixTable': """Select existing global fields or create new fields by name, dropping the rest. @@ -1188,14 +1186,10 @@ def select_entries(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'Ma """ exprs = [to_expr(e) if not isinstance(e, str) else self[e] for e in exprs] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} - strs = [] - all_exprs = [] - base, cleanup = self._process_joins(*itertools.chain(exprs, named_exprs.values())) + assignments = OrderedDict() - ids = [] for e in exprs: all_exprs.append(e) - analyze('MatrixTable.select_entries', e, self._entry_indices) if not e._indices == self._entry_indices: # detect row or col fields here raise ExpressionException("method 'select_entries' parameter 'exprs' expects entry-indexed fields," @@ -1205,14 +1199,10 @@ def select_entries(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'Ma strs.append(e._ast.to_hql()) ids.append(e._ast.name) for k, e in named_exprs.items(): - all_exprs.append(e) - analyze('MatrixTable.select_entries', e, self._entry_indices) check_collisions(self._fields, k, self._entry_indices) - strs.append('{} = {}'.format(escape_id(k), e._ast.to_hql())) - ids.append(k) - check_field_uniqueness(ids) - m = MatrixTable(base._jvds.selectEntries(strs)) - return cleanup(m) + assignments[k] = e + check_field_uniqueness(assignments.keys()) + return self._select_entries("MatrixTable.select_entries", hl.struct(**assignments)) @typecheck_method(exprs=oneof(str, Expression)) def drop(self, *exprs: FieldRefArgs) -> 'MatrixTable': @@ -2613,6 +2603,12 @@ def add_col_index(self, name: str = 'col_idx') -> 'MatrixTable': def _same(self, other, tolerance=1e-6): return self._jvds.same(other._jvds, tolerance) + @typecheck_method(caller=str, s=expr_struct()) + def _select_entries(self, caller, s): + base, cleanup = self._process_joins(s) + analyze(caller, s, self._row_indices) + return cleanup(MatrixTable(base._jt.select(s._ast.to_hql()))) + @typecheck(datasets=matrix_table_type) def union_rows(*datasets: Tuple['MatrixTable']) -> 'MatrixTable': """Take the union of dataset rows. From 813958e2046892ff148c91ba7430c4c3e72975fb Mon Sep 17 00:00:00 2001 From: wang Date: Tue, 20 Mar 2018 17:51:56 -0400 Subject: [PATCH 4/9] wip --- python/hail/matrixtable.py | 27 +++++++++-------- python/hail/table.py | 9 ++++-- python/hail/tests/test_api.py | 8 +++-- .../scala/is/hail/variant/MatrixTable.scala | 29 +++++++++++++++++-- .../scala/is/hail/expr/TableIRSuite.scala | 9 ++++++ 5 files changed, 61 insertions(+), 21 deletions(-) diff --git a/python/hail/matrixtable.py b/python/hail/matrixtable.py index d29d200d6b5..ef055daae83 100644 --- a/python/hail/matrixtable.py +++ b/python/hail/matrixtable.py @@ -1,5 +1,6 @@ import itertools from typing import * +from collections import OrderedDict import hail import hail as hl @@ -947,7 +948,7 @@ def annotate_entries(self, **named_exprs: NamedExprs) -> 'MatrixTable': for k, v in named_exprs.items(): check_collisions(self._fields, k, self._entry_indices) - return self._select_entries("MatrixTable.annotate_entries", self.entries.annotate(**named_exprs)) + return self._select_entries("MatrixTable.annotate_entries", self.entry.annotate(**named_exprs)) def select_globals(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'MatrixTable': """Select existing global fields or create new fields by name, dropping the rest. @@ -1280,10 +1281,11 @@ def drop(self, *exprs: FieldRefArgs) -> 'MatrixTable': new_col_fields = [f for f in m.col if f not in fields_to_drop] m = m.select_cols(*new_col_fields) - entry_fields = [x for x in fields_to_drop if self._fields[x]._indices == self._entry_indices] - if any(self._fields[field]._indices == self._entry_indices for field in fields_to_drop): - # need to drop entry fields - m = MatrixTable(m._jvds.dropEntries(entry_fields)) + + + entry_fields = [field for field in fields_to_drop if self._fields[field]._indices == self._entry_indices] + if entry_fields: + m = m._select_entries("MatrixTable.drop_entries", m.entry.drop(*entry_fields)) return m @@ -2182,6 +2184,7 @@ def index_entries(self, row_exprs: Tuple[Expression], col_exprs: Tuple[Expressio uids.append(row_uid) col_uid = Env.get_uid() uids.append(col_uid) + uids.append(col_uid) def joiner(left: MatrixTable): localized = Table(self._jvds.localizeEntries(row_uid)) @@ -2189,7 +2192,6 @@ def joiner(left: MatrixTable): src_cols_indexed = src_cols_indexed.annotate(**{col_uid: hl.int32(src_cols_indexed[col_uid])}) left = left._annotate_all(row_exprs = {row_uid: localized.index(*row_exprs)[row_uid]}, col_exprs = {col_uid: src_cols_indexed.index(*col_exprs)[col_uid]}) - return left.annotate_entries(**{uid: left[row_uid][left[col_uid]]}) return construct_expr(Select(TopLevelReference('g', self._entry_indices), uid), @@ -2228,12 +2230,9 @@ def _annotate_all(self, check_collisions(self._fields, k, self._col_indices) jmt = jmt.annotateColsExpr(",\n".join(col_strs)) if entry_exprs: - entry_strs = [] - for k, v in entry_exprs.items(): - analyze('MatrixTable.annotate_entries', v, self._entry_indices) - entry_strs.append('g.{k} = {v}'.format(k=escape_id(k), v=v._ast.to_hql())) - check_collisions(self._fields, k, self._entry_indices) - jmt = jmt.annotateEntriesExpr(",\n".join(entry_strs)) + entry_struct = self.entry.annotate(**entry_exprs) + analyze("MatrixTable.annotate_entries", entry_struct, self._entry_indices) + jmt = jmt.selectEntries(entry_struct._ast.to_hql()) if global_exprs: global_strs = [] for k, v in global_exprs.items(): @@ -2606,8 +2605,8 @@ def _same(self, other, tolerance=1e-6): @typecheck_method(caller=str, s=expr_struct()) def _select_entries(self, caller, s): base, cleanup = self._process_joins(s) - analyze(caller, s, self._row_indices) - return cleanup(MatrixTable(base._jt.select(s._ast.to_hql()))) + analyze(caller, s, self._entry_indices) + return cleanup(MatrixTable(base._jvds.selectEntries(s._ast.to_hql()))) @typecheck(datasets=matrix_table_type) def union_rows(*datasets: Tuple['MatrixTable']) -> 'MatrixTable': diff --git a/python/hail/table.py b/python/hail/table.py index 2a6e98c3011..09d9a1b8bc9 100644 --- a/python/hail/table.py +++ b/python/hail/table.py @@ -1192,8 +1192,11 @@ def joiner(left): if is_row_key or is_partition_key: # no vds_key (way faster) def joiner(left): - return MatrixTable(left._jvds.annotateRowsTable( + left.annotate_rows(es=hl.agg.collect(left.x)).rows().show() + rt = MatrixTable(left._jvds.annotateRowsTable( right._jt, uid, False)) + rt.annotate_rows(es=hl.agg.collect(rt.x)).rows().show() + return rt return construct_expr(Select(TopLevelReference('va', src._row_indices), uid), new_schema, indices, aggregations, joins.push(Join(joiner, [uid], uid, exprs))) @@ -1237,9 +1240,9 @@ def joiner(left): if len(exprs) == len(src.col_key) and all([ exprs[i] is src[list(src.col_key)[i]] for i in range(len(exprs))]): # no vds_key (faster) - joiner = lambda left: MatrixTable(left._jvds.annotateColsTable( - right._jt, None, uid, False)) + joiner = lambda left: MatrixTable(left._jvds.annotateColsTable(right._jt, None, uid, False)) else: + # use vds_key # use vds_key joiner = lambda left: MatrixTable(left._jvds.annotateColsTable( right._jt, [e._ast.to_hql() for e in exprs], uid, False)) diff --git a/python/hail/tests/test_api.py b/python/hail/tests/test_api.py index 8c6e5c4b883..1e4d2a28751 100644 --- a/python/hail/tests/test_api.py +++ b/python/hail/tests/test_api.py @@ -799,13 +799,17 @@ def test_computed_key_join_3(self): hl.is_missing(rt['value'])))) def test_entry_join_self(self): - mt1 = hl.utils.range_matrix_table(10, 10, n_partitions=4) - mt1 = mt1.annotate_entries(x = mt1.row_idx + mt1.col_idx) + mt1 = hl.utils.range_matrix_table(6, 6, n_partitions=2) + mt1 = mt1.annotate_entries(x = 100*mt1.row_idx + 10*mt1.col_idx) + mt1.entries().show(20) self.assertEqual(mt1[mt1.row_idx, mt1.col_idx].dtype, mt1.entry.dtype) mt_join = mt1.annotate_entries(x2 = mt1[mt1.row_idx, mt1.col_idx].x) mt_join_entries = mt_join.entries() + + mt_join_entries.show(20) + self.assertTrue(mt_join_entries.all(mt_join_entries.x == mt_join_entries.x2)) def test_entry_join_const(self): diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index a06c8a8980a..b993336e712 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -976,6 +976,21 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } def orderedRVDLeftJoinDistinctAndInsert(right: OrderedRVD, root: String, product: Boolean): MatrixTable = { + println(s"unsafe row rdd") + println(this.unsafeRowRDD().collect().mkString("\n")) + + val localRowType = right.rowType + val urrdd = right.rdd.mapPartitions{ it => + val ur = new UnsafeRow(localRowType) + it.map { rv => + ur.set(rv) + ur.copy() + } + } + + println(s"unsafe row rdd2") + println(urrdd.collect().mkString("\n")) + assert(!rowKey.contains(root)) assert(right.typ.pkType.types.map(_.deepOptional()) .sameElements(rowPartitionKeyTypes.map(_.deepOptional()))) @@ -1005,7 +1020,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val newMatrixType = matrixType.copy(rvRowType = newRVType) val intermediateMatrixType = newMatrixType.copy(rowKey = newMatrixType.rowPartitionKey) - copyMT(matrixType = newMatrixType, + val res = copyMT(matrixType = newMatrixType, rvd = OrderedRVD( newMatrixType.orvdType, leftRVD.partitioner, @@ -1017,6 +1032,11 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { it.map { jrv => val lrv = jrv.rvLeft + println(s"leftRV: ${jrv.rvLeft}, rightRV: ${jrv.rvRight}") + if (jrv.rvLeft != null) + println(s"left: ${new UnsafeRow(leftRowType, jrv.rvLeft)}") + if (jrv.rvRight != null) + println(s"right: ${new UnsafeRow(rightRowType, jrv.rvRight)}") rvb.set(lrv.region) rvb.start(newRVType) ins(lrv.region, lrv.offset, rvb, @@ -1046,6 +1066,9 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } }) ) + println(s"result") + println(res.unsafeRowRDD().collect().mkString("\n")) + res } private def annotateRowsIntervalTable(kt: Table, root: String, product: Boolean): MatrixTable = { @@ -1134,6 +1157,8 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } def annotateRowsTable(kt: Table, root: String, product: Boolean = false): MatrixTable = { + println(kt.showString()) + assert(!rowKey.contains(root)) val keyTypes = kt.keyFields.map(_.typ) @@ -1338,7 +1363,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } def selectEntries(expr: String): MatrixTable = { - println(expr) val ec = entryEC val entryAST = Parser.parseToAST(expr, ec) @@ -1348,6 +1372,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { case Some(ir) => new MatrixTable(hc, MapEntries(ast, ir)) case None => + info(s"Using AST") val (t, f) = Parser.parseExpr(expr, ec) val newEntryType = t.asInstanceOf[TStruct] val globalsBc = globals.broadcast diff --git a/src/test/scala/is/hail/expr/TableIRSuite.scala b/src/test/scala/is/hail/expr/TableIRSuite.scala index e1d9bf6202c..773dca3bfc9 100644 --- a/src/test/scala/is/hail/expr/TableIRSuite.scala +++ b/src/test/scala/is/hail/expr/TableIRSuite.scala @@ -3,6 +3,8 @@ package is.hail.expr import is.hail.SparkSuite import is.hail.expr.types._ import is.hail.table.Table +import is.hail.variant.MatrixTable +import is.hail.testUtils._ import org.apache.spark.sql.Row import org.testng.annotations.Test @@ -31,4 +33,11 @@ class TableIRSuite extends SparkSuite { ir.ApplyBinaryPrimOp(ir.EQ(), ir.GetField(ir.Ref("row"), "field1"), ir.GetField(ir.Ref("global"), "g")))) assert(kt2.count() == 1) } + + @Test def testAnnotateRowsTable() { + val mt = MatrixTable.range(hc, 6, 6, Some(4)).annotateEntriesExpr(("x", "va.row_idx * 10 + sa.col_idx")) + val entries = mt.localizeEntries("x") + val joined = mt.annotateRowsTable(entries, "x2", false) + println(joined.unsafeRowRDD().collect().mkString("\n")) + } } From 25db5113788eeb9534fec29921d121750aabcfba Mon Sep 17 00:00:00 2001 From: wang Date: Tue, 20 Mar 2018 18:15:56 -0400 Subject: [PATCH 5/9] wip --- python/hail/matrixtable.py | 1 - python/hail/table.py | 5 ++--- src/main/scala/is/hail/variant/MatrixTable.scala | 4 +++- src/test/scala/is/hail/utils/RichMatrixTable.scala | 13 ++----------- 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/python/hail/matrixtable.py b/python/hail/matrixtable.py index ef055daae83..32d1785b590 100644 --- a/python/hail/matrixtable.py +++ b/python/hail/matrixtable.py @@ -2184,7 +2184,6 @@ def index_entries(self, row_exprs: Tuple[Expression], col_exprs: Tuple[Expressio uids.append(row_uid) col_uid = Env.get_uid() uids.append(col_uid) - uids.append(col_uid) def joiner(left: MatrixTable): localized = Table(self._jvds.localizeEntries(row_uid)) diff --git a/python/hail/table.py b/python/hail/table.py index 09d9a1b8bc9..0170e6b8062 100644 --- a/python/hail/table.py +++ b/python/hail/table.py @@ -1192,10 +1192,10 @@ def joiner(left): if is_row_key or is_partition_key: # no vds_key (way faster) def joiner(left): - left.annotate_rows(es=hl.agg.collect(left.x)).rows().show() + #left.annotate_rows(es=hl.agg.collect(left.x)).rows().show() rt = MatrixTable(left._jvds.annotateRowsTable( right._jt, uid, False)) - rt.annotate_rows(es=hl.agg.collect(rt.x)).rows().show() + #rt.annotate_rows(es=hl.agg.collect(rt.x)).rows().show() return rt return construct_expr(Select(TopLevelReference('va', src._row_indices), uid), new_schema, @@ -1242,7 +1242,6 @@ def joiner(left): # no vds_key (faster) joiner = lambda left: MatrixTable(left._jvds.annotateColsTable(right._jt, None, uid, False)) else: - # use vds_key # use vds_key joiner = lambda left: MatrixTable(left._jvds.annotateColsTable( right._jt, [e._ast.to_hql() for e in exprs], uid, False)) diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index b993336e712..c52448b3b4f 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -1363,6 +1363,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } def selectEntries(expr: String): MatrixTable = { + println(expr) val ec = entryEC val entryAST = Parser.parseToAST(expr, ec) @@ -1396,7 +1397,8 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val entry = entries(i) ec.set(2, localColValuesBc.value(i)) ec.set(3, entry) - rvb.addAnnotation(newEntryType, f()) + val result = f() + rvb.addAnnotation(newEntryType, result) i += 1 } rvb.endArray() diff --git a/src/test/scala/is/hail/utils/RichMatrixTable.scala b/src/test/scala/is/hail/utils/RichMatrixTable.scala index 3ebc34af3c7..e462e5f7b3f 100644 --- a/src/test/scala/is/hail/utils/RichMatrixTable.scala +++ b/src/test/scala/is/hail/utils/RichMatrixTable.scala @@ -43,17 +43,8 @@ class RichMatrixTable(vsm: MatrixTable) { vsm.annotateCols(t, i) { case (_, i) => annotation(sampleIds(i)) } } - def annotateEntriesExpr(exprs: (String, String)*): MatrixTable = { - val exprMap = exprs.toMap - val expr = (vsm.matrixType.entryType.fieldNames ++ exprMap.keys.filter(!vsm.matrixType.entryType.fieldNames.contains(_))).map { n => - if (exprMap.keySet.contains(n)) - s"$n: ${ exprMap(n) }" - else - s"$n: g.`$n`" - }.mkString(", ") - - vsm.selectEntries(s"{$expr}") - } + def annotateEntriesExpr(exprs: (String, String)*): MatrixTable = + vsm.selectEntries(s"annotate(g, {${ exprs.map { case (n, e) => s"`$n`: $e" }.mkString(",") }})") def querySA(code: String): (Type, Querier) = { val st = Map(Annotation.COL_HEAD -> (0, vsm.colType)) From aaa1b3578b67e82fa365fb3da73413f45cd26f08 Mon Sep 17 00:00:00 2001 From: wang Date: Wed, 21 Mar 2018 11:31:32 -0400 Subject: [PATCH 6/9] fixed selectEntries bug --- python/hail/matrixtable.py | 2 -- python/hail/table.py | 8 ++--- python/hail/tests/test_api.py | 9 ++---- .../scala/is/hail/variant/MatrixTable.scala | 30 ++----------------- .../scala/is/hail/expr/TableIRSuite.scala | 7 ----- 5 files changed, 8 insertions(+), 48 deletions(-) diff --git a/python/hail/matrixtable.py b/python/hail/matrixtable.py index 32d1785b590..6d38b983d06 100644 --- a/python/hail/matrixtable.py +++ b/python/hail/matrixtable.py @@ -941,8 +941,6 @@ def annotate_entries(self, **named_exprs: NamedExprs) -> 'MatrixTable': :class:`.MatrixTable` Matrix table with new row-and-column-indexed field(s). """ - - named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} for k, v in named_exprs.items(): diff --git a/python/hail/table.py b/python/hail/table.py index 0170e6b8062..2a6e98c3011 100644 --- a/python/hail/table.py +++ b/python/hail/table.py @@ -1192,11 +1192,8 @@ def joiner(left): if is_row_key or is_partition_key: # no vds_key (way faster) def joiner(left): - #left.annotate_rows(es=hl.agg.collect(left.x)).rows().show() - rt = MatrixTable(left._jvds.annotateRowsTable( + return MatrixTable(left._jvds.annotateRowsTable( right._jt, uid, False)) - #rt.annotate_rows(es=hl.agg.collect(rt.x)).rows().show() - return rt return construct_expr(Select(TopLevelReference('va', src._row_indices), uid), new_schema, indices, aggregations, joins.push(Join(joiner, [uid], uid, exprs))) @@ -1240,7 +1237,8 @@ def joiner(left): if len(exprs) == len(src.col_key) and all([ exprs[i] is src[list(src.col_key)[i]] for i in range(len(exprs))]): # no vds_key (faster) - joiner = lambda left: MatrixTable(left._jvds.annotateColsTable(right._jt, None, uid, False)) + joiner = lambda left: MatrixTable(left._jvds.annotateColsTable( + right._jt, None, uid, False)) else: # use vds_key joiner = lambda left: MatrixTable(left._jvds.annotateColsTable( diff --git a/python/hail/tests/test_api.py b/python/hail/tests/test_api.py index 1e4d2a28751..8da1c90da5a 100644 --- a/python/hail/tests/test_api.py +++ b/python/hail/tests/test_api.py @@ -799,17 +799,14 @@ def test_computed_key_join_3(self): hl.is_missing(rt['value'])))) def test_entry_join_self(self): - mt1 = hl.utils.range_matrix_table(6, 6, n_partitions=2) - mt1 = mt1.annotate_entries(x = 100*mt1.row_idx + 10*mt1.col_idx) + mt1 = hl.utils.range_matrix_table(10, 10, n_partitions=4) + mt1 = mt1.annotate_entries(x = 10*mt1.row_idx + mt1.col_idx) - mt1.entries().show(20) self.assertEqual(mt1[mt1.row_idx, mt1.col_idx].dtype, mt1.entry.dtype) - mt_join = mt1.annotate_entries(x2 = mt1[mt1.row_idx, mt1.col_idx].x) + mt_join = mt1.annotate_entries(x2 = mt1[mt1.row_idx, mt1.col_idx].x + mt1.x3) mt_join_entries = mt_join.entries() - mt_join_entries.show(20) - self.assertTrue(mt_join_entries.all(mt_join_entries.x == mt_join_entries.x2)) def test_entry_join_const(self): diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index c52448b3b4f..c8760ce2f55 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -976,21 +976,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } def orderedRVDLeftJoinDistinctAndInsert(right: OrderedRVD, root: String, product: Boolean): MatrixTable = { - println(s"unsafe row rdd") - println(this.unsafeRowRDD().collect().mkString("\n")) - - val localRowType = right.rowType - val urrdd = right.rdd.mapPartitions{ it => - val ur = new UnsafeRow(localRowType) - it.map { rv => - ur.set(rv) - ur.copy() - } - } - - println(s"unsafe row rdd2") - println(urrdd.collect().mkString("\n")) - assert(!rowKey.contains(root)) assert(right.typ.pkType.types.map(_.deepOptional()) .sameElements(rowPartitionKeyTypes.map(_.deepOptional()))) @@ -1020,7 +1005,7 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val newMatrixType = matrixType.copy(rvRowType = newRVType) val intermediateMatrixType = newMatrixType.copy(rowKey = newMatrixType.rowPartitionKey) - val res = copyMT(matrixType = newMatrixType, + copyMT(matrixType = newMatrixType, rvd = OrderedRVD( newMatrixType.orvdType, leftRVD.partitioner, @@ -1032,11 +1017,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { it.map { jrv => val lrv = jrv.rvLeft - println(s"leftRV: ${jrv.rvLeft}, rightRV: ${jrv.rvRight}") - if (jrv.rvLeft != null) - println(s"left: ${new UnsafeRow(leftRowType, jrv.rvLeft)}") - if (jrv.rvRight != null) - println(s"right: ${new UnsafeRow(rightRowType, jrv.rvRight)}") rvb.set(lrv.region) rvb.start(newRVType) ins(lrv.region, lrv.offset, rvb, @@ -1066,9 +1046,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } }) ) - println(s"result") - println(res.unsafeRowRDD().collect().mkString("\n")) - res } private def annotateRowsIntervalTable(kt: Table, root: String, product: Boolean): MatrixTable = { @@ -1157,8 +1134,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } def annotateRowsTable(kt: Table, root: String, product: Boolean = false): MatrixTable = { - println(kt.showString()) - assert(!rowKey.contains(root)) val keyTypes = kt.keyFields.map(_.typ) @@ -1373,7 +1348,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { case Some(ir) => new MatrixTable(hc, MapEntries(ast, ir)) case None => - info(s"Using AST") val (t, f) = Parser.parseExpr(expr, ec) val newEntryType = t.asInstanceOf[TStruct] val globalsBc = globals.broadcast @@ -1386,10 +1360,10 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { val fullRow = new UnsafeRow(fullRowType) val row = fullRow.deleteField(localEntriesIndex) ec.set(0, globalsBc.value) - ec.set(1, row) fullRow -> row })(newEntryType, { case ((fullRow, row), rv, rvb) => fullRow.set(rv) + ec.set(1, row) val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex) rvb.startArray(localNCols) var i = 0 diff --git a/src/test/scala/is/hail/expr/TableIRSuite.scala b/src/test/scala/is/hail/expr/TableIRSuite.scala index 773dca3bfc9..0eb18f7bd5f 100644 --- a/src/test/scala/is/hail/expr/TableIRSuite.scala +++ b/src/test/scala/is/hail/expr/TableIRSuite.scala @@ -33,11 +33,4 @@ class TableIRSuite extends SparkSuite { ir.ApplyBinaryPrimOp(ir.EQ(), ir.GetField(ir.Ref("row"), "field1"), ir.GetField(ir.Ref("global"), "g")))) assert(kt2.count() == 1) } - - @Test def testAnnotateRowsTable() { - val mt = MatrixTable.range(hc, 6, 6, Some(4)).annotateEntriesExpr(("x", "va.row_idx * 10 + sa.col_idx")) - val entries = mt.localizeEntries("x") - val joined = mt.annotateRowsTable(entries, "x2", false) - println(joined.unsafeRowRDD().collect().mkString("\n")) - } } From bab6aa13b3b0dd1912a7ad097ff07b50661f4d0d Mon Sep 17 00:00:00 2001 From: wang Date: Wed, 21 Mar 2018 11:39:21 -0400 Subject: [PATCH 7/9] cleanup --- python/hail/tests/test_api.py | 2 +- src/main/scala/is/hail/variant/MatrixTable.scala | 1 - src/test/scala/is/hail/expr/TableIRSuite.scala | 2 -- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/hail/tests/test_api.py b/python/hail/tests/test_api.py index 8da1c90da5a..c7f43abdbb8 100644 --- a/python/hail/tests/test_api.py +++ b/python/hail/tests/test_api.py @@ -804,7 +804,7 @@ def test_entry_join_self(self): self.assertEqual(mt1[mt1.row_idx, mt1.col_idx].dtype, mt1.entry.dtype) - mt_join = mt1.annotate_entries(x2 = mt1[mt1.row_idx, mt1.col_idx].x + mt1.x3) + mt_join = mt1.annotate_entries(x2 = mt1[mt1.row_idx, mt1.col_idx].x) mt_join_entries = mt_join.entries() self.assertTrue(mt_join_entries.all(mt_join_entries.x == mt_join_entries.x2)) diff --git a/src/main/scala/is/hail/variant/MatrixTable.scala b/src/main/scala/is/hail/variant/MatrixTable.scala index c8760ce2f55..8094db022fc 100644 --- a/src/main/scala/is/hail/variant/MatrixTable.scala +++ b/src/main/scala/is/hail/variant/MatrixTable.scala @@ -1338,7 +1338,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) { } def selectEntries(expr: String): MatrixTable = { - println(expr) val ec = entryEC val entryAST = Parser.parseToAST(expr, ec) diff --git a/src/test/scala/is/hail/expr/TableIRSuite.scala b/src/test/scala/is/hail/expr/TableIRSuite.scala index 0eb18f7bd5f..e1d9bf6202c 100644 --- a/src/test/scala/is/hail/expr/TableIRSuite.scala +++ b/src/test/scala/is/hail/expr/TableIRSuite.scala @@ -3,8 +3,6 @@ package is.hail.expr import is.hail.SparkSuite import is.hail.expr.types._ import is.hail.table.Table -import is.hail.variant.MatrixTable -import is.hail.testUtils._ import org.apache.spark.sql.Row import org.testng.annotations.Test From 91e4f5c12acc683b212b3b2b864cd54e832667ef Mon Sep 17 00:00:00 2001 From: wang Date: Wed, 21 Mar 2018 12:58:17 -0400 Subject: [PATCH 8/9] fix --- python/hail/matrixtable.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/hail/matrixtable.py b/python/hail/matrixtable.py index 6d38b983d06..34545898b8d 100644 --- a/python/hail/matrixtable.py +++ b/python/hail/matrixtable.py @@ -1188,15 +1188,12 @@ def select_entries(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'Ma assignments = OrderedDict() for e in exprs: - all_exprs.append(e) if not e._indices == self._entry_indices: # detect row or col fields here raise ExpressionException("method 'select_entries' parameter 'exprs' expects entry-indexed fields," " found indices {}".format(list(e._indices.axes))) if e._ast.search(lambda ast: not isinstance(ast, TopLevelReference) and not isinstance(ast, Select)): raise ExpressionException("method 'select_entries' expects keyword arguments for complex expressions") - strs.append(e._ast.to_hql()) - ids.append(e._ast.name) for k, e in named_exprs.items(): check_collisions(self._fields, k, self._entry_indices) assignments[k] = e From 2db0a72f09e8b22480eb4580882776bc4632f032 Mon Sep 17 00:00:00 2001 From: wang Date: Wed, 21 Mar 2018 16:09:50 -0400 Subject: [PATCH 9/9] fix select_entries --- python/hail/matrixtable.py | 3 +-- python/hail/tests/test_api.py | 9 +++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/hail/matrixtable.py b/python/hail/matrixtable.py index 34545898b8d..c81137e7a68 100644 --- a/python/hail/matrixtable.py +++ b/python/hail/matrixtable.py @@ -1194,6 +1194,7 @@ def select_entries(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'Ma " found indices {}".format(list(e._indices.axes))) if e._ast.search(lambda ast: not isinstance(ast, TopLevelReference) and not isinstance(ast, Select)): raise ExpressionException("method 'select_entries' expects keyword arguments for complex expressions") + assignments[e._ast.name] = e for k, e in named_exprs.items(): check_collisions(self._fields, k, self._entry_indices) assignments[k] = e @@ -1276,8 +1277,6 @@ def drop(self, *exprs: FieldRefArgs) -> 'MatrixTable': new_col_fields = [f for f in m.col if f not in fields_to_drop] m = m.select_cols(*new_col_fields) - - entry_fields = [field for field in fields_to_drop if self._fields[field]._indices == self._entry_indices] if entry_fields: m = m._select_entries("MatrixTable.drop_entries", m.entry.drop(*entry_fields)) diff --git a/python/hail/tests/test_api.py b/python/hail/tests/test_api.py index c7f43abdbb8..2fc56ee0ea1 100644 --- a/python/hail/tests/test_api.py +++ b/python/hail/tests/test_api.py @@ -616,6 +616,15 @@ def test_query(self): qgs = vds.aggregate_entries(hl.Struct(x=agg.collect(agg.filter(False, vds.y1)), y=agg.collect(agg.filter(hl.rand_bool(0.1), vds.GT)))) + def test_select_entries(self): + mt = hl.utils.range_matrix_table(10, 10, n_partitions=4) + mt = mt.annotate_entries(a=hl.struct(b=mt.row_idx, c=mt.col_idx), foo=mt.row_idx * 10 + mt.col_idx) + mt = mt.select_entries(mt.a.b, mt.a.c, mt.foo) + mt = mt.annotate_entries(bc=mt.b * 10 + mt.c) + mt_entries = mt.entries() + + assert(mt_entries.all(mt_entries.bc == mt_entries.foo)) + def test_drop(self): vds = self.get_vds() vds = vds.annotate_globals(foo=5)