Skip to content

Commit

Permalink
[djl-bench] Add a tool to genreate NDList file (#1155)
Browse files Browse the repository at this point in the history
Change-Id: I37c1fe468a5b4784a83e9be4ce62f9b6f10e2b49
  • Loading branch information
frankfliu authored Aug 12, 2021
1 parent c09b449 commit 520f90f
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.time.Duration;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
Expand Down Expand Up @@ -72,7 +71,8 @@ public final boolean runBenchmark(String[] args) {
Options options = Arguments.getOptions();
try {
if (Arguments.hasHelp(args)) {
printHelp("djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS]", options);
Arguments.printHelp(
"usage: djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS]", options);
return true;
}
DefaultParser parser = new DefaultParser();
Expand Down Expand Up @@ -233,7 +233,7 @@ public final boolean runBenchmark(String[] args) {
}
return true;
} catch (ParseException e) {
printHelp(e.getMessage(), options);
Arguments.printHelp(e.getMessage(), options);
} catch (Throwable t) {
logger.error("Unexpected error", t);
}
Expand Down Expand Up @@ -269,13 +269,6 @@ protected ZooModel<Void, float[]> loadModel(Arguments arguments, Metrics metrics
return model;
}

private void printHelp(String msg, Options options) {
HelpFormatter formatter = new HelpFormatter();
formatter.setLeftPadding(1);
formatter.setWidth(120);
formatter.printHelp(msg, options);
}

private static final class BenchmarkTranslator implements Translator<Void, float[]> {

private PairList<DataType, Shape> shapes;
Expand Down
63 changes: 10 additions & 53 deletions extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionGroup;
import org.apache.commons.cli.Options;
Expand Down Expand Up @@ -67,7 +66,6 @@ public class Arguments {

modelName = cmd.getOptionValue("model-name");
outputDir = cmd.getOptionValue("output-dir");
inputShapes = new PairList<>();

if (cmd.hasOption("engine")) {
engine = cmd.getOptionValue("engine");
Expand Down Expand Up @@ -120,56 +118,7 @@ public class Arguments {
}

String shape = cmd.getOptionValue("input-shapes");
if (shape != null) {
if (shape.contains("(")) {
Pattern pattern =
Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)([sdubilBfS]?)");
Matcher matcher = pattern.matcher(shape);
while (matcher.find()) {
String[] tokens = matcher.group(1).split(",");
long[] array = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
DataType dataType;
String dataTypeStr = matcher.group(4);
if (dataTypeStr == null || dataTypeStr.isEmpty()) {
dataType = DataType.FLOAT32;
} else {
switch (dataTypeStr) {
case "s":
dataType = DataType.FLOAT16;
break;
case "d":
dataType = DataType.FLOAT64;
break;
case "u":
dataType = DataType.UINT8;
break;
case "b":
dataType = DataType.INT8;
break;
case "i":
dataType = DataType.INT32;
break;
case "l":
dataType = DataType.INT64;
break;
case "B":
dataType = DataType.BOOLEAN;
break;
case "f":
dataType = DataType.FLOAT32;
break;
default:
throw new IllegalArgumentException("Invalid input-shape: " + shape);
}
}
inputShapes.add(dataType, new Shape(array));
}
} else {
String[] tokens = shape.split(",");
long[] shapes = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
inputShapes.add(DataType.FLOAT32, new Shape(shapes));
}
}
inputShapes = NDListGenerator.parseShape(shape);
}

static Options getOptions() {
Expand Down Expand Up @@ -265,6 +214,14 @@ static boolean hasHelp(String[] args) {
return list.contains("-h") || list.contains("help");
}

static void printHelp(String msg, Options options) {
HelpFormatter formatter = new HelpFormatter();
formatter.setSyntaxPrefix("");
formatter.setLeftPadding(1);
formatter.setWidth(120);
formatter.printHelp(msg, options);
}

int getDuration() {
return duration;
}
Expand Down
18 changes: 12 additions & 6 deletions extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,21 @@ public static void main(String[] args) {
return;
}
List<String> list = Arrays.asList(args);
boolean multithreading = list.contains("-t") || list.contains("--threads");
configEngines(multithreading);
boolean success;
if (multithreading) {
success = new MultithreadedBenchmark().runBenchmark(args);
if (!list.isEmpty() && "ndlist-gen".equals(list.get(0))) {
success = NDListGenerator.generate(Arrays.copyOfRange(args, 1, args.length));
} else {
success = new Benchmark().runBenchmark(args);
boolean multithreading = list.contains("-t") || list.contains("--threads");
configEngines(multithreading);
if (multithreading) {
success = new MultithreadedBenchmark().runBenchmark(args);
} else {
success = new Benchmark().runBenchmark(args);
}
}
if (!success) {
System.exit(-1); // NOPMD
}
System.exit(success ? 0 : -1); // NOPMD
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.benchmark;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.BufferedOutputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** A class generates NDList files. */
final class NDListGenerator {

private static final Logger logger = LoggerFactory.getLogger(NDListGenerator.class);

private NDListGenerator() {}

static boolean generate(String[] args) {
Options options = getOptions();
try {
if (Arguments.hasHelp(args)) {
Arguments.printHelp(
"usage: djl-bench ndlist-gen -s INPUT-SHAPES -o OUTPUT_FILE", options);
return true;
}
DefaultParser parser = new DefaultParser();
CommandLine cmd = parser.parse(options, args, null, false);
String inputShapes = cmd.getOptionValue("input-shapes");
String output = cmd.getOptionValue("output-file");
boolean ones = cmd.hasOption("ones");
Path path = Paths.get(output);

try (NDManager manager = NDManager.newBaseManager()) {
NDList list = new NDList();
for (Pair<DataType, Shape> pair : parseShape(inputShapes)) {
DataType dataType = pair.getKey();
Shape shape = pair.getValue();
if (ones) {
list.add(manager.ones(shape, dataType));
} else {
list.add(manager.zeros(shape, dataType));
}
}
try (OutputStream os = new BufferedOutputStream(Files.newOutputStream(path))) {
list.encode(os);
}
}
logger.info("NDList file created: {}", path.toAbsolutePath());
return true;
} catch (ParseException e) {
Arguments.printHelp(e.getMessage(), options);
} catch (Throwable t) {
logger.error("Unexpected error", t);
}
return false;
}

static PairList<DataType, Shape> parseShape(String shape) {
PairList<DataType, Shape> inputShapes = new PairList<>();
if (shape != null) {
if (shape.contains("(")) {
Pattern pattern =
Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)([sdubilBfS]?)");
Matcher matcher = pattern.matcher(shape);
while (matcher.find()) {
String[] tokens = matcher.group(1).split(",");
long[] array = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
DataType dataType;
String dataTypeStr = matcher.group(4);
if (dataTypeStr == null || dataTypeStr.isEmpty()) {
dataType = DataType.FLOAT32;
} else {
switch (dataTypeStr) {
case "s":
dataType = DataType.FLOAT16;
break;
case "d":
dataType = DataType.FLOAT64;
break;
case "u":
dataType = DataType.UINT8;
break;
case "b":
dataType = DataType.INT8;
break;
case "i":
dataType = DataType.INT32;
break;
case "l":
dataType = DataType.INT64;
break;
case "B":
dataType = DataType.BOOLEAN;
break;
case "f":
dataType = DataType.FLOAT32;
break;
default:
throw new IllegalArgumentException("Invalid input-shape: " + shape);
}
}
inputShapes.add(dataType, new Shape(array));
}
} else {
String[] tokens = shape.split(",");
long[] shapes = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray();
inputShapes.add(DataType.FLOAT32, new Shape(shapes));
}
}
return inputShapes;
}

private static Options getOptions() {
Options options = new Options();
options.addOption(
Option.builder("h").longOpt("help").hasArg(false).desc("Print this help.").build());
options.addOption(
Option.builder("s")
.required()
.longOpt("input-shapes")
.hasArg()
.argName("INPUT-SHAPES")
.desc("Input data shapes for the model.")
.build());
options.addOption(
Option.builder("o")
.required()
.longOpt("output-file")
.hasArg()
.argName("OUTPUT-FILE")
.desc("Write output NDList to file.")
.build());
options.addOption(
Option.builder("1")
.longOpt("ones")
.hasArg(false)
.argName("ones")
.desc("Use all ones instead of zeros.")
.build());
return options;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class BenchmarkTest {
@Test
public void testHelp() {
String[] args = {"-h"};
new Benchmark().runBenchmark(args);
Benchmark.main(args);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.benchmark;

import org.testng.Assert;
import org.testng.annotations.Test;

public class NDListGeneratorTest {

@Test
public void testHelp() {
String[] args = {"ndlist-gen", "-h"};
Benchmark.main(args);
}

@Test
public void testMissingOptions() {
String[] args = {"ndlist-gen", "-s"};
boolean success = NDListGenerator.generate(args);
Assert.assertFalse(success);
}

@Test
public void testOnes() {
String[] args = {"ndlist-gen", "-s", "1", "-o", "build/ones.ndlist", "-1"};
boolean success = NDListGenerator.generate(args);
Assert.assertTrue(success);
}

@Test
public void testZeros() {
String[] args = {"ndlist-gen", "-s", "1", "-o", "build/ones.ndlist"};
boolean success = NDListGenerator.generate(args);
Assert.assertTrue(success);
}
}

0 comments on commit 520f90f

Please sign in to comment.