Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] add hl.utils.genomic_range_table #12679

Merged
merged 6 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion hail/python/hail/expr/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .nat import NatBase, NatLiteral
from .type_parsing import type_grammar, type_node_visitor
from .. import genetics
from ..typecheck import typecheck, typecheck_method, oneof, transformed
from ..typecheck import typecheck, typecheck_method, oneof, transformed, nullable
from ..utils.struct import Struct
from ..utils.byte_reader import ByteReader
from ..utils.misc import lookup_bit
Expand Down Expand Up @@ -1780,6 +1780,14 @@ class tlocus(HailType):

struct_repr = tstruct(contig=_tstr(), pos=_tint32())

@classmethod
@typecheck_method(reference_genome=nullable(reference_genome_type))
def _schema_from_rg(cls, reference_genome='default'):
# must match TLocus.schemaFromRG
if reference_genome is None:
return hl.tstruct(contig=hl.tstr, position=hl.tint32)
return cls(reference_genome)

@typecheck_method(reference_genome=reference_genome_type)
def __init__(self, reference_genome='default'):
self._rg = reference_genome
Expand Down
41 changes: 41 additions & 0 deletions hail/python/hail/ir/table_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hail.ir.ir as ir
from hail.ir.utils import modify_deep_field, zip_with_index, default_row_uid, default_col_uid
from hail.ir.ir import unify_uid_types, pad_uid, concat_uids
from hail.genetics import ReferenceGenome
from hail.utils import FatalError
from hail.utils.java import Env
from hail.utils.misc import escape_str, parsable_strings, escape_id
Expand Down Expand Up @@ -197,6 +198,46 @@ def _compute_type(self, deep_typecheck):
['idx'])


class TableGenomicRange(TableIR):
def __init__(self, n: int, n_partitions: Optional[int], reference_genome: Optional[ReferenceGenome]):
super().__init__()
self.n = n
self.n_partitions = n_partitions
self.reference_genome = reference_genome

def _handle_randomness(self, uid_field_name):
assert(uid_field_name is not None)
if self.reference_genome is not None:
global_position = ir.Apply(
'locusToGlobalPos',
tint64,
ir.GetField(ir.Ref('row', self.typ.row_type), 'locus'))
else:
global_position = ir.Cast(
ir.GetField(ir.GetField(ir.Ref('row', self.typ.row_type), 'locus'), 'position'),
tint64)

new_row = ir.InsertFields(
ir.Ref('row', self.typ.row_type),
[(uid_field_name, global_position)],
None)
return TableMapRows(self, new_row)

def head_str(self):
reference_genome = self.reference_genome.name if self.reference_genome else None
return f'{self.n} {self.n_partitions} {reference_genome}'

def _eq(self, other):
return self.n == other.n and \
self.n_partitions == other.n_partitions and \
self.reference_genome == other.reference_genome

def _compute_type(self, deep_typecheck):
return hl.ttable(hl.tstruct(),
hl.tstruct(locus=hl.tlocus._schema_from_rg(self.reference_genome)),
['locus'])


