diff --git a/.travis.yml b/.travis.yml index c19abb3a..e3bafd2d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,5 +8,4 @@ jobs: install: - #empty install step script: - - cd ${TRAVIS_BUILD_DIR}/oap-shuffle/remote-shuffle/ - - mvn -q test + - mvn -q clean install diff --git a/pom.xml b/pom.xml index 79d831e1..c5ea1e29 100644 --- a/pom.xml +++ b/pom.xml @@ -259,8 +259,8 @@ - shuffle-hadoop shuffle-daos + shuffle-hadoop diff --git a/shuffle-daos/pom.xml b/shuffle-daos/pom.xml index 91588e6f..9e4f28a8 100644 --- a/shuffle-daos/pom.xml +++ b/shuffle-daos/pom.xml @@ -121,6 +121,9 @@ org.scalatest scalatest-maven-plugin + + -Xmx2048m + test diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReader.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReader.java index fde2ac56..20a48bff 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReader.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReader.java @@ -24,221 +24,185 @@ package org.apache.spark.shuffle.daos; import io.daos.obj.DaosObject; -import io.daos.obj.IODataDesc; -import io.netty.util.internal.ObjectPool; +import io.netty.buffer.ByteBuf; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.shuffle.ShuffleReadMetricsReporter; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; +import scala.Tuple3; +import java.io.IOException; +import java.util.LinkedHashMap; import java.util.Map; -import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.Lock; /** - * A class with {@link DaosObject} wrapped to read data from DAOS in either caller's thread or - * dedicated executor thread. The actual read is performed by {@link DaosObject#fetch(IODataDesc)}. + * A abstract class with {@link DaosObject} wrapped to read data from DAOS. */ -public class DaosReader { +public interface DaosReader { - private DaosObject object; - private Map bufferSourceMap = new ConcurrentHashMap<>(); - - private BoundThreadExecutors executors; - - private Map readerMap; - - private static Logger logger = LoggerFactory.getLogger(DaosReader.class); + DaosObject getObject(); /** - * construct DaosReader with object and dedicated read executors. + * release resources bound with this reader. * - * @param object - * opened DaosObject - * @param executors - * null means read in caller's thread. Submit {@link ReadTask} to dedicate executor retrieved by - * {@link #nextReaderExecutor()} otherwise. + * @param force + * force close even if there is on-going read */ - public DaosReader(DaosObject object, BoundThreadExecutors executors) { - this.object = object; - this.executors = executors; - } - - public DaosObject getObject() { - return object; - } - - public boolean hasExecutors() { - return executors != null; - } + void close(boolean force); /** - * next executor. null if there is no executors being set. + * set global readMap and hook this reader for releasing resources. * - * @return shareable executor instance. null means no executor set. + * @param readerMap + * global reader map */ - public BoundThreadExecutors.SingleThreadExecutor nextReaderExecutor() { - if (executors != null) { - return executors.nextExecutor(); - } - return null; - } + void setReaderMap(Map readerMap); /** - * release resources of all {@link org.apache.spark.shuffle.daos.DaosShuffleInputStream.BufferSource} - * bound with this reader. + * prepare read with some parameters. + * + * @param partSizeMap + * @param maxBytesInFlight + * how many bytes can be read concurrently + * @param maxReqSizeShuffleToMem + * maximum data can be put in memory + * @param metrics + * @return */ - public void close() { - // force releasing - bufferSourceMap.forEach((k, v) -> k.cleanup(true)); - bufferSourceMap.clear(); - if (readerMap != null) { - readerMap.remove(this); - readerMap = null; - } - } - - @Override - public String toString() { - return "DaosReader{" + - "object=" + object + - '}'; - } + void prepare(LinkedHashMap, Tuple3> partSizeMap, + long maxBytesInFlight, long maxReqSizeShuffleToMem, ShuffleReadMetricsReporter metrics); /** - * register buffer source for resource cleanup. + * current map/reduce id being requested. * - * @param source - * BufferSource instance + * @return map/reduce id tuple */ - public void register(DaosShuffleInputStream.BufferSource source) { - bufferSourceMap.put(source, 1); - } + Tuple2 curMapReduceId(); /** - * unregister buffer source if source is release already. + * get available buffer after iterating current buffer, next buffer in current desc and next desc. * - * @param source - * BufferSource instance + * @return buffer with data read from DAOS + * @throws IOException */ - public void unregister(DaosShuffleInputStream.BufferSource source) { - bufferSourceMap.remove(source); - } + ByteBuf nextBuf() throws IOException; /** - * set global readMap and hook this reader for releasing resources. + * All data from current map output is read and + * reach to data from next map? * - * @param readerMap - * global reader map + * @return true or false */ - public void setReaderMap(Map readerMap) { - readerMap.put(this, 0); - this.readerMap = readerMap; - } + boolean isNextMap(); /** - * Task to read from DAOS. Task itself is cached to reduce GC time. - * To reuse task for different reads, prepare and reset {@link ReadTaskContext} by calling - * {@link #newInstance(ReadTaskContext)} + * upper layer should call this method to read more map output */ - static final class ReadTask implements Runnable { - private ReadTaskContext context; - private final ObjectPool.Handle handle; + void setNextMap(boolean b); - private static final ObjectPool objectPool = ObjectPool.newPool(handle -> new ReadTask(handle)); + /** + * check if all data from current map output is read. + */ + void checkPartitionSize() throws IOException; - private static final Logger log = LoggerFactory.getLogger(ReadTask.class); + /** + * check if all map outputs are read. + * + * @throws IOException + */ + void checkTotalPartitions() throws IOException; - static ReadTask newInstance(ReadTaskContext context) { - ReadTask task = objectPool.get(); - task.context = context; - return task; + /** + * reader configurations, please check configs prefixed with SHUFFLE_DAOS_READ in {@link package$#MODULE$}. + */ + final class ReaderConfig { + private long minReadSize; + private long maxBytesInFlight; + private long maxMem; + private int readBatchSize; + private int waitDataTimeMs; + private int waitTimeoutTimes; + private boolean fromOtherThread; + + private static final Logger log = LoggerFactory.getLogger(ReaderConfig.class); + + public ReaderConfig() { + this(true); } - private ReadTask(ObjectPool.Handle handle) { - this.handle = handle; + private ReaderConfig(boolean load) { + if (load) { + initialize(); + } } - @Override - public void run() { - boolean cancelled = context.cancelled; - try { - if (!cancelled) { - context.object.fetch(context.desc); - } - } catch (Exception e) { - log.error("failed to read for " + context.desc, e); - } finally { - // release desc buffer and keep data buffer - context.desc.release(cancelled); - context.signal(); - context = null; - handle.recycle(this); + private void initialize() { + SparkConf conf = SparkEnv.get().conf(); + minReadSize = (long)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_MINIMUM_SIZE()) * 1024; + this.maxBytesInFlight = -1L; + this.maxMem = -1L; + this.readBatchSize = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_BATCH_SIZE()); + this.waitDataTimeMs = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_WAIT_DATA_MS()); + this.waitTimeoutTimes = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_WAIT_DATA_TIMEOUT_TIMES()); + this.fromOtherThread = (boolean)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_FROM_OTHER_THREAD()); + if (log.isDebugEnabled()) { + log.debug("minReadSize: " + minReadSize); + log.debug("maxBytesInFlight: " + maxBytesInFlight); + log.debug("maxMem: " + maxMem); + log.debug("readBatchSize: " + readBatchSize); + log.debug("waitDataTimeMs: " + waitDataTimeMs); + log.debug("waitTimeoutTimes: " + waitTimeoutTimes); + log.debug("fromOtherThread: " + fromOtherThread); } } - } - /** - * Context for read task. It holds all other object to read and sync between caller thread and read thread. - * It should be cached in caller thread for reusing. - */ - static final class ReadTaskContext extends LinkedTaskContext { - - /** - * constructor with all parameters. Some of them can be reused later. - * - * @param object - * DAOS object to fetch data from DAOS - * @param counter - * counter to indicate how many data ready for being consumed - * @param takeLock - * lock to work with notEmpty condition to signal caller thread there is data ready to be consumed - * @param notEmpty - * condition to signal there is some data ready - * @param desc - * desc object to describe which part of data to be fetch and hold returned data - * @param mapReduceId - * to track which map reduce ID this task fetches data for - */ - ReadTaskContext(DaosObject object, AtomicInteger counter, Lock takeLock, Condition notEmpty, - IODataDesc desc, Object mapReduceId) { - super(object, counter, takeLock, notEmpty); - this.desc = desc; - this.morePara = mapReduceId; + public ReaderConfig copy(long maxBytesInFlight, long maxMem) { + ReaderConfig rc = new ReaderConfig(false); + rc.maxMem = maxMem; + rc.minReadSize = minReadSize; + rc.readBatchSize = readBatchSize; + rc.waitDataTimeMs = waitDataTimeMs; + rc.waitTimeoutTimes = waitTimeoutTimes; + rc.fromOtherThread = fromOtherThread; + if (maxBytesInFlight < rc.minReadSize) { + rc.maxBytesInFlight = minReadSize; + } else { + rc.maxBytesInFlight = maxBytesInFlight; + } + return rc; } - @Override - public ReadTaskContext getNext() { - return (ReadTaskContext) next; + public int getReadBatchSize() { + return readBatchSize; } - public Tuple2 getMapReduceId() { - return (Tuple2) morePara; + public int getWaitDataTimeMs() { + return waitDataTimeMs; } - } - /** - * Thread factory for DAOS read tasks. - */ - protected static class ReadThreadFactory implements ThreadFactory { - private AtomicInteger id = new AtomicInteger(0); - - @Override - public Thread newThread(Runnable runnable) { - Thread t; - String name = "daos_read_" + id.getAndIncrement(); - if (runnable == null) { - t = new Thread(name); - } else { - t = new Thread(runnable, name); - } - t.setDaemon(true); - t.setUncaughtExceptionHandler((thread, throwable) -> - logger.error("exception occurred in thread " + name, throwable)); - return t; + public int getWaitTimeoutTimes() { + return waitTimeoutTimes; + } + + public long getMaxBytesInFlight() { + return maxBytesInFlight; } - } + public long getMaxMem() { + return maxMem; + } + + public long getMinReadSize() { + return minReadSize; + } + + public boolean isFromOtherThread() { + return fromOtherThread; + } + } } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderSync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderSync.java new file mode 100644 index 00000000..19974540 --- /dev/null +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosReaderSync.java @@ -0,0 +1,657 @@ +/* + * (C) Copyright 2018-2020 Intel Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * GOVERNMENT LICENSE RIGHTS-OPEN SOURCE SOFTWARE + * The Government's rights to use, modify, reproduce, release, perform, display, + * or disclose this software are subject to the terms of the Apache License as + * provided in Contract No. B609815. + * Any reproduction of computer software, computer software documentation, or + * portions thereof marked with this legend must also reproduce the markings. + */ + +package org.apache.spark.shuffle.daos; + +import io.daos.obj.DaosObject; +import io.daos.obj.IODataDesc; +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.ObjectPool; +import org.apache.spark.shuffle.ShuffleReadMetricsReporter; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManagerId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Tuple2; +import scala.Tuple3; + +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; + +/** + * A class with {@link DaosObject} wrapped to read data from DAOS in either caller's thread or + * dedicated executor thread. The actual read is performed by {@link DaosObject#fetch(IODataDesc)}. + * + * User just calls {@link #nextBuf()} and reads from buffer repeatedly until no buffer returned. + * Reader determines when and how (caller thread or from dedicated thread) based on config, to read from DAOS + * as well as controlling buffer size and task batch size. It also has some fault tolerance ability, like + * re-read from caller thread if task doesn't respond from the dedicated threads. + */ +public class DaosReaderSync extends TaskSubmitter implements DaosReader { + + private DaosObject object; + + private Map readerMap; + + private ReaderConfig config; + + protected LinkedHashMap, Tuple3> partSizeMap; + + private Iterator> mapIdIt; + + private ShuffleReadMetricsReporter metrics; + + protected long currentPartSize; + + protected Tuple2 curMapReduceId; + protected Tuple2 lastMapReduceIdForSubmit; + protected Tuple2 lastMapReduceIdForReturn; + protected int curOffset; + protected boolean nextMap; + + protected int totalParts; + protected int partsRead; + + private ReadTaskContext selfCurrentCtx; + private IODataDesc currentDesc; + private IODataDesc.Entry currentEntry; + + private boolean fromOtherThread; + + private int entryIdx; + + private static Logger logger = LoggerFactory.getLogger(DaosReader.class); + + /** + * construct DaosReader with object and dedicated read executors. + * + * @param object + * opened DaosObject + * @param config + * reader configuration + * @param executor + * single thread executor + */ + public DaosReaderSync(DaosObject object, ReaderConfig config, BoundThreadExecutors.SingleThreadExecutor executor) { + super(executor); + this.object = object; + this.config = config; + this.fromOtherThread = config.isFromOtherThread(); + if (fromOtherThread && executor == null) { + throw new IllegalArgumentException("executor should not be null if read from other thread"); + } + } + + @Override + public DaosObject getObject() { + return object; + } + + @Override + public void close(boolean force) { + boolean allReleased = true; + allReleased &= cleanupSubmitted(force); + allReleased &= cleanupConsumed(force); + if (allReleased) { + if (readerMap != null) { + readerMap.remove(this); + readerMap = null; + } + } + } + + @Override + public void setReaderMap(Map readerMap) { + readerMap.put(this, 0); + this.readerMap = readerMap; + } + + public boolean hasExecutors() { + return executor != null; + } + + /** + * invoke this method when fromOtherThread is false. + * + * @return + * @throws {@link IOException} + */ + public ByteBuf readBySelf() throws IOException { + if (lastCtx != null) { // duplicated IODataDescs which were submitted to other thread, but cancelled + ByteBuf buf = readDuplicated(false); + if (buf != null) { + return buf; + } + } + // all submitted were duplicated. Now start from mapId iterator. + IODataDesc desc = createNextDesc(config.getMaxBytesInFlight()); + return getBySelf(desc, lastMapReduceIdForSubmit); + } + + /** + * get available buffer after iterating current buffer, next buffer in current desc and next desc. + * + * @return buffer with data read from DAOS + * @throws IOException + */ + public ByteBuf nextBuf() throws IOException { + ByteBuf buf = tryCurrentEntry(); + if (buf != null) { + return buf; + } + // next entry + buf = tryCurrentDesc(); + if (buf != null) { + return buf; + } + // from next partition + if (fromOtherThread) { + // next ready queue + if (headCtx != null) { + return tryNextTaskContext(); + } + // get data by self and submit request for remaining data + return getBySelfAndSubmitMore(config.getMinReadSize()); + } + // get data by self after fromOtherThread disabled + return readBySelf(); + } + + @Override + public boolean isNextMap() { + return nextMap; + } + + @Override + public void setNextMap(boolean nextMap) { + this.nextMap = nextMap; + } + + private ByteBuf tryNextTaskContext() throws IOException { + // make sure there are still some read tasks waiting/running/returned from other threads + // or they are readDuplicated by self + if (totalSubmitted == 0 || selfCurrentCtx == lastCtx) { + return getBySelfAndSubmitMore(config.getMaxBytesInFlight()); + } + if (totalSubmitted < 0) { + throw new IllegalStateException("total submitted should be no less than 0. " + totalSubmitted); + } + try { + IODataDesc desc; + if ((desc = tryGetFromOtherThread()) != null) { + submitMore(); + return validateLastEntryAndGetBuf(desc.getEntry(entryIdx)); + } + // duplicate and get data by self + return readDuplicated(true); + } catch (InterruptedException e) { + throw new IOException("read interrupted.", e); + } + } + + /** + * we have to duplicate submitted desc since mapId was moved. + * + * @return + * @throws IOException + */ + private ByteBuf readDuplicated(boolean expectNotNullCtx) throws IOException { + ReadTaskContext context = getNextNonReturnedCtx(); + if (context == null) { + if (expectNotNullCtx) { + throw new IllegalStateException("context should not be null. totalSubmitted: " + totalSubmitted); + } + if (!fromOtherThread) { + lastCtx = null; + } + return null; + } + IODataDesc newDesc = context.getDesc().duplicate(); + ByteBuf buf = getBySelf(newDesc, context.getMapReduceId()); + selfCurrentCtx = context; + return buf; + } + + private IODataDesc tryGetFromOtherThread() throws InterruptedException, IOException { + IODataDesc desc = tryGetValidCompleted(); + if (desc != null) { + return desc; + } + // check completion + if ((!mapIdIt.hasNext()) && curMapReduceId == null && totalSubmitted == 0) { + return null; + } + // wait for specified time + desc = waitForValidFromOtherThread(); + if (desc != null) { + return desc; + } + // check wait times and cancel task + // TODO: stop reading from other threads? + cancelTasks(false); + return null; + } + + private IODataDesc waitForValidFromOtherThread() throws InterruptedException, IOException { + IODataDesc desc; + while (true) { + long start = System.nanoTime(); + boolean timeout = waitForCondition(config.getWaitDataTimeMs()); + metrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + if (timeout) { + exceedWaitTimes++; + if (logger.isDebugEnabled()) { + logger.debug("exceed wait: {}ms, times: {}", config.getWaitDataTimeMs(), exceedWaitTimes); + } + if (exceedWaitTimes >= config.getWaitTimeoutTimes()) { + return null; + } + } + // get some results after wait + desc = tryGetValidCompleted(); + if (desc != null) { + return desc; + } + } + } + + protected IODataDesc tryGetValidCompleted() throws IOException { + if (moveForward()) { + return currentDesc; + } + return null; + } + + private ByteBuf tryCurrentDesc() throws IOException { + if (currentDesc != null) { + ByteBuf buf; + while (entryIdx < currentDesc.getNbrOfEntries()) { + IODataDesc.Entry entry = currentDesc.getEntry(entryIdx); + buf = validateLastEntryAndGetBuf(entry); + if (buf.readableBytes() > 0) { + return buf; + } + entryIdx++; + } + entryIdx = 0; + // no need to release desc since all its entries are released in tryCurrentEntry and + // internal buffers are released after object.fetch + // reader.close will release all in case of failure + currentDesc = null; + } + return null; + } + + private ByteBuf tryCurrentEntry() { + if (currentEntry != null && !currentEntry.isFetchBufReleased()) { + ByteBuf buf = currentEntry.getFetchedData(); + if (buf.readableBytes() > 0) { + return buf; + } + // release buffer as soon as possible + currentEntry.releaseDataBuffer(); + entryIdx++; + } + // not null currentEntry since it will be used for size validation + return null; + } + + /** + * for first read. + * + * @param selfReadLimit + * @return + * @throws IOException + */ + private ByteBuf getBySelfAndSubmitMore(long selfReadLimit) throws IOException { + entryIdx = 0; + // fetch the next by self + IODataDesc desc = createNextDesc(selfReadLimit); + Tuple2 mapreduceId = lastMapReduceIdForSubmit; + try { + if (fromOtherThread) { + submitMore(); + } + } catch (Exception e) { + desc.release(); + if (e instanceof IOException) { + throw (IOException)e; + } + throw new IOException("failed to submit more", e); + } + // first time read from reduce task + return getBySelf(desc, mapreduceId); + } + + private void submitMore() throws IOException { + while (totalSubmitted < config.getReadBatchSize() && totalInMemSize < config.getMaxMem()) { + IODataDesc taskDesc = createNextDesc(config.getMaxBytesInFlight()); + if (taskDesc == null) { + break; + } + submit(taskDesc, lastMapReduceIdForSubmit); + } + } + + private ByteBuf getBySelf(IODataDesc desc, Tuple2 mapreduceId) throws IOException { + // get data by self, no need to release currentDesc + if (desc == null) { // reach end + return null; + } + boolean releaseBuf = false; + try { + object.fetch(desc); + currentDesc = desc; + ByteBuf buf = validateLastEntryAndGetBuf(desc.getEntry(entryIdx)); + lastMapReduceIdForReturn = mapreduceId; + return buf; + } catch (IOException | IllegalStateException e) { + releaseBuf = true; + throw e; + } finally { + desc.release(releaseBuf); + } + } + + private IODataDesc createNextDesc(long sizeLimit) throws IOException { + long remaining = sizeLimit; + int reduceId = -1; + long mapId; + IODataDesc desc = null; + while (remaining > 0) { + nextMapReduceId(); + if (curMapReduceId == null) { + break; + } + if (reduceId > 0 && curMapReduceId._2 != reduceId) { // make sure entries under same reduce + break; + } + reduceId = curMapReduceId._2; + mapId = curMapReduceId._1; + lastMapReduceIdForSubmit = curMapReduceId; + long readSize = partSizeMap.get(curMapReduceId)._1() - curOffset; + long offset = curOffset; + if (readSize > remaining) { + readSize = remaining; + curOffset += readSize; + } else { + curOffset = 0; + curMapReduceId = null; + } + if (desc == null) { + desc = object.createDataDescForFetch(String.valueOf(reduceId), IODataDesc.IodType.ARRAY, 1); + } + desc.addEntryForFetch(String.valueOf(mapId), (int)offset, (int)readSize); + remaining -= readSize; + } + return desc; + } + + private void nextMapReduceId() { + if (curMapReduceId != null) { + return; + } + curOffset = 0; + if (mapIdIt.hasNext()) { + curMapReduceId = mapIdIt.next(); + partsRead++; + } else { + curMapReduceId = null; + } + } + + private ByteBuf validateLastEntryAndGetBuf(IODataDesc.Entry entry) throws IOException { + ByteBuf buf = entry.getFetchedData(); + int byteLen = buf.readableBytes(); + nextMap = false; + if (currentEntry != null && entry != currentEntry) { + if (entry.getKey().equals(currentEntry.getKey())) { + currentPartSize += byteLen; + } else { + checkPartitionSize(); + nextMap = true; + currentPartSize = byteLen; + } + } + currentEntry = entry; + metrics.incRemoteBytesRead(byteLen); + return buf; + } + + @Override + public void checkPartitionSize() throws IOException { + if (lastMapReduceIdForReturn == null) { + return; + } + // partition size is not accurate after compress/decompress + long size = partSizeMap.get(lastMapReduceIdForReturn)._1(); + if (size < 35 * 1024 * 1024 * 1024 && currentPartSize * 1.1 < size) { + throw new IOException("expect partition size " + partSizeMap.get(lastMapReduceIdForReturn) + + ", actual size " + currentPartSize + ", mapId and reduceId: " + lastMapReduceIdForReturn); + } + metrics.incRemoteBlocksFetched(1); + } + + @Override + public void checkTotalPartitions() throws IOException { + if (partsRead != totalParts) { + throw new IOException("expect total partitions to be read: " + totalParts + ", actual read: " + partsRead); + } + } + + @Override + public void prepare(LinkedHashMap, Tuple3> partSizeMap, + long maxBytesInFlight, long maxReqSizeShuffleToMem, + ShuffleReadMetricsReporter metrics) { + this.partSizeMap = partSizeMap; + this.config = config.copy(maxBytesInFlight, maxReqSizeShuffleToMem); + this.metrics = metrics; + this.totalParts = partSizeMap.size(); + mapIdIt = partSizeMap.keySet().iterator(); + } + + @Override + public Tuple2 curMapReduceId() { + return lastMapReduceIdForSubmit; + } + + @Override + protected ReadTaskContext getNextNonReturnedCtx() { + // in case no even single return from other thread + // check selfCurrentCtx since the wait could span multiple contexts/descs + ReadTaskContext curCtx = selfCurrentCtx == null ? + getCurrentCtx() : selfCurrentCtx; + if (curCtx == null) { + return getHeadCtx(); + } + // no consumedStack push and no totalInMemSize and totalSubmitted update + // since they will be updated when the task context finally returned + return curCtx.getNext(); + } + + @Override + protected boolean consumed(LinkedTaskContext consumed) { + return !consumed.isCancelled(); + } + + @Override + protected boolean validateReturned(LinkedTaskContext context) throws IOException { + if (context.isCancelled()) { + return false; + } + selfCurrentCtx = null; // non-cancelled currentCtx overrides selfCurrentCtx + lastMapReduceIdForReturn = ((ReadTaskContext)context).getMapReduceId(); + IODataDesc desc = context.getDesc(); + if (!desc.isSucceeded()) { + String msg = "failed to get data from DAOS, desc: " + desc.toString(4096); + if (desc.getCause() != null) { + throw new IOException(msg, desc.getCause()); + } else { + throw new IllegalStateException(msg + "\nno exception got. logic error or crash?"); + } + } + currentDesc = desc; + return true; + } + + @Override + protected Runnable newTask(LinkedTaskContext context) { + return ReadTask.newInstance((ReadTaskContext) context); + } + + @Override + protected LinkedTaskContext createTaskContext(IODataDesc desc, Object morePara) { + return new ReadTaskContext(object, counter, lock, condition, desc, morePara); + } + + @Override + public ReadTaskContext getCurrentCtx() { + return (ReadTaskContext) currentCtx; + } + + @Override + public ReadTaskContext getHeadCtx() { + return (ReadTaskContext) headCtx; + } + + @Override + public ReadTaskContext getLastCtx() { + return (ReadTaskContext) lastCtx; + } + + @Override + public String toString() { + return "DaosReaderSync{" + + "object=" + object + + '}'; + } + + /** + * Task to read from DAOS. Task itself is cached to reduce GC time. + * To reuse task for different reads, prepare and reset {@link ReadTaskContext} by calling + * {@link #newInstance(ReadTaskContext)} + */ + static final class ReadTask implements Runnable { + private ReadTaskContext context; + private final ObjectPool.Handle handle; + + private static final ObjectPool objectPool = ObjectPool.newPool(handle -> new ReadTask(handle)); + + private static final Logger log = LoggerFactory.getLogger(ReadTask.class); + + static ReadTask newInstance(ReadTaskContext context) { + ReadTask task = objectPool.get(); + task.context = context; + return task; + } + + private ReadTask(ObjectPool.Handle handle) { + this.handle = handle; + } + + @Override + public void run() { + boolean cancelled = context.cancelled; + try { + if (!cancelled) { + context.object.fetch(context.desc); + } + } catch (Exception e) { + log.error("failed to read for " + context.desc, e); + } finally { + // release desc buffer and keep data buffer + context.desc.release(cancelled); + context.signal(); + context = null; + handle.recycle(this); + } + } + } + + /** + * Context for read task. It holds all other object to read and sync between caller thread and read thread. + * It should be cached in caller thread for reusing. + */ + static final class ReadTaskContext extends LinkedTaskContext { + + /** + * constructor with all parameters. Some of them can be reused later. + * + * @param object + * DAOS object to fetch data from DAOS + * @param counter + * counter to indicate how many data ready for being consumed + * @param takeLock + * lock to work with notEmpty condition to signal caller thread there is data ready to be consumed + * @param notEmpty + * condition to signal there is some data ready + * @param desc + * desc object to describe which part of data to be fetch and hold returned data + * @param mapReduceId + * to track which map reduce ID this task fetches data for + */ + ReadTaskContext(DaosObject object, AtomicInteger counter, Lock takeLock, Condition notEmpty, + IODataDesc desc, Object mapReduceId) { + super(object, counter, takeLock, notEmpty); + this.desc = desc; + this.morePara = mapReduceId; + } + + @Override + public ReadTaskContext getNext() { + return (ReadTaskContext) next; + } + + public Tuple2 getMapReduceId() { + return (Tuple2) morePara; + } + } + + /** + * Thread factory for DAOS read tasks. + */ + protected static class ReadThreadFactory implements ThreadFactory { + private AtomicInteger id = new AtomicInteger(0); + + @Override + public Thread newThread(Runnable runnable) { + Thread t; + String name = "daos_read_" + id.getAndIncrement(); + if (runnable == null) { + t = new Thread(name); + } else { + t = new Thread(runnable, name); + } + t.setDaemon(true); + t.setUncaughtExceptionHandler((thread, throwable) -> + logger.error("exception occurred in thread " + name, throwable)); + return t; + } + } +} diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleIO.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleIO.java index 26990085..dd0aea90 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleIO.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleIO.java @@ -25,10 +25,7 @@ import io.daos.obj.DaosObjClient; import io.daos.obj.DaosObject; -import io.daos.obj.DaosObjectException; -import io.daos.obj.DaosObjectId; import org.apache.spark.SparkConf; -import org.apache.spark.launcher.SparkLauncher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,6 +47,8 @@ public class DaosShuffleIO { private DaosObjClient objClient; + private Map objectMap = new ConcurrentHashMap<>(); + private SparkConf conf; private Map driverConf; @@ -58,21 +57,9 @@ public class DaosShuffleIO { private String contId; - private String ranks; - private boolean removeShuffleData; - private DaosWriter.WriteConfig writeConfig; - - private Map readerMap = new ConcurrentHashMap<>(); - - private Map writerMap = new ConcurrentHashMap<>(); - - private Map objectMap = new ConcurrentHashMap<>(); - - private BoundThreadExecutors readerExes; - - private BoundThreadExecutors writerExes; + private IOManager ioManager; private static final Logger logger = LoggerFactory.getLogger(DaosShuffleIO.class); @@ -84,61 +71,11 @@ public class DaosShuffleIO { */ public DaosShuffleIO(SparkConf conf) { this.conf = conf; - this.writeConfig = loadWriteConfig(conf); - this.readerExes = createReaderExes(); - this.writerExes = createWriterExes(); + boolean async = (boolean)conf.get(package$.MODULE$.SHUFFLE_DAOS_IO_ASYNC()); + this.ioManager = async ? new IOManagerAsync(conf, objectMap) : new IOManagerSync(conf, objectMap); this.removeShuffleData = (boolean)conf.get(package$.MODULE$.SHUFFLE_DAOS_REMOVE_SHUFFLE_DATA()); } - protected static DaosWriter.WriteConfig loadWriteConfig(SparkConf conf) { - DaosWriter.WriteConfig config = new DaosWriter.WriteConfig(); - config.warnSmallWrite((boolean)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WARN_SMALL_SIZE())); - config.bufferSize((int) ((long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE()) - * 1024 * 1024)); - config.minSize((int) ((long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_MINIMUM_SIZE()) * 1024)); - config.timeoutTimes((int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WAIT_DATA_TIMEOUT_TIMES())); - config.waitTimeMs((int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WAIT_MS())); - config.totalInMemSize((long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_MAX_BYTES_IN_FLIGHT()) * 1024); - config.totalSubmittedLimit((int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SUBMITTED_LIMIT())); - config.threads((int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_THREADS())); - config.fromOtherThreads((boolean)conf - .get(package$.MODULE$.SHUFFLE_DAOS_WRITE_IN_OTHER_THREAD())); - logger.info("write configs, " + config); - return config; - } - - private BoundThreadExecutors createWriterExes() { - if (writeConfig.isFromOtherThreads()) { - BoundThreadExecutors executors; - int threads = writeConfig.getThreads(); - if (threads == -1) { - threads = conf.getInt(SparkLauncher.EXECUTOR_CORES, 1); - } - executors = new BoundThreadExecutors("write_executors", threads, - new DaosWriter.WriteThreadFactory()); - logger.info("created BoundThreadExecutors with " + threads + " threads for write"); - return executors; - } - return null; - } - - private BoundThreadExecutors createReaderExes() { - boolean fromOtherThread = (boolean)conf - .get(package$.MODULE$.SHUFFLE_DAOS_READ_FROM_OTHER_THREAD()); - if (fromOtherThread) { - BoundThreadExecutors executors; - int threads = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_THREADS()); - if (threads == -1) { - threads = conf.getInt(SparkLauncher.EXECUTOR_CORES, 1); - } - executors = new BoundThreadExecutors("read_executors", threads, - new DaosReader.ReadThreadFactory()); - logger.info("created BoundThreadExecutors with " + threads + " threads for read"); - return executors; - } - return null; - } - /** * connect DAOS server. * @@ -149,18 +86,14 @@ public void initialize(Map driverConf) throws IOException { this.driverConf = driverConf; poolId = conf.get(package$.MODULE$.SHUFFLE_DAOS_POOL_UUID()); contId = conf.get(package$.MODULE$.SHUFFLE_DAOS_CONTAINER_UUID()); - ranks = conf.get(package$.MODULE$.SHUFFLE_DAOS_POOL_RANKS()); if (poolId == null || contId == null) { throw new IllegalArgumentException("DaosShuffleManager needs pool id and container id"); } objClient = new DaosObjClient.DaosObjClientBuilder() - .poolId(poolId).containerId(contId).ranks(ranks) + .poolId(poolId).containerId(contId) .build(); - } - - private long parseAppId(String appId) { - return Long.valueOf(appId.replaceAll("\\D", "")); + ioManager.setObjClient(objClient); } /** @@ -177,20 +110,7 @@ private long parseAppId(String appId) { */ public DaosWriter getDaosWriter(int numPartitions, int shuffleId, long mapId) throws IOException { - long appId = parseAppId(conf.getAppId()); - if (logger.isDebugEnabled()) { - logger.debug("getting daoswriter for app id: " + appId + ", shuffle id: " + shuffleId + ", map id: " + mapId + - ", numPartitions: " + numPartitions); - } - DaosWriter.WriteParam param = new DaosWriter.WriteParam(); - param.numPartitions(numPartitions) - .shuffleId(shuffleId) - .mapId(mapId) - .config(writeConfig); - DaosWriter writer = new DaosWriter(param, getObject(appId, shuffleId), - writerExes == null ? null : writerExes.nextExecutor()); - writer.setWriterMap(writerMap); - return writer; + return ioManager.getDaosWriter(numPartitions, shuffleId, mapId); } /** @@ -198,43 +118,18 @@ public DaosWriter getDaosWriter(int numPartitions, int shuffleId, long mapId) * * @param shuffleId * @return DaosReader - * @throws DaosObjectException + * @throws IOException */ - public DaosReader getDaosReader(int shuffleId) throws DaosObjectException { - long appId = parseAppId(conf.getAppId()); - if (logger.isDebugEnabled()) { - logger.debug("getting daosreader for app id: " + appId + ", shuffle id: " + shuffleId); - } - DaosReader reader = new DaosReader(getObject(appId, shuffleId), readerExes); - reader.setReaderMap(readerMap); - return reader; + public DaosReader getDaosReader(int shuffleId) throws IOException { + return ioManager.getDaosReader(shuffleId); } private String getKey(long appId, int shuffleId) { return appId + "" + shuffleId; } - private DaosObject getObject(long appId, int shuffleId) throws DaosObjectException { - String key = getKey(appId, shuffleId); - DaosObject object = objectMap.get(key); - if (object == null) { - DaosObjectId id = new DaosObjectId(appId, shuffleId); - id.encode(); - object = objClient.getObject(id); - objectMap.putIfAbsent(key, object); - DaosObject activeObject = objectMap.get(key); - if (activeObject != object) { // release just created DaosObject - object.close(); - object = activeObject; - } - } - // open just once in multiple threads - if (!object.isOpen()) { - synchronized (object) { - object.open(); - } - } - return object; + public IOManager getIoManager() { + return ioManager; } /** @@ -244,7 +139,7 @@ private DaosObject getObject(long appId, int shuffleId) throws DaosObjectExcepti * @return */ public boolean removeShuffle(int shuffleId) { - long appId = parseAppId(conf.getAppId()); + long appId = IOManager.parseAppId(conf.getAppId()); logger.info("punching daos object for app id: " + appId + ", shuffle id: " + shuffleId); try { DaosObject object = objectMap.remove(getKey(appId, shuffleId)); @@ -267,18 +162,7 @@ public boolean removeShuffle(int shuffleId) { * @throws IOException */ public void close() throws IOException { - if (readerExes != null) { - readerExes.stop(); - readerMap.keySet().forEach(r -> r.close()); - readerMap.clear(); - readerExes = null; - } - if (writerExes != null) { - writerExes.stop(); - writerMap.keySet().forEach(r -> r.close()); - writerMap.clear(); - writerExes = null; - } + ioManager.close(); objClient.forceClose(); } } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleInputStream.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleInputStream.java index 264168a3..a1b5f5c9 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleInputStream.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosShuffleInputStream.java @@ -24,10 +24,7 @@ package org.apache.spark.shuffle.daos; import io.daos.obj.DaosObject; -import io.daos.obj.IODataDesc; import io.netty.buffer.ByteBuf; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleReadMetricsReporter; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManagerId; @@ -40,7 +37,6 @@ import java.io.IOException; import java.io.InputStream; import java.util.*; -import java.util.concurrent.*; @NotThreadSafe /** @@ -58,14 +54,8 @@ public class DaosShuffleInputStream extends InputStream { private DaosObject object; - private BoundThreadExecutors.SingleThreadExecutor executor; - - private ReaderConfig config; - private ShuffleReadMetricsReporter metrics; - private boolean fromOtherThread; - private volatile boolean cleaned; private boolean completed; @@ -73,14 +63,11 @@ public class DaosShuffleInputStream extends InputStream { // ensure the order of partition // (mapid, reduceid) -> (length, BlockId, BlockManagerId) private LinkedHashMap, Tuple3> partSizeMap; - private Iterator> mapIdIt; - - private BufferSource source; private static final Logger log = LoggerFactory.getLogger(DaosShuffleInputStream.class); /** - * constructor with ordered map outputs info. Check {@link ReaderConfig} for more paras controlling + * constructor with ordered map outputs info. Check {@link DaosReader.ReaderConfig} for more paras controlling * how data being read from DAOS. * * @param reader @@ -89,7 +76,7 @@ public class DaosShuffleInputStream extends InputStream { * ordered map outputs info. They are organize as (mapid, reduceid) -> (length, BlockId, BlockManagerId) * @param maxBytesInFlight * how many bytes can be read concurrently - * @param maxMem + * @param maxReqSizeShuffleToMem * maximum data can be put in memory * @param metrics * read metrics @@ -97,51 +84,45 @@ public class DaosShuffleInputStream extends InputStream { public DaosShuffleInputStream( DaosReader reader, LinkedHashMap, Tuple3> partSizeMap, - long maxBytesInFlight, long maxMem, ShuffleReadMetricsReporter metrics) { + long maxBytesInFlight, long maxReqSizeShuffleToMem, + ShuffleReadMetricsReporter metrics) { this.partSizeMap = partSizeMap; this.reader = reader; - this.config = new ReaderConfig(maxBytesInFlight, maxMem); - this.fromOtherThread = config.fromOtherThread; - if (fromOtherThread) { - this.executor = reader.nextReaderExecutor(); - } - this.source = new BufferSource(executor); - reader.register(source); + reader.prepare(partSizeMap, maxBytesInFlight, maxReqSizeShuffleToMem, metrics); this.object = reader.getObject(); this.metrics = metrics; - this.mapIdIt = partSizeMap.keySet().iterator(); } public BlockId getCurBlockId() { - if (source.lastMapReduceIdForSubmit == null) { + if (reader.curMapReduceId() == null) { return null; } - return partSizeMap.get(source.lastMapReduceIdForSubmit)._2(); + return partSizeMap.get(reader.curMapReduceId())._2(); } public BlockManagerId getCurOriginAddress() { - if (source.lastMapReduceIdForSubmit == null) { + if (reader.curMapReduceId() == null) { return null; } - return partSizeMap.get(source.lastMapReduceIdForSubmit)._3(); + return partSizeMap.get(reader.curMapReduceId())._3(); } public long getCurMapIndex() { - if (source.lastMapReduceIdForSubmit == null) { + if (reader.curMapReduceId() == null) { return -1; } - return source.lastMapReduceIdForSubmit._1; + return reader.curMapReduceId()._1; } @Override public int read() throws IOException { while (!completed) { - ByteBuf buf = source.nextBuf(); + ByteBuf buf = reader.nextBuf(); if (buf == null) { // reach end complete(); return -1; } - if (source.newMap) { // indication to close upper layer object inputstream + if (reader.isNextMap()) { // indication to close upper layer object inputstream return -1; } if (buf.readableBytes() >= 1) { @@ -160,13 +141,13 @@ public int read(byte[] bytes) throws IOException { public int read(byte[] bytes, int offset, int length) throws IOException { int len = length; while (!completed) { - ByteBuf buf = source.nextBuf(); + ByteBuf buf = reader.nextBuf(); if (buf == null) { // reach end complete(); int r = length - len; return r == 0 ? -1 : r; } - if (source.newMap) { // indication to close upper layer object inputstream + if (reader.isNextMap()) { // indication to close upper layer object inputstream int r = length - len; return r == 0 ? -1 : r; } @@ -186,24 +167,20 @@ public int read(byte[] bytes, int offset, int length) throws IOException { * upper layer should call this method to read more map output */ public void nextMap() { - source.newMap = false; + reader.setNextMap(false); } private void complete() throws IOException { if (!completed) { - source.checkPartitionSize(); - source.checkTotalPartitions(); + reader.checkPartitionSize(); + reader.checkTotalPartitions(); completed = true; } } private void cleanup() { if (!cleaned) { - boolean allReleased = source.cleanup(false); - if (allReleased) { - reader.unregister(source); - } - source = null; + reader.close(false); cleaned = true; completed = true; } @@ -231,460 +208,4 @@ public void close(boolean force) { public boolean isCompleted() { return completed; } - - /** - * Source of map output data. User just calls {@link #nextBuf()} and reads from buffer repeatedly until no buffer - * returned. - * BufferSource does all other dirty things, like when and how (caller thread or from dedicated thread) to - * read from DAOS as well as controlling buffer size and task batch size. - * It also has some fault tolerance ability, like re-read from caller thread if task doesn't respond from the - * dedicated threads. - */ - public class BufferSource extends TaskSubmitter { - private DaosReader.ReadTaskContext selfCurrentCtx; - private IODataDesc currentDesc; - private IODataDesc.Entry currentEntry; - private long currentPartSize; - - private int entryIdx; - private Tuple2 curMapReduceId; - private Tuple2 lastMapReduceIdForSubmit; - private Tuple2 lastMapReduceIdForReturn; - private int curOffset; - private boolean newMap; - - private int totalParts = partSizeMap.size(); - private int partsRead; - - protected BufferSource(BoundThreadExecutors.SingleThreadExecutor executor) { - super(executor); - } - - /** - * invoke this method when fromOtherThread is false. - * - * @return - * @throws {@link IOException} - */ - public ByteBuf readBySelf() throws IOException { - if (lastCtx != null) { // duplicated IODataDescs which were submitted to other thread, but cancelled - ByteBuf buf = readDuplicated(false); - if (buf != null) { - return buf; - } - } - // all submitted were duplicated. Now start from mapId iterator. - IODataDesc desc = createNextDesc(config.maxBytesInFlight); - return getBySelf(desc, lastMapReduceIdForSubmit); - } - - /** - * get available buffer after iterating current buffer, next buffer in current desc and next desc. - * - * @return buffer with data read from DAOS - * @throws IOException - */ - public ByteBuf nextBuf() throws IOException { - ByteBuf buf = tryCurrentEntry(); - if (buf != null) { - return buf; - } - // next entry - buf = tryCurrentDesc(); - if (buf != null) { - return buf; - } - // from next partition - if (fromOtherThread) { - // next ready queue - if (headCtx != null) { - return tryNextTaskContext(); - } - // get data by self and submit request for remaining data - return getBySelfAndSubmitMore(config.minReadSize); - } - // get data by self after fromOtherThread disabled - return readBySelf(); - } - - private ByteBuf tryNextTaskContext() throws IOException { - // make sure there are still some read tasks waiting/running/returned from other threads - // or they are readDuplicated by self - if (totalSubmitted == 0 || selfCurrentCtx == lastCtx) { - return getBySelfAndSubmitMore(config.maxBytesInFlight); - } - if (totalSubmitted < 0) { - throw new IllegalStateException("total submitted should be no less than 0. " + totalSubmitted); - } - try { - IODataDesc desc; - if ((desc = tryGetFromOtherThread()) != null) { - submitMore(); - return validateLastEntryAndGetBuf(desc.getEntry(entryIdx)); - } - // duplicate and get data by self - return readDuplicated(true); - } catch (InterruptedException e) { - throw new IOException("read interrupted.", e); - } - } - - /** - * we have to duplicate submitted desc since mapId was moved. - * - * @return - * @throws IOException - */ - private ByteBuf readDuplicated(boolean expectNotNullCtx) throws IOException { - DaosReader.ReadTaskContext context = getNextNonReturnedCtx(); - if (context == null) { - if (expectNotNullCtx) { - throw new IllegalStateException("context should not be null. totalSubmitted: " + totalSubmitted); - } - if (!fromOtherThread) { - lastCtx = null; - } - return null; - } - IODataDesc newDesc = context.getDesc().duplicate(); - ByteBuf buf = getBySelf(newDesc, context.getMapReduceId()); - selfCurrentCtx = context; - return buf; - } - - @Override - protected DaosReader.ReadTaskContext getNextNonReturnedCtx() { - // in case no even single return from other thread - // check selfCurrentCtx since the wait could span multiple contexts/descs - DaosReader.ReadTaskContext curCtx = selfCurrentCtx == null ? - getCurrentCtx() : selfCurrentCtx; - if (curCtx == null) { - return getHeadCtx(); - } - // no consumedStack push and no totalInMemSize and totalSubmitted update - // since they will be updated when the task context finally returned - return curCtx.getNext(); - } - - private IODataDesc tryGetFromOtherThread() throws InterruptedException, IOException { - IODataDesc desc = tryGetValidCompleted(); - if (desc != null) { - return desc; - } - // check completion - if ((!mapIdIt.hasNext()) && curMapReduceId == null && totalSubmitted == 0) { - return null; - } - // wait for specified time - desc = waitForValidFromOtherThread(); - if (desc != null) { - return desc; - } - // check wait times and cancel task - // TODO: stop reading from other threads? - cancelTasks(false); - return null; - } - - private IODataDesc waitForValidFromOtherThread() throws InterruptedException, IOException { - IODataDesc desc; - while (true) { - long start = System.nanoTime(); - boolean timeout = waitForCondition(config.waitDataTimeMs); - metrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); - if (timeout) { - exceedWaitTimes++; - if (log.isDebugEnabled()) { - log.debug("exceed wait: {}ms, times: {}", config.waitDataTimeMs, exceedWaitTimes); - } - if (exceedWaitTimes >= config.waitTimeoutTimes) { - return null; - } - } - // get some results after wait - desc = tryGetValidCompleted(); - if (desc != null) { - return desc; - } - } - } - - protected IODataDesc tryGetValidCompleted() throws IOException { - if (moveForward()) { - return currentDesc; - } - return null; - } - - @Override - protected boolean consumed(LinkedTaskContext consumed) { - return !consumed.isCancelled(); - } - - @Override - protected boolean validateReturned(LinkedTaskContext context) throws IOException { - if (context.isCancelled()) { - return false; - } - selfCurrentCtx = null; // non-cancelled currentCtx overrides selfCurrentCtx - lastMapReduceIdForReturn = ((DaosReader.ReadTaskContext)context).getMapReduceId(); - IODataDesc desc = context.getDesc(); - if (!desc.isSucceeded()) { - String msg = "failed to get data from DAOS, desc: " + desc.toString(4096); - if (desc.getCause() != null) { - throw new IOException(msg, desc.getCause()); - } else { - throw new IllegalStateException(msg + "\nno exception got. logic error or crash?"); - } - } - currentDesc = desc; - return true; - } - - private ByteBuf tryCurrentDesc() throws IOException { - if (currentDesc != null) { - ByteBuf buf; - while (entryIdx < currentDesc.getNbrOfEntries()) { - IODataDesc.Entry entry = currentDesc.getEntry(entryIdx); - buf = validateLastEntryAndGetBuf(entry); - if (buf.readableBytes() > 0) { - return buf; - } - entryIdx++; - } - entryIdx = 0; - // no need to release desc since all its entries are released in tryCurrentEntry and - // internal buffers are released after object.fetch - // reader.close will release all in case of failure - currentDesc = null; - } - return null; - } - - private ByteBuf tryCurrentEntry() { - if (currentEntry != null && !currentEntry.isFetchBufReleased()) { - ByteBuf buf = currentEntry.getFetchedData(); - if (buf.readableBytes() > 0) { - return buf; - } - // release buffer as soon as possible - currentEntry.releaseDataBuffer(); - entryIdx++; - } - // not null currentEntry since it will be used for size validation - return null; - } - - /** - * for first read. - * - * @param selfReadLimit - * @return - * @throws IOException - */ - private ByteBuf getBySelfAndSubmitMore(long selfReadLimit) throws IOException { - entryIdx = 0; - // fetch the next by self - IODataDesc desc = createNextDesc(selfReadLimit); - Tuple2 mapreduceId = lastMapReduceIdForSubmit; - try { - if (fromOtherThread) { - submitMore(); - } - } catch (Exception e) { - desc.release(); - if (e instanceof IOException) { - throw (IOException)e; - } - throw new IOException("failed to submit more", e); - } - // first time read from reduce task - return getBySelf(desc, mapreduceId); - } - - private void submitMore() throws IOException { - while (totalSubmitted < config.readBatchSize && totalInMemSize < config.maxMem) { - IODataDesc taskDesc = createNextDesc(config.maxBytesInFlight); - if (taskDesc == null) { - break; - } - submit(taskDesc, lastMapReduceIdForSubmit); - } - } - - @Override - protected Runnable newTask(LinkedTaskContext context) { - return DaosReader.ReadTask.newInstance((DaosReader.ReadTaskContext) context); - } - - @Override - protected LinkedTaskContext createTaskContext(IODataDesc desc, Object morePara) { - return new DaosReader.ReadTaskContext(object, counter, lock, condition, desc, morePara); - } - - private ByteBuf getBySelf(IODataDesc desc, Tuple2 mapreduceId) throws IOException { - // get data by self, no need to release currentDesc - if (desc == null) { // reach end - return null; - } - boolean releaseBuf = false; - try { - object.fetch(desc); - currentDesc = desc; - ByteBuf buf = validateLastEntryAndGetBuf(desc.getEntry(entryIdx)); - lastMapReduceIdForReturn = mapreduceId; - return buf; - } catch (IOException | IllegalStateException e) { - releaseBuf = true; - throw e; - } finally { - desc.release(releaseBuf); - } - } - - private IODataDesc createNextDesc(long sizeLimit) throws IOException { - long remaining = sizeLimit; - int reduceId = -1; - long mapId; - IODataDesc desc = null; - while (remaining > 0) { - nextMapReduceId(); - if (curMapReduceId == null) { - break; - } - if (reduceId > 0 && curMapReduceId._2 != reduceId) { // make sure entries under same reduce - break; - } - reduceId = curMapReduceId._2; - mapId = curMapReduceId._1; - lastMapReduceIdForSubmit = curMapReduceId; - long readSize = partSizeMap.get(curMapReduceId)._1() - curOffset; - long offset = curOffset; - if (readSize > remaining) { - readSize = remaining; - curOffset += readSize; - } else { - curOffset = 0; - curMapReduceId = null; - } - if (desc == null) { - desc = object.createDataDescForFetch(String.valueOf(reduceId), IODataDesc.IodType.ARRAY, 1); - } - desc.addEntryForFetch(String.valueOf(mapId), (int)offset, (int)readSize); - remaining -= readSize; - } - return desc; - } - - private void nextMapReduceId() { - if (curMapReduceId != null) { - return; - } - curOffset = 0; - if (mapIdIt.hasNext()) { - curMapReduceId = mapIdIt.next(); - partsRead++; - } else { - curMapReduceId = null; - } - } - - private ByteBuf validateLastEntryAndGetBuf(IODataDesc.Entry entry) throws IOException { - ByteBuf buf = entry.getFetchedData(); - int byteLen = buf.readableBytes(); - newMap = false; - if (currentEntry != null && entry != currentEntry) { - if (entry.getKey().equals(currentEntry.getKey())) { - currentPartSize += byteLen; - } else { - checkPartitionSize(); - newMap = true; - currentPartSize = byteLen; - } - } - currentEntry = entry; - metrics.incRemoteBytesRead(byteLen); - return buf; - } - - private void checkPartitionSize() throws IOException { - if (lastMapReduceIdForReturn == null) { - return; - } - // partition size is not accurate after compress/decompress - long size = partSizeMap.get(lastMapReduceIdForReturn)._1(); - if (size < 35 * 1024 * 1024 * 1024 && currentPartSize * 1.1 < size) { - throw new IOException("expect partition size " + partSizeMap.get(lastMapReduceIdForReturn) + - ", actual size " + currentPartSize + ", mapId and reduceId: " + lastMapReduceIdForReturn); - } - metrics.incRemoteBlocksFetched(1); - } - - public boolean cleanup(boolean force) { - boolean allReleased = true; - if (!cleaned) { - allReleased &= cleanupSubmitted(force); - allReleased &= cleanupConsumed(force); - } - return allReleased; - } - - public void checkTotalPartitions() throws IOException { - if (partsRead != totalParts) { - throw new IOException("expect total partitions to be read: " + totalParts + ", actual read: " + partsRead); - } - } - - @Override - public DaosReader.ReadTaskContext getCurrentCtx() { - return (DaosReader.ReadTaskContext) currentCtx; - } - - @Override - public DaosReader.ReadTaskContext getHeadCtx() { - return (DaosReader.ReadTaskContext) headCtx; - } - - @Override - public DaosReader.ReadTaskContext getLastCtx() { - return (DaosReader.ReadTaskContext) lastCtx; - } - } - - /** - * reader configurations, please check configs prefixed with SHUFFLE_DAOS_READ in {@link package$#MODULE$}. - */ - private static final class ReaderConfig { - private long minReadSize; - private long maxBytesInFlight; - private long maxMem; - private int readBatchSize; - private int waitDataTimeMs; - private int waitTimeoutTimes; - private boolean fromOtherThread; - - private ReaderConfig(long maxBytesInFlight, long maxMem) { - SparkConf conf = SparkEnv.get().conf(); - minReadSize = (long)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_MINIMUM_SIZE()) * 1024; - if (maxBytesInFlight < minReadSize) { - this.maxBytesInFlight = minReadSize; - } else { - this.maxBytesInFlight = maxBytesInFlight; - } - this.maxMem = maxMem; - this.readBatchSize = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_BATCH_SIZE()); - this.waitDataTimeMs = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_WAIT_DATA_MS()); - this.waitTimeoutTimes = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_WAIT_DATA_TIMEOUT_TIMES()); - this.fromOtherThread = (boolean)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_FROM_OTHER_THREAD()); - if (log.isDebugEnabled()) { - log.debug("minReadSize: " + minReadSize); - log.debug("maxBytesInFlight: " + maxBytesInFlight); - log.debug("maxMem: " + maxMem); - log.debug("readBatchSize: " + readBatchSize); - log.debug("waitDataTimeMs: " + waitDataTimeMs); - log.debug("waitTimeoutTimes: " + waitTimeoutTimes); - log.debug("fromOtherThread: " + fromOtherThread); - } - } - } } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriter.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriter.java index 8b9714f2..f3922294 100644 --- a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriter.java +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriter.java @@ -23,80 +23,20 @@ package org.apache.spark.shuffle.daos; -import io.daos.BufferAllocator; -import io.daos.DaosIOException; -import io.daos.obj.DaosObject; -import io.daos.obj.IODataDesc; -import io.netty.buffer.ByteBuf; -import io.netty.util.internal.ObjectPool; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.*; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.Lock; /** * A DAOS writer per map task which may have multiple map output partitions. - * Each partition has one corresponding {@link NativeBuffer} which caches records until - * a specific {@link #flush(int)} call being made. Then {@link NativeBuffer} creates - * {@link IODataDesc} and write to DAOS in either caller thread or other dedicated thread. + * Each partition has one corresponding buffer which caches records until + * a specific {@link #flush(int)} call being made. Data is written to DAOS + * in either caller thread or other dedicated thread. */ -public class DaosWriter extends TaskSubmitter { - - private DaosObject object; - - private String mapId; - - private WriteParam param; - - private WriteConfig config; - - private Map writerMap; - - private NativeBuffer[] partitionBufArray; - - private int totalTimeoutTimes; - - private int totalWriteTimes; - - private int totalBySelfTimes; - - private volatile boolean cleaned; - - private static Logger LOG = LoggerFactory.getLogger(DaosWriter.class); - - /** - * construct DaosWriter with object and dedicated read executors. - * - * @param param - * write parameters - * @param object - * opened DaosObject - * @param executor - * null means write in caller's thread. Submit {@link WriteTask} to it otherwise. - */ - public DaosWriter(DaosWriter.WriteParam param, DaosObject object, - BoundThreadExecutors.SingleThreadExecutor executor) { - super(executor); - this.param = param; - this.config = param.config; - this.partitionBufArray = new NativeBuffer[param.numPartitions]; - this.mapId = String.valueOf(param.mapId); - this.object = object; - } - - private NativeBuffer getNativeBuffer(int partitionId) { - NativeBuffer buffer = partitionBufArray[partitionId]; - if (buffer == null) { - buffer = new NativeBuffer(partitionId, config.bufferSize); - partitionBufArray[partitionId] = buffer; - } - return buffer; - } +public interface DaosWriter { /** * write to buffer. @@ -104,9 +44,7 @@ private NativeBuffer getNativeBuffer(int partitionId) { * @param partitionId * @param b */ - public void write(int partitionId, int b) { - getNativeBuffer(partitionId).write(b); - } + void write(int partitionId, int b); /** * write to buffer. @@ -114,9 +52,7 @@ public void write(int partitionId, int b) { * @param partitionId * @param array */ - public void write(int partitionId, byte[] array) { - getNativeBuffer(partitionId).write(array); - } + void write(int partitionId, byte[] array); /** * write to buffer. @@ -126,9 +62,7 @@ public void write(int partitionId, byte[] array) { * @param offset * @param len */ - public void write(int partitionId, byte[] array, int offset, int len) { - getNativeBuffer(partitionId).write(array, offset, len); - } + void write(int partitionId, byte[] array, int offset, int len); /** * get length of all partitions. @@ -137,32 +71,7 @@ public void write(int partitionId, byte[] array, int offset, int len) { * @param numPartitions * @return array of partition lengths */ - public long[] getPartitionLens(int numPartitions) { - if (LOG.isDebugEnabled()) { - LOG.debug("partition map size: " + partitionBufArray.length); - for (int i = 0; i < numPartitions; i++) { - NativeBuffer nb = partitionBufArray[i]; - if (nb != null) { - LOG.debug("id: " + i + ", native buffer: " + nb.partitionId + ", " + - nb.totalSize + ", " + nb.roundSize); - } - } - } - long[] lens = new long[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - NativeBuffer nb = partitionBufArray[i]; - if (nb != null) { - lens[i] = nb.totalSize; - if (nb.roundSize != 0 || !nb.bufList.isEmpty()) { - throw new IllegalStateException("round size should be 0, " + nb.roundSize + ", buflist should be empty, " + - nb.bufList.size()); - } - } else { - lens[i] = 0; - } - } - return lens; - } + long[] getPartitionLens(int numPartitions); /** * Flush specific partition to DAOS. @@ -170,280 +79,17 @@ public long[] getPartitionLens(int numPartitions) { * @param partitionId * @throws IOException */ - public void flush(int partitionId) throws IOException { - NativeBuffer buffer = partitionBufArray[partitionId]; - if (buffer == null) { - return; - } - IODataDesc desc = buffer.createUpdateDesc(); - if (desc == null) { - return; - } - totalWriteTimes++; - if (config.warnSmallWrite && buffer.roundSize < config.minSize) { - LOG.warn("too small partition size {}, shuffle {}, map {}, partition {}", - buffer.roundSize, param.shuffleId, mapId, partitionId); - } - if (executor == null) { // run write by self - runBySelf(desc, buffer); - return; - } - submitToOtherThreads(desc, buffer); - } - - private void runBySelf(IODataDesc desc, NativeBuffer buffer) throws IOException { - totalBySelfTimes++; - try { - object.update(desc); - } catch (IOException e) { - throw new IOException("failed to write partition of " + desc, e); - } finally { - desc.release(); - buffer.reset(true); - } - } - - private void submitToOtherThreads(IODataDesc desc, NativeBuffer buffer) throws IOException { - // move forward to release write buffers - moveForward(); - // check if we need to wait submitted tasks to be executed - if (goodForSubmit()) { - submitAndReset(desc, buffer); - return; - } - // to wait - int timeoutTimes = 0; - try { - while (!goodForSubmit()) { - boolean timeout = waitForCondition(config.waitTimeMs); - moveForward(); - if (timeout) { - timeoutTimes++; - if (LOG.isDebugEnabled()) { - LOG.debug("wait daos write timeout times: " + timeoutTimes); - } - if (timeoutTimes >= config.timeoutTimes) { - totalTimeoutTimes += timeoutTimes; - runBySelf(desc, buffer); - return; - } - } - } - } catch (InterruptedException e) { - desc.release(); - Thread.currentThread().interrupt(); - throw new IOException("interrupted when wait daos write", e); - } - // submit write task after some wait - totalTimeoutTimes += timeoutTimes; - submitAndReset(desc, buffer); - } - - private boolean goodForSubmit() { - return totalInMemSize < config.totalInMemSize && totalSubmitted < config.totalSubmittedLimit; - } - - private void submitAndReset(IODataDesc desc, NativeBuffer buffer) { - try { - submit(desc, buffer.bufList); - } finally { - buffer.reset(false); - } - } - - private void cleanup(boolean force) { - if (cleaned) { - return; - } - boolean allReleased = true; - allReleased &= cleanupSubmitted(force); - allReleased &= cleanupConsumed(force); - if (allReleased) { - cleaned = true; - } - } + void flush(int partitionId) throws IOException; /** - * wait write task to be completed and clean up resources. + * close writer. */ - public void close() { - try { - close(true); - } catch (Exception e) { - throw new IllegalStateException("failed to complete all write tasks and cleanup", e); - } - } - - private void close(boolean force) throws Exception { - if (partitionBufArray != null) { - waitCompletion(force); - partitionBufArray = null; - object = null; - if (LOG.isDebugEnabled()) { - LOG.debug("total writes: " + totalWriteTimes + ", total timeout times: " + totalTimeoutTimes + - ", total write-by-self times: " + totalBySelfTimes + ", total timeout times/total writes: " + - ((float) totalTimeoutTimes) / totalWriteTimes); - } - } - cleanup(force); - if (writerMap != null && (force || cleaned)) { - writerMap.remove(this); - writerMap = null; - } - } - - private void waitCompletion(boolean force) throws Exception { - if (!force) { - return; - } - try { - while (totalSubmitted > 0) { - waitForCondition(config.waitTimeMs); - moveForward(); - } - } catch (Exception e) { - LOG.error("failed to wait completion of daos writing", e); - throw e; - } - } - - public void setWriterMap(Map writerMap) { - writerMap.put(this, 0); - this.writerMap = writerMap; - } - - @Override - protected Runnable newTask(LinkedTaskContext context) { - return WriteTask.newInstance((WriteTaskContext) context); - } - - @Override - protected LinkedTaskContext createTaskContext(IODataDesc desc, Object morePara) { - return new WriteTaskContext(object, counter, lock, condition, desc, morePara); - } - - @Override - protected boolean validateReturned(LinkedTaskContext context) throws IOException { - if (!context.desc.isSucceeded()) { - throw new DaosIOException("write is not succeeded: " + context.desc); - } - return false; - } - - @Override - protected boolean consumed(LinkedTaskContext context) { - // release write buffers - List bufList = (List) context.morePara; - bufList.forEach(b -> b.release()); - bufList.clear(); - return true; - } - - /** - * Write data to one or multiple netty direct buffers which will be written to DAOS without copy - */ - private class NativeBuffer implements Comparable { - private int partitionId; - private String partitionIdKey; - private int bufferSize; - private int idx = -1; - private List bufList = new ArrayList<>(); - private long totalSize; - private long roundSize; - - NativeBuffer(int partitionId, int bufferSize) { - this.partitionId = partitionId; - this.partitionIdKey = String.valueOf(partitionId); - this.bufferSize = bufferSize; - } - - private ByteBuf addNewByteBuf(int len) { - ByteBuf buf; - try { - buf = BufferAllocator.objBufWithNativeOrder(Math.max(bufferSize, len)); - } catch (OutOfMemoryError e) { - LOG.error("too big buffer size: " + Math.max(bufferSize, len)); - throw e; - } - bufList.add(buf); - idx++; - return buf; - } - - private ByteBuf getBuffer(int len) { - if (idx < 0) { - return addNewByteBuf(len); - } - return bufList.get(idx); - } - - public void write(int b) { - ByteBuf buf = getBuffer(1); - if (buf.writableBytes() < 1) { - buf = addNewByteBuf(1); - } - buf.writeByte(b); - roundSize += 1; - } - - public void write(byte[] b) { - write(b, 0, b.length); - } - - public void write(byte[] b, int offset, int len) { - if (len <= 0) { - return; - } - ByteBuf buf = getBuffer(len); - int avail = buf.writableBytes(); - int gap = len - avail; - if (gap <= 0) { - buf.writeBytes(b, offset, len); - } else { - buf.writeBytes(b, offset, avail); - buf = addNewByteBuf(gap); - buf.writeBytes(b, avail, gap); - } - roundSize += len; - } - - public IODataDesc createUpdateDesc() throws IOException { - if (roundSize == 0 || bufList.isEmpty()) { - return null; - } - long bufSize = 0; - IODataDesc desc = object.createDataDescForUpdate(partitionIdKey, IODataDesc.IodType.ARRAY, 1); - for (ByteBuf buf : bufList) { - desc.addEntryForUpdate(mapId, (int) totalSize, buf); - bufSize += buf.readableBytes(); - } - if (roundSize != bufSize) { - throw new IOException("expect update size: " + roundSize + ", actual: " + bufSize); - } - return desc; - } - - public void reset(boolean release) { - if (release) { - bufList.forEach(b -> b.release()); - } - // release==false, buffers will be released when tasks are executed and consumed - bufList.clear(); - idx = -1; - totalSize += roundSize; - roundSize = 0; - } - - @Override - public int compareTo(NativeBuffer nativeBuffer) { - return partitionId - nativeBuffer.partitionId; - } - } + void close(); /** * Write configurations. Please check configs prefixed with SHUFFLE_DAOS_WRITE in {@link package$#MODULE$}. */ - public static class WriteConfig { + class WriterConfig { private int bufferSize; private int minSize; private boolean warnSmallWrite; @@ -454,50 +100,26 @@ public static class WriteConfig { private int threads; private boolean fromOtherThreads; - public WriteConfig bufferSize(int bufferSize) { - this.bufferSize = bufferSize; - return this; - } - - public WriteConfig minSize(int minSize) { - this.minSize = minSize; - return this; - } - - public WriteConfig warnSmallWrite(boolean warnSmallWrite) { - this.warnSmallWrite = warnSmallWrite; - return this; - } - - public WriteConfig waitTimeMs(long waitTimeMs) { - this.waitTimeMs = waitTimeMs; - return this; - } - - public WriteConfig timeoutTimes(int timeoutTimes) { - this.timeoutTimes = timeoutTimes; - return this; - } - - public WriteConfig totalInMemSize(long totalInMemSize) { - this.totalInMemSize = totalInMemSize; - return this; - } + private static final Logger logger = LoggerFactory.getLogger(WriterConfig.class); - public WriteConfig totalSubmittedLimit(int totalSubmittedLimit) { - this.totalSubmittedLimit = totalSubmittedLimit; - return this; - } - - public WriteConfig threads(int threads) { - this.threads = threads; - return this; + WriterConfig() { + SparkConf conf = SparkEnv.get().conf(); + warnSmallWrite = (boolean) conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WARN_SMALL_SIZE()); + bufferSize = (int) ((long) conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SINGLE_BUFFER_SIZE()) + * 1024 * 1024); + minSize = (int) ((long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_MINIMUM_SIZE()) * 1024); + timeoutTimes = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WAIT_DATA_TIMEOUT_TIMES()); + waitTimeMs = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_WAIT_MS()); + totalInMemSize = (long)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_MAX_BYTES_IN_FLIGHT()) * 1024; + totalSubmittedLimit = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_SUBMITTED_LIMIT()); + threads = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_WRITE_THREADS()); + fromOtherThreads = (boolean)conf + .get(package$.MODULE$.SHUFFLE_DAOS_WRITE_IN_OTHER_THREAD()); + if (logger.isDebugEnabled()) { + logger.debug(toString()); + } } - public WriteConfig fromOtherThreads(boolean fromOtherThreads) { - this.fromOtherThreads = fromOtherThreads; - return this; - } public int getBufferSize() { return bufferSize; @@ -550,141 +172,4 @@ public String toString() { '}'; } } - - public static class WriteParam { - private int numPartitions; - private int shuffleId; - private long mapId; - private WriteConfig config; - - public WriteParam numPartitions(int numPartitions) { - this.numPartitions = numPartitions; - return this; - } - - public WriteParam shuffleId(int shuffleId) { - this.shuffleId = shuffleId; - return this; - } - - public WriteParam mapId(long mapId) { - this.mapId = mapId; - return this; - } - - public WriteParam config(WriteConfig config) { - this.config = config; - return this; - } - } - - /** - * Task to write data to DAOS. Task itself is cached to reduce GC time. - * To reuse task for different writes, prepare and reset {@link WriteTaskContext} by calling - * {@link #newInstance(WriteTaskContext)} - */ - static final class WriteTask implements Runnable { - private final ObjectPool.Handle handle; - private WriteTaskContext context; - - private static final ObjectPool objectPool = ObjectPool.newPool(handle -> new WriteTask(handle)); - - private static final Logger log = LoggerFactory.getLogger(WriteTask.class); - - static WriteTask newInstance(WriteTaskContext context) { - WriteTask task = objectPool.get(); - task.context = context; - return task; - } - - private WriteTask(ObjectPool.Handle handle) { - this.handle = handle; - } - - @Override - public void run() { - boolean cancelled = context.cancelled; - try { - if (!cancelled) { - context.object.update(context.desc); - } - } catch (Exception e) { - log.error("failed to write for " + context.desc, e); - } finally { - context.desc.release(); - context.signal(); - context = null; - handle.recycle(this); - } - } - } - - /** - * Context for write task. It holds all other object to read and sync between caller thread and write thread. - * It should be cached in caller thread for reusing. - */ - static final class WriteTaskContext extends LinkedTaskContext { - - /** - * constructor with all parameters. Some of them can be reused later. - * - * @param object - * DAOS object to fetch data from DAOS - * @param counter - * counter to indicate how many write is on-going - * @param writeLock - * lock to work with notFull condition to signal caller thread to submit more write task - * @param notFull - * condition to signal caller thread - * @param desc - * desc object to describe where to write data - * @param bufList - * list of buffers to write to DAOS - */ - WriteTaskContext(DaosObject object, AtomicInteger counter, Lock writeLock, Condition notFull, - IODataDesc desc, Object bufList) { - super(object, counter, writeLock, notFull); - this.desc = desc; - List myBufList = new ArrayList<>(); - myBufList.addAll((List) bufList); - this.morePara = myBufList; - } - - @Override - public WriteTaskContext getNext() { - return (WriteTaskContext) next; - } - - @Override - public void reuse(IODataDesc desc, Object morePara) { - List myBufList = (List) this.morePara; - if (!myBufList.isEmpty()) { - throw new IllegalStateException("bufList in reusing write task context should be empty"); - } - myBufList.addAll((List) morePara); - super.reuse(desc, myBufList); - } - } - - /** - * Thread factory for write - */ - protected static class WriteThreadFactory implements ThreadFactory { - private AtomicInteger id = new AtomicInteger(0); - - @Override - public Thread newThread(Runnable runnable) { - Thread t; - String name = "daos_write_" + id.getAndIncrement(); - if (runnable == null) { - t = new Thread(name); - } else { - t = new Thread(runnable, name); - } - t.setDaemon(true); - t.setUncaughtExceptionHandler((thread, throwable) -> - LOG.error("exception occurred in thread " + name, throwable)); - return t; - } - } } diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterSync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterSync.java new file mode 100644 index 00000000..b1165ce7 --- /dev/null +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/DaosWriterSync.java @@ -0,0 +1,559 @@ +/* + * (C) Copyright 2018-2020 Intel Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * GOVERNMENT LICENSE RIGHTS-OPEN SOURCE SOFTWARE + * The Government's rights to use, modify, reproduce, release, perform, display, + * or disclose this software are subject to the terms of the Apache License as + * provided in Contract No. B609815. + * Any reproduction of computer software, computer software documentation, or + * portions thereof marked with this legend must also reproduce the markings. + */ + +package org.apache.spark.shuffle.daos; + +import io.daos.BufferAllocator; +import io.daos.DaosIOException; +import io.daos.obj.DaosObject; +import io.daos.obj.IODataDesc; +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.ObjectPool; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; + +/** + * A implementation of {@link DaosWriter} bases on synchronous DAOS Object API. + * + * For each partition, there is one corresponding {@link NativeBuffer} creates + * {@link IODataDesc} and write to DAOS in either caller thread or other dedicated thread. + */ +public class DaosWriterSync extends TaskSubmitter implements DaosWriter { + + private DaosObject object; + + private String mapId; + + private WriteParam param; + + private WriterConfig config; + + private Map writerMap; + + private NativeBuffer[] partitionBufArray; + + private int totalTimeoutTimes; + + private int totalWriteTimes; + + private int totalBySelfTimes; + + private volatile boolean cleaned; + + private static Logger LOG = LoggerFactory.getLogger(DaosWriterSync.class); + + /** + * construct DaosWriter with object and dedicated read executors. + * + * @param param + * write parameters + * @param object + * opened DaosObject + * @param executor + * null means write in caller's thread. Submit {@link WriteTask} to it otherwise. + */ + public DaosWriterSync(DaosObject object, DaosWriterSync.WriteParam param, + BoundThreadExecutors.SingleThreadExecutor executor) { + super(executor); + this.param = param; + this.config = param.config; + this.partitionBufArray = new NativeBuffer[param.numPartitions]; + this.mapId = String.valueOf(param.mapId); + this.object = object; + } + + private NativeBuffer getNativeBuffer(int partitionId) { + NativeBuffer buffer = partitionBufArray[partitionId]; + if (buffer == null) { + buffer = new NativeBuffer(partitionId, config.getBufferSize()); + partitionBufArray[partitionId] = buffer; + } + return buffer; + } + + @Override + public void write(int partitionId, int b) { + getNativeBuffer(partitionId).write(b); + } + + @Override + public void write(int partitionId, byte[] array) { + getNativeBuffer(partitionId).write(array); + } + + @Override + public void write(int partitionId, byte[] array, int offset, int len) { + getNativeBuffer(partitionId).write(array, offset, len); + } + + @Override + public long[] getPartitionLens(int numPartitions) { + if (LOG.isDebugEnabled()) { + LOG.debug("partition map size: " + partitionBufArray.length); + for (int i = 0; i < numPartitions; i++) { + NativeBuffer nb = partitionBufArray[i]; + if (nb != null) { + LOG.debug("id: " + i + ", native buffer: " + nb.partitionId + ", " + + nb.totalSize + ", " + nb.roundSize); + } + } + } + long[] lens = new long[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + NativeBuffer nb = partitionBufArray[i]; + if (nb != null) { + lens[i] = nb.totalSize; + if (nb.roundSize != 0 || !nb.bufList.isEmpty()) { + throw new IllegalStateException("round size should be 0, " + nb.roundSize + ", buflist should be empty, " + + nb.bufList.size()); + } + } else { + lens[i] = 0; + } + } + return lens; + } + + @Override + public void flush(int partitionId) throws IOException { + NativeBuffer buffer = partitionBufArray[partitionId]; + if (buffer == null) { + return; + } + IODataDesc desc = buffer.createUpdateDesc(); + if (desc == null) { + return; + } + totalWriteTimes++; + if (config.isWarnSmallWrite() && buffer.roundSize < config.getMinSize()) { + LOG.warn("too small partition size {}, shuffle {}, map {}, partition {}", + buffer.roundSize, param.shuffleId, mapId, partitionId); + } + if (executor == null) { // run write by self + runBySelf(desc, buffer); + return; + } + submitToOtherThreads(desc, buffer); + } + + private void runBySelf(IODataDesc desc, NativeBuffer buffer) throws IOException { + totalBySelfTimes++; + try { + object.update(desc); + } catch (IOException e) { + throw new IOException("failed to write partition of " + desc, e); + } finally { + desc.release(); + buffer.reset(true); + } + } + + private void submitToOtherThreads(IODataDesc desc, NativeBuffer buffer) throws IOException { + // move forward to release write buffers + moveForward(); + // check if we need to wait submitted tasks to be executed + if (goodForSubmit()) { + submitAndReset(desc, buffer); + return; + } + // to wait + int timeoutTimes = 0; + try { + while (!goodForSubmit()) { + boolean timeout = waitForCondition(config.getWaitTimeMs()); + moveForward(); + if (timeout) { + timeoutTimes++; + if (LOG.isDebugEnabled()) { + LOG.debug("wait daos write timeout times: " + timeoutTimes); + } + if (timeoutTimes >= config.getTimeoutTimes()) { + totalTimeoutTimes += timeoutTimes; + runBySelf(desc, buffer); + return; + } + } + } + } catch (InterruptedException e) { + desc.release(); + Thread.currentThread().interrupt(); + throw new IOException("interrupted when wait daos write", e); + } + // submit write task after some wait + totalTimeoutTimes += timeoutTimes; + submitAndReset(desc, buffer); + } + + private boolean goodForSubmit() { + return totalInMemSize < config.getTotalInMemSize() && totalSubmitted < config.getTotalSubmittedLimit(); + } + + private void submitAndReset(IODataDesc desc, NativeBuffer buffer) { + try { + submit(desc, buffer.bufList); + } finally { + buffer.reset(false); + } + } + + private void cleanup(boolean force) { + if (cleaned) { + return; + } + boolean allReleased = true; + allReleased &= cleanupSubmitted(force); + allReleased &= cleanupConsumed(force); + if (allReleased) { + cleaned = true; + } + } + + /** + * wait write task to be completed and clean up resources. + */ + @Override + public void close() { + try { + close(true); + } catch (Exception e) { + throw new IllegalStateException("failed to complete all write tasks and cleanup", e); + } + } + + private void close(boolean force) throws Exception { + if (partitionBufArray != null) { + waitCompletion(force); + partitionBufArray = null; + object = null; + if (LOG.isDebugEnabled()) { + LOG.debug("total writes: " + totalWriteTimes + ", total timeout times: " + totalTimeoutTimes + + ", total write-by-self times: " + totalBySelfTimes + ", total timeout times/total writes: " + + ((float) totalTimeoutTimes) / totalWriteTimes); + } + } + cleanup(force); + if (writerMap != null && (force || cleaned)) { + writerMap.remove(this); + writerMap = null; + } + } + + private void waitCompletion(boolean force) throws Exception { + if (!force) { + return; + } + try { + while (totalSubmitted > 0) { + waitForCondition(config.getWaitTimeMs()); + moveForward(); + } + } catch (Exception e) { + LOG.error("failed to wait completion of daos writing", e); + throw e; + } + } + + public void setWriterMap(Map writerMap) { + writerMap.put(this, 0); + this.writerMap = writerMap; + } + + @Override + protected Runnable newTask(LinkedTaskContext context) { + return WriteTask.newInstance((WriteTaskContext) context); + } + + @Override + protected LinkedTaskContext createTaskContext(IODataDesc desc, Object morePara) { + return new WriteTaskContext(object, counter, lock, condition, desc, morePara); + } + + @Override + protected boolean validateReturned(LinkedTaskContext context) throws IOException { + if (!context.desc.isSucceeded()) { + throw new DaosIOException("write is not succeeded: " + context.desc); + } + return false; + } + + @Override + protected boolean consumed(LinkedTaskContext context) { + // release write buffers + @SuppressWarnings("unchecked") + List bufList = (List) context.morePara; + bufList.forEach(b -> b.release()); + bufList.clear(); + return true; + } + + /** + * Write data to one or multiple netty direct buffers which will be written to DAOS without copy + */ + private class NativeBuffer implements Comparable { + private int partitionId; + private String partitionIdKey; + private int bufferSize; + private int idx = -1; + private List bufList = new ArrayList<>(); + private long totalSize; + private long roundSize; + + NativeBuffer(int partitionId, int bufferSize) { + this.partitionId = partitionId; + this.partitionIdKey = String.valueOf(partitionId); + this.bufferSize = bufferSize; + } + + private ByteBuf addNewByteBuf(int len) { + ByteBuf buf; + try { + buf = BufferAllocator.objBufWithNativeOrder(Math.max(bufferSize, len)); + } catch (OutOfMemoryError e) { + LOG.error("too big buffer size: " + Math.max(bufferSize, len)); + throw e; + } + bufList.add(buf); + idx++; + return buf; + } + + private ByteBuf getBuffer(int len) { + if (idx < 0) { + return addNewByteBuf(len); + } + return bufList.get(idx); + } + + public void write(int b) { + ByteBuf buf = getBuffer(1); + if (buf.writableBytes() < 1) { + buf = addNewByteBuf(1); + } + buf.writeByte(b); + roundSize += 1; + } + + public void write(byte[] b) { + write(b, 0, b.length); + } + + public void write(byte[] b, int offset, int len) { + if (len <= 0) { + return; + } + ByteBuf buf = getBuffer(len); + int avail = buf.writableBytes(); + int gap = len - avail; + if (gap <= 0) { + buf.writeBytes(b, offset, len); + } else { + buf.writeBytes(b, offset, avail); + buf = addNewByteBuf(gap); + buf.writeBytes(b, avail, gap); + } + roundSize += len; + } + + public IODataDesc createUpdateDesc() throws IOException { + if (roundSize == 0 || bufList.isEmpty()) { + return null; + } + long bufSize = 0; + IODataDesc desc = object.createDataDescForUpdate(partitionIdKey, IODataDesc.IodType.ARRAY, 1); + for (ByteBuf buf : bufList) { + desc.addEntryForUpdate(mapId, (int) totalSize, buf); + bufSize += buf.readableBytes(); + } + if (roundSize != bufSize) { + throw new IOException("expect update size: " + roundSize + ", actual: " + bufSize); + } + return desc; + } + + public void reset(boolean release) { + if (release) { + bufList.forEach(b -> b.release()); + } + // release==false, buffers will be released when tasks are executed and consumed + bufList.clear(); + idx = -1; + totalSize += roundSize; + roundSize = 0; + } + + @Override + public int compareTo(NativeBuffer nativeBuffer) { + return partitionId - nativeBuffer.partitionId; + } + } + + public static class WriteParam { + private int numPartitions; + private int shuffleId; + private long mapId; + private DaosWriter.WriterConfig config; + + public WriteParam numPartitions(int numPartitions) { + this.numPartitions = numPartitions; + return this; + } + + public WriteParam shuffleId(int shuffleId) { + this.shuffleId = shuffleId; + return this; + } + + public WriteParam mapId(long mapId) { + this.mapId = mapId; + return this; + } + + public WriteParam config(DaosWriter.WriterConfig config) { + this.config = config; + return this; + } + } + + /** + * Task to write data to DAOS. Task itself is cached to reduce GC time. + * To reuse task for different writes, prepare and reset {@link WriteTaskContext} by calling + * {@link #newInstance(WriteTaskContext)} + */ + static final class WriteTask implements Runnable { + private final ObjectPool.Handle handle; + private WriteTaskContext context; + + private static final ObjectPool objectPool = ObjectPool.newPool(handle -> new WriteTask(handle)); + + private static final Logger log = LoggerFactory.getLogger(WriteTask.class); + + static WriteTask newInstance(WriteTaskContext context) { + WriteTask task = objectPool.get(); + task.context = context; + return task; + } + + private WriteTask(ObjectPool.Handle handle) { + this.handle = handle; + } + + @Override + public void run() { + boolean cancelled = context.cancelled; + try { + if (!cancelled) { + context.object.update(context.desc); + } + } catch (Exception e) { + log.error("failed to write for " + context.desc, e); + } finally { + context.desc.release(); + context.signal(); + context = null; + handle.recycle(this); + } + } + } + + /** + * Context for write task. It holds all other object to read and sync between caller thread and write thread. + * It should be cached in caller thread for reusing. + */ + static final class WriteTaskContext extends LinkedTaskContext { + + /** + * constructor with all parameters. Some of them can be reused later. + * + * @param object + * DAOS object to fetch data from DAOS + * @param counter + * counter to indicate how many write is on-going + * @param writeLock + * lock to work with notFull condition to signal caller thread to submit more write task + * @param notFull + * condition to signal caller thread + * @param desc + * desc object to describe where to write data + * @param bufList + * list of buffers to write to DAOS + */ + WriteTaskContext(DaosObject object, AtomicInteger counter, Lock writeLock, Condition notFull, + IODataDesc desc, Object bufList) { + super(object, counter, writeLock, notFull); + this.desc = desc; + @SuppressWarnings("unchecked") + List myBufList = new ArrayList<>(); + myBufList.addAll((List) bufList); + this.morePara = myBufList; + } + + @Override + public WriteTaskContext getNext() { + @SuppressWarnings("unchecked") + WriteTaskContext ctx = (WriteTaskContext) next; + return ctx; + } + + @Override + public void reuse(IODataDesc desc, Object morePara) { + @SuppressWarnings("unchecked") + List myBufList = (List) this.morePara; + if (!myBufList.isEmpty()) { + throw new IllegalStateException("bufList in reusing write task context should be empty"); + } + myBufList.addAll((List) morePara); + super.reuse(desc, myBufList); + } + } + + /** + * Thread factory for write + */ + protected static class WriteThreadFactory implements ThreadFactory { + private AtomicInteger id = new AtomicInteger(0); + + @Override + public Thread newThread(Runnable runnable) { + Thread t; + String name = "daos_write_" + id.getAndIncrement(); + if (runnable == null) { + t = new Thread(name); + } else { + t = new Thread(runnable, name); + } + t.setDaemon(true); + t.setUncaughtExceptionHandler((thread, throwable) -> + LOG.error("exception occurred in thread " + name, throwable)); + return t; + } + } +} diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManager.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManager.java new file mode 100644 index 00000000..aebea83f --- /dev/null +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManager.java @@ -0,0 +1,88 @@ +/* + * (C) Copyright 2018-2020 Intel Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * GOVERNMENT LICENSE RIGHTS-OPEN SOURCE SOFTWARE + * The Government's rights to use, modify, reproduce, release, perform, display, + * or disclose this software are subject to the terms of the Apache License as + * provided in Contract No. B609815. + * Any reproduction of computer software, computer software documentation, or + * portions thereof marked with this legend must also reproduce the markings. + */ + +package org.apache.spark.shuffle.daos; + +import io.daos.obj.DaosObjClient; +import io.daos.obj.DaosObject; +import io.daos.obj.DaosObjectException; +import io.daos.obj.DaosObjectId; +import org.apache.spark.SparkConf; + +import java.io.IOException; +import java.util.Map; + +public abstract class IOManager { + + protected Map objectMap; + + protected SparkConf conf; + + protected DaosObjClient objClient; + + protected IOManager(SparkConf conf, Map objectMap) { + this.conf = conf; + this.objectMap = objectMap; + } + + private String getKey(long appId, int shuffleId) { + return appId + "" + shuffleId; + } + + protected static long parseAppId(String appId) { + return Long.valueOf(appId.replaceAll("\\D", "")); + } + + protected DaosObject getObject(long appId, int shuffleId) throws DaosObjectException { + String key = getKey(appId, shuffleId); + DaosObject object = objectMap.get(key); + if (object == null) { + DaosObjectId id = new DaosObjectId(appId, shuffleId); + id.encode(); + object = objClient.getObject(id); + objectMap.putIfAbsent(key, object); + DaosObject activeObject = objectMap.get(key); + if (activeObject != object) { // release just created DaosObject + object.close(); + object = activeObject; + } + } + // open just once in multiple threads + if (!object.isOpen()) { + synchronized (object) { + object.open(); + } + } + return object; + } + + public void setObjClient(DaosObjClient objClient) { + this.objClient = objClient; + } + + abstract DaosWriter getDaosWriter(int numPartitions, int shuffleId, long mapId) throws IOException; + + abstract DaosReader getDaosReader(int shuffleId) throws IOException; + + abstract void close() throws IOException; +} diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerAsync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerAsync.java new file mode 100644 index 00000000..122493c3 --- /dev/null +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerAsync.java @@ -0,0 +1,29 @@ +package org.apache.spark.shuffle.daos; + +import io.daos.obj.DaosObject; +import org.apache.spark.SparkConf; + +import java.io.IOException; +import java.util.Map; + +public class IOManagerAsync extends IOManager { + + public IOManagerAsync(SparkConf conf, Map objectMap) { + super(conf, objectMap); + } + + @Override + DaosWriter getDaosWriter(int numPartitions, int shuffleId, long mapId) throws IOException { + return null; + } + + @Override + DaosReader getDaosReader(int shuffleId) throws IOException { + return null; + } + + @Override + void close() throws IOException { + + } +} diff --git a/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerSync.java b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerSync.java new file mode 100644 index 00000000..d49f79bf --- /dev/null +++ b/shuffle-daos/src/main/java/org/apache/spark/shuffle/daos/IOManagerSync.java @@ -0,0 +1,113 @@ +package org.apache.spark.shuffle.daos; + +import io.daos.obj.DaosObject; +import org.apache.spark.SparkConf; +import org.apache.spark.launcher.SparkLauncher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class IOManagerSync extends IOManager { + + private BoundThreadExecutors readerExes; + + private BoundThreadExecutors writerExes; + + private DaosReader.ReaderConfig readerConfig; + + private DaosWriter.WriterConfig writerConfig; + + private Map readerMap = new ConcurrentHashMap<>(); + + private Map writerMap = new ConcurrentHashMap<>(); + + private Logger logger = LoggerFactory.getLogger(IOManagerSync.class); + + public IOManagerSync(SparkConf conf, Map objectMap) { + super(conf, objectMap); + readerConfig = new DaosReader.ReaderConfig(); + writerConfig = new DaosWriter.WriterConfig(); + readerExes = createReaderExes(); + writerExes = createWriterExes(); + } + + private BoundThreadExecutors createWriterExes() { + if (writerConfig.isFromOtherThreads()) { + BoundThreadExecutors executors; + int threads = writerConfig.getThreads(); + if (threads == -1) { + threads = conf.getInt(SparkLauncher.EXECUTOR_CORES, 1); + } + executors = new BoundThreadExecutors("write_executors", threads, + new DaosWriterSync.WriteThreadFactory()); + logger.info("created BoundThreadExecutors with " + threads + " threads for write"); + return executors; + } + return null; + } + + private BoundThreadExecutors createReaderExes() { + if (readerConfig.isFromOtherThread()) { + BoundThreadExecutors executors; + int threads = (int)conf.get(package$.MODULE$.SHUFFLE_DAOS_READ_THREADS()); + if (threads == -1) { + threads = conf.getInt(SparkLauncher.EXECUTOR_CORES, 1); + } + executors = new BoundThreadExecutors("read_executors", threads, + new DaosReaderSync.ReadThreadFactory()); + logger.info("created BoundThreadExecutors with " + threads + " threads for read"); + return executors; + } + return null; + } + + @Override + public DaosWriterSync getDaosWriter(int numPartitions, int shuffleId, long mapId) throws IOException { + long appId = parseAppId(conf.getAppId()); + if (logger.isDebugEnabled()) { + logger.debug("getting daoswriter for app id: " + appId + ", shuffle id: " + shuffleId + ", map id: " + mapId + + ", numPartitions: " + numPartitions); + } + DaosWriterSync.WriteParam param = new DaosWriterSync.WriteParam(); + param.numPartitions(numPartitions) + .shuffleId(shuffleId) + .mapId(mapId) + .config(writerConfig); + DaosWriterSync writer = new DaosWriterSync(getObject(appId, shuffleId), param, + writerExes == null ? null : writerExes.nextExecutor()); + writer.setWriterMap(writerMap); + return writer; + } + + @Override + DaosReader getDaosReader(int shuffleId) throws IOException { + long appId = parseAppId(conf.getAppId()); + if (logger.isDebugEnabled()) { + logger.debug("getting daosreader for app id: " + appId + ", shuffle id: " + shuffleId); + } + DaosReaderSync reader = new DaosReaderSync(getObject(appId, shuffleId), readerConfig, + readerExes == null ? null : readerExes.nextExecutor()); + reader.setReaderMap(readerMap); + return reader; + } + + @Override + void close() throws IOException { + if (readerExes != null) { + readerExes.stop(); + readerExes = null; + } + readerMap.keySet().forEach(r -> r.close(true)); + readerMap.clear(); + if (writerExes != null) { + writerExes.stop(); + writerExes = null; + } + writerMap.keySet().forEach(r -> r.close()); + writerMap.clear(); + objClient.forceClose(); + } +} diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/MapPartitionsWriter.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/MapPartitionsWriter.scala index d8ebb83e..fc7dd323 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/MapPartitionsWriter.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/MapPartitionsWriter.scala @@ -29,7 +29,6 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap class MapPartitionsWriter[K, V, C]( shuffleId: Int, @@ -88,7 +87,6 @@ class MapPartitionsWriter[K, V, C]( def insertAll(records: Iterator[Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high - val start = System.nanoTime(); val shouldCombine = aggregator.isDefined if (shouldCombine) { // Combine values in-memory first using our AppendOnlyMap diff --git a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/package.scala b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/package.scala index f38f68a9..91f8b181 100644 --- a/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/package.scala +++ b/shuffle-daos/src/main/scala/org/apache/spark/shuffle/daos/package.scala @@ -42,12 +42,6 @@ package object daos { .stringConf .createWithDefault(null) - val SHUFFLE_DAOS_POOL_RANKS = - ConfigBuilder("spark.shuffle.daos.ranks") - .version("3.0.0") - .stringConf - .createWithDefault("0") - val SHUFFLE_DAOS_REMOVE_SHUFFLE_DATA = ConfigBuilder("spark.shuffle.remove.shuffle.data") .doc("remove shuffle data from DAOS after shuffle completed. Default is true") @@ -146,6 +140,15 @@ package object daos { s"The DAOS write max bytes in flight must be positive") .createWithDefaultString("20480k") + val SHUFFLE_DAOS_IO_ASYNC = + ConfigBuilder("spark.shuffle.daos.io.async") + .doc("perform shuffle IO asynchronously. Default is true") + .version("3.0.0") + .booleanConf + .createWithDefault(true) + + /* =====configs below for DAOS synchronous API===== */ + val SHUFFLE_DAOS_READ_THREADS = ConfigBuilder("spark.shuffle.daos.read.threads") .doc("number of threads for each executor to read shuffle data concurrently. -1 means use number of executor " + diff --git a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleIOTest.java b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleIOTest.java index 11bbdc6e..42135162 100644 --- a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleIOTest.java +++ b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleIOTest.java @@ -26,7 +26,9 @@ import io.daos.obj.DaosObjClient; import io.daos.obj.DaosObject; import io.daos.obj.DaosObjectId; +import org.apache.hadoop.security.UserGroupInformation; import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -46,7 +48,7 @@ @RunWith(PowerMockRunner.class) @PowerMockIgnore("javax.management.*") -@PrepareForTest(DaosShuffleIO.class) +@PrepareForTest(IOManager.class) @SuppressStaticInitializationFor("io.daos.obj.DaosObjClient") public class DaosShuffleIOTest { @@ -54,53 +56,63 @@ public class DaosShuffleIOTest { public void testSingleObjectInstanceOpen() throws Exception { SparkConf testConf = new SparkConf(false); testConf.set(package$.MODULE$.SHUFFLE_DAOS_READ_FROM_OTHER_THREAD(), false); + testConf.set(package$.MODULE$.SHUFFLE_DAOS_IO_ASYNC(), false); long appId = 1234567; int shuffleId = 1; testConf.set("spark.app.id", String.valueOf(appId)); - Field clientField = DaosShuffleIO.class.getDeclaredField("objClient"); + Field clientField = IOManager.class.getDeclaredField("objClient"); clientField.setAccessible(true); - DaosShuffleIO io = new DaosShuffleIO(testConf); - DaosObjectId id = PowerMockito.mock(DaosObjectId.class); - PowerMockito.whenNew(DaosObjectId.class).withArguments(appId, Long.valueOf(shuffleId)).thenReturn(id); - Mockito.doNothing().when(id).encode(); - Mockito.when(id.isEncoded()).thenReturn(true); - DaosObject daosObject = PowerMockito.mock(DaosObject.class); - DaosObjClient client = PowerMockito.mock(DaosObjClient.class); - Mockito.when(client.getObject(id)).thenReturn(daosObject); + UserGroupInformation.setLoginUser(UserGroupInformation.createRemoteUser("test")); + SparkContext sc = new SparkContext("local", "test", testConf); - AtomicBoolean open = new AtomicBoolean(false); - Mockito.when(daosObject.isOpen()).then(invocationOnMock -> - open.get() - ); - Mockito.doAnswer(invocationOnMock -> { - open.compareAndSet(false, true); - return invocationOnMock; - }).when(daosObject).open(); - clientField.set(io, client); + try { - int numThreads = 50; - ExecutorService es = Executors.newFixedThreadPool(numThreads); - AtomicInteger count = new AtomicInteger(0); + DaosShuffleIO io = new DaosShuffleIO(testConf); - Runnable r = () -> { - try { - DaosReader reader = io.getDaosReader(shuffleId); - if (reader.getObject() == daosObject && reader.getObject().isOpen()) { - count.incrementAndGet(); + DaosObjectId id = PowerMockito.mock(DaosObjectId.class); + PowerMockito.whenNew(DaosObjectId.class).withArguments(appId, Long.valueOf(shuffleId)).thenReturn(id); + Mockito.doNothing().when(id).encode(); + Mockito.when(id.isEncoded()).thenReturn(true); + DaosObject daosObject = PowerMockito.mock(DaosObject.class); + DaosObjClient client = PowerMockito.mock(DaosObjClient.class); + Mockito.when(client.getObject(id)).thenReturn(daosObject); + + AtomicBoolean open = new AtomicBoolean(false); + Mockito.when(daosObject.isOpen()).then(invocationOnMock -> + open.get() + ); + Mockito.doAnswer(invocationOnMock -> { + open.compareAndSet(false, true); + return invocationOnMock; + }).when(daosObject).open(); + clientField.set(io.getIoManager(), client); + + int numThreads = 50; + ExecutorService es = Executors.newFixedThreadPool(numThreads); + AtomicInteger count = new AtomicInteger(0); + + Runnable r = () -> { + try { + DaosReader reader = io.getDaosReader(shuffleId); + if (reader.getObject() == daosObject && reader.getObject().isOpen()) { + count.incrementAndGet(); + } + } catch (Exception e) { + e.printStackTrace(); } - } catch (Exception e) { - e.printStackTrace(); - } - }; + }; - for (int i = 0; i < numThreads; i++) { - es.submit(r); - } + for (int i = 0; i < numThreads; i++) { + es.submit(r); + } - es.shutdown(); - es.awaitTermination(5, TimeUnit.SECONDS); + es.shutdown(); + es.awaitTermination(5, TimeUnit.SECONDS); - Assert.assertEquals(50, count.intValue()); + Assert.assertEquals(50, count.intValue()); + } finally { + sc.stop(); + } } } diff --git a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleInputStreamTest.java b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleInputStreamTest.java index 8b47c6d8..35be8d53 100644 --- a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleInputStreamTest.java +++ b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosShuffleInputStreamTest.java @@ -129,7 +129,7 @@ public void testReadFromOtherThreadCancelMultipleTimesLongWait() throws Exceptio Map maps = new HashMap<>(); maps.put("2", new AtomicInteger(0)); maps.put("4", new AtomicInteger(0)); - readFromOtherThreadCancelMultipleTimes(maps, 400); + readFromOtherThreadCancelMultipleTimes(maps, 800); } @Test @@ -175,9 +175,9 @@ public void readFromOtherThreadCancelMultipleTimes(Map ma IODataDesc.Entry entry = desc.getEntry(0); String mapId = entry.getKey(); if (maps.containsKey(mapId)) { -// Thread thread = maps.get(mapId); + // Thread thread = maps.get(mapId); if (callerThread != Thread.currentThread()) { -// wait.incrementAndGet(); + // wait.incrementAndGet(); // sleep to cause read timeout System.out.println("sleeping at " + mapId); Thread.sleep(waitDataTimeMs + addWaitTimeMs); @@ -317,8 +317,8 @@ private void read(int maps, Answer answer, Mockito.doAnswer(answer).when(daosObject).fetch(any(IODataDesc.class)); BoundThreadExecutors executors = new BoundThreadExecutors("read_executors", 1, - new DaosReader.ReadThreadFactory()); - DaosReader daosReader = new DaosReader(daosObject, executors); + new DaosReaderSync.ReadThreadFactory()); + DaosReaderSync daosReader = new DaosReaderSync(daosObject, new DaosReader.ReaderConfig(), executors.nextExecutor()); LinkedHashMap, Tuple3> partSizeMap = new LinkedHashMap<>(); int shuffleId = 10; int reduceId = 1; @@ -355,7 +355,7 @@ private void read(int maps, Answer answer, System.out.println("total fetch wait time: " + taskContext.taskMetrics().shuffleReadMetrics()._fetchWaitTime().sum()); } finally { - daosReader.close(); + daosReader.close(true); is.close(true); context.stop(); if (executors != null) { diff --git a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterTest.java b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterTest.java index 60c45269..f25ac74d 100644 --- a/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterTest.java +++ b/shuffle-daos/src/test/java/org/apache/spark/shuffle/daos/DaosWriterTest.java @@ -27,8 +27,12 @@ import io.daos.obj.DaosObject; import io.daos.obj.DaosObjectId; import io.daos.obj.IODataDesc; +import org.apache.hadoop.security.UserGroupInformation; import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.junit.AfterClass; import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mockito; @@ -51,16 +55,26 @@ @SuppressStaticInitializationFor("io.daos.obj.DaosObjClient") public class DaosWriterTest { + private static SparkConf testConf = new SparkConf(false); + + private static SparkContext sc; + + @BeforeClass + public static void initialize() { + UserGroupInformation.setLoginUser(UserGroupInformation.createRemoteUser("test")); + sc = new SparkContext("local", "test", testConf); + } + @Test public void testGetLensWithAllEmptyPartitions() { - DaosWriter.WriteConfig writeConfig = DaosShuffleIO.loadWriteConfig(new SparkConf(false)); - DaosWriter.WriteParam param = new DaosWriter.WriteParam(); + DaosWriter.WriterConfig writeConfig = new DaosWriter.WriterConfig(); + DaosWriterSync.WriteParam param = new DaosWriterSync.WriteParam(); int numPart = 10; param.numPartitions(numPart) .shuffleId(1) .mapId(1) .config(writeConfig); - DaosWriter writer = new DaosWriter(param, null, null); + DaosWriterSync writer = new DaosWriterSync(null, param, null); long[] lens = writer.getPartitionLens(numPart); Assert.assertEquals(numPart, lens.length); for (int i = 0; i < numPart; i++) { @@ -81,14 +95,14 @@ public void testGetLensWithPartialEmptyPartitions() throws Exception { Mockito.doNothing().when(daosObject).update(any(IODataDesc.class)); - DaosWriter.WriteConfig writeConfig = DaosShuffleIO.loadWriteConfig(new SparkConf(false)); - DaosWriter.WriteParam param = new DaosWriter.WriteParam(); + DaosWriter.WriterConfig writeConfig = new DaosWriter.WriterConfig(); + DaosWriterSync.WriteParam param = new DaosWriterSync.WriteParam(); int numPart = 10; param.numPartitions(numPart) .shuffleId(1) .mapId(1) .config(writeConfig); - DaosWriter writer = new DaosWriter(param, daosObject, null); + DaosWriterSync writer = new DaosWriterSync(daosObject, param, null); Map expectedLens = new HashMap<>(); Random random = new Random(); for (int i = 0; i < 5; i++) { @@ -133,8 +147,8 @@ public void testWriteTaskFailed() throws Exception { return invoc; }).when(daosObject).update(any(IODataDesc.class)); - DaosWriter.WriteConfig writeConfig = DaosShuffleIO.loadWriteConfig(new SparkConf(false)); - DaosWriter.WriteParam param = new DaosWriter.WriteParam(); + DaosWriter.WriterConfig writeConfig = new DaosWriter.WriterConfig(); + DaosWriterSync.WriteParam param = new DaosWriterSync.WriteParam(); int numPart = 10; param.numPartitions(numPart) .shuffleId(1) @@ -142,8 +156,8 @@ public void testWriteTaskFailed() throws Exception { .config(writeConfig); BoundThreadExecutors executors = new BoundThreadExecutors("read_executors", 1, - new DaosReader.ReadThreadFactory()); - DaosWriter writer = new DaosWriter(param, daosObject, executors.nextExecutor()); + new DaosReaderSync.ReadThreadFactory()); + DaosWriterSync writer = new DaosWriterSync(daosObject, param, executors.nextExecutor()); for (int i = 0; i < numPart; i++) { writer.write(i, new byte[100]); writer.flush(i); @@ -153,4 +167,11 @@ public void testWriteTaskFailed() throws Exception { executors.stop(); } + + @AfterClass + public static void teardown() { + if (sc != null) { + sc.stop(); + } + } } diff --git a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleReaderSuite.scala b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleReaderSuite.scala index 6a9591ad..a6975a2b 100644 --- a/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleReaderSuite.scala +++ b/shuffle-daos/src/test/scala/org/apache/spark/shuffle/daos/DaosShuffleReaderSuite.scala @@ -35,21 +35,22 @@ import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.shuffle.daos.DaosReader.ReaderConfig import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class DaosShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { override def beforeAll(): Unit = { super.beforeAll() - logInfo("start executors in DaosReader " + classOf[DaosReader]) + logInfo("start executors in DaosReaderSync " + classOf[DaosReaderSync]) MockitoAnnotations.initMocks(this) } private def mockObjectsForSingleDaosCall(reduceId: Int, numMaps: Int, byteOutputStream: ByteArrayOutputStream): (DaosReader, DaosShuffleIO, DaosObject) = { // mock - val daosReader: DaosReader = Mockito.mock(classOf[DaosReader]) val daosObject = Mockito.mock(classOf[DaosObject]) + val daosReader: DaosReaderSync = new DaosReaderSync(daosObject, new ReaderConfig(), null) val shuffleIO = Mockito.mock(classOf[DaosShuffleIO]) val desc = Mockito.mock(classOf[IODataDesc]) @@ -74,8 +75,10 @@ class DaosShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { (DaosReader, DaosShuffleIO, DaosObject) = { // mock val daosObject = Mockito.mock(classOf[DaosObject]) - val daosReader: DaosReader = - if (executors != null) Mockito.spy(new DaosReader(daosObject, executors)) else Mockito.mock(classOf[DaosReader]) + val daosReader: DaosReaderSync = + if (executors != null) Mockito.spy(new DaosReaderSync(daosObject, new DaosReader.ReaderConfig(), + executors.nextExecutor())) + else new DaosReaderSync(daosObject, new ReaderConfig(), null) val shuffleIO = Mockito.mock(classOf[DaosShuffleIO]) val descList = new util.ArrayList[IODataDesc] @@ -177,7 +180,7 @@ class DaosShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { } when(shuffleIO.getDaosReader(shuffleId)).thenReturn(daosReader) - when(daosReader.getObject).thenReturn(daosObject) + // when(daosReader.getObject).thenReturn(daosObject) val shuffleReader = new DaosShuffleReader[Int, Int]( shuffleHandle, @@ -201,12 +204,12 @@ class DaosShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { } test("test reader daos multiple times") { - testRead(7168, 4, false) + testRead(1024, 4, false) } test("test reader daos multiple times from other thread") { - val executors = new BoundThreadExecutors("read_executors", 1, new DaosReader.ReadThreadFactory) - testRead(7168, 6, false, executors) + val executors = new BoundThreadExecutors("read_executors", 1, new DaosReaderSync.ReadThreadFactory) + testRead(1024, 6, false, executors) executors.stop() } } diff --git a/shuffle-hadoop/src/main/java/org/apache/spark/shuffle/sort/RemoteBypassMergeSortShuffleWriter.java b/shuffle-hadoop/src/main/java/org/apache/spark/shuffle/sort/RemoteBypassMergeSortShuffleWriter.java index 71934d10..8b06b7d2 100644 --- a/shuffle-hadoop/src/main/java/org/apache/spark/shuffle/sort/RemoteBypassMergeSortShuffleWriter.java +++ b/shuffle-hadoop/src/main/java/org/apache/spark/shuffle/sort/RemoteBypassMergeSortShuffleWriter.java @@ -38,7 +38,6 @@ import org.slf4j.LoggerFactory; import org.apache.spark.*; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; diff --git a/shuffle-hadoop/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriter.java b/shuffle-hadoop/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriter.java index f87d35d1..0e0ba91c 100644 --- a/shuffle-hadoop/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriter.java +++ b/shuffle-hadoop/src/main/java/org/apache/spark/shuffle/sort/RemoteUnsafeShuffleWriter.java @@ -44,7 +44,6 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.annotation.Private; -import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$;