Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config number worker threads for GPU inference #1153

Merged
merged 1 commit into from
Aug 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ public final boolean runBenchmark(String[] args) {
engine.defaultDevice(),
duration.toMinutes());
} else {
logger.info(
"Running {} on: {}.", getClass().getSimpleName(), engine.defaultDevice());
Device[] devices = engine.getDevices(arguments.getMaxGpus());
logger.info("Running {} on: {}.", getClass().getSimpleName(), devices);
}
int numOfThreads = arguments.getThreads();
int iteration = arguments.getIteration();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.benchmark;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
Expand All @@ -27,10 +28,14 @@
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionGroup;
import org.apache.commons.cli.Options;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A class represents parsed command line arguments. */
public class Arguments {

private static final Logger logger = LoggerFactory.getLogger(Arguments.class);

private String modelUrl;
private String modelName;
private String engine;
Expand Down Expand Up @@ -77,12 +82,6 @@ public class Arguments {
if (cmd.hasOption("iteration")) {
iteration = Integer.parseInt(cmd.getOptionValue("iteration"));
}
if (cmd.hasOption("threads")) {
threads = Integer.parseInt(cmd.getOptionValue("threads"));
if (threads <= 0) {
threads = Runtime.getRuntime().availableProcessors();
}
}
if (cmd.hasOption("gpus")) {
maxGpus = Integer.parseInt(cmd.getOptionValue("gpus"));
if (maxGpus < 0) {
Expand All @@ -91,6 +90,31 @@ public class Arguments {
} else {
maxGpus = Integer.MAX_VALUE;
}
if (cmd.hasOption("threads")) {
threads = Integer.parseInt(cmd.getOptionValue("threads"));
Engine eng = Engine.getEngine(engine);
Device[] devices = eng.getDevices(maxGpus);
String deviceType = devices[0].getDeviceType();
if (Device.Type.GPU.equals(deviceType)) {
// one thread per GPU
if (threads <= 0) {
threads = devices.length;
} else if (threads < devices.length) {
threads = devices.length;
logger.warn(
"Number of threads is less than GPU count, adjust to: {}",
devices.length);
} else if ("MXNet".equals(engine) && threads > devices.length) {
threads = devices.length;
logger.warn("MXNet inference can only have one worker per GPU.");
} else if (threads % devices.length != 0) {
threads = threads / devices.length * devices.length;
logger.warn("threads should be multiple of GPU count, change to: {}", threads);
}
} else if (threads <= 0) {
threads = Runtime.getRuntime().availableProcessors();
}
}
if (cmd.hasOption("delay")) {
delay = Integer.parseInt(cmd.getOptionValue("delay"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration)
Engine engine = Engine.getEngine(arguments.getEngine());
Device[] devices = engine.getDevices(arguments.getMaxGpus());
int numOfThreads = arguments.getThreads();
if (numOfThreads < devices.length) {
logger.warn("Number of threads is less than GPU count, adjust to: {}", devices.length);
} else if (numOfThreads % devices.length != 0) {
numOfThreads = numOfThreads / devices.length * devices.length;
logger.warn("Number of threads should be multiple of GPU count.");
}
int delay = arguments.getDelay();
AtomicInteger counter = new AtomicInteger(iteration);
logger.info("Multithreading inference with {} threads.", numOfThreads);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ private void initModelStore() throws IOException {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}

int workers = configManager.getDefaultWorkers();
for (int i = 0; i < gpuIds.length; ++i) {
String modelVersion;
if (gpuIds.length > 1) {
Expand All @@ -360,7 +359,7 @@ private void initModelStore() throws IOException {
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
ModelInfo modelInfo = future.join();
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(1, workers));
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(1, -1));
}
startupModels.add(modelName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private void predict(
ConfigManager.getInstance().getBatchSize(),
ConfigManager.getInstance().getMaxBatchDelay(),
ConfigManager.getInstance().getMaxIdleTime())
.thenApply(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, 1)))
.thenApply(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, -1)))
.thenAccept(
m -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import ai.djl.ModelException;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Endpoint;
import ai.djl.serving.wlm.ModelInfo;
Expand Down Expand Up @@ -168,8 +167,7 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100);
int maxIdleTime = NettyUtils.getIntParameter(decoder, MAX_IDLE_TIME__PARAMETER, 60);
int minWorkers = NettyUtils.getIntParameter(decoder, MIN_WORKER_PARAMETER, 1);
int defaultWorkers = ConfigManager.getInstance().getDefaultWorkers();
int maxWorkers = NettyUtils.getIntParameter(decoder, MAX_WORKER_PARAMETER, defaultWorkers);
int maxWorkers = NettyUtils.getIntParameter(decoder, MAX_WORKER_PARAMETER, -1);
boolean synchronous =
Boolean.parseBoolean(
NettyUtils.getParameter(decoder, SYNCHRONOUS_PARAMETER, "true"));
Expand Down Expand Up @@ -247,19 +245,30 @@ private void handleScaleModel(
NettyUtils.getIntParameter(
decoder, MAX_BATCH_DELAY_PARAMETER, modelInfo.getMaxBatchDelay());

modelInfo
.scaleWorkers(minWorkers, maxWorkers)
.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay);
modelManager.triggerModelUpdated(modelInfo);
if (version == null) {
// scale all versions
Endpoint endpoint = modelManager.getEndpoints().get(modelName);
for (ModelInfo model : endpoint.getModels()) {
model.scaleWorkers(minWorkers, maxWorkers)
.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay);
modelManager.triggerModelUpdated(model);
}
} else {
modelInfo
.scaleWorkers(minWorkers, maxWorkers)
.configurePool(maxIdleTime)
.configureModelBatch(batchSize, maxBatchDelay);
modelManager.triggerModelUpdated(modelInfo);
}

String msg =
"Model \""
+ modelName
+ "\" worker scaled. New Worker configuration min workers:"
+ minWorkers
+ modelInfo.getMinWorkers()
+ " max workers:"
+ maxWorkers;
+ modelInfo.getMaxWorkers();
NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
} catch (NumberFormatException ex) {
throw new BadRequestException("parameter is invalid number." + ex.getMessage(), ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package ai.djl.serving.util;

import ai.djl.Device;
import ai.djl.ndarray.NDManager;
import ai.djl.serving.Arguments;
import ai.djl.util.Utils;
import io.netty.handler.ssl.SslContext;
Expand Down Expand Up @@ -46,7 +48,6 @@ public final class ConfigManager {
private static final String INFERENCE_ADDRESS = "inference_address";
private static final String MANAGEMENT_ADDRESS = "management_address";
private static final String LOAD_MODELS = "load_models";
private static final String DEFAULT_WORKERS_PER_MODEL = "default_workers_per_model";
private static final String NUMBER_OF_NETTY_THREADS = "number_of_netty_threads";
private static final String JOB_QUEUE_SIZE = "job_queue_size";
private static final String MAX_IDLE_TIME = "max_idle_time";
Expand Down Expand Up @@ -189,18 +190,30 @@ public int getMaxBatchDelay() {
/**
* Returns the default number of workers for a new registered model.
*
* @param manager the {@code NDManager} the model uses
* @param target the target number of worker
* @return the default number of workers for a new registered model
*/
public int getDefaultWorkers() {
if (isDebug()) {
public int getDefaultWorkers(NDManager manager, int target) {
if (target == 0) {
return 0;
} else if (target == -1 && isDebug()) {
return 1;
}
if (Device.Type.GPU.equals(manager.getDevice().getDeviceType())) {
if ("MXNet".equals(manager.getEngine().getEngineName())) {
// FIXME: MXNet GPU Model doesn't support multi-threading
return 1;
} else if (target == -1) {
target = 2; // default to max 2 workers per GPU
}
return target;
}

int workers = getIntProperty(DEFAULT_WORKERS_PER_MODEL, 0);
if (workers == 0) {
workers = Runtime.getRuntime().availableProcessors();
if (target > 0) {
return target;
}
return workers;
return Runtime.getRuntime().availableProcessors();
}

/**
Expand Down Expand Up @@ -391,8 +404,6 @@ public String dumpConfigurations() {
+ (getLoadModels() == null ? "N/A" : getLoadModels())
+ "\nNetty threads: "
+ getNettyThreads()
+ "\nDefault workers per model: "
+ getDefaultWorkers()
+ "\nMaximum Request Size: "
+ prop.getProperty(MAX_REQUEST_SIZE, "6553500");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.FilenameUtils;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.serving.util.ConfigManager;
import java.net.URI;
import java.nio.file.Path;
import java.util.Objects;
Expand Down Expand Up @@ -95,8 +97,10 @@ public ModelInfo configureModelBatch(int batchSize, int maxBatchDelay) {
* @return new configured ModelInfo.
*/
public ModelInfo scaleWorkers(int minWorkers, int maxWorkers) {
this.minWorkers = minWorkers;
this.maxWorkers = maxWorkers;
NDManager manager = model.getNDManager();
ConfigManager configManager = ConfigManager.getInstance();
this.maxWorkers = configManager.getDefaultWorkers(manager, maxWorkers);
this.minWorkers = Math.min(minWorkers, this.maxWorkers);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
*/
package ai.djl.serving.wlm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
Expand Down Expand Up @@ -86,9 +86,15 @@ public boolean addJob(Job job) {
int currentWorkers = getNumRunningWorkers(modelInfo);
if (currentWorkers == 0
|| currentWorkers < maxWorkers && queue.size() > modelInfo.getBatchSize() * 2) {
logger.info("Scaling up workers for model {} to {} ", modelInfo, currentWorkers + 1);
synchronized (modelInfo.getModel()) {
addThreads(pool.getWorkers(), modelInfo, 1, false);
currentWorkers = getNumRunningWorkers(modelInfo); // check again
if (currentWorkers < maxWorkers) {
logger.info(
"Scaling up workers for model {} to {} ",
modelInfo,
currentWorkers + 1);
addThreads(pool.getWorkers(), modelInfo, 1, false);
}
}
}
return true;
Expand Down Expand Up @@ -211,7 +217,7 @@ private static final class WorkerPool {
* @param model the model this WorkerPool belongs to.
*/
public WorkerPool(ModelInfo model) {
workers = Collections.synchronizedList(new ArrayList<>());
workers = new CopyOnWriteArrayList<>();
jobQueue = new LinkedBlockingDeque<>(model.getQueueSize());
modelName = model.getModelName();
}
Expand Down