Skip to content

Commit

Permalink
[NSE-518] Arrow buffer cleanup: Support both manual release and auto … (
Browse files Browse the repository at this point in the history
oap-project#519)

* [NSE-518] Arrow buffer cleanup: Support both manual release and auto release as a hybrid mode

* TO BE REVERTED

* fixup

* fix

* fixup

* fixup

* fixup

* fixup

* fixup

Closes oap-project#518
  • Loading branch information
zhztheplayer authored and rui-mo committed Oct 12, 2021
1 parent ecc1f68 commit c2935a2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,15 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow._
import com.sun.xml.internal.messaging.saaj.util.ByteOutputStream
import org.apache.arrow.dataset.jni.NativeMemoryPool
import org.apache.arrow.memory.AllocationListener
import org.apache.arrow.memory.AllocationOutcome
import org.apache.arrow.memory.AutoBufferLedger
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.memory.BufferLedger
import org.apache.arrow.memory.DirectAllocationListener
import org.apache.arrow.memory.ImmutableConfig
import org.apache.arrow.memory.LegacyBufferLedger
import org.apache.arrow.memory.MemoryChunkCleaner
import org.apache.arrow.memory.MemoryChunkManager
import org.apache.arrow.memory.RootAllocator

import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.TaskCompletionListener
Expand All @@ -46,36 +44,6 @@ object SparkMemoryUtils extends Logging {

private val DEBUG: Boolean = false

class AllocationListenerList(listeners: AllocationListener *)
extends AllocationListener {
override def onPreAllocation(size: Long): Unit = {
listeners.foreach(_.onPreAllocation(size))
}

override def onAllocation(size: Long): Unit = {
listeners.foreach(_.onAllocation(size))
}

override def onRelease(size: Long): Unit = {
listeners.foreach(_.onRelease(size))
}

override def onFailedAllocation(size: Long, outcome: AllocationOutcome): Boolean = {
listeners.forall(_.onFailedAllocation(size, outcome))
}

override def onChildAdded(parentAllocator: BufferAllocator,
childAllocator: BufferAllocator): Unit = {
listeners.foreach(_.onChildAdded(parentAllocator, childAllocator))

}

override def onChildRemoved(parentAllocator: BufferAllocator,
childAllocator: BufferAllocator): Unit = {
listeners.foreach(_.onChildRemoved(parentAllocator, childAllocator))
}
}

class TaskMemoryResources {
if (!inSparkTask()) {
throw new IllegalStateException("Creating TaskMemoryResources instance out of Spark task")
Expand All @@ -88,19 +56,18 @@ object SparkMemoryUtils extends Logging {
.getConfString("spark.oap.sql.columnar.autorelease", "false").toBoolean
}

val ledgerFactory: BufferLedger.Factory = if (isArrowAutoReleaseEnabled) {
AutoBufferLedger.newFactory()
val memoryChunkManagerFactory: MemoryChunkManager.Factory = if (isArrowAutoReleaseEnabled) {
MemoryChunkCleaner.newFactory(MemoryChunkCleaner.Mode.HYBRID_WITH_LOG)
} else {
LegacyBufferLedger.FACTORY
MemoryChunkManager.FACTORY
}

val sparkManagedAllocationListener = new SparkManagedAllocationListener(
new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP),
sharedMetrics)
val directAllocationListener = DirectAllocationListener.INSTANCE

val allocListener: AllocationListener = if (isArrowAutoReleaseEnabled) {
new AllocationListenerList(sparkManagedAllocationListener, directAllocationListener)
MemoryChunkCleaner.gcTrigger(sparkManagedAllocationListener)
} else {
sparkManagedAllocationListener
}
Expand All @@ -121,12 +88,13 @@ object SparkMemoryUtils extends Logging {

private val memoryPools = new util.ArrayList[NativeMemoryPoolWrapper]()

val defaultAllocator: BufferAllocator = {
val alloc = new RootAllocator(ImmutableConfig.builder()
.maxAllocation(Long.MaxValue)
.bufferLedgerFactory(ledgerFactory)
.listener(allocListener)
.build)
val taskDefaultAllocator: BufferAllocator = {
val alloc = new RootAllocator(
RootAllocator.configBuilder()
.maxAllocation(Long.MaxValue)
.memoryChunkManagerFactory(memoryChunkManagerFactory)
.listener(allocListener)
.build)
allocators.add(alloc)
alloc
}
Expand Down Expand Up @@ -154,7 +122,7 @@ object SparkMemoryUtils extends Logging {
val al = new SparkManagedAllocationListener(
new NativeSQLMemoryConsumer(getTaskMemoryManager(), spiller),
sharedMetrics)
val parent = defaultAllocator
val parent = taskDefaultAllocator
val alloc = parent.newChildAllocator("Spark Managed Allocator - " +
UUID.randomUUID().toString, al, 0, parent.getLimit).asInstanceOf[BufferAllocator]
allocators.add(alloc)
Expand Down Expand Up @@ -198,7 +166,7 @@ object SparkMemoryUtils extends Logging {
}

def release(): Unit = {
ledgerFactory match {
memoryChunkManagerFactory match {
case closeable: AutoCloseable =>
closeable.close()
case _ =>
Expand All @@ -208,7 +176,11 @@ object SparkMemoryUtils extends Logging {
if (allocated == 0L) {
close(allocator)
} else {
softClose(allocator)
if (isArrowAutoReleaseEnabled) {
close(allocator)
} else {
softClose(allocator)
}
}
}
for (pool <- memoryPools.asScala) {
Expand Down Expand Up @@ -271,15 +243,19 @@ object SparkMemoryUtils extends Logging {
}
}

private val allocator = new RootAllocator(
ImmutableConfig.builder()
.maxAllocation(Long.MaxValue)
.bufferLedgerFactory(AutoBufferLedger.newFactory())
.listener(DirectAllocationListener.INSTANCE)
private val maxAllocationSize = {
SparkEnv.get.conf.get(MEMORY_OFFHEAP_SIZE)
}

private val globalAlloc = new RootAllocator(
RootAllocator.configBuilder()
.maxAllocation(maxAllocationSize)
.memoryChunkManagerFactory(MemoryChunkCleaner.newFactory())
.listener(MemoryChunkCleaner.gcTrigger())
.build)

def globalAllocator(): BufferAllocator = {
allocator
globalAlloc
}

def globalMemoryPool(): NativeMemoryPool = {
Expand All @@ -304,7 +280,7 @@ object SparkMemoryUtils extends Logging {
if (!inSparkTask()) {
return globalAllocator()
}
getTaskMemoryResources().defaultAllocator
getTaskMemoryResources().taskDefaultAllocator
}

def contextMemoryPool(): NativeMemoryPool = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public static JniUtils getInstance(String tmp_dir) throws IOException {
try {
INSTANCE = new JniUtils(tmp_dir);
} catch (IllegalAccessException ex) {
throw new IOException("IllegalAccess");
throw new IOException("IllegalAccess", ex);
}
}
}
Expand Down

0 comments on commit c2935a2

Please sign in to comment.