Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[djl-serving] Add dynamic batch feature back #1154

Merged
merged 1 commit into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions api/src/main/java/ai/djl/modality/Output.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,24 @@ public class Output {
private byte[] content;

/**
* Constructs a {@code Output} with specified {@code requestId}.
* Constructs a {@code Output} with specified {@code requestId}, {@code code} and {@code
* message}.
*
* @param requestId the requestId of the output
* @param code the status code of the output
* @param message the status message of the output
*/
public Output(String requestId) {
this.requestId = requestId;
public Output(int code, String message) {
this.code = code;
this.message = message;
}

/**
* Constructs a {@code Output} with specified {@code requestId}, {@code code} and {@code
* message}.
* Sets the requestId of the output.
*
* @param requestId the requestId of the output
* @param code the status code of the output
* @param message the status message of the output
*/
public Output(String requestId, int code, String message) {
public void setRequestId(String requestId) {
this.requestId = requestId;
this.code = code;
this.message = message;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> arguments
return getSsdTranslator(arguments);
}
}
return new RawTranslator();
String batchifier = (String) arguments.get("batchifier");
if (batchifier == null) {
return new RawTranslator(null);
}
return new RawTranslator(Batchifier.fromString(batchifier));
}

private Translator<Input, Output> getImageClassificationTranslator(Map<String, ?> arguments) {
Expand Down Expand Up @@ -239,8 +243,7 @@ public Batchifier getBatchifier() {
/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Input input = (Input) ctx.getAttachment("input");
Output output = new Output(input.getRequestId(), 200, "OK");
Output output = new Output(200, "OK");
Object obj = translator.processOutput(ctx, list);
if (obj instanceof JsonSerializable) {
output.setContent(((JsonSerializable) obj).toJson() + '\n');
Expand Down Expand Up @@ -280,16 +283,21 @@ public void prepare(NDManager manager, Model model) throws IOException {

private static final class RawTranslator implements Translator<Input, Output> {

private Batchifier batchifier;

RawTranslator(Batchifier batchifier) {
this.batchifier = batchifier;
}

/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
return null;
return batchifier;
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws TranslateException {
ctx.setAttachment("input", input);
PairList<String, byte[]> inputs = input.getContent();
byte[] data = inputs.get("data");
if (data == null) {
Expand All @@ -309,8 +317,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws TranslateE
/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) {
Input input = (Input) ctx.getAttachment("input");
Output output = new Output(input.getRequestId(), 200, "OK");
Output output = new Output(200, "OK");
output.setContent(list.encode());
output.addProperty("Content-Type", "tensor/ndlist");
return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public void testSsdTranslator() throws IOException, ModelException, TranslateExc
Input input = new Input("1");
input.addData(buf);
Output output = predictor.predict(input);
Assert.assertEquals(output.getRequestId(), "1");
Assert.assertEquals(output.getCode(), 200);
String content = new String(output.getContent(), StandardCharsets.UTF_8);
Type type = new TypeToken<List<Classification>>() {}.getType();
List<Classification> result = JsonUtils.GSON.fromJson(content, type);
Expand All @@ -217,7 +217,7 @@ private void runImageClassification(Application application, Map<String, Object>
Input input = new Input("1");
input.addData("body", data);
Output output = predictor.predict(input);
Assert.assertEquals(output.getRequestId(), "1");
Assert.assertEquals(output.getCode(), 200);
String content = new String(output.getContent(), StandardCharsets.UTF_8);
Type type = new TypeToken<List<Classification>>() {}.getType();
List<Classification> result = JsonUtils.GSON.fromJson(content, type);
Expand Down Expand Up @@ -246,7 +246,7 @@ public void runRawTranslator() throws IOException, ModelException, TranslateExce
Input input = new Input("1");
input.addData(0, list.encode());
Output output = predictor.predict(input);
Assert.assertEquals(output.getRequestId(), "1");
Assert.assertEquals(output.getCode(), 200);

// manually post process
list = NDList.decode(manager, output.getContent());
Expand Down
4 changes: 1 addition & 3 deletions integration/src/test/translator/MyTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ public class MyTranslator implements ServingTranslator {

@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
ctx.setAttachment("input", input);
byte[] data = input.getContent().valueAt(0);
ImageFactory factory = ImageFactory.getInstance();
Image image = factory.fromInputStream(new ByteArrayInputStream(data));
Expand All @@ -47,8 +46,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
probabilitiesNd = probabilitiesNd.softmax(0);
Classifications classifications = new Classifications(classes, probabilitiesNd);

Input input = (Input) ctx.getAttachment("input");
Output output = new Output(input.getRequestId(), 200, "OK");
Output output = new Output(200, "OK");
output.setContent(classifications.toJson());
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class ManagementRequestHandler extends HttpRequestHandler {
/** HTTP Parameter "max_batch_delay". */
private static final String MAX_BATCH_DELAY_PARAMETER = "max_batch_delay";
/** HTTP Parameter "max_idle_time". */
private static final String MAX_IDLE_TIME__PARAMETER = "max_idle_time";
private static final String MAX_IDLE_TIME_PARAMETER = "max_idle_time";
/** HTTP Parameter "max_worker". */
private static final String MAX_WORKER_PARAMETER = "max_worker";
/** HTTP Parameter "min_worker". */
Expand Down Expand Up @@ -166,7 +166,7 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
String engineName = NettyUtils.getParameter(decoder, ENGINE_NAME_PARAMETER, null);
int batchSize = NettyUtils.getIntParameter(decoder, BATCH_SIZE_PARAMETER, 1);
int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100);
int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME__PARAMETER, 60);
int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME_PARAMETER, 60);
int minWorkers = NettyUtils.getIntParameter(decoder, MIN_WORKER_PARAMETER, 1);
int defaultWorkers = ConfigManager.getInstance().getDefaultWorkers();
int maxWorkers = NettyUtils.getIntParameter(decoder, MAX_WORKER_PARAMETER, defaultWorkers);
Expand Down Expand Up @@ -239,7 +239,7 @@ private void handleScaleModel(

int maxIdleTime =
NettyUtils.getIntParameter(
decoder, MAX_IDLE_TIME__PARAMETER, modelInfo.getMaxIdleTime());
decoder, MAX_IDLE_TIME_PARAMETER, modelInfo.getMaxIdleTime());
int batchSize =
NettyUtils.getIntParameter(
decoder, BATCH_SIZE_PARAMETER, modelInfo.getBatchSize());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;

/**
* abstract class for all BatchAggregators. A batch aggregator check working queue and combines
Expand All @@ -28,6 +29,7 @@
abstract class BatchAggregator {

protected int batchSize;
protected int maxBatchDelay;
protected List<Job> jobs;
protected LinkedBlockingDeque<Job> jobQueue;

Expand All @@ -39,6 +41,7 @@ abstract class BatchAggregator {
*/
public BatchAggregator(ModelInfo model, LinkedBlockingDeque<Job> jobQueue) {
this.batchSize = model.getBatchSize();
this.maxBatchDelay = model.getMaxBatchDelay();
this.jobQueue = jobQueue;
jobs = new ArrayList<>();
}
Expand Down Expand Up @@ -72,11 +75,8 @@ public void sendResponse(List<Output> outputs) {

int i = 0;
for (Output output : outputs) {
String requestId = output.getRequestId();
Job job = jobs.get(i++);
if (!job.getRequestId().equals(requestId)) {
throw new IllegalStateException("Request response mismatched.");
}
output.setRequestId(job.getRequestId());
job.sendOutput(output);
}
jobs.clear();
Expand Down Expand Up @@ -111,4 +111,23 @@ public void sendError(HttpResponseStatus status, Throwable error) {
* temporary batch aggregator.
*/
public abstract boolean isFinished();

protected void drainTo(List<Job> list, int maxDelay) throws InterruptedException {
long begin = System.currentTimeMillis();
jobQueue.drainTo(list, batchSize - 1);
int remain = batchSize - list.size();
for (int i = 0; i < remain; ++i) {
Job job = jobQueue.poll(maxDelay, TimeUnit.MILLISECONDS);
if (job == null) {
break;
}
long end = System.currentTimeMillis();
maxDelay -= end - begin;
begin = end;
list.add(job);
if (maxDelay <= 0) {
break;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ public CompletableFuture<ModelInfo> registerModel(
} else {
logger.info("Loading model {} on {}.", modelName, Device.cpu());
}
if (batchSize > 1) {
builder.optArgument("batchifier", "stack");
}

ZooModel<Input, Output> model = builder.build().loadModel();
ModelInfo modelInfo =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ protected List<Job> pollBatch() throws InterruptedException {
List<Job> list = new ArrayList<>(batchSize);
Job job = jobQueue.take();
list.add(job);
jobQueue.drainTo(list, batchSize - 1);
logger.trace("get first job: {}", job.getRequestId());
drainTo(list, maxBatchDelay);
logger.trace("sending jobs, size: {}", list.size());
return list;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ protected List<Job> pollBatch() throws InterruptedException {
Job job = jobQueue.poll(maxIdleTime, TimeUnit.SECONDS);
if (job != null) {
list.add(job);
jobQueue.drainTo(list, batchSize - 1);
drainTo(list, maxBatchDelay);
logger.trace("sending jobs, size: {}", list.size());
idleSince = System.currentTimeMillis();
}
Expand Down