diff --git a/scala-package/examples/scripts/benchmark/run_java_inference_bm.sh b/scala-package/examples/scripts/benchmark/run_java_inference_bm.sh new file mode 100644 index 000000000000..5a468e344829 --- /dev/null +++ b/scala-package/examples/scripts/benchmark/run_java_inference_bm.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +hw_type=cpu +if [ "$USE_GPU" = "1" ] +then + hw_type=gpu +fi + +platform=linux-x86_64 + +if [[ $OSTYPE = [darwin]* ]] +then + platform=osx-x86_64 +fi + +MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd) +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/* + +java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \ + org.apache.mxnetexamples.javaapi.benchmark.JavaBenchmark $@ + diff --git a/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh index f444a3a59af7..00ed793a7bb5 100755 --- a/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh +++ b/scala-package/examples/scripts/infer/objectdetector/run_ssd_java_example.sh @@ -41,7 +41,7 @@ INPUT_IMG=$2 INPUT_DIR=$3 java -Xmx8G -cp $CLASS_PATH \ - org.apache.mxnetexamples.infer.javapi.objectdetector.SSDClassifierExample \ + org.apache.mxnetexamples.javaapi.infer.objectdetector.SSDClassifierExample \ --model-path-prefix $MODEL_DIR \ --input-image $INPUT_IMG \ --input-dir $INPUT_DIR diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/InferBase.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/InferBase.java new file mode 100644 index 000000000000..fdcde6b4152c --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/InferBase.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnetexamples.javaapi.benchmark; + +import org.apache.mxnet.javaapi.Context; +import org.kohsuke.args4j.Option; + +import java.util.List; + +abstract class InferBase { + @Option(name = "--num-runs", usage = "Number of runs") + public int numRun = 1; + @Option(name = "--model-name", usage = "Name of the model") + public String modelName = ""; + @Option(name = "--batchsize", usage = "Size of the batch") + public int batchSize = 1; + + public abstract void preProcessModel(List context); + public abstract void runSingleInference(); + public abstract void runBatchInference(); +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/JavaBenchmark.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/JavaBenchmark.java new file mode 100644 index 000000000000..1baca20fbe6d --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/JavaBenchmark.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.javaapi.benchmark; + +import org.apache.mxnet.javaapi.Context; +import org.kohsuke.args4j.CmdLineParser; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class JavaBenchmark { + + private static boolean runBatch = false; + + private static void parse(Object inst, String[] args) { + CmdLineParser parser = new CmdLineParser(inst); + try { + parser.parseArgument(args); + } catch (Exception e) { + System.err.println(e.getMessage() + e); + parser.printUsage(System.err); + System.exit(1); + } + } + + private static long percentile(int p, long[] seq) { + Arrays.sort(seq); + int k = (int) Math.ceil((seq.length - 1) * (p / 100.0)); + return seq[k]; + } + + private static void printStatistics(long[] inferenceTimesRaw, String metricsPrefix) { + long[] inferenceTimes = inferenceTimesRaw; + // remove head and tail + if (inferenceTimes.length > 2) { + inferenceTimes = Arrays.copyOfRange(inferenceTimesRaw, + 1, inferenceTimesRaw.length - 1); + } + double p50 = percentile(50, inferenceTimes) / 1.0e6; + double p99 = percentile(99, inferenceTimes) / 1.0e6; + double p90 = percentile(90, inferenceTimes) / 1.0e6; + long sum = 0; + for (long time: inferenceTimes) sum += time; + double average = sum / (inferenceTimes.length * 1.0e6); + + System.out.println( + String.format("\n%s_p99 %fms\n%s_p90 %fms\n%s_p50 %fms\n%s_average %1.2fms", + metricsPrefix, p99, metricsPrefix, p90, + metricsPrefix, p50, metricsPrefix, average) + ); + + } + + private static List getContext() { + List context = new ArrayList(); + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { + context.add(Context.gpu()); + } else { + context.add(Context.cpu()); + } + return context; + } + + public static void main(String[] args) { + if (args.length < 2) { + StringBuilder sb = new StringBuilder(); + sb.append("Please follow the format:"); + sb.append("\n --model-name "); + sb.append("\n --num-runs "); + sb.append("\n --batchsize "); + System.out.println(sb.toString()); + return; + } + String modelName = args[1]; + InferBase model = null; + switch(modelName) { + case "ObjectDetection": + runBatch = true; + ObjectDetectionBenchmark inst = new ObjectDetectionBenchmark(); + parse(inst, args); + model = inst; + default: + System.err.println("Model name not found! " + modelName); + System.exit(1); + } + List context = getContext(); + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { + context.add(Context.gpu()); + } else { + context.add(Context.cpu()); + } + + long[] result = new long[model.numRun]; + model.preProcessModel(context); + if (runBatch) { + for (int i =0;i < model.numRun; i++) { + long currTime = System.nanoTime(); + model.runBatchInference(); + result[i] = System.nanoTime() - currTime; + } + System.out.println("Batchsize: " + model.batchSize); + System.out.println("Num of runs: " + model.numRun); + printStatistics(result, modelName +"batch_inference"); + } + + model.batchSize = 1; + model.preProcessModel(context); + result = new long[model.numRun]; + for (int i = 0; i < model.numRun; i++) { + long currTime = System.nanoTime(); + model.runSingleInference(); + result[i] = System.nanoTime() - currTime; + } + System.out.println("Num of runs: " + model.numRun); + printStatistics(result, modelName + "single_inference"); + } +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java new file mode 100644 index 000000000000..485e0afa3e46 --- /dev/null +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.javaapi.benchmark; + +import org.apache.mxnet.infer.javaapi.ObjectDetector; +import org.apache.mxnet.javaapi.*; +import org.kohsuke.args4j.Option; + +import java.util.ArrayList; +import java.util.List; + +class ObjectDetectionBenchmark extends InferBase { + @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") + public String modelPathPrefix = "/model/ssd_resnet50_512"; + @Option(name = "--input-image", usage = "the input image") + public String inputImagePath = "/images/dog.jpg"; + + private ObjectDetector objDet; + private NDArray img; + private NDArray$ NDArray = NDArray$.MODULE$; + + public void preProcessModel(List context) { + Shape inputShape = new Shape(new int[] {this.batchSize, 3, 512, 512}); + List inputDescriptors = new ArrayList<>(); + inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW")); + objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0); + img = ObjectDetector.bufferedImageToPixels( + ObjectDetector.reshapeImage( + ObjectDetector.loadImageFromFile(inputImagePath), 512, 512 + ), + new Shape(new int[] {1, 3, 512, 512}) + ); + } + + public void runSingleInference() { + List nd = new ArrayList<>(); + nd.add(img); + objDet.objectDetectWithNDArray(nd, 3); + } + + public void runBatchInference() { + List nd = new ArrayList<>(); + NDArray[] temp = new NDArray[batchSize]; + for (int i = 0; i < batchSize; i++) temp[i] = img.copy(); + NDArray batched = NDArray.concat(temp, batchSize).setdim(0).invoke().get(); + nd.add(batched); + objDet.objectDetectWithNDArray(nd, 3); + } +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md similarity index 100% rename from scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/README.md rename to scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java similarity index 99% rename from scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java rename to scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java index 13f9d2d9a3e5..4befc8edde6b 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/infer/javapi/objectdetector/SSDClassifierExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.mxnetexamples.infer.javapi.objectdetector; +package org.apache.mxnetexamples.javaapi.infer.objectdetector; import org.apache.mxnet.infer.javaapi.ObjectDetectorOutput; import org.kohsuke.args4j.CmdLineParser; diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala index f48375ffe4a7..447518b5a89c 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetector.scala @@ -21,7 +21,7 @@ package org.apache.mxnet.infer.javaapi import java.awt.image.BufferedImage // scalastyle:on -import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray} +import org.apache.mxnet.javaapi.{Context, DataDesc, NDArray, Shape} import scala.collection.JavaConverters import scala.collection.JavaConverters._ @@ -113,6 +113,14 @@ object ObjectDetector { org.apache.mxnet.infer.ImageClassifier.loadImageFromFile(inputImagePath) } + def reshapeImage(img : BufferedImage, newWidth: Int, newHeight: Int): BufferedImage = { + org.apache.mxnet.infer.ImageClassifier.reshapeImage(img, newWidth, newHeight) + } + + def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = { + org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape) + } + def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = { org.apache.mxnet.infer.ImageClassifier .loadInputBatch(inputImagePaths.asScala.toList).toList.asJava