Skip to content

Commit

Permalink
Merge pull request #4 from cseed/ktfix1
Browse files Browse the repository at this point in the history
Ktfix1
  • Loading branch information
jigold authored Nov 21, 2016
2 parents e95b6b2 + 27222d0 commit 1142e58
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 69 deletions.
16 changes: 7 additions & 9 deletions python/pyhail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def import_gen(self, path, tolerance=0.2, sample_file=None, npartitions=None, ch
return self.run_command(None, pargs)

def import_keytable(self, path, key_names, npartitions=None, config=None):
"""Import tabular file as KeyTable
"""Import delimited text file (text table) as KeyTable.
:param path: files to import.
:type path: str or list of str
Expand All @@ -250,7 +250,7 @@ def import_keytable(self, path, key_names, npartitions=None, config=None):
:type npartitions: int or None
:param config: Configuration options for importing text files
:type config: :class:`.TextTableConfig`
:type config: :class:`.TextTableConfig` or None
:rtype: :class:`.KeyTable`
"""
Expand All @@ -262,20 +262,18 @@ def import_keytable(self, path, key_names, npartitions=None, config=None):
pathArgs.append(p)

if not isinstance(key_names, str):
key_names = ",".join(key_names)
key_names = ','.join(key_names)

if not npartitions:
npartitions = self.sc.defaultMinPartitions

if not config:
config = TextTableConfig()._jobj(self)
elif isinstance(config, TextTableConfig):
config = config._jobj(self)
config = TextTableConfig()

return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs),
key_names, npartitions, config))
return KeyTable(self, self.jvm.org.broadinstitute.hail.keytable.KeyTable.importTextTable(
self.jsc, jarray(self.gateway, self.jvm.java.lang.String, pathArgs), key_names, npartitions, config.to_java(self)))

def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing="NA", quantpheno=False):
def import_plink(self, bed, bim, fam, npartitions=None, delimiter='\\\\s+', missing='NA', quantpheno=False):
"""
Import PLINK binary file (.bed, .bim, .fam) as VariantDataset
Expand Down
2 changes: 1 addition & 1 deletion python/pyhail/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, hc, jvds):
self.hc = hc
self.jvds = jvds

def aggregate_by_key(self, key_code=None, agg_code=None):
def aggregate_by_key(self, key_code, agg_code):
"""Aggregate by user-defined key and aggregation expressions.
Equivalent of a group-by operation in SQL.
Expand Down
9 changes: 5 additions & 4 deletions python/pyhail/keytable.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pyhail.utils import Type

class KeyTable(object):
""":class:`.KeyTable` is Hail's version of a SQL
table where fields can be designated as keys.
""":class:`.KeyTable` is Hail's version of a SQL table where fields
can be designated as keys.
"""

def __init__(self, hc, jkt):
Expand Down Expand Up @@ -85,7 +86,7 @@ def filter(self, code, keep=True):
"""
return KeyTable(self.hc, self.jkt.filter(code, keep))

def annotate(self, code, key_names=None):
def annotate(self, code, key_names=''):
"""Add fields to key-table.
:param str code: Annotation expression.
Expand Down Expand Up @@ -144,4 +145,4 @@ def exists(self, code):
:rtype: bool
"""
return self.jkt.exists(code)
return self.jkt.exists(code)
11 changes: 2 additions & 9 deletions python/pyhail/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def test_dataset(self):
sample2.variants_to_pandas()

sample_split.annotate_variants_expr("va.nHet = gs.filter(g => g.isHet).count()")
kt = sample_split.aggregate_by_key("Variant = v", "nHet = gs.filter(g => g.isHet).count()")

kt = sample_split.aggregate_by_key("Variant = v", "nHet = g.map(g => g.isHet.toInt).sum().toLong")

def test_keytable(self):
# Import
Expand Down Expand Up @@ -245,13 +246,5 @@ def test_keytable(self):
self.assertFalse(kt.forall('Status == "CASE"'))
self.assertTrue(kt.exists('Status == "CASE"'))









def tearDown(self):
self.sc.stop()
10 changes: 10 additions & 0 deletions python/pyhail/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

class Type(object):
def __init__(self, jtype):
self.jtype = jtype

def __repr__(self):
return self.jtype.toString()

def __str__(self):
return self.jtype.toPrettyString(False, False)
25 changes: 8 additions & 17 deletions python/pyhail/utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,20 @@

class Type(object):
def __init__(self, jtype):
self.jtype = jtype

def __repr__(self):
return self.jtype.toString()

def __str__(self):
return self.jtype.toPrettyString(False, False)

class TextTableConfig(object):
""":class:`.TextTableConfig` specifies additional options for importing TSV files.
"""Configuration for delimited (text table) files.
:param bool noheader: File has no header and columns should be indicated by `_1, _2, ... _N' (0-indexed)
:param bool impute: Impute column types from the file
:param str comment: Skip lines beginning with the given pattern
:param comment: Skip lines beginning with the given pattern
:type comment: str or None
:param str delimiter: Field delimiter regex
:param str missing: Specify identifier to be treated as missing
:param str types: Define types of fields in annotations files
:param types: Define types of fields in annotations files
:type types: str or None
"""
def __init__(self, noheader = False, impute = False,
comment = None, delimiter = "\t", missing = "NA", types = None):
Expand All @@ -45,12 +37,11 @@ def __str__(self):

