From 9fcef5cbf909121000a70ca8dfe00d4424515cb2 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 8 Aug 2021 14:36:00 -0700 Subject: [PATCH] [djl-bench] Add a tool to genreate NDList file Change-Id: I37c1fe468a5b4784a83e9be4ce62f9b6f10e2b49 --- .../ai/djl/benchmark/AbstractBenchmark.java | 13 +- .../main/java/ai/djl/benchmark/Arguments.java | 63 ++----- .../main/java/ai/djl/benchmark/Benchmark.java | 18 +- .../ai/djl/benchmark/NDListGenerator.java | 168 ++++++++++++++++++ .../java/ai/djl/benchmark/BenchmarkTest.java | 2 +- .../ai/djl/benchmark/NDListGeneratorTest.java | 46 +++++ 6 files changed, 240 insertions(+), 70 deletions(-) create mode 100644 extensions/benchmark/src/main/java/ai/djl/benchmark/NDListGenerator.java create mode 100644 extensions/benchmark/src/test/java/ai/djl/benchmark/NDListGeneratorTest.java diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/AbstractBenchmark.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/AbstractBenchmark.java index 0520ae7368f..9c76f523402 100644 --- a/extensions/benchmark/src/main/java/ai/djl/benchmark/AbstractBenchmark.java +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/AbstractBenchmark.java @@ -34,7 +34,6 @@ import java.time.Duration; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.DefaultParser; -import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.slf4j.Logger; @@ -72,7 +71,8 @@ public final boolean runBenchmark(String[] args) { Options options = Arguments.getOptions(); try { if (Arguments.hasHelp(args)) { - printHelp("djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS]", options); + Arguments.printHelp( + "usage: djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS]", options); return true; } DefaultParser parser = new DefaultParser(); @@ -235,7 +235,7 @@ public final boolean runBenchmark(String[] args) { } return true; } catch (ParseException e) { - printHelp(e.getMessage(), options); + Arguments.printHelp(e.getMessage(), options); } catch (Throwable t) { logger.error("Unexpected error", t); } @@ -271,13 +271,6 @@ protected ZooModel loadModel(Arguments arguments, Metrics metrics return model; } - private void printHelp(String msg, Options options) { - HelpFormatter formatter = new HelpFormatter(); - formatter.setLeftPadding(1); - formatter.setWidth(120); - formatter.printHelp(msg, options); - } - private static final class BenchmarkTranslator implements Translator { private PairList shapes; diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java index 059dcafe2e4..6ac7ee453f6 100644 --- a/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java @@ -21,9 +21,8 @@ import java.nio.file.Paths; import java.util.Arrays; import java.util.List; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Option; import org.apache.commons.cli.OptionGroup; import org.apache.commons.cli.Options; @@ -62,7 +61,6 @@ public class Arguments { modelName = cmd.getOptionValue("model-name"); outputDir = cmd.getOptionValue("output-dir"); - inputShapes = new PairList<>(); if (cmd.hasOption("engine")) { engine = cmd.getOptionValue("engine"); @@ -96,56 +94,7 @@ public class Arguments { } String shape = cmd.getOptionValue("input-shapes"); - if (shape != null) { - if (shape.contains("(")) { - Pattern pattern = - Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)([sdubilBfS]?)"); - Matcher matcher = pattern.matcher(shape); - while (matcher.find()) { - String[] tokens = matcher.group(1).split(","); - long[] array = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray(); - DataType dataType; - String dataTypeStr = matcher.group(4); - if (dataTypeStr == null || dataTypeStr.isEmpty()) { - dataType = DataType.FLOAT32; - } else { - switch (dataTypeStr) { - case "s": - dataType = DataType.FLOAT16; - break; - case "d": - dataType = DataType.FLOAT64; - break; - case "u": - dataType = DataType.UINT8; - break; - case "b": - dataType = DataType.INT8; - break; - case "i": - dataType = DataType.INT32; - break; - case "l": - dataType = DataType.INT64; - break; - case "B": - dataType = DataType.BOOLEAN; - break; - case "f": - dataType = DataType.FLOAT32; - break; - default: - throw new IllegalArgumentException("Invalid input-shape: " + shape); - } - } - inputShapes.add(dataType, new Shape(array)); - } - } else { - String[] tokens = shape.split(","); - long[] shapes = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray(); - inputShapes.add(DataType.FLOAT32, new Shape(shapes)); - } - } + inputShapes = NDListGenerator.parseShape(shape); } static Options getOptions() { @@ -241,6 +190,14 @@ static boolean hasHelp(String[] args) { return list.contains("-h") || list.contains("help"); } + static void printHelp(String msg, Options options) { + HelpFormatter formatter = new HelpFormatter(); + formatter.setSyntaxPrefix(""); + formatter.setLeftPadding(1); + formatter.setWidth(120); + formatter.printHelp(msg, options); + } + int getDuration() { return duration; } diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java index 91cebbcd30c..4b906c1e963 100644 --- a/extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java @@ -43,15 +43,21 @@ public static void main(String[] args) { return; } List list = Arrays.asList(args); - boolean multithreading = list.contains("-t") || list.contains("--threads"); - configEngines(multithreading); boolean success; - if (multithreading) { - success = new MultithreadedBenchmark().runBenchmark(args); + if (!list.isEmpty() && "ndlist-gen".equals(list.get(0))) { + success = NDListGenerator.generate(Arrays.copyOfRange(args, 1, args.length)); } else { - success = new Benchmark().runBenchmark(args); + boolean multithreading = list.contains("-t") || list.contains("--threads"); + configEngines(multithreading); + if (multithreading) { + success = new MultithreadedBenchmark().runBenchmark(args); + } else { + success = new Benchmark().runBenchmark(args); + } + } + if (!success) { + System.exit(-1); // NOPMD } - System.exit(success ? 0 : -1); // NOPMD } /** {@inheritDoc} */ diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/NDListGenerator.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/NDListGenerator.java new file mode 100644 index 00000000000..cd460c34c9e --- /dev/null +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/NDListGenerator.java @@ -0,0 +1,168 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.benchmark; + +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.util.Pair; +import ai.djl.util.PairList; +import java.io.BufferedOutputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A class generates NDList files. */ +final class NDListGenerator { + + private static final Logger logger = LoggerFactory.getLogger(NDListGenerator.class); + + private NDListGenerator() {} + + static boolean generate(String[] args) { + Options options = getOptions(); + try { + if (Arguments.hasHelp(args)) { + Arguments.printHelp( + "usage: djl-bench ndlist-gen -s INPUT-SHAPES -o OUTPUT_FILE", options); + return true; + } + DefaultParser parser = new DefaultParser(); + CommandLine cmd = parser.parse(options, args, null, false); + String inputShapes = cmd.getOptionValue("input-shapes"); + String output = cmd.getOptionValue("output-file"); + boolean ones = cmd.hasOption("ones"); + Path path = Paths.get(output); + + try (NDManager manager = NDManager.newBaseManager()) { + NDList list = new NDList(); + for (Pair pair : parseShape(inputShapes)) { + DataType dataType = pair.getKey(); + Shape shape = pair.getValue(); + if (ones) { + list.add(manager.ones(shape, dataType)); + } else { + list.add(manager.zeros(shape, dataType)); + } + } + try (OutputStream os = new BufferedOutputStream(Files.newOutputStream(path))) { + list.encode(os); + } + } + logger.info("NDList file created: {}", path.toAbsolutePath()); + return true; + } catch (ParseException e) { + Arguments.printHelp(e.getMessage(), options); + } catch (Throwable t) { + logger.error("Unexpected error", t); + } + return false; + } + + static PairList parseShape(String shape) { + PairList inputShapes = new PairList<>(); + if (shape != null) { + if (shape.contains("(")) { + Pattern pattern = + Pattern.compile("\\((\\s*(\\d+)([,\\s]+\\d+)*\\s*)\\)([sdubilBfS]?)"); + Matcher matcher = pattern.matcher(shape); + while (matcher.find()) { + String[] tokens = matcher.group(1).split(","); + long[] array = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray(); + DataType dataType; + String dataTypeStr = matcher.group(4); + if (dataTypeStr == null || dataTypeStr.isEmpty()) { + dataType = DataType.FLOAT32; + } else { + switch (dataTypeStr) { + case "s": + dataType = DataType.FLOAT16; + break; + case "d": + dataType = DataType.FLOAT64; + break; + case "u": + dataType = DataType.UINT8; + break; + case "b": + dataType = DataType.INT8; + break; + case "i": + dataType = DataType.INT32; + break; + case "l": + dataType = DataType.INT64; + break; + case "B": + dataType = DataType.BOOLEAN; + break; + case "f": + dataType = DataType.FLOAT32; + break; + default: + throw new IllegalArgumentException("Invalid input-shape: " + shape); + } + } + inputShapes.add(dataType, new Shape(array)); + } + } else { + String[] tokens = shape.split(","); + long[] shapes = Arrays.stream(tokens).mapToLong(Long::parseLong).toArray(); + inputShapes.add(DataType.FLOAT32, new Shape(shapes)); + } + } + return inputShapes; + } + + private static Options getOptions() { + Options options = new Options(); + options.addOption( + Option.builder("h").longOpt("help").hasArg(false).desc("Print this help.").build()); + options.addOption( + Option.builder("s") + .required() + .longOpt("input-shapes") + .hasArg() + .argName("INPUT-SHAPES") + .desc("Input data shapes for the model.") + .build()); + options.addOption( + Option.builder("o") + .required() + .longOpt("output-file") + .hasArg() + .argName("OUTPUT-FILE") + .desc("Write output NDList to file.") + .build()); + options.addOption( + Option.builder("1") + .longOpt("ones") + .hasArg(false) + .argName("ones") + .desc("Use all ones instead of zeros.") + .build()); + return options; + } +} diff --git a/extensions/benchmark/src/test/java/ai/djl/benchmark/BenchmarkTest.java b/extensions/benchmark/src/test/java/ai/djl/benchmark/BenchmarkTest.java index 21ac5721286..ebe5989e2d1 100644 --- a/extensions/benchmark/src/test/java/ai/djl/benchmark/BenchmarkTest.java +++ b/extensions/benchmark/src/test/java/ai/djl/benchmark/BenchmarkTest.java @@ -27,7 +27,7 @@ public class BenchmarkTest { @Test public void testHelp() { String[] args = {"-h"}; - new Benchmark().runBenchmark(args); + Benchmark.main(args); } @Test diff --git a/extensions/benchmark/src/test/java/ai/djl/benchmark/NDListGeneratorTest.java b/extensions/benchmark/src/test/java/ai/djl/benchmark/NDListGeneratorTest.java new file mode 100644 index 00000000000..50964a9a1f5 --- /dev/null +++ b/extensions/benchmark/src/test/java/ai/djl/benchmark/NDListGeneratorTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.benchmark; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class NDListGeneratorTest { + + @Test + public void testHelp() { + String[] args = {"ndlist-gen", "-h"}; + Benchmark.main(args); + } + + @Test + public void testMissingOptions() { + String[] args = {"ndlist-gen", "-s"}; + boolean success = NDListGenerator.generate(args); + Assert.assertFalse(success); + } + + @Test + public void testOnes() { + String[] args = {"ndlist-gen", "-s", "1", "-o", "build/ones.ndlist", "-1"}; + boolean success = NDListGenerator.generate(args); + Assert.assertTrue(success); + } + + @Test + public void testZeros() { + String[] args = {"ndlist-gen", "-s", "1", "-o", "build/ones.ndlist"}; + boolean success = NDListGenerator.generate(args); + Assert.assertTrue(success); + } +}