Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement annotate/drop in terms of select_entries #3200

Merged
merged 9 commits into from
Mar 22, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 21 additions & 32 deletions python/hail/matrixtable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from typing import *
from collections import OrderedDict

import hail
import hail as hl
Expand Down Expand Up @@ -940,16 +941,12 @@ def annotate_entries(self, **named_exprs: NamedExprs) -> 'MatrixTable':
:class:`.MatrixTable`
Matrix table with new row-and-column-indexed field(s).
"""
exprs = []
named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
base, cleanup = self._process_joins(*named_exprs.values())

for k, v in named_exprs.items():
analyze('MatrixTable.annotate_entries', v, self._entry_indices)
exprs.append('g.{k} = {v}'.format(k=escape_id(k), v=v._ast.to_hql()))
check_collisions(self._fields, k, self._entry_indices)
m = MatrixTable(base._jvds.annotateEntriesExpr(",\n".join(exprs)))
return cleanup(m)

return self._select_entries("MatrixTable.annotate_entries", self.entry.annotate(**named_exprs))

def select_globals(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'MatrixTable':
"""Select existing global fields or create new fields by name, dropping the rest.
Expand Down Expand Up @@ -1188,31 +1185,20 @@ def select_entries(self, *exprs: FieldRefArgs, **named_exprs: NamedExprs) -> 'Ma
"""
exprs = [to_expr(e) if not isinstance(e, str) else self[e] for e in exprs]
named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
strs = []
all_exprs = []
base, cleanup = self._process_joins(*itertools.chain(exprs, named_exprs.values()))
assignments = OrderedDict()

ids = []
for e in exprs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like anything in exprs is getting passed to _select_entries. I might be missing something, but if not you should add a test that fails, and then fix select_entries to make it pass.

all_exprs.append(e)
analyze('MatrixTable.select_entries', e, self._entry_indices)
if not e._indices == self._entry_indices:
# detect row or col fields here
raise ExpressionException("method 'select_entries' parameter 'exprs' expects entry-indexed fields,"
" found indices {}".format(list(e._indices.axes)))
if e._ast.search(lambda ast: not isinstance(ast, TopLevelReference) and not isinstance(ast, Select)):
raise ExpressionException("method 'select_entries' expects keyword arguments for complex expressions")
strs.append(e._ast.to_hql())
ids.append(e._ast.name)
for k, e in named_exprs.items():
all_exprs.append(e)
analyze('MatrixTable.select_entries', e, self._entry_indices)
check_collisions(self._fields, k, self._entry_indices)
strs.append('{} = {}'.format(escape_id(k), e._ast.to_hql()))
ids.append(k)
check_field_uniqueness(ids)
m = MatrixTable(base._jvds.selectEntries(strs))
return cleanup(m)
assignments[k] = e
check_field_uniqueness(assignments.keys())
return self._select_entries("MatrixTable.select_entries", hl.struct(**assignments))

@typecheck_method(exprs=oneof(str, Expression))
def drop(self, *exprs: FieldRefArgs) -> 'MatrixTable':
Expand Down Expand Up @@ -1290,10 +1276,11 @@ def drop(self, *exprs: FieldRefArgs) -> 'MatrixTable':
new_col_fields = [f for f in m.col if f not in fields_to_drop]
m = m.select_cols(*new_col_fields)

entry_fields = [x for x in fields_to_drop if self._fields[x]._indices == self._entry_indices]
if any(self._fields[field]._indices == self._entry_indices for field in fields_to_drop):
# need to drop entry fields
m = MatrixTable(m._jvds.dropEntries(entry_fields))


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why so much space?

entry_fields = [field for field in fields_to_drop if self._fields[field]._indices == self._entry_indices]
if entry_fields:
m = m._select_entries("MatrixTable.drop_entries", m.entry.drop(*entry_fields))

return m

Expand Down Expand Up @@ -2199,7 +2186,6 @@ def joiner(left: MatrixTable):
src_cols_indexed = src_cols_indexed.annotate(**{col_uid: hl.int32(src_cols_indexed[col_uid])})
left = left._annotate_all(row_exprs = {row_uid: localized.index(*row_exprs)[row_uid]},
col_exprs = {col_uid: src_cols_indexed.index(*col_exprs)[col_uid]})

return left.annotate_entries(**{uid: left[row_uid][left[col_uid]]})

return construct_expr(Select(TopLevelReference('g', self._entry_indices), uid),
Expand Down Expand Up @@ -2238,12 +2224,9 @@ def _annotate_all(self,
check_collisions(self._fields, k, self._col_indices)
jmt = jmt.annotateColsExpr(",\n".join(col_strs))
if entry_exprs:
entry_strs = []
for k, v in entry_exprs.items():
analyze('MatrixTable.annotate_entries', v, self._entry_indices)
entry_strs.append('g.{k} = {v}'.format(k=escape_id(k), v=v._ast.to_hql()))
check_collisions(self._fields, k, self._entry_indices)
jmt = jmt.annotateEntriesExpr(",\n".join(entry_strs))
entry_struct = self.entry.annotate(**entry_exprs)
analyze("MatrixTable.annotate_entries", entry_struct, self._entry_indices)
jmt = jmt.selectEntries(entry_struct._ast.to_hql())
if global_exprs:
global_strs = []
for k, v in global_exprs.items():
Expand Down Expand Up @@ -2613,6 +2596,12 @@ def add_col_index(self, name: str = 'col_idx') -> 'MatrixTable':
def _same(self, other, tolerance=1e-6):
return self._jvds.same(other._jvds, tolerance)

@typecheck_method(caller=str, s=expr_struct())
def _select_entries(self, caller, s):
base, cleanup = self._process_joins(s)
analyze(caller, s, self._entry_indices)
return cleanup(MatrixTable(base._jvds.selectEntries(s._ast.to_hql())))

@typecheck(datasets=matrix_table_type)
def union_rows(*datasets: Tuple['MatrixTable']) -> 'MatrixTable':
"""Take the union of dataset rows.
Expand Down
3 changes: 2 additions & 1 deletion python/hail/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,12 +800,13 @@ def test_computed_key_join_3(self):

