Skip to content

Commit

Permalink
[vds] Fix local_to_global to support non-ascending local alleles
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoterba committed Dec 7, 2022
1 parent 1940d9e commit b51db68
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 41 deletions.
3 changes: 2 additions & 1 deletion hail/python/hail/ir/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def ndarray_floating_point_divide(arg_type, ret_type):
register_function("UnphasedDiploidGtIndexCall", (dtype("int32"),), dtype("call"))
register_function("lgt_to_gt", (dtype("call"), dtype("array<int32>"),), dtype("call"))
register_function("allele_to_genotype_reindex", (dtype("array<int32>"),), dtype("array<int32>"))
register_function("local_to_global", (dtype("array<?T>"), dtype("array<int32>"), dtype("int32"), dtype("?T")), dtype("array<?T>"))
register_function("local_to_global_g", (dtype("array<?T>"), dtype("array<int32>"), dtype("int32"), dtype("?T")), dtype("array<?T>"))
register_function("local_to_global_a_r", (dtype("array<?T>"), dtype("array<int32>"), dtype("int32"), dtype("?T"), dtype("bool")), dtype("array<?T>"))
register_function("index", (dtype("call"), dtype("int32"),), dtype("int32"))
register_function("sign", (dtype("int64"),), dtype("int64"))
register_function("sign", (dtype("float64"),), dtype("float64"))
Expand Down
14 changes: 6 additions & 8 deletions hail/python/hail/vds/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,19 @@ def local_to_global(array, local_alleles, n_alleles, fill_value, number):
-------
:class:`.ArrayExpression`
"""
try:
fill_value = hl.coercer_from_dtype(array.dtype.element_type).coerce(fill_value)
except Exception as e:
raise ValueError(f'fill_value type {fill_value.dtype} is incompatible with array type {array.dtype}') from e

if number == 'G':
local_alleles = _func("allele_to_genotype_reindex", hl.tarray(hl.tint32), local_alleles)
n_alleles = hl.triangle(n_alleles)
omit_first = False
return _func("local_to_global_g", array.dtype, array, local_alleles, n_alleles, fill_value)
elif number == 'R':
omit_first = False
elif number == 'A':
omit_first = True
else:
raise ValueError(f'unrecognized number {number}')

try:
fill_value = hl.coercer_from_dtype(array.dtype.element_type).coerce(fill_value)
except Exception as e:
raise ValueError(f'fill_value type {fill_value.dtype} is incompatible with array type {array.dtype}') from e

return _func("local_to_global", array.dtype, array, local_alleles, n_alleles, fill_value, hl.bool(omit_first))
return _func("local_to_global_a_r", array.dtype, array, local_alleles, n_alleles, fill_value, hl.bool(omit_first))
10 changes: 10 additions & 0 deletions hail/python/test/hail/vds/test_vds_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,13 @@ def test_local_to_global():
assert hl.eval(hl.vds.local_to_global(lpl, local_alleles, 4, 999, number='G')) == [1001, 1002, 1003, 999, 999, 999, 1004, 0, 999, 1005]
assert hl.eval(hl.vds.local_to_global(lad, [0,1,2], 3, 0, number='R')) == lad
assert hl.eval(hl.vds.local_to_global(lpl, [0,1,2], 3, 999, number='G')) == lpl

def test_local_to_global_alleles_non_increasing():
local_alleles = [0, 3, 1]
lad = [1, 10, 9]
lpl = [1001, 1004, 0, 1002, 1003, 1005]

assert hl.eval(hl.vds.local_to_global(lad, local_alleles, 4, 0, number='R')) == [1, 9, 0, 10]
assert hl.eval(hl.vds.local_to_global(lpl, local_alleles, 4, 999, number='G')) == [1001, 1002, 1005, 999, 999, 999, 1004, 1003, 999, 0]

assert hl.eval(hl.vds.local_to_global([0, 1, 2, 3, 4, 5], [0, 2, 1], 3, 0, number='G')) == [0, 3, 5, 1, 4, 2]
107 changes: 75 additions & 32 deletions hail/src/main/scala/is/hail/expr/ir/functions/ArrayFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,64 @@ object ArrayFunctions extends RegistryFunctions {
finish(cb)
}

registerIEmitCode5("local_to_global", TArray(TVariable("T")), TArray(TInt32), TInt32, TVariable("T"), TBoolean, TArray(TVariable("T")),
registerIEmitCode4("local_to_global_g", TArray(TVariable("T")), TArray(TInt32), TInt32, TVariable("T"), TArray(TVariable("T")),
{ case (rt, inArrayET, la, n, _) => EmitType(PCanonicalArray(PType.canonical(inArrayET.st.asInstanceOf[SContainer].elementType.storageType())).sType, inArrayET.required && la.required && n.required) })(
{ case (cb, region, rt: SIndexablePointer, err, array, localAlleles, nTotalAlleles, fillInValue) =>

IEmitCode.multiMapEmitCodes(cb, FastIndexedSeq(array, localAlleles, nTotalAlleles)) {
case IndexedSeq(array: SIndexableValue, localAlleles: SIndexableValue, _nTotalAlleles: SInt32Value) =>
def triangle(x: Value[Int]): Code[Int] = (x * (x + 1)) / 2
val nTotalAlleles =_nTotalAlleles.value
val nGenotypes = cb.memoize(triangle(nTotalAlleles))
val pt = rt.pType.asInstanceOf[PCanonicalArray]
cb.ifx(nTotalAlleles < 0, cb._fatalWithError(err, "local_to_global: n_total_alleles less than 0: ", nGenotypes.toS))
val localLen = array.loadLength()
val laLen = localAlleles.loadLength()
cb.ifx(localLen cne triangle(laLen), cb._fatalWithError(err, "local_to_global: array should be the triangle number of local alleles: found: ", localLen.toS, " elements, and", laLen.toS, " alleles"))

val fillIn = cb.memoize(fillInValue)

val (push, finish) = pt.constructFromIndicesUnsafe(cb, region, nGenotypes, false)

// fill in if necessary
cb.ifx(localLen cne nGenotypes, {
val i = cb.newLocal[Int]("i", 0)
cb.whileLoop(i < nGenotypes, {
push(cb, i, fillIn.toI(cb))
cb.assign(i, i + 1)
})
})


val i = cb.newLocal[Int]("la_i", 0)
val laGIndexer = cb.newLocal[Int]("g_indexer", 0)
cb.whileLoop(i < laLen, {
val lai = localAlleles.loadElement(cb, i).get(cb, "local_to_global: local alleles elements cannot be missing", err).asInt32.value

val j = cb.newLocal[Int]("la_j", 0)
cb.whileLoop(j <= i, {
val laj = localAlleles.loadElement(cb, j).get(cb, "local_to_global: local alleles elements cannot be missing", err).asInt32.value

val dest = cb.newLocal[Int]("dest")
cb.ifx(lai >= laj, {
cb.assign(dest, triangle(lai) + laj)
}, {
cb.assign(dest, triangle(laj) + lai)
})

push(cb, dest, array.loadElement(cb, laGIndexer))
cb.assign(laGIndexer, laGIndexer + 1)
cb.assign(j, j+1)
})

cb.assign(i, i+1)
})

finish(cb)
}
})

registerIEmitCode5("local_to_global_a_r", TArray(TVariable("T")), TArray(TInt32), TInt32, TVariable("T"), TBoolean, TArray(TVariable("T")),
{case (rt, inArrayET, la, n, _, omitFirst) => EmitType(PCanonicalArray(PType.canonical(inArrayET.st.asInstanceOf[SContainer].elementType.storageType())).sType, inArrayET.required && la.required && n.required && omitFirst.required)})(
{ case (cb, region, rt: SIndexablePointer, err, array, localAlleles, nTotalAlleles, fillInValue, omitFirstElement) =>

Expand All @@ -363,40 +420,26 @@ object ArrayFunctions extends RegistryFunctions {
cb.assign(idxAdjustmentForOmitFirst, 0))

val globalLen = cb.memoize(nTotalAlleles - idxAdjustmentForOmitFirst)
val (push, finish) = pt.constructFromFunctions(cb, region, globalLen, false)

val currIdxGlobal = cb.newLocal[Int]("idxGlobal", 0)
val currIdxLocal = cb.newLocal[Int]("idxLocal", idxAdjustmentForOmitFirst)
val nextArrayValue = cb.emb.newEmitLocal(array.st.elementEmitType)
val nextToSet = cb.newLocal[Int]("nextToSet", -1)

val LreadNextLocalIndex = CodeLabel()
val LloopStart = CodeLabel()
val Lend = CodeLabel()

cb.define(LreadNextLocalIndex)
cb.ifx(currIdxLocal < localLen, {
val nextLA = localAlleles.loadElement(cb, currIdxLocal).get(cb, "local alleles elements cannot be missing", err).asInt32.value
cb.ifx(nextLA <= nextToSet, cb._fatalWithError(err,"local_to_global: local alleles not strictly increasing: ", cb.strValue(localAlleles)))
cb.assign(nextToSet, nextLA)
cb.assign(nextArrayValue, array.loadElement(cb, currIdxLocal))
}, {
cb.assign(nextToSet, globalLen)

val (push, finish) = pt.constructFromIndicesUnsafe(cb, region, globalLen, false)

// fill in if necessary
cb.ifx(localLen cne globalLen, {
val i = cb.newLocal[Int]("i", 0)
cb.whileLoop(i < globalLen, {
push(cb, i, fillIn.toI(cb))
cb.assign(i, i + 1)
})
})
cb.assign(currIdxLocal, currIdxLocal + 1)

cb.define(LloopStart)
cb.ifx(currIdxGlobal >= globalLen, cb.goto(Lend))
cb.ifx(currIdxGlobal ceq nextToSet, {
push(cb, nextArrayValue.toI(cb))
cb.assign(currIdxGlobal, currIdxGlobal + 1)
cb.goto(LreadNextLocalIndex)

val i = cb.newLocal[Int]("la_i", 0)
cb.whileLoop(i < localLen, {
val lai = localAlleles.loadElement(cb, i + idxAdjustmentForOmitFirst).get(cb, "local_to_global: local alleles elements cannot be missing", err).asInt32.value
push(cb, cb.memoize(lai - idxAdjustmentForOmitFirst), array.loadElement(cb, i))

cb.assign(i, i + 1)
})
push(cb, fillIn.toI(cb))
cb.assign(currIdxGlobal, currIdxGlobal + 1)
cb.goto(LloopStart)

cb.define(Lend)
finish(cb)
}
})
Expand Down
22 changes: 22 additions & 0 deletions hail/src/main/scala/is/hail/types/physical/PCanonicalArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,28 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false)
(push, finish)
}

def constructFromIndicesUnsafe(cb: EmitCodeBuilder, region: Value[Region], length: Value[Int], deepCopy: Boolean):
(((EmitCodeBuilder, Value[Int], IEmitCode) => Unit, (EmitCodeBuilder => SIndexablePointerValue))) = {

val addr = cb.newLocal[Long]("pcarray_construct2_addr", allocate(region, length))
stagedInitialize(cb, addr, length, setMissing = false)
val firstElementAddress = cb.newLocal[Long]("pcarray_construct2_first_addr", firstElementOffset(addr, length))

val push: (EmitCodeBuilder, Value[Int], IEmitCode) => Unit = {
case (cb, idx, iec) =>
iec.consume(cb,
setElementMissing(cb, addr, idx),
{ sc =>
elementType.storeAtAddress(cb, firstElementAddress + idx.toL * elementByteSize, region, sc, deepCopy = deepCopy)
})
}
val finish: EmitCodeBuilder => SIndexablePointerValue = { (cb: EmitCodeBuilder) =>
new SIndexablePointerValue(sType, addr, length, firstElementAddress)
}
(push, finish)
}


def loadFromNested(addr: Code[Long]): Code[Long] = Region.loadAddress(addr)

override def unstagedLoadFromNested(addr: Long): Long = Region.loadAddress(addr)
Expand Down

0 comments on commit b51db68

Please sign in to comment.