Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-569] CPU overhead on fine grain / concurrent off-heap acquire op…
Browse files Browse the repository at this point in the history
…erations (#590)

* [NSE-569] CPU overhead on fine grain / concurrent off-heap acquire operations

* fixup

* fixu
  • Loading branch information
zhztheplayer authored Dec 1, 2021
1 parent b6487d7 commit f2fe8eb
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,62 @@
import org.apache.arrow.memory.AllocationListener;

public class SparkManagedAllocationListener implements AllocationListener {
public static long BLOCK_SIZE = 8L * 1024 * 1024; // 8MB per block

private final NativeSQLMemoryConsumer consumer;
private final NativeSQLMemoryMetrics metrics;

private long bytesReserved = 0L;
private long blocksReserved = 0L;

public SparkManagedAllocationListener(NativeSQLMemoryConsumer consumer, NativeSQLMemoryMetrics metrics) {
this.consumer = consumer;
this.metrics = metrics;
}

@Override
public void onPreAllocation(long size) {
consumer.acquire(size);
metrics.inc(size);
long requiredBlocks = updateReservation(size);
if (requiredBlocks < 0) {
throw new IllegalStateException();
}
if (requiredBlocks == 0) {
return;
}
long toBeAcquired = requiredBlocks * BLOCK_SIZE;
consumer.acquire(toBeAcquired);
metrics.inc(toBeAcquired);
}

@Override
public void onRelease(long size) {
consumer.free(size);
metrics.inc(-size);
long requiredBlocks = updateReservation(-size);
if (requiredBlocks > 0) {
throw new IllegalStateException();
}
if (requiredBlocks == 0) {
return;
}
long toBeReleased = -requiredBlocks * BLOCK_SIZE;
consumer.free(toBeReleased);
metrics.inc(-toBeReleased);
}

public long updateReservation(long bytesToAdd) {
synchronized (this) {
long newBytesReserved = bytesReserved + bytesToAdd;
final long newBlocksReserved;
// ceiling
if (newBytesReserved == 0L) {
// 0 is the special case in ceiling algorithm
newBlocksReserved = 0L;
} else {
newBlocksReserved = (newBytesReserved - 1L) / BLOCK_SIZE + 1L;
}
long requiredBlocks = newBlocksReserved - blocksReserved;
bytesReserved = newBytesReserved;
blocksReserved = newBlocksReserved;
return requiredBlocks;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ object SparkMemoryUtils extends Logging {
sparkManagedAllocationListener
}

val allocListenerForBufferImport: AllocationListener = if (isArrowAutoReleaseEnabled) {
MemoryChunkCleaner.gcTrigger()
} else {
AllocationListener.NOOP
}

private def collectStackForDebug = {
if (DEBUG) {
val out = new ByteOutputStream()
Expand Down Expand Up @@ -99,6 +105,10 @@ object SparkMemoryUtils extends Logging {
alloc
}

val taskDefaultAllocatorForBufferImport: BufferAllocator = taskDefaultAllocator
.newChildAllocator("CHILD-ALLOC-BUFFER-IMPORT", allocListenerForBufferImport, 0L,
Long.MaxValue)

val defaultMemoryPool: NativeMemoryPoolWrapper = {
val rl = new SparkManagedReservationListener(
new NativeSQLMemoryConsumer(getTaskMemoryManager(), Spiller.NO_OP),
Expand Down Expand Up @@ -283,6 +293,13 @@ object SparkMemoryUtils extends Logging {
getTaskMemoryResources().taskDefaultAllocator
}

def contextAllocatorForBufferImport(): BufferAllocator = {
if (!inSparkTask()) {
return globalAllocator()
}
getTaskMemoryResources().taskDefaultAllocatorForBufferImport
}

def contextMemoryPool(): NativeMemoryPool = {
if (!inSparkTask()) {
return globalMemoryPool()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils;

/** Parquet Reader Class. */
public class ParquetReader implements AutoCloseable {
Expand All @@ -41,7 +42,6 @@ public class ParquetReader implements AutoCloseable {
/** last readed length of a record batch. */
private long lastReadLength;

private BufferAllocator allocator;
private ParquetReaderJniWrapper jniWrapper;

/**
Expand All @@ -51,13 +51,11 @@ public class ParquetReader implements AutoCloseable {
* @param rowGroupIndices An array to indicate which rowGroup to read.
* @param columnIndices An array to indicate which columns to read.
* @param batchSize number of rows expected to be read in one batch.
* @param allocator A BufferAllocator reference.
* @throws IOException throws io exception in case of native failure.
*/
public ParquetReader(String path, int[] rowGroupIndices, int[] columnIndices,
long batchSize, BufferAllocator allocator, String tmp_dir) throws IOException {
long batchSize, String tmp_dir) throws IOException {
this.jniWrapper = new ParquetReaderJniWrapper(tmp_dir);
this.allocator = allocator;
this.nativeInstanceId = jniWrapper.nativeOpenParquetReader(path, batchSize);
jniWrapper.nativeInitParquetReader(nativeInstanceId, columnIndices, rowGroupIndices);
}
Expand All @@ -76,7 +74,6 @@ public ParquetReader(String path, int[] rowGroupIndices, int[] columnIndices,
public ParquetReader(String path, long startPos, long endPos, int[] columnIndices,
long batchSize, BufferAllocator allocator, String tmp_dir) throws IOException {
this.jniWrapper = new ParquetReaderJniWrapper(tmp_dir);
this.allocator = allocator;
this.nativeInstanceId = jniWrapper.nativeOpenParquetReader(path, batchSize);
jniWrapper.nativeInitParquetReader2(
nativeInstanceId, columnIndices, startPos, endPos);
Expand All @@ -93,7 +90,7 @@ public Schema getSchema() throws IOException {

try (MessageChannelReader schemaReader = new MessageChannelReader(
new ReadChannel(new ByteArrayReadableSeekableByteChannel(schemaBytes)),
allocator)) {
SparkMemoryUtils.contextAllocator())) {
MessageResult result = schemaReader.readNext();
if (result == null) {
throw new IOException("Unexpected end of input. Missing schema.");
Expand All @@ -115,8 +112,8 @@ public ArrowRecordBatch readNext() throws IOException {
if (serializedBatch == null) {
return null;
}
ArrowRecordBatch batch = UnsafeRecordBatchSerializer.deserializeUnsafe(allocator,
serializedBatch);
ArrowRecordBatch batch = UnsafeRecordBatchSerializer.deserializeUnsafe(
SparkMemoryUtils.contextAllocatorForBufferImport(), serializedBatch);
if (batch == null) {
throw new IllegalArgumentException("failed to build record batch");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public boolean hasNext() throws IOException {
}

public ArrowRecordBatch next() throws IOException {
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();
if (nativeHandler == 0) {
return null;
}
Expand Down Expand Up @@ -132,7 +132,7 @@ public ArrowRecordBatch process(Schema schema, ArrowRecordBatch recordBatch,
if (nativeHandler == 0) {
return null;
}
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();
byte[] serializedRecordBatch;
if (selectionVector != null) {
int selectionVectorRecordCount = selectionVector.getRecordCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public void evaluate(ColumnarNativeIterator batchItr)
public ArrowRecordBatch[] evaluate2(ArrowRecordBatch recordBatch) throws RuntimeException, IOException {
byte[] bytes = UnsafeRecordBatchSerializer.serializeUnsafe(recordBatch);
byte[][] serializedBatchArray = jniWrapper.nativeEvaluate2(nativeHandler, bytes);
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();
ArrowRecordBatch[] recordBatchList = new ArrowRecordBatch[serializedBatchArray.length];
for (int i = 0; i < serializedBatchArray.length; i++) {
if (serializedBatchArray[i] == null) {
Expand Down Expand Up @@ -191,7 +191,7 @@ public ArrowRecordBatch[] evaluate(ArrowRecordBatch recordBatch, SelectionVector
bufSizes[idx++] = bufLayout.getSize();
}

BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();

byte[][] serializedBatchArray;
if (selectionVector != null) {
Expand Down Expand Up @@ -237,7 +237,7 @@ public void SetMember(ArrowRecordBatch recordBatch) throws RuntimeException, IOE
}

public ArrowRecordBatch[] finish() throws RuntimeException, IOException {
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
BufferAllocator allocator = SparkMemoryUtils.contextAllocatorForBufferImport();
byte[][] serializedBatchArray = jniWrapper.nativeFinish(nativeHandler);
ArrowRecordBatch[] recordBatchList = new ArrowRecordBatch[serializedBatchArray.length];
for (int i = 0; i < serializedBatchArray.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private class ArrowColumnarBatchSerializerInstance(
SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)

private val allocator: BufferAllocator = SparkMemoryUtils
.contextAllocator()
.contextAllocatorForBufferImport()
.newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue)

private var reader: ArrowStreamReader = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class PartitioningSuite extends QueryTest with SharedSparkSession {
val df = spark.sql("SELECT COUNT(*) AS cnt FROM ltab, rtab WHERE ltab.id = rtab.id")
df.explain(true)
df.show()
Thread.sleep(1000000)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
"ws_item_sk = i_item_sk LIMIT 10")
df.explain(true)
df.show()
Thread.sleep(1000000)
}
}

Expand All @@ -142,7 +141,6 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
"web_sales) LIMIT 10")
df.explain()
df.show()
Thread.sleep(1000000)
}
}

Expand All @@ -165,7 +163,6 @@ class TPCDSSuite extends QueryTest with SharedSparkSession {
)
df.explain(true)
df.show()
Thread.sleep(1000000)
}
}

Expand Down

0 comments on commit f2fe8eb

Please sign in to comment.