def test_entry_join_self(self):
mt1 = hl.utils.range_matrix_table(10, 10, n_partitions=4)
mt1 = mt1.annotate_entries(x = mt1.row_idx + mt1.col_idx)
mt1 = mt1.annotate_entries(x = 10*mt1.row_idx + mt1.col_idx)

self.assertEqual(mt1[mt1.row_idx, mt1.col_idx].dtype, mt1.entry.dtype)

mt_join = mt1.annotate_entries(x2 = mt1[mt1.row_idx, mt1.col_idx].x)
mt_join_entries = mt_join.entries()

self.assertTrue(mt_join_entries.all(mt_join_entries.x == mt_join_entries.x2))

def test_entry_join_const(self):
Expand Down
201 changes: 37 additions & 164 deletions src/main/scala/is/hail/variant/MatrixTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1337,106 +1337,46 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {
rvd = rvd.mapPartitionsPreservesPartitioning(newMatrixType.orvdType)(mapPartitionsF))
}

def selectEntries(selectExprs: java.util.ArrayList[String]): MatrixTable = selectEntries(selectExprs.asScala.toArray: _*)

def selectEntries(exprs: String*): MatrixTable = {
def selectEntries(expr: String): MatrixTable = {
val ec = entryEC
val globalsBc = globals.broadcast

val (paths, types, f) = Parser.parseSelectExprs(exprs.toArray, ec)
val topLevelFields = mutable.Set.empty[String]

val finalNames = paths.map {
// assignment
case Left(name) => name
case Right(path) =>
assert(path.head == Annotation.ENTRY_HEAD)
path match {
case List(Annotation.ENTRY_HEAD, name) => topLevelFields += name
}
path.last
}
assert(finalNames.areDistinct())

val newEntryType = TStruct(finalNames.zip(types): _*)
val fullRowType = rvRowType
val localEntriesIndex = entriesIndex
val localNCols = numCols
val localColValuesBc = colValuesBc

insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
ec.set(0, globalsBc.value)
ec.set(1, row)
fullRow -> row
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)
rvb.startArray(localNCols)
var i = 0
while (i < localNCols) {
val entry = entries(i)
ec.set(2, localColValuesBc.value(i))
ec.set(3, entry)
val results = f()
var j = 0
rvb.startStruct()
while (j < types.length) {
rvb.addAnnotation(types(j), results(j))
j += 1
}
rvb.endStruct()
i += 1
}
rvb.endArray()
})
}


