Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serving] auto detect number of GPUs #1132

Merged
merged 1 commit into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/src/main/java/ai/djl/repository/Artifact.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.repository;

import ai.djl.Application;
import ai.djl.util.JsonUtils;
import java.io.Serializable;
import java.net.URI;
Expand Down Expand Up @@ -154,6 +155,12 @@ public Map<String, Object> getArguments(Map<String, Object> override) {
if (override != null) {
map.putAll(override);
}
if (!map.containsKey("application") && metadata != null) {
Application application = metadata.getApplication();
if (application != null && Application.UNDEFINED != application) {
map.put("application", application.getPath());
}
}
return map;
}

Expand Down
2 changes: 0 additions & 2 deletions api/src/main/java/ai/djl/repository/JarRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ private synchronized Metadata getMetadata() {
artifact.setName(modelName);

metadata = new Metadata.MatchAllMetadata();
metadata.setApplication(Application.UNDEFINED);
metadata.setGroupId(DefaultModelZoo.GROUP_ID);
metadata.setArtifactId(artifactId);
metadata.setArtifacts(Collections.singletonList(artifact));
String hash = md5hash(uri.toString());
Expand Down
12 changes: 2 additions & 10 deletions api/src/main/java/ai/djl/repository/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,20 +208,11 @@ public Application getApplication() {
*
* @param application {@link Application}
*/
public void setApplication(Application application) {
public final void setApplication(Application application) {
this.applicationClass = application;
this.application = application.getPath();
}

/**
* Returns the {@link Application} name.
*
* @return the {@link Application} name
*/
public String getApplicationName() {
return application;
}

/**
* Returns the {@link License}.
*
Expand Down Expand Up @@ -341,6 +332,7 @@ public static final class MatchAllMetadata extends Metadata {
public MatchAllMetadata() {
groupId = DefaultModelZoo.GROUP_ID;
artifacts = Collections.emptyList();
setApplication(Application.UNDEFINED);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ private synchronized Metadata getMetadata() throws IOException {
resolved = true;
metadata = new Metadata.MatchAllMetadata();
metadata.setRepositoryUri(URI.create(""));
metadata.setApplication(Application.UNDEFINED);
metadata.setArtifactId(artifactId);
if (!Files.exists(path)) {
logger.debug("Specified path doesn't exists: {}", path.toAbsolutePath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ private synchronized Metadata getMetadata() throws IOException {
artifact.setName(modelName);

metadata = new Metadata.MatchAllMetadata();
metadata.setApplication(Application.UNDEFINED);
metadata.setGroupId(DefaultModelZoo.GROUP_ID);
metadata.setArtifactId(artifactId);
metadata.setArtifacts(Collections.singletonList(artifact));
String hash = md5hash(uri.toString());
Expand Down
51 changes: 36 additions & 15 deletions serving/serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.serving.util.ServerGroups;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import ai.djl.util.cuda.CudaUtils;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
Expand All @@ -43,6 +44,7 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
Expand All @@ -56,7 +58,7 @@ public class ModelServer {

private static final Logger logger = LoggerFactory.getLogger(ModelServer.class);

private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[(.+)]=)?(.+)");
private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[?(.+?)]?=)?(.+)");

private ServerGroups serverGroups;
private List<ChannelFuture> futures = new ArrayList<>(2);
Expand Down Expand Up @@ -310,7 +312,7 @@ private void initModelStore() throws IOException {
String modelUrl = matcher.group(3);
String version = null;
String engine = null;
int gpuId = -1;
int[] gpuIds = {-1};
String modelName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
Expand All @@ -322,25 +324,44 @@ private void initModelStore() throws IOException {
engine = tokens[2].isEmpty() ? null : tokens[2];
}
if (tokens.length > 3) {
gpuId = tokens[3].isEmpty() ? -1 : Integer.parseInt(tokens[3]);
if ("*".equals(tokens[3])) {
int gpuCount = CudaUtils.getGpuCount();
if (gpuCount > 0) {
gpuIds = IntStream.range(0, gpuCount).toArray();
}
} else if (!tokens[3].isEmpty()) {
gpuIds[0] = Integer.parseInt(tokens[3]);
}
}
} else {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}

int workers = configManager.getDefaultWorkers();
CompletableFuture<ModelInfo> future =
modelManager.registerModel(
modelName,
version,
modelUrl,
engine,
gpuId,
configManager.getBatchSize(),
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
ModelInfo modelInfo = future.join();
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(1, workers));
for (int i = 0; i < gpuIds.length; ++i) {
String modelVersion;
if (gpuIds.length > 1) {
if (version == null) {
modelVersion = "v" + i;
} else {
modelVersion = version + i;
}
} else {
modelVersion = version;
}
CompletableFuture<ModelInfo> future =
modelManager.registerModel(
modelName,
modelVersion,
modelUrl,
engine,
gpuIds[i],
configManager.getBatchSize(),
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
ModelInfo modelInfo = future.join();
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(1, workers));
}
startupModels.add(modelName);
}
}
Expand Down
2 changes: 1 addition & 1 deletion serving/serving/src/test/resources/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ inference_address=https://127.0.0.1:8443
management_address=https://127.0.0.1:8443
# management_address=unix:/tmp/management.sock
# model_store=models
load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:]=https://resources.djl.ai/test-models/mlp.tar.gz
load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:*]=https://resources.djl.ai/test-models/mlp.tar.gz
# model_url_pattern=.*
# number_of_netty_threads=0
# default_workers_per_model=0
Expand Down