diff --git a/api/src/main/java/ai/djl/modality/Output.java b/api/src/main/java/ai/djl/modality/Output.java
index 965f12c07c0..be8f942c792 100644
--- a/api/src/main/java/ai/djl/modality/Output.java
+++ b/api/src/main/java/ai/djl/modality/Output.java
@@ -27,26 +27,24 @@ public class Output {
private byte[] content;
/**
- * Constructs a {@code Output} with specified {@code requestId}.
+ * Constructs a {@code Output} with specified {@code requestId}, {@code code} and {@code
+ * message}.
*
- * @param requestId the requestId of the output
+ * @param code the status code of the output
+ * @param message the status message of the output
*/
- public Output(String requestId) {
- this.requestId = requestId;
+ public Output(int code, String message) {
+ this.code = code;
+ this.message = message;
}
/**
- * Constructs a {@code Output} with specified {@code requestId}, {@code code} and {@code
- * message}.
+ * Sets the requestId of the output.
*
* @param requestId the requestId of the output
- * @param code the status code of the output
- * @param message the status message of the output
*/
- public Output(String requestId, int code, String message) {
+ public void setRequestId(String requestId) {
this.requestId = requestId;
- this.code = code;
- this.message = message;
}
/**
diff --git a/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java b/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java
index 2d5b707b3a0..c40e2b9acea 100644
--- a/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java
+++ b/api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java
@@ -189,7 +189,11 @@ private Translator loadDefaultTranslator(Map arguments
return getSsdTranslator(arguments);
}
}
- return new RawTranslator();
+ String batchifier = (String) arguments.get("batchifier");
+ if (batchifier == null) {
+ return new RawTranslator(null);
+ }
+ return new RawTranslator(Batchifier.fromString(batchifier));
}
private Translator getImageClassificationTranslator(Map arguments) {
@@ -239,8 +243,7 @@ public Batchifier getBatchifier() {
/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
- Input input = (Input) ctx.getAttachment("input");
- Output output = new Output(input.getRequestId(), 200, "OK");
+ Output output = new Output(200, "OK");
Object obj = translator.processOutput(ctx, list);
if (obj instanceof JsonSerializable) {
output.setContent(((JsonSerializable) obj).toJson() + '\n');
@@ -280,16 +283,21 @@ public void prepare(NDManager manager, Model model) throws IOException {
private static final class RawTranslator implements Translator {
+ private Batchifier batchifier;
+
+ RawTranslator(Batchifier batchifier) {
+ this.batchifier = batchifier;
+ }
+
/** {@inheritDoc} */
@Override
public Batchifier getBatchifier() {
- return null;
+ return batchifier;
}
/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws TranslateException {
- ctx.setAttachment("input", input);
PairList inputs = input.getContent();
byte[] data = inputs.get("data");
if (data == null) {
@@ -309,8 +317,7 @@ public NDList processInput(TranslatorContext ctx, Input input) throws TranslateE
/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) {
- Input input = (Input) ctx.getAttachment("input");
- Output output = new Output(input.getRequestId(), 200, "OK");
+ Output output = new Output(200, "OK");
output.setContent(list.encode());
output.addProperty("Content-Type", "tensor/ndlist");
return output;
diff --git a/integration/src/main/java/ai/djl/integration/tests/model_zoo/CustomTranslatorTest.java b/integration/src/main/java/ai/djl/integration/tests/model_zoo/CustomTranslatorTest.java
index 26136a7b887..464a0f80982 100644
--- a/integration/src/main/java/ai/djl/integration/tests/model_zoo/CustomTranslatorTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/model_zoo/CustomTranslatorTest.java
@@ -194,7 +194,7 @@ public void testSsdTranslator() throws IOException, ModelException, TranslateExc
Input input = new Input("1");
input.addData(buf);
Output output = predictor.predict(input);
- Assert.assertEquals(output.getRequestId(), "1");
+ Assert.assertEquals(output.getCode(), 200);
String content = new String(output.getContent(), StandardCharsets.UTF_8);
Type type = new TypeToken>() {}.getType();
List result = JsonUtils.GSON.fromJson(content, type);
@@ -217,7 +217,7 @@ private void runImageClassification(Application application, Map
Input input = new Input("1");
input.addData("body", data);
Output output = predictor.predict(input);
- Assert.assertEquals(output.getRequestId(), "1");
+ Assert.assertEquals(output.getCode(), 200);
String content = new String(output.getContent(), StandardCharsets.UTF_8);
Type type = new TypeToken>() {}.getType();
List result = JsonUtils.GSON.fromJson(content, type);
@@ -246,7 +246,7 @@ public void runRawTranslator() throws IOException, ModelException, TranslateExce
Input input = new Input("1");
input.addData(0, list.encode());
Output output = predictor.predict(input);
- Assert.assertEquals(output.getRequestId(), "1");
+ Assert.assertEquals(output.getCode(), 200);
// manually post process
list = NDList.decode(manager, output.getContent());
diff --git a/integration/src/test/translator/MyTranslator.java b/integration/src/test/translator/MyTranslator.java
index 34b745bf5e5..abae0d8cb85 100644
--- a/integration/src/test/translator/MyTranslator.java
+++ b/integration/src/test/translator/MyTranslator.java
@@ -28,7 +28,6 @@ public class MyTranslator implements ServingTranslator {
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
- ctx.setAttachment("input", input);
byte[] data = input.getContent().valueAt(0);
ImageFactory factory = ImageFactory.getInstance();
Image image = factory.fromInputStream(new ByteArrayInputStream(data));
@@ -47,8 +46,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
probabilitiesNd = probabilitiesNd.softmax(0);
Classifications classifications = new Classifications(classes, probabilitiesNd);
- Input input = (Input) ctx.getAttachment("input");
- Output output = new Output(input.getRequestId(), 200, "OK");
+ Output output = new Output(200, "OK");
output.setContent(classifications.toJson());
return output;
}
diff --git a/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java
index 088598d04e3..79d24fff0f3 100644
--- a/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java
+++ b/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java
@@ -55,7 +55,7 @@ public class ManagementRequestHandler extends HttpRequestHandler {
/** HTTP Parameter "max_batch_delay". */
private static final String MAX_BATCH_DELAY_PARAMETER = "max_batch_delay";
/** HTTP Parameter "max_idle_time". */
- private static final String MAX_IDLE_TIME__PARAMETER = "max_idle_time";
+ private static final String MAX_IDLE_TIME_PARAMETER = "max_idle_time";
/** HTTP Parameter "max_worker". */
private static final String MAX_WORKER_PARAMETER = "max_worker";
/** HTTP Parameter "min_worker". */
@@ -166,7 +166,7 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
String engineName = NettyUtils.getParameter(decoder, ENGINE_NAME_PARAMETER, null);
int batchSize = NettyUtils.getIntParameter(decoder, BATCH_SIZE_PARAMETER, 1);
int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100);
- int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME__PARAMETER, 60);
+ int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME_PARAMETER, 60);
int minWorkers = NettyUtils.getIntParameter(decoder, MIN_WORKER_PARAMETER, 1);
int defaultWorkers = ConfigManager.getInstance().getDefaultWorkers();
int maxWorkers = NettyUtils.getIntParameter(decoder, MAX_WORKER_PARAMETER, defaultWorkers);
@@ -239,7 +239,7 @@ private void handleScaleModel(
int maxIdleTime =
NettyUtils.getIntParameter(
- decoder, MAX_IDLE_TIME__PARAMETER, modelInfo.getMaxIdleTime());
+ decoder, MAX_IDLE_TIME_PARAMETER, modelInfo.getMaxIdleTime());
int batchSize =
NettyUtils.getIntParameter(
decoder, BATCH_SIZE_PARAMETER, modelInfo.getBatchSize());
diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java b/serving/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java
index daa9f11a45e..25ed8515d35 100644
--- a/serving/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java
+++ b/serving/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java
@@ -18,6 +18,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.TimeUnit;
/**
* abstract class for all BatchAggregators. A batch aggregator check working queue and combines
@@ -28,6 +29,7 @@
abstract class BatchAggregator {
protected int batchSize;
+ protected int maxBatchDelay;
protected List jobs;
protected LinkedBlockingDeque jobQueue;
@@ -39,6 +41,7 @@ abstract class BatchAggregator {
*/
public BatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) {
this.batchSize = model.getBatchSize();
+ this.maxBatchDelay = model.getMaxBatchDelay();
this.jobQueue = jobQueue;
jobs = new ArrayList<>();
}
@@ -72,11 +75,8 @@ public void sendResponse(List