def dropEntries(fields: java.util.ArrayList[String]): MatrixTable = dropEntries(fields.asScala.toArray: _*)

def dropEntries(fields: String*): MatrixTable = {
if (fields.isEmpty)
return this
assert(fields.areDistinct())
val dropSet = fields.toSet
val allEntryFields = entryType.fieldNames.toSet
assert(fields.forall(allEntryFields.contains))

val keepIndices = entryType.fields
.filter(f => !dropSet.contains(f.name))
.map(f => f.index)
.toArray
val newEntryType = TStruct(keepIndices.map(entryType.fields(_)).map(f => f.name -> f.typ): _*)

val fullRowType = rvRowType
val localEntriesIndex = entriesIndex
val localEntriesType = matrixType.entryArrayType
val localEntryType = entryType
val localNCols = numCols
// FIXME: replace with physical type
insertEntries(noOp)(newEntryType, { case (_, rv, rvb) =>
val entriesOffset = fullRowType.loadField(rv, localEntriesIndex)
rvb.startArray(localNCols)
var i = 0
while (i < localNCols) {
if (localEntriesType.isElementMissing(rv.region, entriesOffset, i))
rvb.setMissing()
else {
val eltOffset = localEntriesType.loadElement(rv.region, entriesOffset, localNCols, i)
rvb.startStruct()
var j = 0
while (j < keepIndices.length) {
rvb.addField(localEntryType, rv.region, eltOffset, keepIndices(j))
j += 1
val entryAST = Parser.parseToAST(expr, ec)
assert(entryAST.`type`.isInstanceOf[TStruct])

entryAST.toIR() match {
case Some(ir) =>
new MatrixTable(hc, MapEntries(ast, ir))
case None =>
val (t, f) = Parser.parseExpr(expr, ec)
val newEntryType = t.asInstanceOf[TStruct]
val globalsBc = globals.broadcast
val fullRowType = rvRowType
val localEntriesIndex = entriesIndex
val localNCols = numCols
val localColValuesBc = colValuesBc

insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
ec.set(0, globalsBc.value)
fullRow -> row
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
ec.set(1, row)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)
rvb.startArray(localNCols)
var i = 0
while (i < localNCols) {
val entry = entries(i)
ec.set(2, localColValuesBc.value(i))
ec.set(3, entry)
val result = f()
rvb.addAnnotation(newEntryType, result)
i += 1
}
rvb.endStruct()
}
i += 1
}
rvb.endArray()
})
rvb.endArray()
})
}
}

def nPartitions: Int = rvd.partitions.length
Expand Down Expand Up @@ -1569,73 +1509,6 @@ class MatrixTable(val hc: HailContext, val ast: MatrixIR) {
new Table(hc, TableLiteral(TableValue(TableType(rvRowType.rename(m), rowKey, globalType), globals, rvd)))
}

