Skip to content

Commit

Permalink
[djl-serving] Update auto scale algorithm (#1149)
Browse files Browse the repository at this point in the history
Change-Id: Ib7b8d1dd4f2df95e928967f417dc872040e7e81e
  • Loading branch information
frankfliu authored Aug 6, 2021
1 parent 2cc65c1 commit 0df8265
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class DescribeModelResponse {
private int batchSize;
private int maxBatchDelay;
private int maxIdleTime;
private int queueLength;
private String status;
private boolean loadedAtStartup;

Expand Down Expand Up @@ -162,6 +163,24 @@ public void setMaxBatchDelay(int maxBatchDelay) {
this.maxBatchDelay = maxBatchDelay;
}

/**
* Returns the number of request in the queue.
*
* @return the number of request in the queue
*/
public int getQueueLength() {
return queueLength;
}

/**
* Sets the number of request in the queue.
*
* @param queueLength the number of request in the queue
*/
public void setQueueLength(int queueLength) {
this.queueLength = queueLength;
}

/**
* Returns the model's status.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ public DescribeModelResponse describeModel(String modelName, String version)
resp.setMaxWorkers(model.getMaxWorkers());
resp.setMinWorkers(model.getMinWorkers());
resp.setMaxIdleTime(model.getMaxIdleTime());
resp.setQueueLength(wlm.getQueueLength(model));
resp.setLoadedAtStartup(startupModels.contains(modelName));

int activeWorker = wlm.getNumRunningWorkers(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -71,42 +70,28 @@ public List<WorkerThread> getWorkers(ModelInfo modelInfo) {
* @return {@code true} if submit success, false otherwise.
*/
public boolean addJob(Job job) {
boolean accepted = false;
ModelInfo modelInfo = job.getModel();
int maxWorkers = modelInfo.getMaxWorkers();
if (maxWorkers == 0) {
logger.info("All model workers has been shutdown: {}", modelInfo.getModelName());
return false;
}
WorkerPool pool = getWorkerPoolForModel(modelInfo);
if (getNumRunningWorkers(modelInfo) > 0) {
try {
accepted = pool.getJobQueue().offer(job);
if (!accepted) {
synchronized (modelInfo.getModel()) {
scaleUpWorkers(modelInfo, pool);
accepted =
pool.getJobQueue()
.offer(
job,
modelInfo.getMaxBatchDelay(),
TimeUnit.MILLISECONDS);
}
}

} catch (InterruptedException e) {
logger.info(
"Worker Queue Capacity Exceeded. cannot add to worker queue in appropriate time. You can configure max batch delay time for this model.");
}
LinkedBlockingDeque<Job> queue = pool.getJobQueue();
if (!queue.offer(job)) {
logger.warn("Worker queue capacity exceeded for model: {}", modelInfo.getModelName());
return false;
}
return accepted;
}

private void scaleUpWorkers(ModelInfo modelInfo, WorkerPool pool) {
int currentWorkers = getNumRunningWorkers(modelInfo);
if (currentWorkers < modelInfo.getMaxWorkers()) {
logger.debug("scaling up workers for model {} to {} ", modelInfo, currentWorkers + 1);
addThreads(pool.getWorkers(), modelInfo, 1, false);
} else {
logger.warn(
"scale up capacity of {} workers reached. Unable to scale up worker pool.",
modelInfo.getMaxWorkers());
if (currentWorkers == 0
|| currentWorkers < maxWorkers && queue.size() > modelInfo.getBatchSize() * 2) {
logger.info("Scaling up workers for model {} to {} ", modelInfo, currentWorkers + 1);
synchronized (modelInfo.getModel()) {
addThreads(pool.getWorkers(), modelInfo, 1, false);
}
}
return true;
}

/**
Expand Down Expand Up @@ -177,6 +162,17 @@ public void modelChanged(ModelInfo modelInfo) {
}
}

/**
* Returns the current number of request in the queue.
*
* @param modelInfo the model
* @return the current number of request in the queue
*/
public int getQueueLength(ModelInfo modelInfo) {
WorkerPool pool = getWorkerPoolForModel(modelInfo);
return pool.getJobQueue().size();
}

private WorkerPool getWorkerPoolForModel(ModelInfo modelInfo) {
return workerPools.computeIfAbsent(modelInfo, k -> new WorkerPool(modelInfo));
}
Expand Down

0 comments on commit 0df8265

Please sign in to comment.