Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Stage (matrix)table writes locally #12773

Merged
merged 9 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,8 +1287,6 @@ def test_hw_func_and_agg_agree(self):
rt_one_sided = mt_one_sided.rows()
self.assertTrue(rt_one_sided.all(rt_one_sided.hw == rt_one_sided.hw2))

@fails_service_backend()
@fails_local_backend()
def test_write_stage_locally(self):
mt = self.get_mt()
f = new_temp_file(extension='mt')
Expand Down
2 changes: 0 additions & 2 deletions hail/python/test/hail/table/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,8 +914,6 @@ def test_export_delim(self):
with hl.hadoop_open(tmp_file, 'r') as f_in:
assert f_in.read() == 'idx,foo\n0,3\n'

@fails_service_backend()
@fails_local_backend()
def test_write_stage_locally(self):
t = hl.utils.range_table(5)
f = new_temp_file(extension='ht')
Expand Down
200 changes: 112 additions & 88 deletions hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package is.hail.expr.ir

import scala.language.existentials
import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
Expand All @@ -17,24 +16,25 @@ import is.hail.io.plink.{BitPacker, ExportPlink}
import is.hail.io.vcf.{ExportVCF, TabixVCF}
import is.hail.linalg.BlockMatrix
import is.hail.rvd.{IndexSpec, RVDPartitioner, RVDSpecMaker}
import is.hail.types._
import is.hail.types.encoded.{EBaseStruct, EBlockMatrixNDArray, EType}
import is.hail.types.physical.stypes.{EmitType, SValue}
import is.hail.types.physical._
import is.hail.types.physical.stypes.concrete.{SJavaArrayString, SJavaArrayStringValue, SJavaString, SStackStruct}
import is.hail.types.physical.stypes.interfaces._
import is.hail.types.physical.stypes.primitives._
import is.hail.types.physical.{PBooleanRequired, PCanonicalBaseStruct, PCanonicalString, PCanonicalStruct, PInt64, PStruct, PType}
import is.hail.types.physical.stypes.{EmitType, SValue}
import is.hail.types.virtual._
import is.hail.types._
import is.hail.types.physical.stypes.concrete.{SJavaArrayString, SJavaArrayStringValue, SJavaString, SStackStruct}
import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SIndexableValue, SStringValue}
import is.hail.types.physical.stypes.primitives.{SBooleanValue, SInt64Value}
import is.hail.utils._
import is.hail.utils.richUtils.ByteTrackingOutputStream
import is.hail.variant.{Call, Locus, ReferenceGenome}
import is.hail.variant.{Call, ReferenceGenome}
import org.apache.spark.sql.Row
import org.json4s.jackson.JsonMethods
import org.json4s.{DefaultFormats, Formats, ShortTypeHints}

import java.io.{InputStream, OutputStream}
import java.nio.file.{FileSystems, Path}
import java.util.UUID
import scala.language.existentials

