Skip to content

Commit

Permalink
[djl-serving] Add dynamic batch feature back
Browse files Browse the repository at this point in the history
Change-Id: I5a70eb894e749f02377a534509881ce25211c2e8
  • Loading branch information
frankfliu committed Aug 9, 2021
1 parent 0df8265 commit f58f351
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 33 deletions.
20 changes: 9 additions & 11 deletions api/src/main/java/ai/djl/modality/Output.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,24 @@ public class Output {
private byte[] content;

/**
* Constructs a {@code Output} with specified {@code requestId}.
* Constructs a {@code Output} with specified {@code requestId}, {@code code} and {@code
* message}.
*
* @param requestId the requestId of the output
* @param code the status code of the output
* @param message the status message of the output
*/
public Output(String requestId) {
this.requestId = requestId;
public Output(int code, String message) {
this.code = code;
this.message = message;
}

/**
* Constructs a {@code Output} with specified {@code requestId}, {@code code} and {@code
* message}.
* Sets the requestId of the output.
*
* @param requestId the requestId of the output
* @param code the status code of the output
* @param message the status message of the output
*/
public Output(String requestId, int code, String message) {
public void setRequestId(String requestId) {
this.requestId = requestId;
this.code = code;
this.message = message;
}

/**
Expand Down
21 changes: 14 additions & 7 deletions api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> arguments
return getSsdTranslator(arguments);
}
}
return new RawTranslator();
String batchifier = (String) arguments.get("batchifier");
if (batchifier == null) {
return new RawTranslator(null);
}
return new RawTranslator(Batchifier.fromString(batchifier));
}

private Translator<Input, Output> getImageClassificationTranslator(Map<String, ?> arguments) {
Expand Down Expand Up @@ -239,8 +243,7 @@ public Batchifier getBatchifier() {
/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Input input = (Input) ctx.getAttachment("input");
Output output = new Output(input.getRequestId(), 200, "OK");
Output output = new Output(200, "OK");
Object obj = translator.processOutput(ctx, list);
if (obj instanceof JsonSerializable) {
output.setContent(((JsonSerializable) obj).toJson() + '\n');
Expand Down Expand Up @@ -280,16 +283,21 @@ public void prepare(NDManager manager, Model model) throws IOException {

private static final class RawTranslator implements Translator<Input, Output> {

private Batchifier batchifier;

RawTranslator(Batchifier batchifier) {
this.batchifier = batchifier;
}

/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
return null;
return batchifier;
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws TranslateException {
ctx.setAttachment("input", input);
PairList<String, byte[]> inputs = input.getContent();
byte[] data = inputs.get("data");
if (data == null) {
Expand All @@ -309,8 +317,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws TranslateE
/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) {
Input input = (Input) ctx.getAttachment("input");
Output output = new Output(input.getRequestId(), 200, "OK");
Output output = new Output(200, "OK");
output.setContent(list.encode());
output.addProperty("Content-Type", "tensor/ndlist");
return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public void testSsdTranslator() throws IOException, ModelException, TranslateExc
Input input = new Input("1");
input.addData(buf);
Output output = predictor.predict(input);
Assert.assertEquals(output.getRequestId(), "1");
Assert.assertEquals(output.getCode(), 200);
String content = new String(output.getContent(), StandardCharsets.UTF_8);
Type type = new TypeToken<List<Classification>>() {}.getType();
List<Classification> result = JsonUtils.GSON.fromJson(content, type);
Expand All @@ -217,7 +217,7 @@ private void runImageClassification(Application application, Map<String, Object>
Input input = new Input("1");
input.addData("body", data);
Output output = predictor.predict(input);
Assert.assertEquals(output.getRequestId(), "1");
Assert.assertEquals(output.getCode(), 200);
String content = new String(output.getContent(), StandardCharsets.UTF_8);
Type type = new TypeToken<List<Classification>>() {}.getType();
List<Classification> result = JsonUtils.GSON.fromJson(content, type);
Expand Down Expand Up @@ -246,7 +246,7 @@ public void runRawTranslator() throws IOException, ModelException, TranslateExce
Input input = new Input("1");
input.addData(0, list.encode());
Output output = predictor.predict(input);
Assert.assertEquals(output.getRequestId(), "1");
Assert.assertEquals(output.getCode(), 200);

// manually post process
list = NDList.decode(manager, output.getContent());
Expand Down
4 changes: 1 addition & 3 deletions integration/src/test/translator/MyTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ public class MyTranslator implements ServingTranslator {

@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
ctx.setAttachment("input", input);
byte[] data = input.getContent().valueAt(0);
ImageFactory factory = ImageFactory.getInstance();
Image image = factory.fromInputStream(new ByteArrayInputStream(data));
Expand All @@ -47,8 +46,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
probabilitiesNd = probabilitiesNd.softmax(0);
Classifications classifications = new Classifications(classes, probabilitiesNd);

Input input = (Input) ctx.getAttachment("input");
Output output = new Output(input.getRequestId(), 200, "OK");
Output output = new Output(200, "OK");
output.setContent(classifications.toJson());
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class ManagementRequestHandler extends HttpRequestHandler {
/** HTTP Parameter "max_batch_delay". */
private static final String MAX_BATCH_DELAY_PARAMETER = "max_batch_delay";
/** HTTP Parameter "max_idle_time". */
private static final String MAX_IDLE_TIME__PARAMETER = "max_idle_time";
private static final String MAX_IDLE_TIME_PARAMETER = "max_idle_time";
/** HTTP Parameter "max_worker". */
private static final String MAX_WORKER_PARAMETER = "max_worker";
/** HTTP Parameter "min_worker". */
Expand Down Expand Up @@ -166,7 +166,7 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
String engineName = NettyUtils.getParameter(decoder, ENGINE_NAME_PARAMETER, null);
int batchSize = NettyUtils.getIntParameter(decoder, BATCH_SIZE_PARAMETER, 1);
int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100);
int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME__PARAMETER, 60);
int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME_PARAMETER, 60);
int minWorkers = NettyUtils.getIntParameter(decoder, MIN_WORKER_PARAMETER, 1);
int defaultWorkers = ConfigManager.getInstance().getDefaultWorkers();
int maxWorkers = NettyUtils.getIntParameter(decoder, MAX_WORKER_PARAMETER, defaultWorkers);
Expand Down Expand Up @@ -239,7 +239,7 @@ private void handleScaleModel(

int maxIdleTime =
NettyUtils.getIntParameter(
decoder, MAX_IDLE_TIME__PARAMETER, modelInfo.getMaxIdleTime());
decoder, MAX_IDLE_TIME_PARAMETER, modelInfo.getMaxIdleTime());
int batchSize =
NettyUtils.getIntParameter(
decoder, BATCH_SIZE_PARAMETER, modelInfo.getBatchSize());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;

