Skip to content

Commit

Permalink
[serving] make input header key case-insensitive (#1134)
Browse files Browse the repository at this point in the history
1. make it easy to get content-type header
2. improve error handling

Change-Id: I12a1ea9f12e3565649d0f218870fff6d111bc2fe
  • Loading branch information
frankfliu authored Aug 2, 2021
1 parent 91cce8e commit 14bf240
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 7 deletions.
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/modality/Input.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.modality;

import ai.djl.util.PairList;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand Down Expand Up @@ -82,7 +83,7 @@ public void addProperty(String key, String value) {
* @return the value to which the specified key is mapped
*/
public String getProperty(String key, String defaultValue) {
return properties.getOrDefault(key, defaultValue);
return properties.getOrDefault(key.toLowerCase(Locale.ROOT), defaultValue);
}

/**
Expand Down
10 changes: 8 additions & 2 deletions api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) throws Exception
} else {
output.setContent(JsonUtils.GSON_PRETTY.toJson(obj) + '\n');
}
output.addProperty("Content-Type", "application/json");
return output;
}

Expand All @@ -262,8 +263,12 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
if (data == null) {
data = input.getContent().valueAt(0);
}
Image image = factory.fromInputStream(new ByteArrayInputStream(data));
return translator.processInput(ctx, image);
try {
Image image = factory.fromInputStream(new ByteArrayInputStream(data));
return translator.processInput(ctx, image);
} catch (IOException e) {
throw new TranslateException("Input is not an Image data type", e);
}
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -307,6 +312,7 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
Input input = (Input) ctx.getAttachment("input");
Output output = new Output(input.getRequestId(), 200, "OK");
output.setContent(list.encode());
output.addProperty("Content-Type", "tensor/ndlist");
return output;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -58,9 +59,13 @@ public Input parseRequest(
}
}

CharSequence contentType = HttpUtil.getMimeType(req);
for (Map.Entry<String, String> entry : req.headers().entries()) {
input.addProperty(entry.getKey(), entry.getValue());
input.addProperty(entry.getKey().toLowerCase(Locale.ROOT), entry.getValue());
}
CharSequence contentType = HttpUtil.getMimeType(req);
if (contentType != null) {
// use normalized content type
input.addProperty("content-type", contentType.toString());
}

if (HttpPostRequestDecoder.isMultipart(req)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.serving.wlm;

import ai.djl.engine.EngineException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
Expand Down Expand Up @@ -67,13 +68,17 @@ public void run() {
currentThread.set(thread);
this.state = WorkerState.WORKER_STARTED;
List<Input> req = null;
String errorMessage = "Worker shutting down";
try {
while (isRunning() && !aggregator.isFinished()) {
req = aggregator.getRequest();
if (req != null && !req.isEmpty()) {
try {
List<Output> reply = predictor.batchPredict(req);
aggregator.sendResponse(reply);
} catch (EngineException e) {
logger.warn("Failed to predict", e);
aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e);
} catch (TranslateException e) {
logger.warn("Failed to predict", e);
aggregator.sendError(HttpResponseStatus.BAD_REQUEST, e);
Expand All @@ -85,12 +90,13 @@ public void run() {
logger.debug("Shutting down the thread .. Scaling down.");
} catch (Throwable t) {
logger.error("Server error", t);
errorMessage = t.getMessage();
} finally {
logger.debug("Shutting down worker thread .. {}", currentThread.get().getName());
currentThread.set(null);
shutdown(WorkerState.WORKER_STOPPED);
if (req != null) {
Exception e = new InternalServerException("Server shutting down");
Exception e = new InternalServerException(errorMessage);
aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e);
}
}
Expand Down Expand Up @@ -122,7 +128,7 @@ public void shutdown(WorkerState state) {
Thread thread = currentThread.getAndSet(null);
if (thread != null) {
thread.interrupt();
Exception e = new InternalServerException("Server shutting down");
Exception e = new InternalServerException("Worker shutting down");
aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e);
}
predictor.close();
Expand Down

0 comments on commit 14bf240

Please sign in to comment.