Skip to content

Commit

Permalink
Support benchmark onnx on GPU machine
Browse files Browse the repository at this point in the history
Change-Id: I5c2b18b06202941fe6cac3b9d538ba74405a9677
  • Loading branch information
frankfliu committed Aug 6, 2021
1 parent eea8c6c commit f0e58c0
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 6 deletions.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/repository/zoo/Criteria.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public ZooModel<I, O> loadModel()
Set<String> supportedEngine = modelZoo.getSupportedEngines();
if (engine != null && !supportedEngine.contains(engine)) {
throw new ModelNotFoundException(
"ModelZoo doesn't support specified with engine: " + engine);
"ModelZoo doesn't support specified engine: " + engine);
}
list.add(modelZoo);
} else {
Expand Down
17 changes: 16 additions & 1 deletion extensions/benchmark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,22 @@ dependencies {
runtimeOnly project(":paddlepaddle:paddlepaddle-model-zoo")
runtimeOnly "ai.djl.paddlepaddle:paddlepaddle-native-auto:${paddlepaddle_version}"

runtimeOnly project(":onnxruntime:onnxruntime-engine")
ProcessBuilder pb = new ProcessBuilder("nvidia-smi", "-L")
def hasGPU = false;
try {
Process process = pb.start()
hasGPU = process.waitFor() == 0
} catch (IOException ignore) {
}

if (hasGPU) {
runtimeOnly(project(":onnxruntime:onnxruntime-engine")) {
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
}
runtimeOnly "com.microsoft.onnxruntime:onnxruntime_gpu:${onnxruntime_version}"
} else {
runtimeOnly project(":onnxruntime:onnxruntime-engine")
}

runtimeOnly project(":dlr:dlr-engine")
runtimeOnly "ai.djl.dlr:dlr-native-auto:${dlr_version}"
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/onnxruntime-engine/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Maven:
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime_gpu</artifactId>
<version>1.7.0</version>
<version>1.8.1</version>
<scope>runtime</scope>
</dependency>
```
Expand All @@ -83,5 +83,5 @@ Gradle:
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.12.0") {
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
}
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.7.0"
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.8.1"
```
17 changes: 17 additions & 0 deletions serving/serving/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ dependencies {
runtimeOnly project(":tensorflow:tensorflow-model-zoo")
runtimeOnly project(":pytorch:pytorch-model-zoo")

ProcessBuilder pb = new ProcessBuilder("nvidia-smi", "-L")
def hasGPU = false;
try {
Process process = pb.start()
hasGPU = process.waitFor() == 0
} catch (IOException ignore) {
}

if (hasGPU) {
runtimeOnly(project(":onnxruntime:onnxruntime-engine")) {
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
}
runtimeOnly "com.microsoft.onnxruntime:onnxruntime_gpu:${onnxruntime_version}"
} else {
runtimeOnly project(":onnxruntime:onnxruntime-engine")
}

runtimeOnly "ai.djl.mxnet:mxnet-native-auto:${mxnet_version}"
runtimeOnly "ai.djl.pytorch:pytorch-native-auto:${pytorch_version}-SNAPSHOT"
runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto:${tensorflow_version}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ public CompletableFuture<ModelInfo> registerModel(
.optEngine(engineName);
if (gpuId != -1) {
builder.optDevice(Device.gpu(gpuId));
logger.info("Loading model {} on {}.", modelName, Device.cpu());
} else {
logger.info("Loading model {} on {}.", modelName, Device.gpu(gpuId));
} else {
logger.info("Loading model {} on {}.", modelName, Device.cpu());
}

ZooModel<Input, Output> model = builder.build().loadModel();
Expand Down

0 comments on commit f0e58c0

Please sign in to comment.