Skip to content

Commit

Permalink
[lowering] Fix MatrixBlockMatrixWriter lowering when some partitions …
Browse files Browse the repository at this point in the history
…are empty (#12797)

Using changePartitionerNoRepartition with a partitioner with a different
number of partitions cannot be correct and will result in dropped data.
Just construct the new partitioner (based on row index, so it will
comply with invariants) directly.
  • Loading branch information
chrisvittal authored Mar 21, 2023
1 parent 2eab792 commit 4d5f170
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
8 changes: 8 additions & 0 deletions hail/python/test/hail/linalg/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ def test_from_entry_expr_simple(self):
a4 = hl.eval(BlockMatrix.read(path).to_ndarray())
self._assert_eq(a1, a4)

def test_from_entry_expr_empty_parts(self):
with hl.TemporaryDirectory(ensure_exists=False) as path:
mt = hl.balding_nichols_model(n_populations=5, n_variants=2000, n_samples=20, n_partitions=200)
mt = mt.filter_rows((mt.locus.position <= 500) | (mt.locus.position > 1500)).checkpoint(path)
bm = BlockMatrix.from_entry_expr(mt.GT.n_alt_alleles())
nd = (bm @ bm.T).to_numpy()
assert nd.shape == (1000, 1000)

def test_from_entry_expr_options(self):
def build_mt(a):
data = [{'v': 0, 's': 0, 'x': a[0]},
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1490,8 +1490,8 @@ case class MatrixBlockMatrixWriter(
val inputRowIntervals = inputPartStarts.zip(inputPartStops).map{ case (intervalStart, intervalEnd) =>
Interval(Row(intervalStart.toInt), Row(intervalEnd.toInt), true, false)
}
val rowIdxPartitioner = RVDPartitioner.generate(ctx.stateManager, TStruct((perRowIdxId, TInt32)), inputRowIntervals)

val rowIdxPartitioner = new RVDPartitioner(ctx.stateManager, TStruct((perRowIdxId, TInt32)), inputRowIntervals)
val keyedByRowIdx = partsZippedWithIdx.changePartitionerNoRepartition(rowIdxPartitioner)

// Now create a partitioner that makes appropriately sized blocks
Expand All @@ -1505,12 +1505,12 @@ case class MatrixBlockMatrixWriter(
val rowsInBlockSizeGroups: TableStage = keyedByRowIdx.repartitionNoShuffle(blockSizeGroupsPartitioner)

def createBlockMakingContexts(tablePartsStreamIR: IR): IR = {
flatten(zip2(tablePartsStreamIR, rangeIR(numBlockRows), ArrayZipBehavior.AssertSameLength) { case (tableSinglePartCtx, blockColIdx) =>
flatten(zip2(tablePartsStreamIR, rangeIR(numBlockRows), ArrayZipBehavior.AssertSameLength) { case (tableSinglePartCtx, blockRowIdx) =>
mapIR(rangeIR(I32(numBlockCols))){ blockColIdx =>
MakeStruct(FastIndexedSeq("oldTableCtx" -> tableSinglePartCtx, "blockStart" -> (blockColIdx * I32(blockSize)),
"blockSize" -> If(blockColIdx ceq I32(numBlockCols - 1), I32(lastBlockNumCols), I32(blockSize)),
"blockColIdx" -> blockColIdx,
"blockRowIdx" -> blockColIdx))
"blockRowIdx" -> blockRowIdx))
}
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ class TableStage(

def getNumPartitions(): IR = TableStage.wrapInBindings(StreamLen(contexts), letBindings)

def changePartitionerNoRepartition(newPartitioner: RVDPartitioner): TableStage =
def changePartitionerNoRepartition(newPartitioner: RVDPartitioner): TableStage = {
require(partitioner.numPartitions == newPartitioner.numPartitions)
copy(partitioner = newPartitioner)
}

def strictify(allowedOverlap: Int = kType.size - 1): TableStage = {
val newPart = partitioner.strictify(allowedOverlap)
Expand Down

0 comments on commit 4d5f170

Please sign in to comment.