class TableMapGlobals(TableIR):
def __init__(self, child, new_globals):
super().__init__(child, new_globals)
Expand Down
20 changes: 13 additions & 7 deletions hail/python/hail/methods/statgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3076,7 +3076,7 @@ def balding_nichols_model(n_populations: int,
n_variants: int,
n_partitions: Optional[int] = None,
pop_dist: Optional[List[int]] = None,
fst: Optional[List[int]] = None,
fst: Optional[List[Union[float, int]]] = None,
af_dist: Optional[hl.Expression] = None,
reference_genome: str = 'default',
mixture: bool = False,
Expand Down Expand Up @@ -3316,7 +3316,10 @@ def balding_nichols_model(n_populations: int,

# generate matrix table

bn = hl.utils.range_matrix_table(n_variants, n_samples, n_partitions)
bn = hl.utils.genomic_range_table(n_variants, n_partitions, reference_genome=reference_genome)
bn = bn.annotate(alleles=['A', 'C'])
bn = bn._key_by_assert_sorted('locus', 'alleles')

bn = bn.annotate_globals(
bn=hl.struct(n_populations=n_populations,
n_samples=n_samples,
Expand All @@ -3327,12 +3330,15 @@ def balding_nichols_model(n_populations: int,
mixture=mixture))
# col info
pop_f = hl.rand_dirichlet if mixture else hl.rand_cat
bn = bn.key_cols_by(sample_idx=bn.col_idx)
bn = bn.select_cols(pop=pop_f(pop_dist))

bn = bn.annotate_globals(cols = hl.range(n_samples).map(
lambda idx: hl.struct(
sample_idx=idx,
pop=pop_f(pop_dist)
)
))
bn = bn.annotate(entries = hl.range(n_samples).map(lambda _: hl.struct()))
bn = bn._unlocalize_entries('entries', 'cols', ['sample_idx'])
# row info
bn = bn.key_rows_by(locus=hl.locus_from_global_position(bn.row_idx, reference_genome=reference_genome),
alleles=['A', 'C'])
bn = bn.select_rows(ancestral_af=af_dist,
af=hl.bind(lambda ancestral:
hl.array([(1 - x) / x for x in fst])
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .tutorial import get_1kg, get_hgdp, get_movie_lens
from .deduplicate import deduplicate
from .jsonx import JSONEncoder
from .genomic_range_table import genomic_range_table

__all__ = ['hadoop_open',
'hadoop_copy',
Expand Down Expand Up @@ -53,4 +54,5 @@
'guess_cloud_spark_provider',
'no_service_backend',
'JSONEncoder',
'genomic_range_table',
]
50 changes: 50 additions & 0 deletions hail/python/hail/utils/genomic_range_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional

import hail as hl
from .misc import check_nonnegative_and_in_range, check_positive_and_in_range
from ..genetics.reference_genome import reference_genome_type
from ..typecheck import typecheck, nullable


@typecheck(n=int, n_partitions=nullable(int), reference_genome=nullable(reference_genome_type))
def genomic_range_table(n: int,
n_partitions: Optional[int] = None,
reference_genome='default'
) -> 'hl.Table':
"""Construct a table with a locus and no other fields.

Examples
--------

>>> ht = hl.utils.range_table(100)
>>> ht.count()
100

Notes
-----
The resulting table contains one field:

- `locus` (:py:data:`.tlocus`) - Row index (key).

The loci appear in sequential ascending order.

Parameters
----------
n : int
Number of loci. Must be less than 2 ** 31.
n_partitions : int, optional
Number of partitions (uses Spark default parallelism if None).
reference_genome : :class:`str` or :class:`.ReferenceGenome`
Reference genome to use for creating the loci.

Returns
-------
:class:`.Table`
"""
check_nonnegative_and_in_range('range_table', 'n', n)
if n_partitions is not None:
check_positive_and_in_range('range_table', 'n_partitions', n_partitions)
if n >= (1 << 31):
raise ValueError(f'`n`, {n}, must be less than 2 ** 31.')

return hl.Table(hl.ir.TableGenomicRange(n, n_partitions, reference_genome))
12 changes: 12 additions & 0 deletions hail/python/test/hail/expr/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from hail.expr import coercer_from_dtype
from hail.expr.types import *
from hail.genetics import reference_genome
from ..helpers import *
from hail.utils.java import Env

Expand Down Expand Up @@ -139,3 +140,14 @@ def test_get_context(self):
for types, rgs in types_and_rgs:
for t in types:
self.assertEqual(t.get_context().references, rgs)

def test_tlocus_schema_from_rg_matches_scala(self):
def locus_from_import_vcf(rg: str) -> HailType:
return hl.import_vcf(resource('sample2.vcf'), reference_genome=None).locus.dtype

assert tlocus._schema_from_rg(None) == locus_from_import_vcf(None)
assert tlocus._schema_from_rg('GRCh37') == locus_from_import_vcf('GRCh38')
assert tlocus._schema_from_rg('GRCh38') == locus_from_import_vcf('GRCh38')
assert tlocus._schema_from_rg('GRCm38') == locus_from_import_vcf('GRCm38')
assert tlocus._schema_from_rg('CanFam3') == locus_from_import_vcf('CanFam3')
assert tlocus._schema_from_rg('default') == locus_from_import_vcf('default')
8 changes: 8 additions & 0 deletions hail/python/test/hail/utils/test_genomic_range_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import hail as hl


def test_genomic_range_table():
actual = hl.utils.genomic_range_table(10, reference_genome='GRCh38').collect()
expected = [hl.Struct(locus=hl.locus("chr1", pos + 1))
for pos in range(10)]
assert actual == expected
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,12 @@ object IRParser {
val n = int32_literal(it)
val nPartitions = opt(it, int32_literal)
done(TableRange(n, nPartitions.getOrElse(HailContext.backend.defaultParallelism)))
case "TableGenomicRange" =>
val n = int32_literal(it)
val nPartitions = opt(it, int32_literal)
val optRgStr = opt(it, identifier)
val optRg = optRgStr.map(env.typEnv.getReferenceGenome)
done(TableGenomicRange(n, nPartitions.getOrElse(HailContext.backend.defaultParallelism), optRg))
case "TableUnion" => table_ir_children(env.onlyRelational)(it).map(TableUnion(_))
case "TableOrderBy" =>
val sortFields = sort_fields(it)
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
case TableKeyBy(_, keys, isSorted) =>
FastSeq(prettyIdentifiers(keys), Pretty.prettyBooleanLiteral(isSorted))
case TableRange(n, nPartitions) => FastSeq(n.toString, nPartitions.toString)
case TableGenomicRange(n, nPartitions, referenceGenome) => FastSeq(
n.toString, nPartitions.toString, referenceGenome.map(_.name).getOrElse("None").toString)
case TableRepartition(_, n, strategy) => FastSeq(n.toString, strategy.toString)
case TableHead(_, n) => single(n.toString)
case TableTail(_, n) => single(n.toString)
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ object PruneDeadFields {
case TableParallelize(rowsAndGlobal, _) =>
memoizeValueIR(ctx, rowsAndGlobal, TStruct("rows" -> TArray(requestedType.rowType), "global" -> requestedType.globalType), memo)
case TableRange(_, _) =>
case TableGenomicRange(_, _, _) =>
case TableRepartition(child, _, _) => memoizeTableIR(ctx, child, requestedType, memo)
case TableHead(child, _) => memoizeTableIR(ctx, child, TableType(
key = child.typ.key,
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/Requiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
requiredness.rowType.unionFields(rowReq.r.asInstanceOf[RStruct])
requiredness.globalType.unionFields(globalReq.r.asInstanceOf[RStruct])
case TableRange(_, _) =>
case TableGenomicRange(_, _, _) =>

// pass through TableIR child
case TableKeyBy(child, _, _) => requiredness.unionFrom(lookup(child))
Expand Down
8 changes: 8 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Simplify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ object Simplify {
case TableCount(TableLeftJoinRightDistinct(child, _, _)) => TableCount(child)
case TableCount(TableIntervalJoin(child, _, _, _)) => TableCount(child)
case TableCount(TableRange(n, _)) => I64(n)
case TableCount(TableGenomicRange(n, _, _)) => I64(n)
case TableCount(TableParallelize(rowsAndGlobal, _)) => Cast(ArrayLen(GetField(rowsAndGlobal, "rows")), TInt64)
case TableCount(TableRename(child, _, _)) => TableCount(child)
case TableCount(TableAggregateByKey(child, _)) => TableCount(TableDistinct(child))
Expand Down Expand Up @@ -743,6 +744,7 @@ object Simplify {
case MatrixColsTable(MatrixKeyRowsBy(child, _, _)) => MatrixColsTable(child)

case TableRepartition(TableRange(nRows, _), nParts, _) => TableRange(nRows, nParts)
case TableRepartition(TableGenomicRange(nRows, _, referenceGenome), nParts, _) => TableGenomicRange(nRows, nParts, referenceGenome)

case TableMapGlobals(TableMapGlobals(child, ng1), ng2) =>
val uid = genUID()
Expand All @@ -763,6 +765,12 @@ object Simplify {
else
tr

case TableHead(tr@TableGenomicRange(nRows, nPar, referenceGenome), n) if canRepartition =>
if (n < nRows)
TableGenomicRange(n.toInt, (nPar.toFloat * n / nRows).toInt.max(1), referenceGenome)
else
tr

case TableHead(TableMapGlobals(child, newGlobals), n) =>
TableMapGlobals(TableHead(child, n), newGlobals)

Expand Down
91 changes: 91 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,97 @@ case class TableRange(n: Int, nPartitions: Int) extends TableIR {
}
}

object TableGenomicRange {
def toLocus(
referenceGenome: Option[ReferenceGenome]
): Int => Annotation = referenceGenome match {
case Some(rg) => (x: Int) => rg.globalPosToLocus(x.toLong)
case None => (idx: Int) =>
Row("1", idx.toInt + 1)
}
}

case class TableGenomicRange(
n: Int,
nPartitions: Int,
referenceGenome: Option[ReferenceGenome]
) extends TableIR {
require(n < Int.MaxValue)
require(n >= 0)
require(nPartitions > 0)
private val nPartitionsAdj = math.max(math.min(n, nPartitions), 1)
val children: IndexedSeq[BaseIR] = Array.empty[BaseIR]

def copy(newChildren: IndexedSeq[BaseIR]): TableGenomicRange = {
assert(newChildren.isEmpty)
TableGenomicRange(n, nPartitions, referenceGenome)
}

private[this] val partCounts = partition(n, nPartitionsAdj)

override val partitionCounts = Some(partCounts.map(_.toLong).toFastIndexedSeq)

lazy val rowCountUpperBound: Option[Long] = Some(n.toLong)

val typ: TableType = TableType(
TStruct("locus" -> TLocus.schemaFromRG(referenceGenome)),
Array("locus"),
TStruct.empty)

protected[ir] override def execute(ctx: ExecuteContext, r: TableRunContext): TableExecuteIntermediate = {
val localLocusType = PCanonicalLocus.schemaFromRG(referenceGenome, true)
val unstagedStoreLocusFromGlobalPos: (Long, Int, Region) => Unit = localLocusType match {
case x: PCanonicalLocus =>
val rgBc = x.rgBc

{ (off: Long, globalPos: Int, region: Region) =>
val l = rgBc.value.globalPosToLocus(globalPos)
x.unstagedStoreJavaObjectAtAddress(off, l, region)
}
case x: PCanonicalStruct =>
{ (off: Long, globalPos: Int, region: Region) =>
Region.storeLong(
x.fieldOffset(off, 0),
x.field("locus").asInstanceOf[PCanonicalString].unstagedStoreJavaObject("1", region))
Region.storeInt(x.fieldOffset(off, 1), globalPos)
}
}
val localRowType = PCanonicalStruct(true, "locus" -> localLocusType)
val localPartCounts = partCounts
val partStarts = partCounts.scanLeft(0)(_ + _)

val partLocusStarts = partStarts.map(TableGenomicRange.toLocus(referenceGenome))
new TableValueIntermediate(TableValue(ctx, typ,
BroadcastRow.empty(ctx),
new RVD(
RVDType(localRowType, Array("locus")),
new RVDPartitioner(
Array("locus"),
typ.rowType,
Array.tabulate(nPartitionsAdj) { i =>
val start = partLocusStarts(i)
val end = partLocusStarts(i + 1)
Interval(Row(start), Row(end), includesStart = true, includesEnd = false)
}
),
ContextRDD.parallelize(Range(0, nPartitionsAdj), nPartitionsAdj)
.cmapPartitionsWithIndex { case (i, ctx, _) =>
val region = ctx.region

val start = partStarts(i)
Iterator.range(start, (start + localPartCounts(i)))
.map { j =>
val off = localRowType.allocate(region)
unstagedStoreLocusFromGlobalPos(
localRowType.fieldOffset(off, 0),
j,
region)
off
}
})))
}
}

case class TableFilter(child: TableIR, pred: IR) extends TableIR {
val children: IndexedSeq[BaseIR] = Array(child, pred)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ object CanLowerEfficiently {
case t: TableRepartition => fail(s"TableRepartition has no lowered implementation")
case t: TableParallelize =>
case t: TableRange =>
case t: TableGenomicRange =>
case TableKeyBy(child, keys, isSorted) =>
case t: TableOrderBy =>
case t: TableFilter =>
Expand Down
Loading