Skip to content

Commit

Permalink
[serving] auto detect number of GPUs
Browse files Browse the repository at this point in the history
Change-Id: Ief11f4df5d7f1622d42385842dacadefc118e7d4
  • Loading branch information
frankfliu committed Jul 31, 2021
1 parent 21bf5fe commit 6ec6c1e
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 30 deletions.
7 changes: 7 additions & 0 deletions api/src/main/java/ai/djl/repository/Artifact.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.repository;

import ai.djl.Application;
import ai.djl.util.JsonUtils;
import java.io.Serializable;
import java.net.URI;
Expand Down Expand Up @@ -154,6 +155,12 @@ public Map<String, Object> getArguments(Map<String, Object> override) {
if (override != null) {
map.putAll(override);
}
if (!map.containsKey("application") && metadata != null) {
Application application = metadata.getApplication();
if (application != null && Application.UNDEFINED != application) {
map.put("application", application);
}
}
return map;
}

Expand Down
2 changes: 0 additions & 2 deletions api/src/main/java/ai/djl/repository/JarRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ private synchronized Metadata getMetadata() {
artifact.setName(modelName);

metadata = new Metadata.MatchAllMetadata();
metadata.setApplication(Application.UNDEFINED);
metadata.setGroupId(DefaultModelZoo.GROUP_ID);
metadata.setArtifactId(artifactId);
metadata.setArtifacts(Collections.singletonList(artifact));
String hash = md5hash(uri.toString());
Expand Down
12 changes: 2 additions & 10 deletions api/src/main/java/ai/djl/repository/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,20 +208,11 @@ public Application getApplication() {
*
* @param application {@link Application}
*/
public void setApplication(Application application) {
public final void setApplication(Application application) {
this.applicationClass = application;
this.application = application.getPath();
}

/**
* Returns the {@link Application} name.
*
* @return the {@link Application} name
*/
public String getApplicationName() {
return application;
}

/**
* Returns the {@link License}.
*
Expand Down Expand Up @@ -341,6 +332,7 @@ public static final class MatchAllMetadata extends Metadata {
public MatchAllMetadata() {
groupId = DefaultModelZoo.GROUP_ID;
artifacts = Collections.emptyList();
setApplication(Application.UNDEFINED);
}

/** {@inheritDoc} */
Expand Down
1 change: 0 additions & 1 deletion api/src/main/java/ai/djl/repository/SimpleRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ private synchronized Metadata getMetadata() throws IOException {
resolved = true;
metadata = new Metadata.MatchAllMetadata();
metadata.setRepositoryUri(URI.create(""));
metadata.setApplication(Application.UNDEFINED);
metadata.setArtifactId(artifactId);
if (!Files.exists(path)) {
logger.debug("Specified path doesn't exists: {}", path.toAbsolutePath());
Expand Down
2 changes: 0 additions & 2 deletions api/src/main/java/ai/djl/repository/SimpleUrlRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ private synchronized Metadata getMetadata() throws IOException {
artifact.setName(modelName);

metadata = new Metadata.MatchAllMetadata();
metadata.setApplication(Application.UNDEFINED);
metadata.setGroupId(DefaultModelZoo.GROUP_ID);
metadata.setArtifactId(artifactId);
metadata.setArtifacts(Collections.singletonList(artifact));
String hash = md5hash(uri.toString());
Expand Down
51 changes: 36 additions & 15 deletions serving/serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.serving.util.ServerGroups;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import ai.djl.util.cuda.CudaUtils;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
Expand All @@ -43,6 +44,7 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
Expand All @@ -56,7 +58,7 @@ public class ModelServer {

private static final Logger logger = LoggerFactory.getLogger(ModelServer.class);

private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[(.+)]=)?(.+)");
private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[?(.+)]?=)?(.+)");

private ServerGroups serverGroups;
private List<ChannelFuture> futures = new ArrayList<>(2);
Expand Down Expand Up @@ -310,7 +312,7 @@ private void initModelStore() throws IOException {
String modelUrl = matcher.group(3);
String version = null;
String engine = null;
int gpuId = -1;
int[] gpuIds = {-1};
String modelName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
Expand All @@ -322,25 +324,44 @@ private void initModelStore() throws IOException {
engine = tokens[2].isEmpty() ? null : tokens[2];
}
if (tokens.length > 3) {
gpuId = tokens[3].isEmpty() ? -1 : Integer.parseInt(tokens[3]);
if ("*".equals(tokens[3])) {
int gpuCount = CudaUtils.getGpuCount();
if (gpuCount > 0) {
gpuIds = IntStream.range(0, gpuCount).toArray();
}
} else if (!tokens[3].isEmpty()) {
gpuIds[0] = Integer.parseInt(tokens[3]);
}
}
} else {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}

int workers = configManager.getDefaultWorkers();
CompletableFuture<ModelInfo> future =
modelManager.registerModel(
modelName,
version,
modelUrl,
engine,
gpuId,
configManager.getBatchSize(),
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
ModelInfo modelInfo = future.join();
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(1, workers));
for (int i = 0; i < gpuIds.length; ++i) {
String modelVersion;
if (gpuIds.length > 1) {
if (version == null) {
modelVersion = "v" + i;
} else {
modelVersion = version + i;
}
} else {
modelVersion = version;
}
CompletableFuture<ModelInfo> future =
modelManager.registerModel(
modelName,
modelVersion,
modelUrl,
engine,
gpuIds[i],
configManager.getBatchSize(),
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
ModelInfo modelInfo = future.join();
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(1, workers));
}
startupModels.add(modelName);
}
}
Expand Down

0 comments on commit 6ec6c1e

Please sign in to comment.