Skip to content

Commit

Permalink
[query] Expand ReadValue beyond hail EType deserialization (#12948)
Browse files Browse the repository at this point in the history
This represents a small redesign of ValueWriter as well.

Add deserializers, subclasses of ValueReader and have ReadValue use
them. A ValueReader deserializes a single hail value from an input
stream.

This change also alters the semantics of ValueWriters. Both readers and
writers no longer manage their own I/O resources. It is the
responsibility of 'callers' to do so instead. However programmers must
be careful as we need to create input/output buffers to perform native
serialization/deserialization. For InputBuffers in particular, the
underlying stream may be left in an unusable state.
  • Loading branch information
chrisvittal authored May 16, 2023
1 parent 0aaf389 commit 532e688
Show file tree
Hide file tree
Showing 18 changed files with 117 additions and 78 deletions.
7 changes: 5 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,14 @@ class BlockMatrixNativeReader(

val vType = TNDArray(fullType.elementType, Nat(2))
val spec = TypedCodecSpec(EBlockMatrixNDArray(EFloat64(required = true), required = true), vType, BlockMatrix.bufferSpec)
val reader = ETypeValueReader(spec)

def blockIR(ctx: IR): IR = {
val path = Apply("concat", FastSeq(),
FastSeq(Str(s"${ params.path }/parts/"), ctx),
TString, ErrorIDs.NO_ERROR)

ReadValue(path, spec, vType)
ReadValue(path, reader, vType)
}

BlockMatrixStage2(
Expand Down Expand Up @@ -207,9 +208,11 @@ case class BlockMatrixBinaryReader(path: String, shape: IndexedSeq[Long], blockS
}

override def lower(ctx: ExecuteContext, evalCtx: IRBuilder): BlockMatrixStage2 = {
// FIXME numpy should be it's own value reader
val readFromNumpyEType = ENumpyBinaryNDArray(nRows, nCols, true)
val readFromNumpySpec = TypedCodecSpec(readFromNumpyEType, TNDArray(TFloat64, Nat(2)), new StreamBufferSpec())
val nd = evalCtx.memoize(ReadValue(Str(path), readFromNumpySpec, TNDArray(TFloat64, nDimsBase = Nat(2))))
val reader = ETypeValueReader(readFromNumpySpec)
val nd = evalCtx.memoize(ReadValue(Str(path), reader, TNDArray(TFloat64, nDimsBase = Nat(2))))

val typ = fullType
val contexts = BMSContexts.tabulate(typ, evalCtx) { (blockRow, blockCol) =>
Expand Down
5 changes: 3 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class BlockMatrixNativeWriter(
override def lower(ctx: ExecuteContext, s: BlockMatrixStage2, evalCtx: IRBuilder, eltR: TypeWithRequiredness): IR = {
val etype = EBlockMatrixNDArray(EType.fromTypeAndAnalysis(s.typ.elementType, eltR), encodeRowMajor = forceRowMajor, required = true)
val spec = TypedCodecSpec(etype, TNDArray(s.typ.elementType, Nat(2)), BlockMatrix.bufferSpec)
val writer = ETypeFileValueWriter(spec)
val writer = ETypeValueWriter(spec)

val paths = s.collectBlocks(evalCtx, "block_matrix_native_writer") { (_, idx, block) =>
val suffix = strConcat("parts/part-", idx, UUID4())
Expand Down Expand Up @@ -124,9 +124,10 @@ case class BlockMatrixBinaryWriter(path: String) extends BlockMatrixWriter {
override def lower(ctx: ExecuteContext, s: BlockMatrixStage2, evalCtx: IRBuilder, eltR: TypeWithRequiredness): IR = {
val nd = s.collectLocal(evalCtx, "block_matrix_binary_writer")

// FIXME remove numpy encoder
val etype = ENumpyBinaryNDArray(s.typ.nRows, s.typ.nCols, true)
val spec = TypedCodecSpec(etype, TNDArray(s.typ.elementType, Nat(2)), new StreamBufferSpec())
val writer = ETypeFileValueWriter(spec)
val writer = ETypeValueWriter(spec)
WriteValue(nd, Str(path), writer)
}
}
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Copy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ object Copy {
case WriteMetadata(ctx, writer) =>
assert(newChildren.length == 1)
WriteMetadata(newChildren(0).asInstanceOf[IR], writer)
case ReadValue(path, spec, requestedType) =>
case ReadValue(path, writer, requestedType) =>
assert(newChildren.length == 1)
ReadValue(newChildren(0).asInstanceOf[IR], spec, requestedType)
ReadValue(newChildren(0).asInstanceOf[IR], writer, requestedType)
case WriteValue(_, _, writer, _) =>
assert(newChildren.length == 2 || newChildren.length == 3)
val value = newChildren(0).asInstanceOf[IR]
Expand Down
14 changes: 7 additions & 7 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2280,21 +2280,21 @@ class Emit[C](
case CastToArray(a) =>
emitI(a).map(cb) { ind => ind.asIndexable.castToArray(cb) }

case ReadValue(path, spec, requestedType) =>
case ReadValue(path, reader, requestedType) =>
emitI(path).map(cb) { pv =>
val ib = cb.memoize[InputBuffer](spec.buildCodeInputBuffer(
mb.openUnbuffered(pv.asString.loadString(cb), checkCodec = true)))
val decoded = spec.encodedType.buildDecoder(requestedType, mb.ecb)(cb, region, ib)
cb += ib.close()
val is = cb.memoize(mb.openUnbuffered(pv.asString.loadString(cb), checkCodec = true))
val decoded = reader.readValue(cb, requestedType, region, is)
cb += is.invoke[Unit]("close")
decoded
}

case WriteValue(value, path, writer, stagingFile) =>
emitI(path).flatMap(cb) { case pv: SStringValue =>
emitI(value).map(cb) { v =>
val s = stagingFile.map(emitI(_).get(cb).asString)
val p = EmitCode.present(mb, s.getOrElse(pv))
writer.writeValue(cb, v, p)
val os = cb.memoize(mb.createUnbuffered(s.getOrElse(pv).loadString(cb)))
writer.writeValue(cb, v, os)
cb += os.invoke[Unit]("close")
s.foreach { stage =>
cb += mb.getFS.invoke[String, String, Boolean, Unit]("copy", stage.loadString(cb), pv.loadString(cb), const(true))
}
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ final case class ReadPartition(context: IR, rowType: TStruct, reader: PartitionR
final case class WritePartition(value: IR, writeCtx: IR, writer: PartitionWriter) extends IR
final case class WriteMetadata(writeAnnotations: IR, writer: MetadataWriter) extends IR

final case class ReadValue(path: IR, spec: AbstractTypedCodecSpec, requestedType: Type) extends IR
final case class ReadValue(path: IR, reader: ValueReader, requestedType: Type) extends IR
final case class WriteValue(value: IR, path: IR, writer: ValueWriter, stagingFile: Option[IR] = None) extends IR

class PrimitiveIR(val self: IR) extends AnyVal {
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/InferType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ object InferType {
case WritePartition(value, writeCtx, writer) => writer.returnType
case _: WriteMetadata => TVoid
case ReadValue(_, _, typ) => typ
case WriteValue(_, _, writer, _) => writer.returnType
case _: WriteValue => TString
case LiftMeOut(child) => child.typ
}
}
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ case class MatrixBlockMatrixWriter(
val elementType = tm.entryType.fieldType(entryField)
val etype = EBlockMatrixNDArray(EType.fromTypeAndAnalysis(elementType, rm.entryType.field(entryField)), encodeRowMajor = true, required = true)
val spec = TypedCodecSpec(etype, TNDArray(tm.entryType.fieldType(entryField), Nat(2)), BlockMatrix.bufferSpec)
val writer = ETypeFileValueWriter(spec)
val writer = ETypeValueWriter(spec)

val pathsWithColMajorIndices = tableOfNDArrays.mapCollect("matrix_block_matrix_writer") { partition =>
ToArray(mapIR(partition) { singleNDArrayTuple =>
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1542,11 +1542,11 @@ object IRParser {
WriteMetadata(ctx, writer)
}
case "ReadValue" =>
import AbstractRVDSpec.formats
val spec = JsonMethods.parse(string_literal(it)).extract[AbstractTypedCodecSpec]
import ValueReader.formats
val reader = JsonMethods.parse(string_literal(it)).extract[ValueReader]
val typ = type_expr(it)
ir_value_expr(env)(it).map { path =>
ReadValue(path, spec, typ)
ReadValue(path, reader, typ)
}
case "WriteValue" =>
import ValueWriter.formats
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
single(prettyStringLiteral(JsonMethods.compact(writer.toJValue)))
case WriteMetadata(writeAnnotations, writer) =>
single(prettyStringLiteral(JsonMethods.compact(writer.toJValue), elide = elideLiterals))
case ReadValue(_, spec, reqType) =>
FastSeq(prettyStringLiteral(spec.toString), reqType.parsableString())
case ReadValue(_, reader, reqType) =>
FastSeq(prettyStringLiteral(JsonMethods.compact(reader.toJValue)), reqType.parsableString())
case WriteValue(_, _, writer, _) =>
single(prettyStringLiteral(JsonMethods.compact(writer.toJValue)))
case MakeNDArray(_, _, _, errorId) => FastSeq(errorId.toString)
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/Requiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -739,9 +739,9 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
val streamtype = tcoerce[RIterable](lookup(value))
val ctxType = lookup(writeCtx)
writer.unionTypeRequiredness(requiredness, ctxType, streamtype)
case ReadValue(path, spec, rt) =>
case ReadValue(path, reader, rt) =>
requiredness.union(lookup(path).required)
requiredness.fromPType(spec.encodedType.decodedPType(rt))
reader.unionRequiredness(rt, requiredness)
case In(_, t) => t match {
case SCodeEmitParamType(et) => requiredness.unionFrom(et.typeWithRequiredness.r)
case SingleCodeEmitParamType(required, StreamSingleCodeType(_, eltType, eltRequired)) =>
Expand Down
9 changes: 6 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -546,13 +546,16 @@ object TypeCheck {
assert(x.typ == writer.returnType)
case WriteMetadata(writeAnnotations, writer) =>
assert(writeAnnotations.typ == writer.annotationType)
case x@ReadValue(path, spec, requestedType) =>
case x@ReadValue(path, reader, requestedType) =>
assert(path.typ == TString)
assert(spec.encodedType.decodedPType(requestedType).virtualType == requestedType)
reader match {
case reader: ETypeValueReader =>
assert(reader.spec.encodedType.decodedPType(requestedType).virtualType == requestedType)
case _ => // do nothing, we can't in general typecheck an arbitrary value reader
}
case WriteValue(_, path, writer, stagingFile) =>
assert(path.typ == TString)
assert(stagingFile.forall(_.typ == TString))
assert(writer.returnType == TString || writer.returnType == TBinary || writer.returnType == TVoid)
case LiftMeOut(_) =>
case Consume(_) =>
case TableMapRows(child, newRow) =>
Expand Down
52 changes: 52 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/ValueReader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package is.hail.expr.ir

import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, TypedCodecSpec}
import is.hail.types.TypeWithRequiredness
import is.hail.types.encoded._
import is.hail.types.physical._
import is.hail.types.physical.stypes.{SCode, SType, SValue}
import is.hail.types.physical.stypes.concrete.SStackStruct
import is.hail.types.virtual._
import is.hail.utils._

import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints}

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream, OutputStream}

object ValueReader {
implicit val formats: Formats = new DefaultFormats() {
override val typeHints = ShortTypeHints(List(
classOf[ETypeValueReader],
classOf[AbstractTypedCodecSpec],
classOf[TypedCodecSpec]),
typeHintFieldName = "name"
) + BufferSpec.shortTypeHints
} +
new TStructSerializer +
new TypeSerializer +
new PTypeSerializer +
new ETypeSerializer
}

abstract class ValueReader {
def unionRequiredness(requestedType: Type, requiredness: TypeWithRequiredness): Unit

def readValue(cb: EmitCodeBuilder, t: Type, region: Value[Region], is: Value[InputStream]): SValue

def toJValue: JValue = Extraction.decompose(this)(ValueReader.formats)
}


final case class ETypeValueReader(spec: AbstractTypedCodecSpec) extends ValueReader {
def unionRequiredness(requestedType: Type, requiredness: TypeWithRequiredness): Unit =
requiredness.fromPType(spec.encodedType.decodedPType(requestedType))

def readValue(cb: EmitCodeBuilder, t: Type, region: Value[Region], is: Value[InputStream]): SValue = {
val decoder = spec.encodedType.buildDecoder(t, cb.emb.ecb)
val ib = cb.memoize(spec.buildCodeInputBuffer(is))
val ret = decoder.apply(cb, region, ib)
ret
}
}
36 changes: 5 additions & 31 deletions hail/src/main/scala/is/hail/expr/ir/ValueWriter.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package is.hail.expr.ir

import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, TypedCodecSpec}
import is.hail.types.encoded._
import is.hail.types.physical._
Expand All @@ -18,7 +16,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream, Output
object ValueWriter {
implicit val formats: Formats = new DefaultFormats() {
override val typeHints = ShortTypeHints(List(
classOf[ETypeFileValueWriter],
classOf[ETypeValueWriter],
classOf[AbstractTypedCodecSpec],
classOf[TypedCodecSpec]),
typeHintFieldName = "name"
Expand All @@ -31,40 +29,16 @@ object ValueWriter {
}

abstract class ValueWriter {
final def writeValue(cb: EmitCodeBuilder, value: SValue, _args: EmitCode*): SValue = {
val args = _args.toIndexedSeq
val argsTypes = args.map(_.st.virtualType)
assert(argsTypes == argumentTypes, s"argument mismatch, required argument types $argumentTypes, got $argsTypes")
_writeValue(cb, value, args)
}
protected def _writeValue(cb: EmitCodeBuilder, value: SValue, args: IndexedSeq[EmitCode]): SValue
def argumentTypes: IndexedSeq[Type]
// one of void, binary, or string, checked in TypeCheck
def returnType: Type
def writeValue(cb: EmitCodeBuilder, value: SValue, os: Value[OutputStream]): Unit

def toJValue: JValue = Extraction.decompose(this)(ValueWriter.formats)
}

abstract class ETypeValueWriter extends ValueWriter {
def spec: AbstractTypedCodecSpec

final def serialize(cb: EmitCodeBuilder, value: SValue, os: Value[OutputStream]) = {
final case class ETypeValueWriter(spec: AbstractTypedCodecSpec) extends ValueWriter {
def writeValue(cb: EmitCodeBuilder, value: SValue, os: Value[OutputStream]): Unit = {
val encoder = spec.encodedType.buildEncoder(value.st, cb.emb.ecb)
val ob = cb.memoize(spec.buildCodeOutputBuffer(os))
encoder.apply(cb, value, ob)
cb += ob.invoke[Unit]("close")
}
}

final case class ETypeFileValueWriter(spec: AbstractTypedCodecSpec) extends ETypeValueWriter {
protected def _writeValue(cb: EmitCodeBuilder, value: SValue, args: IndexedSeq[EmitCode]): SValue = {
val IndexedSeq(path_) = args
val path = path_.toI(cb).get(cb).asString
val os = cb.memoize(cb.emb.createUnbuffered(path.loadString(cb)))
serialize(cb, value, os) // takes ownership and closes the stream
path
cb += ob.invoke[Unit]("flush")
}

val argumentTypes: IndexedSeq[Type] = FastIndexedSeq(/*path=*/TString)
val returnType: Type = TString
}
20 changes: 11 additions & 9 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,8 @@ object LowerTableIR {
val tmpDir = ctx.createTmpPath("aggregate_intermediates/")

val codecSpec = TypedCodecSpec(PCanonicalTuple(true, aggs.aggs.map(_ => PCanonicalBinary(true)): _*), BufferSpec.wireSpec)
val writer = ETypeFileValueWriter(codecSpec)
val writer = ETypeValueWriter(codecSpec)
val reader = ETypeValueReader(codecSpec)
lcWithInitBinding.mapCollectWithGlobals("table_aggregate")({ part: IR =>
Let("global", lc.globals,
RunAgg(
Expand All @@ -671,7 +672,7 @@ object LowerTableIR {
if (useInitStates) {
initFromSerializedStates
} else {
bindIR(ReadValue(ArrayRef(partArrayRef, 0), codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
bindIR(ReadValue(ArrayRef(partArrayRef, 0), reader, reader.spec.encodedVirtualType)) { serializedTuple =>
Begin(
aggs.aggs.zipWithIndex.map { case (sig, i) =>
InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state)
Expand All @@ -680,7 +681,7 @@ object LowerTableIR {
},
forIR(StreamRange(if (useInitStates) 0 else 1, ArrayLen(partArrayRef), 1, requiresMemoryManagementPerElement = true)) { fileIdx =>

bindIR(ReadValue(ArrayRef(partArrayRef, fileIdx), codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
bindIR(ReadValue(ArrayRef(partArrayRef, fileIdx), reader, reader.spec.encodedVirtualType)) { serializedTuple =>
Begin(
aggs.aggs.zipWithIndex.map { case (sig, i) =>
CombOpValue(i, GetTupleElement(serializedTuple, i), sig)
Expand Down Expand Up @@ -1241,7 +1242,8 @@ object LowerTableIR {
val tmpDir = ctx.createTmpPath("aggregate_intermediates/")

val codecSpec = TypedCodecSpec(PCanonicalTuple(true, aggs.aggs.map(_ => PCanonicalBinary(true)): _*), BufferSpec.wireSpec)
val writer = ETypeFileValueWriter(codecSpec)
val writer = ETypeValueWriter(codecSpec)
val reader = ETypeValueReader(codecSpec)
val partitionPrefixSumFiles = lcWithInitBinding.mapCollectWithGlobals("table_scan_write_prefix_sums")({ part: IR =>
Let("global", lcWithInitBinding.globals,
RunAgg(
Expand All @@ -1260,15 +1262,15 @@ object LowerTableIR {

def combineGroup(partArrayRef: IR): IR = {
Begin(FastIndexedSeq(
bindIR(ReadValue(ArrayRef(partArrayRef, 0), codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
bindIR(ReadValue(ArrayRef(partArrayRef, 0), reader, reader.spec.encodedVirtualType)) { serializedTuple =>
Begin(
aggs.aggs.zipWithIndex.map { case (sig, i) =>
InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state)
})
},
forIR(StreamRange(1, ArrayLen(partArrayRef), 1, requiresMemoryManagementPerElement = true)) { fileIdx =>

bindIR(ReadValue(ArrayRef(partArrayRef, fileIdx), codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
bindIR(ReadValue(ArrayRef(partArrayRef, fileIdx), reader, reader.spec.encodedVirtualType)) { serializedTuple =>
Begin(
aggs.aggs.zipWithIndex.map { case (sig, i) =>
CombOpValue(i, GetTupleElement(serializedTuple, i), sig)
Expand Down Expand Up @@ -1352,13 +1354,13 @@ object LowerTableIR {
ToArray(RunAggScan(
ToStream(GetField(context, "partialSums"), requiresMemoryManagementPerElement = true),
elt.name,
bindIR(ReadValue(prev, codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
bindIR(ReadValue(prev, reader, reader.spec.encodedVirtualType)) { serializedTuple =>
Begin(
aggs.aggs.zipWithIndex.map { case (sig, i) =>
InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state)
})
},
bindIR(ReadValue(elt, codecSpec, codecSpec.encodedVirtualType)) { serializedTuple =>
bindIR(ReadValue(elt, reader, reader.spec.encodedVirtualType)) { serializedTuple =>
Begin(
aggs.aggs.zipWithIndex.map { case (sig, i) =>
CombOpValue(i, GetTupleElement(serializedTuple, i), sig)
Expand All @@ -1381,7 +1383,7 @@ object LowerTableIR {
}
}
}
(partitionPrefixSumFiles, { (file: IR) => ReadValue(file, codecSpec, codecSpec.encodedVirtualType) })
(partitionPrefixSumFiles, { (file: IR) => ReadValue(file, reader, reader.spec.encodedVirtualType) })

} else {
val partitionAggs = lcWithInitBinding.mapCollectWithGlobals("table_scan_prefix_sums_singlestage")({ part: IR =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import is.hail.types.physical.stypes.primitives.SFloat64
import is.hail.types.virtual.{TNDArray, Type}
import is.hail.utils.FastIndexedSeq

// FIXME numpy format should not be a hail native serialized format, move this to ValueReader/Writer
final case class ENumpyBinaryNDArray(nRows: Long, nCols: Long, required: Boolean) extends EType {
type DecodedPType = PCanonicalNDArray
val elementType = EFloat64(true)
Expand Down
Loading

0 comments on commit 532e688

Please sign in to comment.