/**
* abstract class for all BatchAggregators. A batch aggregator check working queue and combines
Expand All @@ -28,6 +29,7 @@
abstract class BatchAggregator {

protected int batchSize;
protected int maxBatchDelay;
protected List<Job> jobs;
protected LinkedBlockingDeque<Job> jobQueue;

Expand All @@ -39,6 +41,7 @@ abstract class BatchAggregator {
*/
public BatchAggregator(ModelInfo model, LinkedBlockingDeque<Job> jobQueue) {
this.batchSize = model.getBatchSize();
this.maxBatchDelay = model.getMaxBatchDelay();
this.jobQueue = jobQueue;
jobs = new ArrayList<>();
}
Expand Down Expand Up @@ -72,11 +75,8 @@ public void sendResponse(List<Output> outputs) {

int i = 0;
for (Output output : outputs) {
String requestId = output.getRequestId();
Job job = jobs.get(i++);
if (!job.getRequestId().equals(requestId)) {
throw new IllegalStateException("Request response mismatched.");
}
output.setRequestId(job.getRequestId());
job.sendOutput(output);
}
jobs.clear();
Expand Down Expand Up @@ -111,4 +111,23 @@ public void sendError(HttpResponseStatus status, Throwable error) {
* temporary batch aggregator.
*/
public abstract boolean isFinished();

protected void drainTo(List<Job> list, int maxDelay) throws InterruptedException {
long begin = System.currentTimeMillis();
jobQueue.drainTo(list, batchSize - 1);
int remain = batchSize - list.size();
for (int i = 0; i < remain; ++i) {
Job job = jobQueue.poll(maxDelay, TimeUnit.MILLISECONDS);
if (job == null) {
break;
}
long end = System.currentTimeMillis();
maxDelay -= end - begin;
begin = end;
list.add(job);
if (maxDelay <= 0) {
break;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ public CompletableFuture<ModelInfo> registerModel(
} else {
logger.info("Loading model {} on {}.", modelName, Device.cpu());
}
if (batchSize > 1) {
builder.optArgument("batchifier", "stack");
}

ZooModel<Input, Output> model = builder.build().loadModel();
ModelInfo modelInfo =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ protected List<Job> pollBatch() throws InterruptedException {
List<Job> list = new ArrayList<>(batchSize);
Job job = jobQueue.take();
list.add(job);
jobQueue.drainTo(list, batchSize - 1);
logger.trace("get first job: {}", job.getRequestId());
drainTo(list, maxBatchDelay);
logger.trace("sending jobs, size: {}", list.size());
return list;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ protected List<Job> pollBatch() throws InterruptedException {
Job job = jobQueue.poll(maxIdleTime, TimeUnit.SECONDS);
if (job != null) {
list.add(job);
jobQueue.drainTo(list, batchSize - 1);
drainTo(list, maxBatchDelay);
logger.trace("sending jobs, size: {}", list.size());
idleSince = System.currentTimeMillis();
}
Expand Down

0 comments on commit f58f351

Please sign in to comment.