Skip to content

Commit

Permalink
[query] Make caching in PartitionNativeIntervalReader more aggressive (
Browse files Browse the repository at this point in the history
…#12600)

* [query] Make caching in PartitionNativeIntervalReader more aggressive

Add finalizers to HailTaskContext to clean up open indices.

* remove log

* comment
  • Loading branch information
tpoterba authored Feb 15, 2023
1 parent d03121f commit 2fbf68b
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 19 deletions.
31 changes: 31 additions & 0 deletions hail/src/main/scala/is/hail/backend/HailTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@ package is.hail.backend
import is.hail.annotations.RegionPool
import is.hail.utils._

import java.io.Closeable

class TaskFinalizer {
val closeables = new BoxedArrayBuilder[Closeable]()

def clear(): Unit = {
closeables.clear()
}

def addCloseable(c: Closeable): Unit = {
closeables += c
}

def closeAll(): Unit = {
(0 until closeables.size).foreach { i =>
closeables(i).close()
}
}
}

abstract class HailTaskContext extends AutoCloseable {
def stageId(): Int

Expand All @@ -20,10 +40,21 @@ abstract class HailTaskContext extends AutoCloseable {
s"${ stageId() }-${ partitionId() }-${ attemptNumber() }-$fileUUID"
}

val finalizers = new BoxedArrayBuilder[TaskFinalizer]()

def newFinalizer(): TaskFinalizer = {
val f = new TaskFinalizer
finalizers += f
f
}

def close(): Unit = {
log.info(s"TaskReport: stage=${ stageId() }, partition=${ partitionId() }, attempt=${ attemptNumber() }, " +
s"peakBytes=${ thePool.getHighestTotalUsage }, peakBytesReadable=${ formatSpace(thePool.getHighestTotalUsage) }, "+
s"chunks requested=${thePool.getUsage._1}, cache hits=${thePool.getUsage._2}")
(0 until finalizers.size).foreach { i =>
finalizers(i).closeAll()
}
thePool.close()
}
}
51 changes: 35 additions & 16 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package is.hail.expr.ir
import is.hail.HailContext
import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend.{ExecuteContext, HailTaskContext}
import is.hail.backend.{ExecuteContext, HailTaskContext, TaskFinalizer}
import is.hail.backend.spark.{SparkBackend, SparkTaskContext}
import is.hail.expr.ir
import is.hail.expr.ir.functions.IntervalFunctions._
Expand Down Expand Up @@ -35,7 +35,7 @@ import org.json4s.JsonAST.JString
import org.json4s.jackson.JsonMethods
import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints}

import java.io.{DataInputStream, DataOutputStream, InputStream}
import java.io.{Closeable, DataInputStream, DataOutputStream, InputStream}
import scala.reflect.ClassTag

