From 4d5f1704a5cafc3b5d2cde35a307f260e8376be1 Mon Sep 17 00:00:00 2001 From: Christopher Vittal Date: Tue, 21 Mar 2023 14:07:19 -0500 Subject: [PATCH] [lowering] Fix MatrixBlockMatrixWriter lowering when some partitions are empty (#12797) Using changePartitionerNoRepartition with a partitioner with a different number of partitions cannot be correct and will result in dropped data. Just construct the new partitioner (based on row index, so it will comply with invariants) directly. --- hail/python/test/hail/linalg/test_linalg.py | 8 ++++++++ hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala | 6 +++--- .../scala/is/hail/expr/ir/lowering/LowerTableIR.scala | 4 +++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/hail/python/test/hail/linalg/test_linalg.py b/hail/python/test/hail/linalg/test_linalg.py index 6e41958981c..b1a7758b56b 100644 --- a/hail/python/test/hail/linalg/test_linalg.py +++ b/hail/python/test/hail/linalg/test_linalg.py @@ -169,6 +169,14 @@ def test_from_entry_expr_simple(self): a4 = hl.eval(BlockMatrix.read(path).to_ndarray()) self._assert_eq(a1, a4) + def test_from_entry_expr_empty_parts(self): + with hl.TemporaryDirectory(ensure_exists=False) as path: + mt = hl.balding_nichols_model(n_populations=5, n_variants=2000, n_samples=20, n_partitions=200) + mt = mt.filter_rows((mt.locus.position <= 500) | (mt.locus.position > 1500)).checkpoint(path) + bm = BlockMatrix.from_entry_expr(mt.GT.n_alt_alleles()) + nd = (bm @ bm.T).to_numpy() + assert nd.shape == (1000, 1000) + def test_from_entry_expr_options(self): def build_mt(a): data = [{'v': 0, 's': 0, 'x': a[0]}, diff --git a/hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala b/hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala index fde22c118a5..637a5ec5384 100644 --- a/hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala @@ -1490,8 +1490,8 @@ case class MatrixBlockMatrixWriter( val inputRowIntervals = inputPartStarts.zip(inputPartStops).map{ case (intervalStart, intervalEnd) => Interval(Row(intervalStart.toInt), Row(intervalEnd.toInt), true, false) } - val rowIdxPartitioner = RVDPartitioner.generate(ctx.stateManager, TStruct((perRowIdxId, TInt32)), inputRowIntervals) + val rowIdxPartitioner = new RVDPartitioner(ctx.stateManager, TStruct((perRowIdxId, TInt32)), inputRowIntervals) val keyedByRowIdx = partsZippedWithIdx.changePartitionerNoRepartition(rowIdxPartitioner) // Now create a partitioner that makes appropriately sized blocks @@ -1505,12 +1505,12 @@ case class MatrixBlockMatrixWriter( val rowsInBlockSizeGroups: TableStage = keyedByRowIdx.repartitionNoShuffle(blockSizeGroupsPartitioner) def createBlockMakingContexts(tablePartsStreamIR: IR): IR = { - flatten(zip2(tablePartsStreamIR, rangeIR(numBlockRows), ArrayZipBehavior.AssertSameLength) { case (tableSinglePartCtx, blockColIdx) => + flatten(zip2(tablePartsStreamIR, rangeIR(numBlockRows), ArrayZipBehavior.AssertSameLength) { case (tableSinglePartCtx, blockRowIdx) => mapIR(rangeIR(I32(numBlockCols))){ blockColIdx => MakeStruct(FastIndexedSeq("oldTableCtx" -> tableSinglePartCtx, "blockStart" -> (blockColIdx * I32(blockSize)), "blockSize" -> If(blockColIdx ceq I32(numBlockCols - 1), I32(lastBlockNumCols), I32(blockSize)), "blockColIdx" -> blockColIdx, - "blockRowIdx" -> blockColIdx)) + "blockRowIdx" -> blockRowIdx)) } }) } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala index 13dae86ac00..70b395845c5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -237,8 +237,10 @@ class TableStage( def getNumPartitions(): IR = TableStage.wrapInBindings(StreamLen(contexts), letBindings) - def changePartitionerNoRepartition(newPartitioner: RVDPartitioner): TableStage = + def changePartitionerNoRepartition(newPartitioner: RVDPartitioner): TableStage = { + require(partitioner.numPartitions == newPartitioner.numPartitions) copy(partitioner = newPartitioner) + } def strictify(allowedOverlap: Int = kType.size - 1): TableStage = { val newPart = partitioner.strictify(allowedOverlap)