Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1041] Add Java benchmark (#13095)
Browse files Browse the repository at this point in the history
* add java benchmark

* applied changes based on Piyush comments

* applies Andrew's change

* fix clojure test issue

* update the statistic names

* follow Naveen's instruction
  • Loading branch information
lanking520 authored and nswamy committed Nov 13, 2018
1 parent 3664a7c commit 1bb5b7f
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 3 deletions.
40 changes: 40 additions & 0 deletions scala-package/examples/scripts/benchmark/run_java_inference_bm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -e

hw_type=cpu
if [ "$USE_GPU" = "1" ]
then
hw_type=gpu
fi

platform=linux-x86_64

if [[ $OSTYPE = [darwin]* ]]
then
platform=osx-x86_64
fi

MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd)
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*

java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \
org.apache.mxnetexamples.javaapi.benchmark.JavaBenchmark $@

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ INPUT_IMG=$2
INPUT_DIR=$3

java -Xmx8G -cp $CLASS_PATH \
org.apache.mxnetexamples.infer.javapi.objectdetector.SSDClassifierExample \
org.apache.mxnetexamples.javaapi.infer.objectdetector.SSDClassifierExample \
--model-path-prefix $MODEL_DIR \
--input-image $INPUT_IMG \
--input-dir $INPUT_DIR
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mxnetexamples.javaapi.benchmark;

import org.apache.mxnet.javaapi.Context;
import org.kohsuke.args4j.Option;

import java.util.List;

abstract class InferBase {
@Option(name = "--num-runs", usage = "Number of runs")
public int numRun = 1;
@Option(name = "--model-name", usage = "Name of the model")
public String modelName = "";
@Option(name = "--batchsize", usage = "Size of the batch")
public int batchSize = 1;

public abstract void preProcessModel(List<Context> context);
public abstract void runSingleInference();
public abstract void runBatchInference();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.javaapi.benchmark;

import org.apache.mxnet.javaapi.Context;
import org.kohsuke.args4j.CmdLineParser;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class JavaBenchmark {

private static boolean runBatch = false;

private static void parse(Object inst, String[] args) {
CmdLineParser parser = new CmdLineParser(inst);
try {
parser.parseArgument(args);
} catch (Exception e) {
System.err.println(e.getMessage() + e);
parser.printUsage(System.err);
System.exit(1);
}
}

private static long percentile(int p, long[] seq) {
Arrays.sort(seq);
int k = (int) Math.ceil((seq.length - 1) * (p / 100.0));
return seq[k];
}

private static void printStatistics(long[] inferenceTimesRaw, String metricsPrefix) {
long[] inferenceTimes = inferenceTimesRaw;
// remove head and tail
if (inferenceTimes.length > 2) {
inferenceTimes = Arrays.copyOfRange(inferenceTimesRaw,
1, inferenceTimesRaw.length - 1);
}
double p50 = percentile(50, inferenceTimes) / 1.0e6;
double p99 = percentile(99, inferenceTimes) / 1.0e6;
double p90 = percentile(90, inferenceTimes) / 1.0e6;
long sum = 0;
for (long time: inferenceTimes) sum += time;
double average = sum / (inferenceTimes.length * 1.0e6);

System.out.println(
String.format("\n%s_p99 %fms\n%s_p90 %fms\n%s_p50 %fms\n%s_average %1.2fms",
metricsPrefix, p99, metricsPrefix, p90,
metricsPrefix, p50, metricsPrefix, average)
);

}

private static List<Context> getContext() {
List<Context> context = new ArrayList<Context>();
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
context.add(Context.gpu());
} else {
context.add(Context.cpu());
}
return context;
}

public static void main(String[] args) {
if (args.length < 2) {
StringBuilder sb = new StringBuilder();
sb.append("Please follow the format:");
sb.append("\n --model-name <model-name>");
sb.append("\n --num-runs <number of runs>");
sb.append("\n --batchsize <batch size>");
System.out.println(sb.toString());
return;
}
String modelName = args[1];
InferBase model = null;
switch(modelName) {
case "ObjectDetection":
runBatch = true;
ObjectDetectionBenchmark inst = new ObjectDetectionBenchmark();
parse(inst, args);
model = inst;
default:
System.err.println("Model name not found! " + modelName);
System.exit(1);
}
List<Context> context = getContext();
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
context.add(Context.gpu());
} else {
context.add(Context.cpu());
}

long[] result = new long[model.numRun];
model.preProcessModel(context);
if (runBatch) {
for (int i =0;i < model.numRun; i++) {
long currTime = System.nanoTime();
model.runBatchInference();
result[i] = System.nanoTime() - currTime;
}
System.out.println("Batchsize: " + model.batchSize);
System.out.println("Num of runs: " + model.numRun);
printStatistics(result, modelName +"batch_inference");
}

model.batchSize = 1;
model.preProcessModel(context);
result = new long[model.numRun];
for (int i = 0; i < model.numRun; i++) {
long currTime = System.nanoTime();
model.runSingleInference();
result[i] = System.nanoTime() - currTime;
}
System.out.println("Num of runs: " + model.numRun);
printStatistics(result, modelName + "single_inference");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.javaapi.benchmark;

import org.apache.mxnet.infer.javaapi.ObjectDetector;
import org.apache.mxnet.javaapi.*;
import org.kohsuke.args4j.Option;

import java.util.ArrayList;
import java.util.List;

class ObjectDetectionBenchmark extends InferBase {
@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
public String modelPathPrefix = "/model/ssd_resnet50_512";
@Option(name = "--input-image", usage = "the input image")
public String inputImagePath = "/images/dog.jpg";

private ObjectDetector objDet;
private NDArray img;
private NDArray$ NDArray = NDArray$.MODULE$;

public void preProcessModel(List<Context> context) {
Shape inputShape = new Shape(new int[] {this.batchSize, 3, 512, 512});
List<DataDesc> inputDescriptors = new ArrayList<>();
inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
img = ObjectDetector.bufferedImageToPixels(
ObjectDetector.reshapeImage(
ObjectDetector.loadImageFromFile(inputImagePath), 512, 512
),
new Shape(new int[] {1, 3, 512, 512})
);
}

public void runSingleInference() {
List<NDArray> nd = new ArrayList<>();
nd.add(img);
objDet.objectDetectWithNDArray(nd, 3);
}

public void runBatchInference() {
List<NDArray> nd = new ArrayList<>();
NDArray[] temp = new NDArray[batchSize];
for (int i = 0; i < batchSize; i++) temp[i] = img.copy();
NDArray batched = NDArray.concat(temp, batchSize).setdim(0).invoke().get();
nd.add(batched);
objDet.objectDetectWithNDArray(nd, 3);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.mxnetexamples.infer.javapi.objectdetector;
package org.apache.mxnetexamples.javaapi.infer.objectdetector;

import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput;
import org.kohsuke.args4j.CmdLineParser;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.mxnet.infer.javaapi
import java.awt.image.BufferedImage
// scalastyle:on

import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray}
import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray, Shape}

import scala.collection.JavaConverters
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -113,6 +113,14 @@ object ObjectDetector {
org.apache.mxnet.infer.ImageClassifier.loadImageFromFile(inputImagePath)
}

def reshapeImage(img : BufferedImage, newWidth: Int, newHeight: Int): BufferedImage = {
org.apache.mxnet.infer.ImageClassifier.reshapeImage(img, newWidth, newHeight)
}

def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape)
}

def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
org.apache.mxnet.infer.ImageClassifier
.loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
Expand Down

0 comments on commit 1bb5b7f

Please sign in to comment.