Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] WriteValue Stage Locally #12798

Merged
merged 3 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions hail/python/test/hail/linalg/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,6 @@ def test_write_overwrite(self):
bm2.write(path, overwrite=True)
self._assert_eq(BlockMatrix.read(path), bm2)

@fails_service_backend()
@fails_local_backend()
def test_stage_locally(self):
nd = np.arange(0, 80, dtype=float).reshape(8, 10)
with hl.TemporaryDirectory(ensure_exists=False) as bm_uri:
Expand Down
11 changes: 6 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ case class BlockMatrixNativeWriter(
def loweredTyp: Type = TVoid

override def lower(ctx: ExecuteContext, s: BlockMatrixStage2, evalCtx: IRBuilder, eltR: TypeWithRequiredness): IR = {
if (stageLocally)
throw new LowererUnsupportedOperation(s"stageLocally not supported in BlockMatrixWrite lowering")
val etype = EBlockMatrixNDArray(EType.fromTypeAndAnalysis(s.typ.elementType, eltR), encodeRowMajor = forceRowMajor, required = true)
val spec = TypedCodecSpec(etype, TNDArray(s.typ.elementType, Nat(2)), BlockMatrix.bufferSpec)

val paths = s.collectBlocks(evalCtx, "block_matrix_native_writer") { (ctx, idx, block) =>
val filepath = strConcat(s"$path/parts/part-", idx, UUID4())
WriteValue(block, filepath, spec)
val paths = s.collectBlocks(evalCtx, "block_matrix_native_writer") { (_, idx, block) =>
val suffix = strConcat("parts/part-", idx, UUID4())
val filepath = strConcat(s"$path/", suffix)
WriteValue(block, filepath, spec,
if (stageLocally) Some(strConcat(s"${ctx.localTmpdir}/", suffix)) else None
)
}
RelationalWriter.scoped(path, overwrite, None)(WriteMetadata(paths, BlockMatrixNativeMetadataWriter(path, stageLocally, s.typ)))
}
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Children.scala
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ object Children {
case WritePartition(stream, ctx, _) => Array(stream, ctx)
case WriteMetadata(writeAnnotations, _) => Array(writeAnnotations)
case ReadValue(path, _, _) => Array(path)
case WriteValue(value, path, spec) => Array(value, path)
case WriteValue(value, path, _, staged) => Array(value, path) ++ staged.toArray[IR]
case LiftMeOut(child) => Array(child)
}
}
9 changes: 6 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/Copy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,12 @@ object Copy {
case ReadValue(path, spec, requestedType) =>
assert(newChildren.length == 1)
ReadValue(newChildren(0).asInstanceOf[IR], spec, requestedType)
case WriteValue(value, path, spec) =>
assert(newChildren.length == 2)
WriteValue(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], spec)
case WriteValue(_, _, spec, _) =>
assert(newChildren.length == 2 || newChildren.length == 3)
val value = newChildren(0).asInstanceOf[IR]
val path = newChildren(1).asInstanceOf[IR]
val stage = if (newChildren.length == 3) Some(newChildren(2).asInstanceOf[IR]) else None
WriteValue(value, path, spec, stage)
case LiftMeOut(_) =>
LiftMeOut(newChildren(0).asInstanceOf[IR])
}
Expand Down
15 changes: 10 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package is.hail.expr.ir

import is.hail.HailContext
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend.{BackendContext, ExecuteContext, HailTaskContext}
Expand All @@ -9,16 +8,16 @@ import is.hail.expr.ir.analyses.{ComputeMethodSplits, ControlFlowPreventsSplit,
import is.hail.expr.ir.lowering.TableStageDependency
import is.hail.expr.ir.ndarrays.EmitNDArray
import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils}
import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec}
import is.hail.io.fs.FS
import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec}
import is.hail.linalg.{BLAS, LAPACK, LinalgCodeUtils}
import is.hail.types.physical._
import is.hail.types.physical.stypes._
import is.hail.types.physical.stypes.concrete._
import is.hail.types.physical.stypes.interfaces._
import is.hail.types.physical.stypes.primitives._
import is.hail.types.virtual._
import is.hail.types.{RIterable, TypeWithRequiredness, VirtualTypeWithReq, tcoerce}
import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq, tcoerce}
import is.hail.utils._
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -2274,13 +2273,19 @@ class Emit[C](
decoded
}