object TableIR {
Expand Down Expand Up @@ -774,6 +774,7 @@ case class PartitionNativeIntervalReader(tablePath: String, tableSpec: AbstractT

val currIdxInPartition = mb.genFieldThisRef[Long]("n_to_read")
val stopIdxInPartition = mb.genFieldThisRef[Long]("n_to_read")
val finalizer = mb.genFieldThisRef[TaskFinalizer]("finalizer")

val startPartitionIndex = mb.genFieldThisRef[Int]("start_part")
val currPartitionIdx = mb.genFieldThisRef[Int]("curr_part")
Expand All @@ -782,6 +783,8 @@ case class PartitionNativeIntervalReader(tablePath: String, tableSpec: AbstractT

// leave the index open/initialized to allow queries to reuse the same index for the same file
val indexInitialized = mb.genFieldThisRef[Boolean]("index_init")
val indexCachedIndex = mb.genFieldThisRef[Int]("index_last_idx")
val streamFirst = mb.genFieldThisRef[Boolean]("stream_first")

val region = mb.genFieldThisRef[Region]("pnr_region")

Expand Down Expand Up @@ -811,9 +814,11 @@ case class PartitionNativeIntervalReader(tablePath: String, tableSpec: AbstractT
cb.assign(currPartitionIdx, startPartitionIndex)


cb.assign(indexInitialized, false) // basically "!first"
cb.assign(streamFirst, true)
cb.assign(currIdxInPartition, 0L)
cb.assign(stopIdxInPartition, 0L)

cb.assign(finalizer, cb.emb.ecb.getTaskContext.invoke[TaskFinalizer]("newFinalizer"))
}

override val elementRegion: Settable[Region] = region
Expand All @@ -825,19 +830,34 @@ case class PartitionNativeIntervalReader(tablePath: String, tableSpec: AbstractT
cb.ifx(currPartitionIdx >= rowsSpec.partitioner.numPartitions || currPartitionIdx > lastIncludedPartitionIdx,
cb.goto(LendOfStream))

// open the next index
val requiresIndexInit = cb.newLocal[Boolean]("requiresIndexInit")

cb.ifx(indexInitialized, {
index.close(cb)
cb += ib.close()
cb.ifx(streamFirst, {
// if first, reuse open index from previous time the stream was run if possible
// this is a common case if looking up nearby keys
cb.assign(requiresIndexInit, !(indexInitialized && (indexCachedIndex ceq currPartitionIdx)))
}, {
cb.assign(indexInitialized, true)
// if not first, then the index must be open to the previous partition and needs to be reinitialized
cb.assign(streamFirst, false)
cb.assign(requiresIndexInit, true)
})

val partPath = partitionPathsRuntime.loadElement(cb, currPartitionIdx).get(cb).asString.loadString(cb)
val idxPath = indexPathsRuntime.loadElement(cb, currPartitionIdx).get(cb).asString.loadString(cb)
index.initialize(cb, idxPath)
cb.assign(ib, spec.buildCodeInputBuffer(Code.newInstance[ByteTrackingInputStream, InputStream](cb.emb.open(partPath, false))))
cb.ifx(requiresIndexInit, {
cb.ifx(indexInitialized, {
cb += finalizer.invoke[Unit]("clear")
index.close(cb)
cb += ib.close()
}, {
cb.assign(indexInitialized, true)
})
cb.assign(indexCachedIndex, currPartitionIdx)
val partPath = partitionPathsRuntime.loadElement(cb, currPartitionIdx).get(cb).asString.loadString(cb)
val idxPath = indexPathsRuntime.loadElement(cb, currPartitionIdx).get(cb).asString.loadString(cb)
index.initialize(cb, idxPath)
cb.assign(ib, spec.buildCodeInputBuffer(Code.newInstance[ByteTrackingInputStream, InputStream](cb.emb.open(partPath, false))))
index.addToFinalizer(cb, finalizer)
cb += finalizer.invoke[Closeable, Unit]("addCloseable", ib)
})

cb.ifx(currPartitionIdx ceq lastIncludedPartitionIdx, {
cb.ifx(currPartitionIdx ceq startPartitionIndex, {
Expand Down Expand Up @@ -918,10 +938,9 @@ case class PartitionNativeIntervalReader(tablePath: String, tableSpec: AbstractT
}

override def close(cb: EmitCodeBuilder): Unit = {
cb.ifx(indexInitialized, {
index.close(cb)
cb += ib.close()
})
// no cleanup! leave the index open for the next time the stream is run.
// the task finalizer will clean up the last open index, so this node
// leaks 2 open file handles until the end of the task.
}
}
SStreamValue(producer)
Expand Down
10 changes: 8 additions & 2 deletions hail/src/main/scala/is/hail/io/index/StagedIndexReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ package is.hail.io.index

import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend.TaskFinalizer
import is.hail.expr.ir.functions.IntervalFunctions.{arrayOfStructFindIntervalRange, compareStructWithPartitionIntervalEndpoint}
import is.hail.expr.ir.{BinarySearch, EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitValue, IEmitCode}
import is.hail.io.fs.FS
import is.hail.rvd.AbstractIndexSpec
import is.hail.types.physical.stypes.concrete.{SStackInterval, SStackIntervalValue, SStackStruct, SStackStructSettable, SStackStructValue}
import is.hail.types.physical.stypes.concrete._
import is.hail.types.physical.stypes.interfaces._
import is.hail.types.physical.stypes.{SSettable, SValue}
import is.hail.types.physical.{PCanonicalArray, PCanonicalBaseStruct}
import is.hail.types.virtual.{TInt64, TTuple}
import is.hail.utils._

import java.io.InputStream
import java.io.{Closeable, InputStream}

case class VariableMetadata(
branchingFactor: Int,
Expand Down Expand Up @@ -56,6 +57,11 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], spec: AbstractIndexSpec) {

}

def addToFinalizer(cb: EmitCodeBuilder, finalizer: Value[TaskFinalizer]): Unit = {
cb += finalizer.invoke[Closeable, Unit]("addCloseable", cache)
cb += finalizer.invoke[Closeable, Unit]("addCloseable", is)
}

def close(cb: EmitCodeBuilder): Unit = {
cb += is.invoke[Unit]("close")
cb += cache.invoke[Unit]("free")
Expand Down
5 changes: 4 additions & 1 deletion hail/src/main/scala/is/hail/utils/Cache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package is.hail.utils

import is.hail.annotations.{Region, RegionMemory}

import java.io.Closeable
import java.util
import java.util.Map.Entry

Expand All @@ -17,7 +18,7 @@ class Cache[K, V](capacity: Int) {
def size: Int = synchronized { m.size() }
}

class LongToRegionValueCache(capacity: Int) {
class LongToRegionValueCache(capacity: Int) extends Closeable {
private[this] val m = new util.LinkedHashMap[Long, (RegionMemory, Long)](capacity, 0.75f, true) {
override def removeEldestEntry(eldest: Entry[Long, (RegionMemory, Long)]): Boolean = {
val b = (size() > capacity)
Expand Down Expand Up @@ -50,4 +51,6 @@ class LongToRegionValueCache(capacity: Int) {
m.forEach((k, v) => v._1.release())
m.clear()
}

def close(): Unit = free()
}

0 comments on commit 2fbf68b

Please sign in to comment.