return " ".join(res)

def _jobj(self, hc):
"""Convert to java TextTableConfiguration object
def to_java(self, hc):
"""Convert to Java TextTableConfiguration object.
:param :class:`.HailContext` hc: Hail spark context.
:param :class:`.HailContext` The Hail context.
"""
return hc.jvm.org.broadinstitute.hail.utils.TextTableConfiguration.apply(self.types, self.comment,
self.delimiter, self.missing,
self.noheader, self.impute)

5 changes: 2 additions & 3 deletions src/main/scala/org/broadinstitute/hail/expr/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,8 @@ case class MappedAggregable(parent: TAggregable, elementType: Type, mapF: (Any)
val parentF = parent.f
(a: Any) => {
val prev = parentF(a)
if (prev != null) {
if (prev != null)
mapF(prev)
}
else
null
}
Expand Down Expand Up @@ -320,7 +319,7 @@ case class TArray(elementType: Type) extends TIterable {

override def str(a: Annotation): String = JsonMethods.compact(toJSON(a))

override def genValue: Gen[Annotation] = Gen.buildableOf[Array, Annotation](elementType.genValue).map(x => x: IndexedSeq[Annotation])
override def genValue: Gen[Annotation] = Gen.buildableOf[IndexedSeq, Annotation](elementType.genValue)
}

case class TSet(elementType: Type) extends TIterable {
Expand Down
11 changes: 3 additions & 8 deletions src/main/scala/org/broadinstitute/hail/keytable/KeyTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ object KeyTable extends Serializable with TextExporter {
else
TextTableReader.read(sc)(files, config, nPartitions)


val keyNamesValid = keyNameArray.forall { k =>
val res = struct.selfField(k).isDefined
if (!res)
Expand Down Expand Up @@ -83,7 +82,7 @@ object KeyTable extends Serializable with TextExporter {

case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, valueSignature: TStruct) {

require(fieldNames.toSet.size == fieldNames.length)
require(fieldNames.areDistinct())

def signature = keySignature.merge(valueSignature)._1

Expand Down Expand Up @@ -292,12 +291,10 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v

val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean)

val p = (k: Annotation, v: Annotation) => {
rdd.forall { case (k, v) =>
KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal)
f().getOrElse(false)
}

rdd.forall { case (k, v) => p(k, v) }
}

def exists(code: String): Boolean = {
Expand All @@ -307,12 +304,10 @@ case class KeyTable(rdd: RDD[(Annotation, Annotation)], keySignature: TStruct, v

val f: () => Option[Boolean] = Parser.parse[Boolean](code, ec, TBoolean)

val p = (k: Annotation, v: Annotation) => {
rdd.exists { case (k, v) =>
KeyTable.setEvalContext(ec, k, v, nKeysLocal, nValuesLocal)
f().getOrElse(false)
}

rdd.exists { case (k, v) => p(k, v) }
}

def export(sc: SparkContext, output: String, typesFile: String) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,17 +605,20 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata,
"va" -> (1, vaSignature),
"s" -> (2, TSample),
"sa" -> (3, saSignature),
"global" -> (4, globalSignature)))
"global" -> (4, globalSignature),
"g" -> (5, TGenotype)))

val symTab = Map(
val ec = EvalContext(Map(
"v" -> (0, TVariant),
"va" -> (1, vaSignature),
"s" -> (2, TSample),
"sa" -> (3, saSignature),
"global" -> (4, globalSignature),
"gs" -> (-1, BaseAggregable(aggregationEC, TGenotype)))
"gs" -> (-1, BaseAggregable(aggregationEC, TGenotype))))

val ec = EvalContext(symTab)
val ktEC = EvalContext(
aggregationEC.st.map { case (name, (i, t)) => name -> (-1, KeyTableAggregable(aggregationEC, t.asInstanceOf[Type], i)) }
)

ec.set(4, globalAnnotation)
aggregationEC.set(4, globalAnnotation)
Expand All @@ -628,26 +631,25 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata,

val (aggNameParseTypes, aggF) =
if (aggCond != null)
Parser.parseAnnotationArgs(aggCond, ec, None)
Parser.parseAnnotationArgs(aggCond, ktEC, None)
else
(Array.empty[(List[String], Type)], Array.empty[() => Any])

val keyNames = keyNameParseTypes.map(_._1.head)
val aggNames = aggNameParseTypes.map(_._1.head)

val keySignature = TStruct(keyNameParseTypes.map{ case (n, t) => (n.head, t) }: _*)
val valueSignature = TStruct(aggNameParseTypes.map{ case (n, t) => (n.head, t) }: _*)
val keySignature = TStruct(keyNameParseTypes.map { case (n, t) => (n.head, t) }: _*)
val valueSignature = TStruct(aggNameParseTypes.map { case (n, t) => (n.head, t) }: _*)

val (zVals, _, combOp, resultOp) = Aggregators.makeFunctions(aggregationEC)
val aggFunctions = aggregationEC.aggregationFunctions.map(_._1)

val localGlobalAnnotation = globalAnnotation

val seqOp = (array: Array[Aggregator], b: (Any, Any, Any, Any, Any)) => {
val (v, va, s, sa, aggT) = b
ec.set(0, v)
ec.set(1, va)
ec.set(2, s)
ec.set(3, sa)
val seqOp = (array: Array[Aggregator], r: Annotation) => {
KeyTable.setEvalContext(aggregationEC, r, 6)
for (i <- array.indices) {
array(i).seqOp(aggT)
array(i).seqOp(aggFunctions(i)(r))
}
array
}
Expand All @@ -656,7 +658,7 @@ class VariantSampleMatrix[T](val metadata: VariantMetadata,
it.map { case (v, va, s, sa, g) =>
ec.setAll(v, va, s, sa, g)
val key = Annotation.fromSeq(keyF.map(_ ()))
(key, (v, va, s, sa, g))
(key, Annotation(v, va, s, sa, localGlobalAnnotation, g))
}
}.aggregateByKey(zVals)(seqOp, combOp)
.map { case (k, agg) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class AggregateByKeySuite extends SparkSuite {
var s = State(sc, sqlContext)
s = ImportVCF.run(s, Array(inputVCF))
s = AnnotateSamplesExpr.run(s, Array("-c", "sa.nHet = gs.filter(g => g.isHet).count()"))
val kt = s.vds.aggregateByKey("Sample = s", "nHet = gs.filter(g => g.isHet).count()")
val kt = s.vds.aggregateByKey("Sample = s", "nHet = g.map(g => g.isHet.toInt).sum().toLong")

val (_, ktHetQuery) = kt.query("nHet")
val (_, ktSampleQuery) = kt.query("Sample")
Expand All @@ -29,7 +29,7 @@ class AggregateByKeySuite extends SparkSuite {
var s = State(sc, sqlContext)
s = ImportVCF.run(s, Array(inputVCF))
s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()"))
val kt = s.vds.aggregateByKey("Variant = v", "nHet = gs.filter(g => g.isHet).count()")
val kt = s.vds.aggregateByKey("Variant = v", "nHet = g.map(g => g.isHet.toInt).sum().toLong")

val (_, ktHetQuery) = kt.query("nHet")
val (_, ktVariantQuery) = kt.query("Variant")
Expand All @@ -48,7 +48,7 @@ class AggregateByKeySuite extends SparkSuite {
s = ImportVCF.run(s, Array(inputVCF))
s = AnnotateVariantsExpr.run(s, Array("-c", "va.nHet = gs.filter(g => g.isHet).count()"))
s = AnnotateGlobalExpr.run(s, Array("-c", "global.nHet = variants.map(v => va.nHet).sum().toLong"))
val kt = s.vds.aggregateByKey(null, "nHet = gs.filter(g => g.isHet).count()")
val kt = s.vds.aggregateByKey(null, "nHet = g.map(g => g.isHet.toInt).sum().toLong")

val (_, ktHetQuery) = kt.query("nHet")
val (_, globalHetResult) = s.vds.queryGlobal("global.nHet")
Expand Down

0 comments on commit 1142e58

Please sign in to comment.