case WriteValue(value, path, spec) =>
case WriteValue(value, path, spec, stagingFile) =>
emitI(path).flatMap(cb) { case pv: SStringValue =>
emitI(value).map(cb) { v =>
val ob = cb.memoize[OutputBuffer](spec.buildCodeOutputBuffer(mb.createUnbuffered(pv.asString.loadString(cb))))
val s = stagingFile.map(emitI(_).get(cb).asString)
val ob = cb.memoize[OutputBuffer](spec.buildCodeOutputBuffer(mb.createUnbuffered(
s.getOrElse(pv).loadString(cb))
))
spec.encodedType.buildEncoder(v.st, cb.emb.ecb)
.apply(cb, v, ob)
cb += ob.invoke[Unit]("close")
s.foreach { stage =>
cb += mb.getFS.invoke[String, String, Boolean, Unit]("copy", stage.loadString(cb), pv.loadString(cb), const(true))
}
pv
}
}
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ final case class WritePartition(value: IR, writeCtx: IR, writer: PartitionWriter
final case class WriteMetadata(writeAnnotations: IR, writer: MetadataWriter) extends IR

final case class ReadValue(path: IR, spec: AbstractTypedCodecSpec, requestedType: Type) extends IR
final case class WriteValue(value: IR, path: IR, spec: AbstractTypedCodecSpec) extends IR
final case class WriteValue(value: IR, path: IR, spec: AbstractTypedCodecSpec, stagingFile: Option[IR] = None) extends IR

class PrimitiveIR(val self: IR) extends AnyVal {
def +(other: IR): IR = {
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/InferType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ object InferType {
case WritePartition(value, writeCtx, writer) => writer.returnType
case _: WriteMetadata => TVoid
case ReadValue(_, _, typ) => typ
case WriteValue(value, path, spec) => TString
case _: WriteValue => TString
case LiftMeOut(child) => child.typ
}
}
Expand Down
8 changes: 4 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1537,10 +1537,10 @@ object IRParser {
case "WriteValue" =>
import AbstractRVDSpec.formats
val spec = JsonMethods.parse(string_literal(it)).extract[AbstractTypedCodecSpec]
for {
value <- ir_value_expr(env)(it)
path <- ir_value_expr(env)(it)
} yield WriteValue(value, path, spec)
ir_value_children(env)(it).map {
case Array(value, path) => WriteValue(value, path, spec)
case Array(value, path, stagingFile) => WriteValue(value, path, spec, Some(stagingFile))
}
case "LiftMeOut" => ir_value_expr(env)(it).map(LiftMeOut)
case "ReadPartition" =>
val rowType = tcoerce[TStruct](type_expr(it))
Expand Down
4 changes: 1 addition & 3 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package is.hail.expr.ir

import is.hail.HailContext
import is.hail.backend.ExecuteContext
import is.hail.expr.JSONAnnotationImpex
import is.hail.expr.ir.Pretty.prettyBooleanLiteral
Expand All @@ -12,7 +11,6 @@ import is.hail.utils.prettyPrint._
import is.hail.utils.richUtils.RichIterable
import is.hail.utils.{space => _, _}
import org.json4s.DefaultFormats
import org.json4s.JsonAST.JString
import org.json4s.jackson.{JsonMethods, Serialization}

import scala.collection.mutable
Expand Down Expand Up @@ -432,7 +430,7 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
single(prettyStringLiteral(JsonMethods.compact(writer.toJValue), elide = elideLiterals))
case ReadValue(_, spec, reqType) =>
FastSeq(prettyStringLiteral(spec.toString), reqType.parsableString())
case WriteValue(_, _, spec) => single(prettyStringLiteral(spec.toString))
case WriteValue(_, _, spec, _) => single(prettyStringLiteral(spec.toString))
case MakeNDArray(_, _, _, errorId) => FastSeq(errorId.toString)

case _ => Iterable.empty
Expand Down
3 changes: 2 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,9 @@ object TypeCheck {
case x@ReadValue(path, spec, requestedType) =>
assert(path.typ == TString)
assert(spec.encodedType.decodedPType(requestedType).virtualType == requestedType)
case x@WriteValue(value, path, spec) =>
case WriteValue(_, path, _, stagingFile) =>
assert(path.typ == TString)
assert(stagingFile.forall(_.typ == TString))
case LiftMeOut(_) =>
case Consume(_) =>
case TableMapRows(child, newRow) =>
Expand Down
1 change: 1 addition & 0 deletions hail/src/test/scala/is/hail/expr/ir/IRSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2828,6 +2828,7 @@ class IRSuite extends HailSuite {
RelationalWriter("path", overwrite = false, None)),
ReadValue(Str("foo"), TypedCodecSpec(PCanonicalStruct("foo" -> PInt32(), "bar" -> PCanonicalString()), BufferSpec.default), TStruct("foo" -> TInt32)),
WriteValue(I32(1), Str("foo"), TypedCodecSpec(PInt32(), BufferSpec.default)),
WriteValue(I32(1), Str("foo"), TypedCodecSpec(PInt32(), BufferSpec.default), Some(Str("/tmp/uid/part"))),
LiftMeOut(I32(1)),
RelationalLet("x", I32(0), I32(0)),
TailLoop("y", IndexedSeq("x" -> I32(0)), Recur("y", FastSeq(I32(4)), TInt32))
Expand Down