Skip to content

Commit

Permalink
added Direchlet mixture option to BaldingNicholsModel
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom22 committed Mar 21, 2018
1 parent 8042983 commit e6ff453
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 29 deletions.
23 changes: 14 additions & 9 deletions python/hail/methods/statgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ def hwe_normalized_pca(dataset, k=10, compute_loadings=False, as_array=False):
:math:`i` and :math:`j` of :math:`M`; in terms of :math:`C` it is
.. math::
\frac{1}{m}\sum_{l\in\mathcal{C}_i\cap\mathcal{C}_j}\frac{(C_{il}-2p_l)(C_{jl} - 2p_l)}{2p_l(1-p_l)}
where :math:`\mathcal{C}_i = \{l \mid C_{il} \text{ is non-missing}\}`. In
Expand All @@ -1323,7 +1323,7 @@ def hwe_normalized_pca(dataset, k=10, compute_loadings=False, as_array=False):
Parameters
----------
dataset : :class:`.MatrixTable`
Dataset.
Matrix table with entry-indexed ``GT`` field of type :py:data:`.tcall`.
k : :obj:`int`
Number of principal components.
compute_loadings : :obj:`bool`
Expand Down Expand Up @@ -1379,7 +1379,7 @@ def pca(entry_expr, k=10, compute_loadings=False, as_array=False):
1s encoding missingness of genotype calls.
>>> eigenvalues, scores, _ = hl.pca(hl.int(hl.is_defined(dataset.GT)),
... k=2)
... k=2)
Warning
-------
Expand Down Expand Up @@ -1438,8 +1438,6 @@ def pca(entry_expr, k=10, compute_loadings=False, as_array=False):
Parameters
----------
dataset : :class:`.MatrixTable`
Dataset.
entry_expr : :class:`.Expression`
Numeric expression for matrix entries.
k : :obj:`int`
Expand Down Expand Up @@ -2258,10 +2256,11 @@ def realized_relationship_matrix(call_expr):
fst=nullable(listof(numeric)),
af_dist=oneof(UniformDist, BetaDist, TruncatedBetaDist),
seed=int,
reference_genome=reference_genome_type)
reference_genome=reference_genome_type,
mixture=bool)
def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=None,
pop_dist=None, fst=None, af_dist=UniformDist(0.1, 0.9),
seed=0, reference_genome='default'):
seed=0, reference_genome='default', mixture=False):
r"""Generate a matrix table of variants, samples, and genotypes using the
Balding-Nichols model.
Expand Down Expand Up @@ -2324,7 +2323,7 @@ def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=Non
population allele frequencies by :math:`p_{k, m}`, and diploid, unphased
genotype calls by :math:`g_{n, m}` (0, 1, and 2 correspond to homozygous
reference, heterozygous, and homozygous variant, respectively).
The generative model is then given by:
.. math::
Expand Down Expand Up @@ -2354,6 +2353,7 @@ def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=Non
- `ancestral_af_dist` (:class:`.tstruct`) -- Description of the ancestral allele
frequency distribution.
- `seed` (:py:data:`.tint32`) -- Random seed.
- `mixture` (:py:data:`.tbool`) -- Value of `mixture` parameter.
Row fields:
Expand Down Expand Up @@ -2397,6 +2397,10 @@ def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=Non
Random seed.
reference_genome : :obj:`str` or :class:`.ReferenceGenome`
Reference genome to use.
mixture : :obj:`bool`
Treat `pop_dist` as the parameters of a Dirichlet distribution,
as in the Prichard-Stevens-Donnelly model. This feature is
EXPERIMENTAL and currently undocumented and untested.
Returns
-------
Expand All @@ -2420,7 +2424,8 @@ def balding_nichols_model(n_populations, n_samples, n_variants, n_partitions=Non
jvm_fst_opt,
af_dist._jrep(),
seed,
reference_genome._jrep)
reference_genome._jrep,
mixture)
return MatrixTable(jmt)


Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/is/hail/HailContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,9 @@ class HailContext private(val sc: SparkContext,
fst: Option[Array[Double]] = None,
afDist: Distribution = UniformDist(0.1, 0.9),
seed: Int = 0,
rg: ReferenceGenome = ReferenceGenome.defaultReference): MatrixTable =
BaldingNicholsModel(this, populations, samples, variants, popDist, fst, seed, nPartitions, afDist, rg)
rg: ReferenceGenome = ReferenceGenome.defaultReference,
mixture: Boolean = false): MatrixTable =
BaldingNicholsModel(this, populations, samples, variants, popDist, fst, seed, nPartitions, afDist, rg, mixture)

def genDataset(): MatrixTable = VSMSubgen.realistic.gen(this).sample()

Expand Down
59 changes: 42 additions & 17 deletions src/main/scala/is/hail/stats/BaldingNicholsModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@ import is.hail.annotations._
import is.hail.expr.types._
import is.hail.rvd.OrderedRVD
import is.hail.utils._
import is.hail.variant.{Call, Call2, ReferenceGenome, MatrixTable}
import is.hail.variant.{Call2, ReferenceGenome, MatrixTable}
import org.apache.commons.math3.random.JDKRandomGenerator

object BaldingNicholsModel {

def apply(hc: HailContext, nPops: Int, nSamples: Int, nVariants: Int,
popDistArrayOpt: Option[Array[Double]], FstOfPopArrayOpt: Option[Array[Double]],
seed: Int, nPartitionsOpt: Option[Int], af_dist: Distribution,
rg: ReferenceGenome = ReferenceGenome.defaultReference): MatrixTable = {
def apply(hc: HailContext,
nPops: Int,
nSamples: Int,
nVariants: Int,
popDistArrayOpt: Option[Array[Double]],
FstOfPopArrayOpt: Option[Array[Double]],
seed: Int,
nPartitionsOpt: Option[Int],
af_dist: Distribution,
rg: ReferenceGenome = ReferenceGenome.defaultReference,
mixture: Boolean = false): MatrixTable = {

val sc = hc.sc

Expand Down Expand Up @@ -69,17 +76,24 @@ object BaldingNicholsModel {
Rand.generator.setSeed(seed)

val popDist_k = popDist
popDist_k :/= sum(popDist_k)

val popDistRV = Multinomial(popDist_k)
val popOfSample_n: DenseVector[Int] = DenseVector.fill[Int](N)(popDistRV.draw())
val popOfSample_n = DenseMatrix.zeros[Double](if (mixture) K else 1, N)

if (mixture) {
val popDistRV = Dirichlet(popDist_k)
(0 until N).foreach(j => popOfSample_n(::, j) := popDistRV.draw())
} else {
popDist_k :/= sum(popDist_k)
val popDistRV = Multinomial(popDist_k)
(0 until N).foreach(j => popOfSample_n(0, j) = popDistRV.draw())
}

val popOfSample_nBc = sc.broadcast(popOfSample_n)

val Fst_k = FstOfPop
val Fst1_k = (1d - Fst_k) /:/ Fst_k
val Fst1_kBc = sc.broadcast(Fst1_k)

val saSignature = TStruct("sample_idx" -> TInt32(), "pop" -> TInt32())
val saSignature = TStruct("sample_idx" -> TInt32(), "pop" -> (if (mixture) TArray(TFloat64()) else TInt32()))
val vaSignature = TStruct("ancestralAF" -> TFloat64(), "AF" -> TArray(TFloat64()))

val ancestralAFAnnotation = af_dist match {
Expand All @@ -88,7 +102,7 @@ object BaldingNicholsModel {
case TruncatedBetaDist(a, b, min, max) => Annotation("TruncatedBetaDist", a, b, min, max)
}
val globalAnnotation =
Annotation(K, N, M, popDistArray: IndexedSeq[Double], FstOfPopArray: IndexedSeq[Double], ancestralAFAnnotation, seed)
Annotation(K, N, M, popDistArray: IndexedSeq[Double], FstOfPopArray: IndexedSeq[Double], ancestralAFAnnotation, seed, mixture)

val ancestralAFAnnotationSignature = af_dist match {
case UniformDist(min, max) => TStruct("type" -> TString(), "min" -> TFloat64(), "max" -> TFloat64())
Expand All @@ -103,7 +117,8 @@ object BaldingNicholsModel {
"pop_dist" -> TArray(TFloat64()),
"fst" -> TArray(TFloat64()),
"ancestral_af_dist" -> ancestralAFAnnotationSignature,
"seed" -> TInt32())
"seed" -> TInt32(),
"mixture" -> TBoolean())

val matrixType: MatrixType = MatrixType.fromParts(
globalType = globalSignature,
Expand All @@ -130,9 +145,11 @@ object BaldingNicholsModel {

val ancestralAF = af_dist.getBreezeDist(perVariantRandomBasis).draw()

val popAF_k: IndexedSeq[Double] = Array.tabulate(K) { k =>
new Beta(ancestralAF * Fst1_kBc.value(k), (1 - ancestralAF) * Fst1_kBc.value(k))(perVariantRandomBasis).draw()
}
val popAF_k: DenseVector[Double] = DenseVector(
Array.tabulate(K) { k =>
new Beta(ancestralAF * Fst1_kBc.value(k), (1 - ancestralAF) * Fst1_kBc.value(k))(perVariantRandomBasis)
.draw()
})

region.clear()
rvb.start(rvType)
Expand Down Expand Up @@ -165,7 +182,11 @@ object BaldingNicholsModel {
i = 0
val unif = new Uniform(0, 1)(perVariantRandomBasis)
while (i < N) {
val p = popAF_k(popOfSample_nBc.value(i))
val p =
if (mixture)
popOfSample_nBc.value(::, i) dot popAF_k
else
popAF_k(popOfSample_nBc.value(0, i).toInt)
val pSq = p * p
val x = unif.draw()
val c =
Expand All @@ -188,7 +209,11 @@ object BaldingNicholsModel {
}
}

val sampleAnnotations = (0 until N).map { i => Annotation(i, popOfSample_n(i)) }.toArray
val sampleAnnotations: Array[Annotation] =
if (mixture)
Array.tabulate(N)(i => Annotation(i, popOfSample_n(::, i).data.toIndexedSeq))
else
Array.tabulate(N)(i => Annotation(i, popOfSample_n(0, i).toInt))

// FIXME: should use fast keys
val ordrdd = OrderedRVD(matrixType.orvdType, rdd, None, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package is.hail.stats

import breeze.stats._
import is.hail.SparkSuite
import is.hail.variant.{Call, Locus, Variant}
import is.hail.variant.{Call, Variant}
import is.hail.testUtils._
import org.apache.spark.sql.Row
import org.testng.Assert.assertEquals
Expand Down

0 comments on commit e6ff453

Please sign in to comment.