From d8894f11dd29e12a8474ba34fdbf6a1c4b70d473 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sat, 7 Aug 2021 23:08:31 -0700 Subject: [PATCH] [djl-serving] Add dynamic batch feature back Change-Id: I5a70eb894e749f02377a534509881ce25211c2e8 --- api/src/main/java/ai/djl/modality/Output.java | 20 +++++++------- .../translate/ServingTranslatorFactory.java | 21 ++++++++++----- .../http/ManagementRequestHandler.java | 6 ++--- .../ai/djl/serving/wlm/BatchAggregator.java | 27 ++++++++++++++++--- .../java/ai/djl/serving/wlm/ModelManager.java | 3 +++ .../serving/wlm/PermanentBatchAggregator.java | 3 ++- .../serving/wlm/TemporaryBatchAggregator.java | 2 +- 7 files changed, 55 insertions(+), 27 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/Output.java b/api/src/main/java/ai/djl/modality/Output.java index 965f12c07c0..be8f942c792 100644 --- a/api/src/main/java/ai/djl/modality/Output.java +++ b/api/src/main/java/ai/djl/modality/Output.java @@ -27,26 +27,24 @@ public class Output { private byte[] content; /** - * Constructs a {@code Output} with specified {@code requestId}. + * Constructs a {@code Output} with specified {@code requestId}, {@code code} and {@code + * message}. * - * @param requestId the requestId of the output + * @param code the status code of the output + * @param message the status message of the output */ - public Output(String requestId) { - this.requestId = requestId; + public Output(int code, String message) { + this.code = code; + this.message = message; } /** - * Constructs a {@code Output} with specified {@code requestId}, {@code code} and {@code - * message}. + * Sets the requestId of the output. * * @param requestId the requestId of the output - * @param code the status code of the output - * @param message the status message of the output */ - public Output(String requestId, int code, String message) { + public void setRequestId(String requestId) { this.requestId = requestId; - this.code = code; - this.message = message; } /** diff --git a/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java b/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java index 2d5b707b3a0..c40e2b9acea 100644 --- a/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java +++ b/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java @@ -189,7 +189,11 @@ private Translator loadDefaultTranslator(Map arguments return getSsdTranslator(arguments); } } - return new RawTranslator(); + String batchifier = (String) arguments.get("batchifier"); + if (batchifier == null) { + return new RawTranslator(null); + } + return new RawTranslator(Batchifier.fromString(batchifier)); } private Translator getImageClassificationTranslator(Map arguments) { @@ -239,8 +243,7 @@ public Batchifier getBatchifier() { /** {@inheritDoc} */ @Override public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { - Input input = (Input) ctx.getAttachment("input"); - Output output = new Output(input.getRequestId(), 200, "OK"); + Output output = new Output(200, "OK"); Object obj = translator.processOutput(ctx, list); if (obj instanceof JsonSerializable) { output.setContent(((JsonSerializable) obj).toJson() + '\n'); @@ -280,16 +283,21 @@ public void prepare(NDManager manager, Model model) throws IOException { private static final class RawTranslator implements Translator { + private Batchifier batchifier; + + RawTranslator(Batchifier batchifier) { + this.batchifier = batchifier; + } + /** {@inheritDoc} */ @Override public Batchifier getBatchifier() { - return null; + return batchifier; } /** {@inheritDoc} */ @Override public NDList processInput(TranslatorContext ctx, Input input) throws TranslateException { - ctx.setAttachment("input", input); PairList inputs = input.getContent(); byte[] data = inputs.get("data"); if (data == null) { @@ -309,8 +317,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws TranslateE /** {@inheritDoc} */ @Override public Output processOutput(TranslatorContext ctx, NDList list) { - Input input = (Input) ctx.getAttachment("input"); - Output output = new Output(input.getRequestId(), 200, "OK"); + Output output = new Output(200, "OK"); output.setContent(list.encode()); output.addProperty("Content-Type", "tensor/ndlist"); return output; diff --git a/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java index 088598d04e3..79d24fff0f3 100644 --- a/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java +++ b/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java @@ -55,7 +55,7 @@ public class ManagementRequestHandler extends HttpRequestHandler { /** HTTP Parameter "max_batch_delay". */ private static final String MAX_BATCH_DELAY_PARAMETER = "max_batch_delay"; /** HTTP Parameter "max_idle_time". */ - private static final String MAX_IDLE_TIME__PARAMETER = "max_idle_time"; + private static final String MAX_IDLE_TIME_PARAMETER = "max_idle_time"; /** HTTP Parameter "max_worker". */ private static final String MAX_WORKER_PARAMETER = "max_worker"; /** HTTP Parameter "min_worker". */ @@ -166,7 +166,7 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec String engineName = NettyUtils.getParameter(decoder, ENGINE_NAME_PARAMETER, null); int batchSize = NettyUtils.getIntParameter(decoder, BATCH_SIZE_PARAMETER, 1); int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100); - int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME__PARAMETER, 60); + int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME_PARAMETER, 60); int minWorkers = NettyUtils.getIntParameter(decoder, MIN_WORKER_PARAMETER, 1); int defaultWorkers = ConfigManager.getInstance().getDefaultWorkers(); int maxWorkers = NettyUtils.getIntParameter(decoder, MAX_WORKER_PARAMETER, defaultWorkers); @@ -239,7 +239,7 @@ private void handleScaleModel( int maxIdleTime = NettyUtils.getIntParameter( - decoder, MAX_IDLE_TIME__PARAMETER, modelInfo.getMaxIdleTime()); + decoder, MAX_IDLE_TIME_PARAMETER, modelInfo.getMaxIdleTime()); int batchSize = NettyUtils.getIntParameter( decoder, BATCH_SIZE_PARAMETER, modelInfo.getBatchSize()); diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java b/serving/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java index daa9f11a45e..25ed8515d35 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; /** * abstract class for all BatchAggregators. A batch aggregator check working queue and combines @@ -28,6 +29,7 @@ abstract class BatchAggregator { protected int batchSize; + protected int maxBatchDelay; protected List jobs; protected LinkedBlockingDeque jobQueue; @@ -39,6 +41,7 @@ abstract class BatchAggregator { */ public BatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { this.batchSize = model.getBatchSize(); + this.maxBatchDelay = model.getMaxBatchDelay(); this.jobQueue = jobQueue; jobs = new ArrayList<>(); } @@ -72,11 +75,8 @@ public void sendResponse(List outputs) { int i = 0; for (Output output : outputs) { - String requestId = output.getRequestId(); Job job = jobs.get(i++); - if (!job.getRequestId().equals(requestId)) { - throw new IllegalStateException("Request response mismatched."); - } + output.setRequestId(job.getRequestId()); job.sendOutput(output); } jobs.clear(); @@ -111,4 +111,23 @@ public void sendError(HttpResponseStatus status, Throwable error) { * temporary batch aggregator. */ public abstract boolean isFinished(); + + protected void drainTo(List list, int maxDelay) throws InterruptedException { + long begin = System.currentTimeMillis(); + jobQueue.drainTo(list, batchSize - 1); + int remain = batchSize - list.size(); + for (int i = 0; i < remain; ++i) { + Job job = jobQueue.poll(maxDelay, TimeUnit.MILLISECONDS); + if (job == null) { + break; + } + long end = System.currentTimeMillis(); + maxDelay -= end - begin; + begin = end; + list.add(job); + if (maxDelay <= 0) { + break; + } + } + } } diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java b/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java index 470ce0f1802..539d964d77e 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java @@ -106,6 +106,9 @@ public CompletableFuture registerModel( } else { logger.info("Loading model {} on {}.", modelName, Device.cpu()); } + if (batchSize > 1) { + builder.optArgument("batchifier", "stack"); + } ZooModel model = builder.build().loadModel(); ModelInfo modelInfo = diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java b/serving/serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java index 0321cc8cf62..b7d5637fedc 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java @@ -44,7 +44,8 @@ protected List pollBatch() throws InterruptedException { List list = new ArrayList<>(batchSize); Job job = jobQueue.take(); list.add(job); - jobQueue.drainTo(list, batchSize - 1); + logger.trace("get first job: {}", job.getRequestId()); + drainTo(list, maxBatchDelay); logger.trace("sending jobs, size: {}", list.size()); return list; } diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java b/serving/serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java index 6e990a72e70..e3dc134859f 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java @@ -50,7 +50,7 @@ protected List pollBatch() throws InterruptedException { Job job = jobQueue.poll(maxIdleTime, TimeUnit.SECONDS); if (job != null) { list.add(job); - jobQueue.drainTo(list, batchSize - 1); + drainTo(list, maxBatchDelay); logger.trace("sending jobs, size: {}", list.size()); idleSince = System.currentTimeMillis(); }