object MatrixWriter {
implicit val formats: Formats = new DefaultFormats() {
Expand Down Expand Up @@ -113,9 +113,14 @@ case class MatrixNativeWriter(
val globalWriter = PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/rows/parts/", None, None)
val colWriter = PartitionNativeWriter(colSpec, IndexedSeq(), s"$path/cols/rows/parts/", None, None)
val rowWriter = SplitPartitionNativeWriter(
rowSpec, s"$path/rows/rows/parts/",
entrySpec, s"$path/entries/rows/parts/",
pKey.virtualType.fieldNames, Some(s"$path/index/" -> pKey), if (stageLocally) Some(ctx.localTmpdir) else None)
rowSpec,
s"$path/rows/rows/parts/",
entrySpec,
s"$path/entries/rows/parts/",
pKey.virtualType.fieldNames,
Some(s"$path/index/" -> pKey),
if (stageLocally) Some(FileSystems.getDefault.getPath(ctx.localTmpdir, s"hail_stage_tmp_${UUID.randomUUID}")) else None
)

val globalTableWriter = TableSpecWriter(s"$path/globals", TableType(tm.globalType, FastIndexedSeq(), TStruct.empty), "rows", "globals", "../references", log = false)
val colTableWriter = TableSpecWriter(s"$path/cols", tm.colsTableType.copy(key = FastIndexedSeq[String]()), "rows", "../globals/rows", "../references", log = false)
Expand Down Expand Up @@ -185,13 +190,16 @@ case class MatrixNativeWriter(
}
}

case class SplitPartitionNativeWriter(
spec1: AbstractTypedCodecSpec, partPrefix1: String,
spec2: AbstractTypedCodecSpec, partPrefix2: String,
keyFieldNames: IndexedSeq[String],
index: Option[(String, PStruct)], localDir: Option[String]) extends PartitionWriter {
def stageLocally: Boolean = localDir.isDefined
def hasIndex: Boolean = index.isDefined
case class SplitPartitionNativeWriter(spec1: AbstractTypedCodecSpec,
partPrefix1: String,
spec2: AbstractTypedCodecSpec,
partPrefix2: String,
keyFieldNames: IndexedSeq[String],
index: Option[(String, PStruct)],
stageFolder: Option[Path]
)
extends PartitionWriter {

val filenameType = PCanonicalString(required = true)
def pContextType = PCanonicalString()

Expand All @@ -210,30 +218,58 @@ case class SplitPartitionNativeWriter(
r.union(streamType.required)
}

if (stageLocally)
throw new LowererUnsupportedOperation("stageLocally option not yet implemented")
def ifIndexed[T >: Null](obj: => T): T = if (hasIndex) obj else null

def consumeStream(
ctx: ExecuteContext,
cb: EmitCodeBuilder,
stream: StreamProducer,
context: EmitCode,
region: Value[Region]): IEmitCode = {
def consumeStream(ctx: ExecuteContext,
cb: EmitCodeBuilder,
stream: StreamProducer,
context: EmitCode,
region: Value[Region]
): IEmitCode = {
val iAnnotationType = PCanonicalStruct(required = true, "entries_offset" -> PInt64())
val mb = cb.emb

val indexWriter = ifIndexed { StagedIndexWriter.withDefaults(index.get._2, mb.ecb, annotationType = iAnnotationType,
branchingFactor = Option(mb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096)) }
val writeIndexInfo = index.map { case (name, ktype) =>
Copy link
Member Author

@ehigham ehigham Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm allergic to null as have been bitten by them too many times. Keeping things within the Option allows us to avoid things like null refs and ifIndexed

val bfactor = Option(mb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096)
(name, ktype, StagedIndexWriter.withDefaults(ktype, mb.ecb, annotationType = iAnnotationType, branchingFactor = bfactor))
}

context.toI(cb).map(cb) { pctx =>
val filename1 = mb.newLocal[String]("filename1")
val os1 = mb.newLocal[ByteTrackingOutputStream]("write_os1")
val ob1 = mb.newLocal[OutputBuffer]("write_ob1")
val filename2 = mb.newLocal[String]("filename2")
val os2 = mb.newLocal[ByteTrackingOutputStream]("write_os2")
val ob2 = mb.newLocal[OutputBuffer]("write_ob2")
val n = mb.newLocal[Long]("partition_count")
Comment on lines -230 to -236
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code duplication was crying out for parameterisation. Since I'm still using locals, there shouldn't be too much difference in the generated code


val ctxValue = pctx.asString.loadString(cb)
val (filenames, stages, buffers) =
FastIndexedSeq(partPrefix1, partPrefix2)
.map(const)
.zipWithIndex
.map { case (prefix, i) =>
val filename = mb.newLocal[String](s"filename$i")
cb.assign(filename, prefix.concat(ctxValue))

val stagingFile = stageFolder.map { folder =>
val stage = mb.newLocal[String](s"stage$i")
cb.assign(stage, const(s"$folder/$i/").concat(ctxValue))
stage
}

val ostream = mb.newLocal[ByteTrackingOutputStream](s"write_os$i")
cb.assign(ostream, Code.newInstance[ByteTrackingOutputStream, OutputStream](
mb.createUnbuffered(stagingFile.getOrElse(filename).get)
))

val buffer = mb.newLocal[OutputBuffer](s"write_ob$i")
cb.assign(buffer, spec1.buildCodeOutputBuffer(Code.checkcast[OutputStream](ostream)))

(filename, stagingFile, buffer)
}
.unzip3

writeIndexInfo.foreach { case (name, _, writer) =>
val indexFile = cb.newLocal[String]("indexFile")
cb.assign(indexFile, const(name).concat(ctxValue).concat(".idx"))
writer.init(cb, indexFile)
}

val pCount = mb.newLocal[Long]("partition_count")
cb.assign(pCount, 0L)

val distinctlyKeyed = mb.newLocal[Boolean]("distinctlyKeyed")
cb.assign(distinctlyKeyed, !keyFieldNames.isEmpty) // True until proven otherwise, if there's a key to care about all.

Expand All @@ -242,33 +278,40 @@ case class SplitPartitionNativeWriter(
val firstSeenSettable = mb.newEmitLocal("pnw_firstSeen", keyEmitType)
val lastSeenSettable = mb.newEmitLocal("pnw_lastSeen", keyEmitType)
val lastSeenRegion = cb.newLocal[Region]("last_seen_region")

// Start off missing, we will use this to determine if we haven't processed any rows yet.
cb.assign(firstSeenSettable, EmitCode.missing(cb.emb, keyEmitType.st))
cb.assign(lastSeenSettable, EmitCode.missing(cb.emb, keyEmitType.st))
cb.assign(lastSeenRegion, Region.stagedCreate(Region.TINY, region.getPool()))


def writeFile(cb: EmitCodeBuilder, codeRow: EmitCode): Unit = {
val row = codeRow.toI(cb).get(cb, "row can't be missing").asBaseStruct

if (hasIndex) {
indexWriter.add(cb, {
val indexKeyPType = index.get._2
IEmitCode.present(cb, indexKeyPType.asInstanceOf[PCanonicalBaseStruct]
.constructFromFields(cb, stream.elementRegion,
indexKeyPType.fields.map(f => EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name))),
deepCopy = false))
}, ob1.invoke[Long]("indexOffset"), {
IEmitCode.present(cb,
iAnnotationType.constructFromFields(cb, stream.elementRegion,
FastIndexedSeq(EmitCode.present(cb.emb, primitive(cb.memoize(ob2.invoke[Long]("indexOffset"))))),
deepCopy = false))
})
val specs = FastIndexedSeq(spec1, spec2)
stream.memoryManagedConsume(region, cb) { cb =>
val row = stream.element.toI(cb).get(cb, "row can't be missing").asBaseStruct

writeIndexInfo.foreach { case (_, keyType, writer) =>
writer.add(cb, {
IEmitCode.present(cb, keyType.asInstanceOf[PCanonicalBaseStruct]
.constructFromFields(cb, stream.elementRegion, keyType.fields.map { f =>
EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name))
},
deepCopy = false
)
)
},
buffers(0).invoke[Long]("indexOffset"), {
IEmitCode.present(cb,
iAnnotationType.constructFromFields(cb, stream.elementRegion,
FastIndexedSeq(EmitCode.present(cb.emb, primitive(cb.memoize(buffers(1).invoke[Long]("indexOffset"))))),
deepCopy = false
)
)
}
)
}

val key = SStackStruct.constructFromArgs(cb, stream.elementRegion, keyType, keyType.fields.map { f =>
EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name))
}:_*)
}: _*)

if (!keyFieldNames.isEmpty) {
cb.ifx(distinctlyKeyed, {
Expand All @@ -287,44 +330,25 @@ case class SplitPartitionNativeWriter(
cb.assign(lastSeenSettable, IEmitCode.present(cb, key.copyToRegion(cb, lastSeenRegion, lastSeenSettable.st)))
}

cb += ob1.writeByte(1.asInstanceOf[Byte])

spec1.encodedType.buildEncoder(row.st, cb.emb.ecb)
.apply(cb, row, ob1)

cb += ob2.writeByte(1.asInstanceOf[Byte])
buffers.zip(specs).foreach { case (buff, spec) =>
cb += buff.writeByte(1.asInstanceOf[Byte])
spec.encodedType.buildEncoder(row.st, cb.emb.ecb).apply(cb, row, buff)
}

spec2.encodedType.buildEncoder(row.st, cb.emb.ecb)
.apply(cb, row, ob2)
cb.assign(n, n + 1L)
cb.assign(pCount, pCount + 1L)
}

cb.assign(filename1, pctx.asString.loadString(cb))
if (hasIndex) {
val indexFile = cb.newLocal[String]("indexFile")
cb.assign(indexFile, const(index.get._1).concat(filename1).concat(".idx"))
indexWriter.init(cb, indexFile)
}
cb.assign(filename2, const(partPrefix2).concat(filename1))
cb.assign(filename1, const(partPrefix1).concat(filename1))
cb.assign(os1, Code.newInstance[ByteTrackingOutputStream, OutputStream](mb.createUnbuffered(filename1)))
cb.assign(os2, Code.newInstance[ByteTrackingOutputStream, OutputStream](mb.createUnbuffered(filename2)))
cb.assign(ob1, spec1.buildCodeOutputBuffer(Code.checkcast[OutputStream](os1)))
cb.assign(ob2, spec2.buildCodeOutputBuffer(Code.checkcast[OutputStream](os2)))
cb.assign(n, 0L)
writeIndexInfo.foreach(_._3.close(cb))

stream.memoryManagedConsume(region, cb) { cb =>
writeFile(cb, stream.element)
buffers.foreach { buff =>
cb += buff.writeByte(0.asInstanceOf[Byte])
cb += buff.flush()
cb += buff.close()
}

cb += ob1.writeByte(0.asInstanceOf[Byte])
cb += ob2.writeByte(0.asInstanceOf[Byte])
if (hasIndex)
indexWriter.close(cb)
cb += ob1.flush()
cb += ob2.flush()
cb += os1.invoke[Unit]("close")
cb += os2.invoke[Unit]("close")
Comment on lines -326 to -327
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All OutputBuffer implementations seem to close their underlying output stream so there's no need to maintain a reference to them here

Comment on lines +343 to -327
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a slight difference in ordering with respect to the index writer. I didn't think that mattered as the the output buffers and index writer seem to be quite separate.

stages.flatMap(_.toIterable).zip(filenames).foreach { case (source, destination) =>
cb += mb.getFS.invoke[String, String, Boolean, Unit]("copy", source, destination, const(true))
}

lastSeenSettable.loadI(cb).consume(cb, { /* do nothing */ }, { lastSeen =>
cb.assign(lastSeenSettable, IEmitCode.present(cb, lastSeen.copyToRegion(cb, region, lastSeenSettable.st)))
Expand All @@ -333,7 +357,7 @@ case class SplitPartitionNativeWriter(

SStackStruct.constructFromArgs(cb, region, returnType.asInstanceOf[TBaseStruct],
EmitCode.present(mb, pctx),
EmitCode.present(mb, new SInt64Value(n)),
EmitCode.present(mb, new SInt64Value(pCount)),
EmitCode.present(mb, new SBooleanValue(distinctlyKeyed)),
firstSeenSettable,
lastSeenSettable
Expand Down
Loading