def annotateEntriesExpr(expr: String): MatrixTable = {
val symTab = Map(
"va" -> (0, rowType),
"sa" -> (1, colType),
"g" -> (2, entryType),
"global" -> (3, globalType))
val ec = EvalContext(symTab)

val globalsBc = globals.broadcast

val asts = Parser.parseAnnotationExprsToAST(expr, ec, Some(Annotation.ENTRY_HEAD))

val irs = asts.flatMap { case (f, a) => a.toIR().map((f, _)) }

val colValuesIsSmall = colType.size == 1 && colType.types.head.isOfType(TString())
if (irs.length == asts.length && colValuesIsSmall) {
val newEntries = ir.InsertFields(ir.Ref("g"), irs)

new MatrixTable(hc, MapEntries(ast, newEntries))
} else {

val (paths, types, f) = Parser.parseAnnotationExprs(expr, ec, Some(Annotation.ENTRY_HEAD))

val inserterBuilder = new ArrayBuilder[Inserter]()
val newEntryType = (paths, types).zipped.foldLeft(entryType) { case (gsig, (ids, signature)) =>
val (s, i) = gsig.structInsert(signature, ids)
inserterBuilder += i
s
}
val inserters = inserterBuilder.result()

val localNSamples = numCols
val fullRowType = rvRowType
val localColValuesBc = colValuesBc
val localEntriesIndex = entriesIndex

insertEntries(() => {
val fullRow = new UnsafeRow(fullRowType)
val row = fullRow.deleteField(localEntriesIndex)
(fullRow, row)
})(newEntryType, { case ((fullRow, row), rv, rvb) =>
fullRow.set(rv)
val entries = fullRow.getAs[IndexedSeq[Annotation]](localEntriesIndex)

rvb.startArray(localNSamples)

var i = 0
while (i < localNSamples) {
val entry = entries(i)
ec.setAll(row,
localColValuesBc.value(i),
entry,
globalsBc.value)

val newEntry = f().zip(inserters)
.foldLeft(entry) { case (ga, (a, inserter)) =>
inserter(ga, a)
}
rvb.addAnnotation(newEntryType, newEntry)

i += 1
}
rvb.endArray()
})
}
}

def filterCols(p: (Annotation, Int) => Boolean): MatrixTable = {
copyAST(ast = MatrixLiteral(matrixType, value.filterCols(p)))
}
Expand Down
15 changes: 8 additions & 7 deletions src/test/scala/is/hail/io/ExportVCFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import is.hail.io.vcf.ExportVCF
import is.hail.utils._
import is.hail.variant.{MatrixTable, VSMSubgen, Variant}
import org.testng.annotations.Test
import is.hail.testUtils._

import scala.io.Source
import scala.language.postfixOps
Expand Down Expand Up @@ -210,13 +211,13 @@ class ExportVCFSuite extends SparkSuite {

TestUtils.interceptFatal("Invalid type for format field 'BOOL'. Found 'bool'.") {
ExportVCF(vds
.annotateEntriesExpr("g.BOOL = true"),
.annotateEntriesExpr(("BOOL","true")),
out)
}

TestUtils.interceptFatal("Invalid type for format field 'AA'.") {
ExportVCF(vds
.annotateEntriesExpr("g.AA = [[0]]"),
.annotateEntriesExpr(("AA", "[[0]]")),
out)
}
}
Expand Down Expand Up @@ -281,15 +282,15 @@ class ExportVCFSuite extends SparkSuite {
val callArrayFields = schema.fields.filter(fd => fd.typ == TArray(TCall())).map(_.name)
val callSetFields = schema.fields.filter(fd => fd.typ == TSet(TCall())).map(_.name)

val callAnnots = callFields.map(name => s"g.$name = let c = g.$name in " +
s"if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c")
val callAnnots = callFields.map(name => (name, s"let c = g.$name in " +
s"if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c"))

val callContainerAnnots = (callArrayFields ++ callSetFields).map(name => s"g.$name = " +
s"g.$name.map(c => if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c)")
val callContainerAnnots = (callArrayFields ++ callSetFields).map(name => (name,
s"g.$name.map(c => if (c.ploidy == 0 || (c.ploidy == 1 && c.isPhased())) Call(0, 0, false) else c)"))

val annots = callAnnots ++ callContainerAnnots

val vsmAnn = if (annots.nonEmpty) vsm.annotateEntriesExpr(annots.mkString(",")) else vsm
val vsmAnn = if (annots.nonEmpty) vsm.annotateEntriesExpr(annots: _*) else vsm

hadoopConf.delete(out, recursive = true)
ExportVCF(vsmAnn, out)
Expand Down
Loading