From 5c27dbac4468e3e171d88860fa2b534f905c8e45 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 10 Mar 2022 17:37:20 -0800 Subject: [PATCH] [serving] Refactor ModelDefinition class (#67) * [serving] Refactor ModelDefinition class * [serving] Add async model loading on startup * [serving] Start HTTP listener while model is loading --- .../main/java/ai/djl/serving/ModelServer.java | 59 +++-- .../serving/http/InferenceRequestHandler.java | 39 ++-- .../djl/serving/http/ListModelsResponse.java | 19 +- .../http/ManagementRequestHandler.java | 73 +++--- .../ai/djl/serving/models/ModelManager.java | 217 +++++++----------- .../ai/djl/serving/util/ConfigManager.java | 20 ++ .../ai/djl/serving/workflow/Workflow.java | 62 ++--- .../serving/workflow/WorkflowDefinition.java | 209 +++-------------- .../function/ModelWorkflowFunction.java | 2 +- .../java/ai/djl/serving/ModelServerTest.java | 6 +- .../java/ai/djl/serving/WorkflowTest.java | 29 +-- .../test/resources/workflows/criteria.json | 5 +- .../java/ai/djl/serving/wlm/ModelInfo.java | 217 ++++++++++++++++-- .../ai/djl/serving/wlm/WorkLoadManager.java | 19 +- .../java/ai/djl/serving/wlm/WorkerThread.java | 8 +- .../wlm/util/WlmCapacityException.java | 2 +- .../wlm/util/WlmShutdownException.java | 2 +- .../ai/djl/serving/wlm/ModelInfoTest.java | 5 +- 18 files changed, 534 insertions(+), 459 deletions(-) diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index 51c38e2f870..5a2f556a411 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -176,6 +176,11 @@ public List start() futures.add(initializeServer(managementConnector, serverGroup, workerGroup)); } + if (stopped.get()) { + // check if model load failed in wait loading model case + stop(); + } + return futures; } @@ -190,10 +195,7 @@ public boolean isRunning() { /** Stops the model server. */ public void stop() { - if (stopped.get()) { - return; - } - + logger.info("Stopping model server."); stopped.set(true); for (ChannelFuture future : futures) { future.channel().close(); @@ -261,7 +263,6 @@ private ChannelFuture initializeServer( } private void initModelStore() throws IOException { - ModelManager.init(configManager); Set startupModels = ModelManager.getInstance().getStartupModels(); String loadModels = configManager.getLoadModels(); @@ -311,10 +312,10 @@ private void initModelStore() throws IOException { String version = null; String engine = null; String[] devices = {"-1"}; - String modelName; + String workflowName; if (endpoint != null) { String[] tokens = endpoint.split(":", -1); - modelName = tokens[0]; + workflowName = tokens[0]; if (tokens.length > 1) { version = tokens[1].isEmpty() ? null : tokens[1]; } @@ -336,13 +337,12 @@ private void initModelStore() throws IOException { .mapToObj(i -> "nc" + i) .toArray(String[]::new); } - } else if (!tokens[3].isEmpty()) { devices = tokens[3].split(";"); } } } else { - modelName = ModelInfo.inferModelNameFromUrl(modelUrl); + workflowName = ModelInfo.inferModelNameFromUrl(modelUrl); } if (engine == null) { engine = inferEngineFromUrl(modelUrl); @@ -350,6 +350,7 @@ private void initModelStore() throws IOException { for (int i = 0; i < devices.length; ++i) { String modelVersion; + String device = devices[i]; if (devices.length > 1) { if (version == null) { modelVersion = "v" + i; @@ -359,20 +360,40 @@ private void initModelStore() throws IOException { } else { modelVersion = version; } - CompletableFuture future = - modelManager.registerWorkflow( - modelName, - modelVersion, + ModelInfo modelInfo = + new ModelInfo( + workflowName, modelUrl, + modelVersion, engine, - devices[i], - configManager.getBatchSize(), + configManager.getJobQueueSize(), + configManager.getMaxIdleTime(), configManager.getMaxBatchDelay(), - configManager.getMaxIdleTime()); - Workflow workflow = future.join(); - modelManager.scaleWorkers(workflow, devices[i], 1, -1); + configManager.getBatchSize()); + Workflow workflow = new Workflow(modelInfo); + + CompletableFuture f = + modelManager + .registerWorkflow(workflow, device) + .thenAccept(v -> modelManager.scaleWorkers(workflow, device, 1, -1)) + .exceptionally( + t -> { + logger.error("Failed register workflow", t); + // delay 3 seconds, allows REST API to send PING + // response (health check) + try { + Thread.sleep(3000); + } catch (InterruptedException ignore) { + // ignore + } + stop(); + return null; + }); + if (configManager.waitModelLoading()) { + f.join(); + } } - startupModels.add(modelName); + startupModels.add(workflowName); } } diff --git a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java index 56ac7642f10..c2636182c1d 100644 --- a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java @@ -20,8 +20,8 @@ import ai.djl.serving.models.ModelManager; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.NettyUtils; -import ai.djl.serving.wlm.util.WlmCapacityException; -import ai.djl.serving.wlm.util.WlmShutdownException; +import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.util.WlmException; import ai.djl.serving.workflow.Workflow; import ai.djl.translate.TranslateException; import io.netty.channel.ChannelHandlerContext; @@ -73,16 +73,9 @@ protected void handleRequest( throws ModelException { switch (segments[1]) { case "ping": - // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" - // and "Unhealthy" ModelManager.getInstance() .workerStatus() - .thenAccept( - response -> - NettyUtils.sendJsonResponse( - ctx, - new StatusResponse(response), - HttpResponseStatus.OK)); + .thenAccept(r -> NettyUtils.sendHttpResponse(ctx, r, true)); break; case "invocations": handleInvocations(ctx, req, decoder); @@ -171,18 +164,21 @@ private void predict( String deviceName = input.getProperty("device", "-1"); logger.info("Loading model {} from: {}", workflowName, modelUrl); - - modelManager - .registerWorkflow( + ModelInfo modelInfo = + new ModelInfo( workflowName, - version, modelUrl, + version, engineName, - deviceName, - config.getBatchSize(), + config.getJobQueueSize(), + config.getMaxIdleTime(), config.getMaxBatchDelay(), - config.getMaxIdleTime()) - .thenApply(p -> modelManager.scaleWorkers(p, deviceName, 1, -1)) + config.getBatchSize()); + Workflow wf = new Workflow(modelInfo); + + modelManager + .registerWorkflow(wf, deviceName) + .thenApply(p -> modelManager.scaleWorkers(wf, deviceName, 1, -1)) .thenAccept(p -> runJob(modelManager, ctx, p, input)); return; } @@ -243,11 +239,8 @@ void onException(Throwable t, ChannelHandlerContext ctx) { HttpResponseStatus status; if (t instanceof TranslateException) { status = HttpResponseStatus.BAD_REQUEST; - } else if (t instanceof WlmShutdownException) { - logger.info(t.getMessage()); - status = HttpResponseStatus.SERVICE_UNAVAILABLE; - } else if (t instanceof WlmCapacityException) { - logger.warn(t.getMessage()); + } else if (t instanceof WlmException) { + logger.warn(t.getMessage(), t); status = HttpResponseStatus.SERVICE_UNAVAILABLE; } else { logger.warn("Unexpected error", t); diff --git a/serving/src/main/java/ai/djl/serving/http/ListModelsResponse.java b/serving/src/main/java/ai/djl/serving/http/ListModelsResponse.java index 6a6734224d8..b8c0eed15aa 100644 --- a/serving/src/main/java/ai/djl/serving/http/ListModelsResponse.java +++ b/serving/src/main/java/ai/djl/serving/http/ListModelsResponse.java @@ -59,9 +59,10 @@ public List getModels() { * @param modelName the model name * @param version the mode version * @param modelUrl the model url + * @param status the model loading status */ - public void addModel(String modelName, String version, String modelUrl) { - models.add(new ModelItem(modelName, version, modelUrl)); + public void addModel(String modelName, String version, String modelUrl, String status) { + models.add(new ModelItem(modelName, version, modelUrl, status)); } /** A class that holds model name and url. */ @@ -70,6 +71,7 @@ public static final class ModelItem { private String modelName; private String version; private String modelUrl; + private String status; /** Constructs a new {@code ModelItem} instance. */ public ModelItem() {} @@ -80,11 +82,13 @@ public ModelItem() {} * @param modelName the model name * @param version the model version * @param modelUrl the model url + * @param status the model loading status */ - public ModelItem(String modelName, String version, String modelUrl) { + public ModelItem(String modelName, String version, String modelUrl, String status) { this.modelName = modelName; this.version = version; this.modelUrl = modelUrl; + this.status = status; } /** @@ -113,5 +117,14 @@ public String getVersion() { public String getModelUrl() { return modelUrl; } + + /** + * Returns the model loading status. + * + * @return the model loading status + */ + public String getStatus() { + return status; + } } } diff --git a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java index 41ebe53744c..09bba0c5d42 100644 --- a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java @@ -16,6 +16,7 @@ import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.serving.models.Endpoint; import ai.djl.serving.models.ModelManager; +import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.NettyUtils; import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.WorkLoadManager.WorkerPool; @@ -137,9 +138,19 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco } for (int i = pageToken; i < last; ++i) { - String modelName = keys.get(i); - for (Workflow m : endpoints.get(modelName).getWorkflows()) { - list.addModel(modelName, m.getVersion(), m.getUrl()); + String workflowName = keys.get(i); + for (Workflow workflow : endpoints.get(workflowName).getWorkflows()) { + for (ModelInfo m : workflow.getModels()) { + String status = m.getStatus().toString(); + String id = m.getModelId(); + String modelName; + if (workflowName.equals(id)) { + modelName = workflowName; + } else { + modelName = workflowName + ':' + id; + } + list.addModel(modelName, workflow.getVersion(), m.getModelUrl(), status); + } } } @@ -185,40 +196,42 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec Boolean.parseBoolean( NettyUtils.getParameter(decoder, SYNCHRONOUS_PARAMETER, "true")); - final ModelManager modelManager = ModelManager.getInstance(); - CompletableFuture future = - modelManager.registerWorkflow( + ModelInfo modelInfo = + new ModelInfo( modelName, - version, modelUrl, + version, engineName, - deviceName, - batchSize, + ConfigManager.getInstance().getJobQueueSize(), + maxIdleTime, maxBatchDelay, - maxIdleTime); + batchSize); + Workflow workflow = new Workflow(modelInfo); + final ModelManager modelManager = ModelManager.getInstance(); CompletableFuture f = - future.thenAccept( - p -> { - for (ModelInfo m : p.getModels()) { - m.configurePool(maxIdleTime) - .configureModelBatch(batchSize, maxBatchDelay); - modelManager.scaleWorkers(m, deviceName, minWorkers, maxWorkers); - } - }); - + modelManager + .registerWorkflow(workflow, deviceName) + .thenAccept( + v -> { + for (ModelInfo m : workflow.getModels()) { + m.configurePool(maxIdleTime) + .configureModelBatch(batchSize, maxBatchDelay); + modelManager.scaleWorkers( + m, deviceName, minWorkers, maxWorkers); + } + }) + .exceptionally( + t -> { + NettyUtils.sendError(ctx, t.getCause()); + return null; + }); if (synchronous) { final String msg = "Model \"" + modelName + "\" registered."; - f = f.thenAccept(m -> NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg))); + f.thenAccept(v -> NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg))); } else { String msg = "Model \"" + modelName + "\" registration scheduled."; NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg), HttpResponseStatus.ACCEPTED); } - - f.exceptionally( - t -> { - NettyUtils.sendError(ctx, t.getCause()); - return null; - }); } private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName, String version) @@ -240,6 +253,14 @@ private void handleScaleModel( if (workflow == null) { throw new ModelNotFoundException("Model not found: " + modelName); } + + // make sure all models are loaded and ready + for (ModelInfo modelInfo : workflow.getModels()) { + if (modelInfo.getStatus() != ModelInfo.Status.READY) { + throw new ServiceUnavailableException("Model is not ready: " + modelName); + } + } + List msgs = new ArrayList<>(); for (ModelInfo modelInfo : workflow.getModels()) { WorkerPool pool = diff --git a/serving/src/main/java/ai/djl/serving/models/ModelManager.java b/serving/src/main/java/ai/djl/serving/models/ModelManager.java index 35f54fbcee3..b7d73efdbe8 100644 --- a/serving/src/main/java/ai/djl/serving/models/ModelManager.java +++ b/serving/src/main/java/ai/djl/serving/models/ModelManager.java @@ -12,31 +12,34 @@ */ package ai.djl.serving.models; -import ai.djl.Device; -import ai.djl.ModelException; -import ai.djl.engine.Engine; import ai.djl.modality.Input; import ai.djl.modality.Output; -import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ZooModel; import ai.djl.serving.http.BadRequestException; import ai.djl.serving.http.DescribeModelResponse; -import ai.djl.serving.plugins.DependencyManager; +import ai.djl.serving.http.StatusResponse; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.WorkLoadManager; import ai.djl.serving.wlm.WorkLoadManager.WorkerPool; import ai.djl.serving.wlm.WorkerThread; import ai.djl.serving.workflow.Workflow; -import java.io.IOException; +import ai.djl.util.JsonUtils; +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; import java.util.ArrayList; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -47,29 +50,18 @@ public final class ModelManager { private static final Logger logger = LoggerFactory.getLogger(ModelManager.class); - private static ModelManager modelManager; + private static ModelManager modelManager = new ModelManager(); - private ConfigManager configManager; private WorkLoadManager wlm; private Map endpoints; private Set startupModels; - private ModelManager(ConfigManager configManager) { - this.configManager = configManager; + private ModelManager() { wlm = new WorkLoadManager(); endpoints = new ConcurrentHashMap<>(); startupModels = new HashSet<>(); } - /** - * Initialized the global {@code ModelManager} instance. - * - * @param configManager the configuration - */ - public static void init(ConfigManager configManager) { - modelManager = new ModelManager(configManager); - } - /** * Returns the singleton {@code ModelManager} instance. * @@ -80,103 +72,19 @@ public static ModelManager getInstance() { } /** - * Registers and loads a model. - * - * @param modelName the name of the model for HTTP endpoint - * @param version the model version - * @param modelUrl the model url - * @param engineName the engine to load the model - * @param deviceName the accelerator device id, -1 for auto selection - * @param batchSize the batch size - * @param maxBatchDelay the maximum delay for batching - * @param maxIdleTime the maximum idle time of the worker threads before scaling down. - * @return a {@code CompletableFuture} instance - */ - public CompletableFuture registerWorkflow( - final String modelName, - final String version, - final String modelUrl, - final String engineName, - final String deviceName, - final int batchSize, - final int maxBatchDelay, - final int maxIdleTime) { - return CompletableFuture.supplyAsync( - () -> { - try { - if (engineName != null) { - DependencyManager dm = DependencyManager.getInstance(); - dm.installEngine(engineName); - } - Criteria.Builder builder = - Criteria.builder() - .setTypes(Input.class, Output.class) - .optModelUrls(modelUrl) - .optEngine(engineName); - if ("-1".equals(deviceName)) { - Device device; - if (engineName == null) { - device = Device.cpu(); - } else { - device = Engine.getEngine(engineName).defaultDevice(); - } - logger.info("Loading model {} on {}.", modelName, device); - } else if (deviceName.startsWith("nc")) { - logger.info("Loading model {} on {}.", modelName, deviceName); - String ncs = deviceName.substring(2); - builder.optOption("env", "NEURON_RT_VISIBLE_CORES=" + ncs); - } else { - // GPU case - int gpuId = Integer.parseInt(deviceName); - builder.optDevice(Device.gpu(gpuId)); - logger.info( - "Loading model {} on {}.", - modelName, - Device.gpu(gpuId)); - } - if (batchSize > 1) { - builder.optArgument("batchifier", "stack"); - } - - ZooModel model = builder.build().loadModel(); - return new Workflow( - modelName, - version, - modelUrl, - new ModelInfo( - modelName, - version, - model, - configManager.getJobQueueSize(), - maxIdleTime, - maxBatchDelay, - batchSize)); - } catch (ModelException | IOException e) { - throw new CompletionException(e); - } - }) - .thenApply(p -> registerWorkflow(p).join()); - } - - /** - * Registers and loads a workflow. + * Registers and loads a {@link Workflow}. * * @param workflow the workflow to register + * @param deviceName the accelerator device id, -1 for auto selection * @return a {@code CompletableFuture} instance */ - public CompletableFuture registerWorkflow(final Workflow workflow) { - return CompletableFuture.supplyAsync( - () -> { - Endpoint endpoint = - endpoints.computeIfAbsent(workflow.getName(), k -> new Endpoint()); - if (!endpoint.add(workflow)) { - // workflow already exists - throw new BadRequestException( - "Workflow " + workflow + " is already registered."); - } - - return workflow; - }); + public CompletableFuture registerWorkflow(Workflow workflow, String deviceName) { + Endpoint endpoint = endpoints.computeIfAbsent(workflow.getName(), k -> new Endpoint()); + if (!endpoint.add(workflow)) { + // workflow already exists + throw new BadRequestException("Workflow " + workflow + " is already registered."); + } + return workflow.load(deviceName); } /** @@ -256,8 +164,7 @@ public Workflow scaleWorkers( */ public ModelInfo scaleWorkers( ModelInfo model, String deviceName, int minWorkers, int maxWorkers) { - String modelName = model.getModelName(); - logger.debug("updateModel: {}", modelName); + logger.debug("updateModel: {}", model); wlm.getWorkerPoolForModel(model).scaleWorkers(deviceName, minWorkers, maxWorkers); return model; } @@ -362,21 +269,26 @@ public List describeWorkflow(String workflowName, String for (Workflow workflow : list) { for (ModelInfo model : workflow.getModels()) { DescribeModelResponse resp = new DescribeModelResponse(); - resp.setModelName(model.getModelName()); - resp.setModelUrl(list.get(0).getUrl()); + resp.setModelName(model.getModelId()); + resp.setModelUrl(model.getModelUrl()); resp.setBatchSize(model.getBatchSize()); resp.setMaxBatchDelay(model.getMaxBatchDelay()); resp.setMaxIdleTime(model.getMaxIdleTime()); resp.setQueueLength(wlm.getQueueLength(model)); - resp.setLoadedAtStartup(startupModels.contains(model.getModelName())); + resp.setLoadedAtStartup(startupModels.contains(model.getModelId())); WorkerPool wp = wlm.getWorkerPoolForModel(model); resp.setMaxWorkers(wp.getMaxWorkers()); resp.setMinWorkers(wp.getMinWorkers()); - int activeWorker = wlm.getNumRunningWorkers(model); - int targetWorker = wp.getMinWorkers(); - resp.setStatus(activeWorker >= targetWorker ? "Healthy" : "Unhealthy"); + ModelInfo.Status status = model.getStatus(); + if (status == ModelInfo.Status.READY) { + int activeWorker = wlm.getNumRunningWorkers(model); + int targetWorker = wp.getMinWorkers(); + resp.setStatus(activeWorker >= targetWorker ? "Healthy" : "Unhealthy"); + } else { + resp.setStatus(status.name()); + } List workers = wlm.getWorkers(model); for (WorkerThread worker : workers) { @@ -398,29 +310,66 @@ public List describeWorkflow(String workflowName, String * * @return completableFuture with eventually result in the future after async execution */ - public CompletableFuture workerStatus() { + public CompletableFuture workerStatus() { return CompletableFuture.supplyAsync( () -> { - String response = "Healthy"; - int numWorking = 0; - - int numScaled = 0; + boolean hasFailure = false; + boolean hasPending = false; + Map data = new LinkedHashMap<>(); // NOPMD for (Endpoint endpoint : endpoints.values()) { for (Workflow p : endpoint.getWorkflows()) { + String workflowName = p.getName(); for (ModelInfo m : p.getModels()) { - numScaled += wlm.getWorkerPoolForModel(m).getMinWorkers(); - numWorking += wlm.getNumRunningWorkers(m); + String modelName = m.getModelId(); + if (!modelName.equals(workflowName)) { + modelName = workflowName + ':' + modelName; // NOPMD + } + ModelInfo.Status status = m.getStatus(); + switch (status) { + case FAILED: + data.put(modelName, new StatusResponse(status.name())); + hasFailure = true; + break; + case PENDING: + data.put(modelName, new StatusResponse(status.name())); + hasPending = true; + break; + default: + int min = wlm.getWorkerPoolForModel(m).getMinWorkers(); + int actual = wlm.getNumRunningWorkers(m); + if (actual < min) { + data.put(modelName, new StatusResponse("Unhealthy")); + } else { + data.put(modelName, new StatusResponse("Healthy")); + } + break; + } } } } - if ((numWorking > 0) && (numWorking < numScaled)) { - response = "Partial Healthy"; - } else if ((numWorking == 0) && (numScaled > 0)) { - response = "Unhealthy"; + HttpResponseStatus status; + if (hasFailure) { + status = HttpResponseStatus.INTERNAL_SERVER_ERROR; + } else if (hasPending) { + if (ConfigManager.getInstance().allowsMultiStatus()) { + status = HttpResponseStatus.MULTI_STATUS; + } else { + status = HttpResponseStatus.OK; + } + } else { + status = HttpResponseStatus.OK; } - return response; + FullHttpResponse resp = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); + resp.headers() + .set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + ByteBuf content = resp.content(); + String body = JsonUtils.GSON_PRETTY.toJson(data); + content.writeCharSequence(body, CharsetUtil.UTF_8); + content.writeByte('\n'); + return resp; }); } } diff --git a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java index 68c582e5bc2..91887f99d1b 100644 --- a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java +++ b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java @@ -48,6 +48,8 @@ public final class ConfigManager { private static final String INFERENCE_ADDRESS = "inference_address"; private static final String MANAGEMENT_ADDRESS = "management_address"; private static final String LOAD_MODELS = "load_models"; + private static final String WAIT_MODEL_LOADING = "wait_model_loading"; + private static final String ALLOW_MULTI_STATUS = "allow_multi_status"; private static final String NUMBER_OF_NETTY_THREADS = "number_of_netty_threads"; private static final String JOB_QUEUE_SIZE = "job_queue_size"; private static final String MAX_IDLE_TIME = "max_idle_time"; @@ -219,6 +221,24 @@ public static String getModelServerHome() { return home; } + /** + * Returns if model server should wait for model initialization on startup. + * + * @return true if model server should wait for model initialization on startup + */ + public boolean waitModelLoading() { + return Boolean.parseBoolean(prop.getProperty(WAIT_MODEL_LOADING, "true")); + } + + /** + * Returns if allows return MULTI-STATUS HTTP code. + * + * @return true if allows return MULTI-STATUS HTTP code + */ + public boolean allowsMultiStatus() { + return Boolean.parseBoolean(prop.getProperty(ALLOW_MULTI_STATUS)); + } + /** * Returns the model store location. * diff --git a/serving/src/main/java/ai/djl/serving/workflow/Workflow.java b/serving/src/main/java/ai/djl/serving/workflow/Workflow.java index e1453c7d5be..ef3d2400a4f 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/Workflow.java +++ b/serving/src/main/java/ai/djl/serving/workflow/Workflow.java @@ -12,14 +12,17 @@ */ package ai.djl.serving.workflow; +import ai.djl.ModelException; import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.serving.plugins.DependencyManager; import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.WorkLoadManager; import ai.djl.serving.workflow.WorkflowExpression.Item; import ai.djl.serving.workflow.function.IdentityWF; import ai.djl.serving.workflow.function.ModelWorkflowFunction; import ai.djl.serving.workflow.function.WorkflowFunction; +import java.io.IOException; import java.util.Collection; import java.util.Collections; import java.util.HashSet; @@ -28,6 +31,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -49,7 +53,6 @@ public class Workflow implements AutoCloseable { String name; String version; - String url; Map models; Map expressions; Map funcs; @@ -57,16 +60,12 @@ public class Workflow implements AutoCloseable { /** * Constructs a workflow containing only a single model. * - * @param name the model/workflow name - * @param version the model/workflow version - * @param url the url the model was laoded from * @param model the model for the workflow */ - public Workflow(String name, String version, String url, ModelInfo model) { + public Workflow(ModelInfo model) { String modelName = "model"; - this.name = name; - this.version = version; - this.url = url; + this.name = model.getModelId(); + this.version = model.getVersion(); models = Collections.singletonMap(modelName, model); expressions = Collections.singletonMap( @@ -79,7 +78,6 @@ public Workflow(String name, String version, String url, ModelInfo model) { * * @param name workflow name * @param version workflow version - * @param url workflow source url * @param models a map of executableNames for a model (how it is referred to in the {@link * WorkflowExpression}s to model * @param expressions a map of names to refer to an expression to the expression @@ -88,25 +86,14 @@ public Workflow(String name, String version, String url, ModelInfo model) { public Workflow( String name, String version, - String url, Map models, Map expressions, Map funcs) { this.name = name; this.version = version; - this.url = url; this.models = models; this.expressions = expressions; this.funcs = funcs; - - // Default name and version - if (this.name == null && this.url != null) { - String[] nameParts = url.split("/"); - this.name = nameParts[nameParts.length - 1].split("\\.")[0]; - } - if (this.version == null) { - this.version = "1.0"; - } } /** @@ -118,6 +105,32 @@ public Collection getModels() { return models.values(); } + /** + * Load all the models in this workflow. + * + * @param device the device to load the models + * @return a {@code CompletableFuture} instance + */ + public CompletableFuture load(String device) { + return CompletableFuture.supplyAsync( + () -> { + try { + for (ModelInfo modelInfo : models.values()) { + String engine = modelInfo.getEngineName(); + if (engine != null) { + DependencyManager dm = DependencyManager.getInstance(); + dm.installEngine(engine); + } + + modelInfo.load(device); + } + } catch (ModelException | IOException e) { + throw new CompletionException(e); + } + return null; + }); + } + /** * Executes a workflow with an input. * @@ -154,15 +167,6 @@ public String getVersion() { return version; } - /** - * Returns the (optional) string url containing the workflow source. - * - * @return the (optional) string url containing the workflow source - */ - public String getUrl() { - return url; - } - /** {@inheritDoc} */ @Override public boolean equals(Object o) { diff --git a/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java b/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java index f2626447f32..087079959b2 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java +++ b/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java @@ -12,33 +12,20 @@ */ package ai.djl.serving.workflow; -import ai.djl.Application; -import ai.djl.MalformedModelException; -import ai.djl.modality.Input; -import ai.djl.modality.Output; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ModelZoo; -import ai.djl.repository.zoo.ZooModel; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.workflow.WorkflowExpression.Item; import ai.djl.serving.workflow.function.WorkflowFunction; -import ai.djl.translate.ServingTranslator; -import ai.djl.translate.TranslatorFactory; import ai.djl.util.JsonUtils; import com.google.gson.Gson; import com.google.gson.JsonArray; import com.google.gson.JsonDeserializationContext; import com.google.gson.JsonDeserializer; import com.google.gson.JsonElement; -import com.google.gson.JsonObject; import com.google.gson.JsonParseException; import com.google.gson.annotations.SerializedName; -import com.google.gson.reflect.TypeToken; import java.io.IOException; import java.io.Reader; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Type; import java.nio.file.Files; import java.nio.file.Path; @@ -59,9 +46,8 @@ public class WorkflowDefinition { String name; String version; - transient String url; - Map models; + Map models; @SerializedName("workflow") Map expressions; @@ -69,15 +55,15 @@ public class WorkflowDefinition { @SerializedName("functions") Map funcs; - Integer queueSize; - Integer maxIdleTime; - Integer maxBatchDelay; - Integer batchSize; + int queueSize; + int maxIdleTime; + int maxBatchDelay; + int batchSize; private static final Yaml YAML = new Yaml(); public static final Gson GSON = JsonUtils.builder() - .registerTypeAdapter(ModelDefinition.class, new ModelDefinitionDeserializer()) + .registerTypeAdapter(ModelInfo.class, new ModelDefinitionDeserializer()) .registerTypeAdapter(WorkflowExpression.class, new ExpressionDeserializer()) .registerTypeAdapter(Item.class, new ExpressionItemDeserializer()) .create(); @@ -90,63 +76,48 @@ public class WorkflowDefinition { * @throws IOException if it fails to load the file for parsing */ public static WorkflowDefinition parse(Path path) throws IOException { - WorkflowDefinition wd; try (Reader reader = Files.newBufferedReader(path)) { String fileName = Objects.requireNonNull(path.toString()); if (fileName.endsWith(".yml") || fileName.endsWith(".yaml")) { Object yaml = YAML.load(reader); String asJson = GSON.toJson(yaml); - wd = GSON.fromJson(asJson, WorkflowDefinition.class); + return GSON.fromJson(asJson, WorkflowDefinition.class); } else if (fileName.endsWith(".json")) { - wd = GSON.fromJson(reader, WorkflowDefinition.class); + return GSON.fromJson(reader, WorkflowDefinition.class); } else { throw new IllegalArgumentException( "Unexpected file type in workflow file: " + path); } } - wd.url = path.toUri().toString(); - return wd; } /** * Converts the {@link WorkflowDefinition} into a workflow. * * @return a new {@link Workflow} matching this definition - * @throws ModelNotFoundException if the definition contains an unknown model - * @throws MalformedModelException if the definition contains a malformed model - * @throws IOException if it fails to load the definition or resources in it * @throws BadWorkflowException if the workflow could not be parsed successfully */ - public Workflow toWorkflow() - throws ModelNotFoundException, MalformedModelException, IOException, - BadWorkflowException { - Map loadedModels = new ConcurrentHashMap<>(); + public Workflow toWorkflow() throws BadWorkflowException { if (models != null) { - for (Entry emd : models.entrySet()) { - ModelDefinition md = emd.getValue(); - ZooModel model = md.criteria.loadModel(); - - ConfigManager configManager = ConfigManager.getInstance(); - int newQueueSize = - firstNonNull(md.queueSize, queueSize, configManager.getJobQueueSize()); - int newMaxIdleTime = - firstNonNull(md.maxIdleTime, maxIdleTime, configManager.getMaxIdleTime()); - int newMaxBatchDelay = - firstNonNull( - md.maxBatchDelay, maxBatchDelay, configManager.getMaxBatchDelay()); - int newBatchSize = - firstNonNull(md.batchSize, batchSize, configManager.getBatchSize()); - - ModelInfo modelInfo = - new ModelInfo( - model.getName(), - md.version, - model, - newQueueSize, - newMaxIdleTime, - newMaxBatchDelay, - newBatchSize); - loadedModels.put(emd.getKey(), modelInfo); + ConfigManager configManager = ConfigManager.getInstance(); + for (Entry emd : models.entrySet()) { + ModelInfo md = emd.getValue(); + md.setModelId(emd.getKey()); + md.setQueueSize( + firstValid(md.getQueueSize(), queueSize, configManager.getJobQueueSize())); + md.setMaxIdleTime( + firstValid( + md.getMaxIdleTime(), maxIdleTime, configManager.getMaxIdleTime())); + md.setMaxBatchDelay( + firstValid( + md.getMaxBatchDelay(), + maxBatchDelay, + configManager.getMaxBatchDelay())); + md.setBatchSize( + firstValid(md.getBatchSize(), batchSize, configManager.getBatchSize())); + if (name == null) { + name = emd.getKey(); + } } } @@ -163,135 +134,31 @@ public Workflow toWorkflow() } } - return new Workflow(name, version, url, loadedModels, expressions, loadedFunctions); + return new Workflow(name, version, models, expressions, loadedFunctions); } - private int firstNonNull(Integer... inputs) { - for (Integer input : inputs) { - if (input != null) { + private int firstValid(int... inputs) { + for (int input : inputs) { + if (input > 0) { return input; } } return 0; } - private static final class ModelDefinition { - - private Criteria criteria; - private String version; - - private Integer queueSize; - private Integer maxIdleTime; - private Integer maxBatchDelay; - private Integer batchSize; - - private ModelDefinition(Criteria criteria) { - this.criteria = criteria; - } - } - - private static final class ModelDefinitionDeserializer - implements JsonDeserializer { + private static final class ModelDefinitionDeserializer implements JsonDeserializer { /** {@inheritDoc} */ @Override - public ModelDefinition deserialize( + public ModelInfo deserialize( JsonElement json, Type typeOfT, JsonDeserializationContext context) { if (json.isJsonObject()) { - JsonObject obj = json.getAsJsonObject(); - ModelDefinition md = new ModelDefinition(readCriteria(obj, context)); - md.version = readStringProperty(obj, "version"); - md.queueSize = readIntegerProperty(obj, "queueSize"); - md.maxIdleTime = readIntegerProperty(obj, "maxIdleTime"); - md.maxBatchDelay = readIntegerProperty(obj, "maxBatchDelay"); - md.batchSize = readIntegerProperty(obj, "batchSize"); - return md; + return JsonUtils.GSON.fromJson(json, ModelInfo.class); } else if (json.isJsonPrimitive()) { - return new ModelDefinition( - Criteria.builder() - .setTypes(Input.class, Output.class) - .optModelUrls(json.getAsString()) - .build()); - } else { - throw new JsonParseException( - "Unexpected type of model definition: should be Criteria object or URI string"); - } - } - - private Criteria readCriteria( - JsonObject obj, JsonDeserializationContext context) { - try { - Criteria.Builder criteria = - Criteria.builder().setTypes(Input.class, Output.class); - - if (obj.has("application")) { - criteria.optApplication(Application.of(obj.get("application").getAsString())); - } - if (obj.has("engine")) { - criteria.optEngine(obj.get("engine").getAsString()); - } - if (obj.has("groupId")) { - criteria.optGroupId(obj.get("groupId").getAsString()); - } - if (obj.has("artifactId")) { - criteria.optArtifactId(obj.get("artifactId").getAsString()); - } - if (obj.has("modelUrls")) { - criteria.optModelUrls(obj.get("modelUrls").getAsString()); - } - if (obj.has("modelZoo")) { - criteria.optModelZoo(ModelZoo.getModelZoo(obj.get("modelZoo").getAsString())); - } - if (obj.has("filters")) { - Type tp = new TypeToken>() {}.getType(); - criteria.optFilters(context.deserialize(obj.get("filters"), tp)); - } - if (obj.has("arguments")) { - Type tp = new TypeToken>() {}.getType(); - criteria.optFilters(context.deserialize(obj.get("arguments"), tp)); - } - if (obj.has("options")) { - Type tp = new TypeToken>() {}.getType(); - criteria.optFilters(context.deserialize(obj.get("options"), tp)); - } - if (obj.has("modelName")) { - criteria.optArtifactId(obj.get("modelName").getAsString()); - } - if (obj.has("translatorFactory")) { - Class clazz = - Class.forName(obj.get("translatorFactory").getAsString()) - .asSubclass(TranslatorFactory.class); - criteria.optTranslatorFactory(clazz.getConstructor().newInstance()); - } - if (obj.has("translator")) { - Class clazz = - Class.forName(obj.get("translator").getAsString()) - .asSubclass(ServingTranslator.class); - criteria.optTranslator(clazz.getConstructor().newInstance()); - } - - return criteria.build(); - } catch (ClassNotFoundException - | InvocationTargetException - | InstantiationException - | IllegalAccessException - | NoSuchMethodException e) { - throw new JsonParseException("Failed to parse model definition", e); - } - } - - private String readStringProperty(JsonObject obj, String name) { - if (obj.has(name)) { - return obj.get(name).getAsString(); - } - return null; - } - - private Integer readIntegerProperty(JsonObject obj, String name) { - if (obj.has(name)) { - return obj.get(name).getAsInt(); + return new ModelInfo(json.getAsString()); } - return null; + throw new JsonParseException( + "Unexpected type of model definition: should be Criteria object or URI string"); } } diff --git a/serving/src/main/java/ai/djl/serving/workflow/function/ModelWorkflowFunction.java b/serving/src/main/java/ai/djl/serving/workflow/function/ModelWorkflowFunction.java index 83f7ababcd5..5288183077e 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/function/ModelWorkflowFunction.java +++ b/serving/src/main/java/ai/djl/serving/workflow/function/ModelWorkflowFunction.java @@ -41,7 +41,7 @@ public CompletableFuture run( if (args.size() != 1) { throw new IllegalArgumentException( "The model " - + model.getModelName() + + model.getModelId() + " should have one arg, but has " + args.size()); } diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index a4884922cbf..c102bf259f2 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -289,9 +289,9 @@ private void testPing(Channel channel) throws InterruptedException { HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/ping"); channel.writeAndFlush(req); latch.await(); - + Assert.assertEquals(httpStatus.code(), HttpResponseStatus.OK.code()); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); - Assert.assertEquals(resp.getStatus(), "Healthy"); + Assert.assertNotNull(resp); Assert.assertTrue(headers.contains("x-request-id")); } @@ -389,7 +389,7 @@ private void testRegisterModelAsync(Channel channel) ListModelsResponse resp = JsonUtils.GSON.fromJson(result, ListModelsResponse.class); for (ListModelsResponse.ModelItem item : resp.getModels()) { Assert.assertNotNull(item.getModelUrl()); - if ("mlp_1".equals(item.getModelName())) { + if ("mlp_1".equals(item.getModelName()) && "READY".equals(item.getStatus())) { modelRegistered = true; break OUTER; } diff --git a/serving/src/test/java/ai/djl/serving/WorkflowTest.java b/serving/src/test/java/ai/djl/serving/WorkflowTest.java index 981ad0f2ccb..02005cda4b5 100644 --- a/serving/src/test/java/ai/djl/serving/WorkflowTest.java +++ b/serving/src/test/java/ai/djl/serving/WorkflowTest.java @@ -12,10 +12,8 @@ */ package ai.djl.serving; -import ai.djl.MalformedModelException; import ai.djl.modality.Input; import ai.djl.modality.Output; -import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.WorkLoadManager; @@ -28,6 +26,8 @@ import java.net.URL; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.apache.commons.cli.CommandLine; import org.testng.Assert; import org.testng.annotations.BeforeSuite; @@ -52,40 +52,34 @@ public void beforeAll() throws IOException { @Test public void testJson() - throws IOException, ModelNotFoundException, MalformedModelException, - BadWorkflowException { + throws IOException, BadWorkflowException, ExecutionException, InterruptedException { Path workflowFile = Paths.get("src/test/resources/workflows/basic.json"); runWorkflow(workflowFile, zeroInput); } @Test public void testYaml() - throws ModelNotFoundException, MalformedModelException, IOException, - BadWorkflowException { + throws IOException, BadWorkflowException, ExecutionException, InterruptedException { Path workflowFile = Paths.get("src/test/resources/workflows/basic.yaml"); runWorkflow(workflowFile, zeroInput); } @Test public void testCriteria() - throws ModelNotFoundException, MalformedModelException, IOException, - BadWorkflowException { + throws IOException, BadWorkflowException, ExecutionException, InterruptedException { Path workflowFile = Paths.get("src/test/resources/workflows/criteria.json"); runWorkflow(workflowFile, zeroInput); } @Test public void testFunctions() - throws ModelNotFoundException, MalformedModelException, IOException, - BadWorkflowException { + throws IOException, BadWorkflowException, ExecutionException, InterruptedException { Path workflowFile = Paths.get("src/test/resources/workflows/functions.json"); runWorkflow(workflowFile, zeroInput); } @Test - public void testGlobalPerf() - throws ModelNotFoundException, MalformedModelException, IOException, - BadWorkflowException { + public void testGlobalPerf() throws IOException, BadWorkflowException { Path workflowFile = Paths.get("src/test/resources/workflows/globalPerf.json"); Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow(); ModelInfo m = workflow.getModels().stream().findFirst().get(); @@ -97,9 +91,7 @@ public void testGlobalPerf() } @Test - public void testLocalPerf() - throws ModelNotFoundException, MalformedModelException, IOException, - BadWorkflowException { + public void testLocalPerf() throws IOException, BadWorkflowException { Path workflowFile = Paths.get("src/test/resources/workflows/localPerf.json"); Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow(); ModelInfo m = workflow.getModels().stream().findFirst().get(); @@ -111,9 +103,10 @@ public void testLocalPerf() } private Input runWorkflow(Path workflowFile, Input input) - throws IOException, ModelNotFoundException, MalformedModelException, - BadWorkflowException { + throws IOException, BadWorkflowException, ExecutionException, InterruptedException { Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow(); + CompletableFuture future = workflow.load("-1"); + future.get(); try (WorkLoadManager wlm = new WorkLoadManager()) { for (ModelInfo model : workflow.getModels()) { wlm.getWorkerPoolForModel(model).scaleWorkers("cpu", 1, 1); diff --git a/serving/src/test/resources/workflows/criteria.json b/serving/src/test/resources/workflows/criteria.json index 568336cef88..1b49b6947be 100644 --- a/serving/src/test/resources/workflows/criteria.json +++ b/serving/src/test/resources/workflows/criteria.json @@ -2,8 +2,7 @@ "models": { "m": { "application": "cv/image_classification", - "groupId": "ai.djl.zoo", - "artifactId": "mlp", + "modelUrl": "djl://ai.djl.zoo/mlp/0.0.3/mlp", "filters": { "dataset": "mnist" } @@ -12,4 +11,4 @@ "workflow": { "out": ["m", "in"] } -} \ No newline at end of file +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index 982c7392341..851d27f54da 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -12,12 +12,21 @@ */ package ai.djl.serving.wlm; +import ai.djl.Application; +import ai.djl.Device; +import ai.djl.ModelException; +import ai.djl.engine.Engine; import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.repository.FilenameUtils; +import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.ServingTranslator; +import ai.djl.translate.TranslatorFactory; +import java.io.IOException; import java.net.URI; import java.nio.file.Path; +import java.util.Map; import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -27,44 +36,136 @@ public final class ModelInfo implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(ModelInfo.class); - private String modelName; + private transient String id; private String version; + private String modelUrl; + private String engineName; private int queueSize; private int batchSize; private int maxBatchDelay; private int maxIdleTime; - private ZooModel model; + private Map filters; + private Map arguments; + private Map options; + private String application; + private String modelName; + private String translatorFactory; + private String translator; + private transient Status status; + + private transient ZooModel model; + + /** + * Constructs a new {@code ModelInfo} instance. + * + * @param modelUrl the model Url + */ + public ModelInfo(String modelUrl) { + this.modelUrl = modelUrl; + } /** * Constructs a new {@code ModelInfo} instance. * - * @param modelName the name of the model that will be used as HTTP endpoint + * @param id the ID of the model that will be used by workflow + * @param modelUrl the model url * @param version the version of the model - * @param model the {@link ZooModel} + * @param engineName the engine to load the model * @param queueSize the maximum request queue size * @param maxIdleTime the initial maximum idle time for workers. * @param maxBatchDelay the initial maximum delay when scaling up before giving up. * @param batchSize the batch size for this model. */ public ModelInfo( - String modelName, + String id, + String modelUrl, String version, - ZooModel model, + String engineName, int queueSize, int maxIdleTime, int maxBatchDelay, int batchSize) { - this.modelName = modelName; + this.id = id; + this.modelUrl = modelUrl; this.version = version; - this.model = model; + this.engineName = engineName; this.maxBatchDelay = maxBatchDelay; this.maxIdleTime = maxIdleTime; // default max idle time 60s this.queueSize = queueSize; this.batchSize = batchSize; } + /** + * Loads the model to the specified device. + * + * @param deviceName the device to load model on + * @throws IOException if failed to read model file + * @throws ModelException if failed to load the specified model + */ + public void load(String deviceName) throws ModelException, IOException { + if (model != null) { + return; + } + Criteria.Builder builder = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelUrls(modelUrl) + .optModelName(modelName) + .optEngine(engineName) + .optFilters(filters) + .optArguments(arguments) + .optOptions(options); + if (application != null) { + builder.optApplication(Application.of(application)); + } + try { + if (translator != null) { + Class clazz = + Class.forName(translator).asSubclass(ServingTranslator.class); + builder.optTranslator(clazz.getConstructor().newInstance()); + } + if (translatorFactory != null) { + Class clazz = + Class.forName(translator).asSubclass(TranslatorFactory.class); + builder.optTranslatorFactory(clazz.getConstructor().newInstance()); + } + } catch (ReflectiveOperationException e) { + throw new ModelException("Invalid criteria", e); + } + if ("-1".equals(deviceName)) { + Device device; + if (engineName == null) { + device = Device.cpu(); + } else { + device = Engine.getEngine(engineName).defaultDevice(); + } + logger.info("Loading model {} on {}.", id, device); + } else if (deviceName.startsWith("nc")) { + logger.info("Loading model {} on {}.", id, deviceName); + String ncs = deviceName.substring(2); + builder.optOption("env", "NEURON_RT_VISIBLE_CORES=" + ncs); + } else { + // GPU case + int gpuId = Integer.parseInt(deviceName); + builder.optDevice(Device.gpu(gpuId)); + logger.info("Loading model {} on {}.", id, Device.gpu(gpuId)); + } + if (batchSize > 1) { + builder.optArgument("batchifier", "stack"); + } + + try { + model = builder.build().loadModel(); + status = Status.READY; + } finally { + if (status == null) { + status = Status.FAILED; + } + } + } + /** * Sets a new batchSize and returns a new configured ModelInfo object. You have to * triggerUpdates in the {@code ModelManager} using this new model. @@ -99,16 +200,28 @@ public ModelInfo configurePool(int maxIdleTime) { * @return the loaded {@link ZooModel} */ public ZooModel getModel() { + if (model == null) { + throw new IllegalStateException("Model \"" + id + "\" has not been loaded yet."); + } return model; } /** - * Returns the model name. + * Sets the model ID. * - * @return the model name + * @param id the model ID */ - public String getModelName() { - return modelName; + public void setModelId(String id) { + this.id = id; + } + + /** + * Returns the model ID. + * + * @return the model ID + */ + public String getModelId() { + return id; } /** @@ -120,6 +233,33 @@ public String getVersion() { return version; } + /** + * Returns the engine name. + * + * @return the engine name + */ + public String getEngineName() { + return engineName; + } + + /** + * Returns the model url. + * + * @return the model url + */ + public String getModelUrl() { + return modelUrl; + } + + /** + * Returns the model loading status. + * + * @return the model loading status + */ + public Status getStatus() { + return status == null ? Status.PENDING : status; + } + /** * Returns the model cache directory. * @@ -130,7 +270,16 @@ public Path getModelDir() { } /** - * returns the configured maxIdleTime of workers. + * Sets the configured maxIdleTime of workers. + * + * @param maxIdleTime the configured maxIdleTime of workers + */ + public void setMaxIdleTime(int maxIdleTime) { + this.maxIdleTime = maxIdleTime; + } + + /** + * Returns the configured maxIdleTime of workers. * * @return the maxIdleTime */ @@ -138,6 +287,15 @@ public int getMaxIdleTime() { return maxIdleTime; } + /** + * Sets the configured batch size. + * + * @param batchSize the configured batch size + */ + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + /** * Returns the configured batch size. * @@ -147,6 +305,15 @@ public int getBatchSize() { return batchSize; } + /** + * Sets the maximum delay in milliseconds to aggregate a batch. + * + * @param maxBatchDelay the maximum delay in milliseconds to aggregate a batch + */ + public void setMaxBatchDelay(int maxBatchDelay) { + this.maxBatchDelay = maxBatchDelay; + } + /** * Returns the maximum delay in milliseconds to aggregate a batch. * @@ -156,6 +323,15 @@ public int getMaxBatchDelay() { return maxBatchDelay; } + /** + * Sets the configured size of the workers queue. + * + * @param queueSize the configured size of the workers queue + */ + public void setQueueSize(int queueSize) { + this.queueSize = queueSize; + } + /** * Returns the configured size of the workers queue. * @@ -211,21 +387,28 @@ public boolean equals(Object o) { return false; } ModelInfo modelInfo = (ModelInfo) o; - return modelName.equals(modelInfo.modelName) && Objects.equals(version, modelInfo.version); + return id.equals(modelInfo.id) && Objects.equals(version, modelInfo.version); } /** {@inheritDoc} */ @Override public int hashCode() { - return Objects.hash(modelName, version); + return Objects.hash(id, version); } /** {@inheritDoc} */ @Override public String toString() { if (version != null) { - return modelName + ':' + version; + return id + ':' + version; } - return modelName; + return id; + } + + /** An enum represents state of a model. */ + public enum Status { + PENDING, + READY, + FAILED } } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java index a2ec7276d65..b91f732db4d 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java @@ -16,6 +16,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.serving.wlm.util.WlmCapacityException; import ai.djl.serving.wlm.util.WlmConfigManager; +import ai.djl.serving.wlm.util.WlmException; import ai.djl.serving.wlm.util.WlmShutdownException; import ai.djl.serving.wlm.util.WorkerJob; import java.util.Collections; @@ -90,20 +91,24 @@ public void unregisterModel(ModelInfo model) { public CompletableFuture runJob(Job job) { CompletableFuture result = new CompletableFuture<>(); ModelInfo modelInfo = job.getModel(); + if (modelInfo.getStatus() != ModelInfo.Status.READY) { + result.completeExceptionally( + new WlmException("Model is not ready: " + modelInfo.getStatus())); + return result; + } + WorkerPool pool = getWorkerPoolForModel(modelInfo); int maxWorkers = pool.getMaxWorkers(); if (maxWorkers == 0) { result.completeExceptionally( - new WlmShutdownException( - "All model workers has been shutdown: " + modelInfo.getModelName())); + new WlmShutdownException("All model workers has been shutdown: " + modelInfo)); return result; } LinkedBlockingDeque queue = pool.getJobQueue(); if (!queue.offer(new WorkerJob(job, result))) { result.completeExceptionally( new WlmCapacityException( - "Worker queue capacity exceeded for model: " - + modelInfo.getModelName())); + "Worker queue capacity exceeded for model: " + modelInfo)); return result; } @@ -247,6 +252,10 @@ public int getMaxWorkers() { */ public WorkerPool scaleWorkers(String deviceName, int newMinWorkers, int newMaxWorkers) { synchronized (model) { + if (model.getStatus() != ModelInfo.Status.READY) { + logger.warn("Cannot scale workers while model is not READY: {}", model); + return this; + } NDManager manager = model.getModel().getNDManager(); WlmConfigManager configManager = WlmConfigManager.getInstance(); maxWorkers = configManager.getDefaultWorkers(manager, deviceName, newMaxWorkers); @@ -316,7 +325,7 @@ public void log() { buf.append("-tmpPool\n"); } }); - logger.debug("worker pool for model {}:\n {}", model.getModelName(), buf); + logger.debug("worker pool for model {}:\n {}", model, buf); } } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java index 7411581b601..3f77c23d0a7 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java @@ -161,11 +161,11 @@ public void shutdown(WorkerState state) { } private String buildWorkerName(ModelInfo model) { - String modelName = model.getModelName(); - if (modelName.length() > 25) { - modelName = modelName.substring(0, 25); + String modelId = model.getModelId(); + if (modelId.length() > 25) { + modelId = modelId.substring(0, 25); } - return "W-" + modelName + '-' + workerId; + return "W-" + modelId + '-' + workerId; } void setState(WorkerState newState) { diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java index 22042c77469..f4c5b290244 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java @@ -13,7 +13,7 @@ package ai.djl.serving.wlm.util; /** Thrown to throttle when a job is run but the job queue capacity is exceeded. */ -public class WlmCapacityException extends RuntimeException { +public class WlmCapacityException extends WlmException { static final long serialVersionUID = 1L; diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java index 3527fa1a7a3..1a5e1654495 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java @@ -13,7 +13,7 @@ package ai.djl.serving.wlm.util; /** Thrown when a job is run but all workers are shutdown. */ -public class WlmShutdownException extends RuntimeException { +public class WlmShutdownException extends WlmException { static final long serialVersionUID = 1L; diff --git a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java index 6b94f331c97..e26a3332a81 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java @@ -19,7 +19,10 @@ public class ModelInfoTest { @Test public void testQueueSizeIsSet() { - ModelInfo modelInfo = new ModelInfo("", null, null, 4711, 1, 300, 1); + ModelInfo modelInfo = new ModelInfo("", null, null, "MXNet", 4711, 1, 300, 1); Assert.assertEquals(4711, modelInfo.getQueueSize()); + Assert.assertEquals(1, modelInfo.getMaxIdleTime()); + Assert.assertEquals(300, modelInfo.getMaxBatchDelay()); + Assert.assertEquals(1, modelInfo.getBatchSize()); } }