From a7ce5a00dbef172700e0e0822a5c114d74a3569d Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 17 Sep 2021 23:44:12 -0700 Subject: [PATCH] Add tfhub url support Fixes #1122 Change-Id: Ia1a2fafc502cb07878ed23dea66f1914b8b3159a --- .../djl/repository/RepositoryFactoryImpl.java | 18 +++++++ .../djl/repository/TfhubRepositoryTest.java | 28 ++++++++++ extensions/benchmark/README.md | 51 +++++++++++++------ 3 files changed, 82 insertions(+), 15 deletions(-) create mode 100644 api/src/test/java/ai/djl/repository/TfhubRepositoryTest.java diff --git a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java index 39feafde132..404cdb21e3e 100644 --- a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java +++ b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java @@ -23,6 +23,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Collections; +import java.util.Locale; import java.util.Map; import java.util.ServiceLoader; import java.util.Set; @@ -59,6 +60,23 @@ public Repository newInstance(String name, URI uri) { throw new IllegalArgumentException("Malformed URL: " + uri, e); } + if ("tfhub.dev".equals(uri.getHost().toLowerCase(Locale.ROOT))) { + // Handle tfhub case + String path = uri.getPath(); + if (path.endsWith("/")) { + path = path.substring(0, path.length() - 1); + } + path = "/tfhub-modules" + path + ".tar.gz"; + try { + uri = new URI("https", null, "storage.googleapis.com", -1, path, null, null); + } catch (URISyntaxException e) { + throw new IllegalArgumentException("Failed to append query string: " + uri, e); + } + String[] tokens = path.split("/"); + String modelName = tokens[tokens.length - 2]; + return new SimpleUrlRepository(name, uri, modelName); + } + Path path = parseFilePath(uri); String fileName = path.toFile().getName(); if (FilenameUtils.isArchiveFile(fileName)) { diff --git a/api/src/test/java/ai/djl/repository/TfhubRepositoryTest.java b/api/src/test/java/ai/djl/repository/TfhubRepositoryTest.java new file mode 100644 index 00000000000..42582a19b0a --- /dev/null +++ b/api/src/test/java/ai/djl/repository/TfhubRepositoryTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.repository; + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class TfhubRepositoryTest { + + @Test + public void testResource() { + Repository repo = + Repository.newInstance( + "tfhub", + "https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1/"); + Assert.assertEquals(repo.getResources().size(), 1); + } +} diff --git a/extensions/benchmark/README.md b/extensions/benchmark/README.md index 6b60c7659c1..b00362db4fd 100644 --- a/extensions/benchmark/README.md +++ b/extensions/benchmark/README.md @@ -21,6 +21,7 @@ djl-bench currently support benchmark the following type of models: - ONNX model - PaddlePaddle model - TFLite model +- TensorRT model - Neo DLR (TVM) model - XGBoost model @@ -50,12 +51,24 @@ curl -O https://publish.djl.ai/djl-bench/0.12.0/djl-bench_0.12.0-1_all.deb sudo dpkg -i djl-bench_0.12.0-1_all.deb ``` +For centOS or Amazon Linux 2 + +You can download djl-bench zip file from [here](https://publish.djl.ai/djl-bench/0.12.0/benchmark-0.12.0.zip). + +``` +curl -O https://publish.djl.ai/djl-bench/0.12.0/benchmark-0.12.0.zip +unzip benchmark-0.12.0.zip +rm benchmark-0.12.0.zip +sudo ln -s $PWD/benchmark-0.12.0/bin/benchmark /usr/bin/djl-bench +``` + For Windows We are considering to create a `chocolatey` package for Windows. For the time being, you can download djl-bench zip file from [here](https://publish.djl.ai/djl-bench/0.12.0/benchmark-0.12.0.zip). Or you can run benchmark using gradle: + ``` cd djl @@ -87,10 +100,10 @@ they have different CUDA version to support. Please check the individual Engine Here is a few sample benchmark script for you to refer. You can also skip this and directly follow the 4-step instructions for your own model. -Benchmark on a Tensorflow model from http url with all-ones NDArray input for 10 times: +Benchmark on a Tensorflow model from [tfhub](https://tfhub.dev/) url with all-zeros NDArray input for 10 times: ``` -djl-bench -e TensorFlow -u https://storage.googleapis.com/tfhub-modules/tensorflow/resnet_50/classification/1.tar.gz -c 10 -s 1,224,224,3 +djl-bench -e TensorFlow -u https://tfhub.dev/tensorflow/resnet_50/classification/1 -c 10 -s 1,224,224,3 ``` Similarly, this is for PyTorch @@ -99,6 +112,12 @@ Similarly, this is for PyTorch djl-bench -e PyTorch -u https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip -n traced_resnet18 -c 10 -s 1,3,224,224 ``` +Benchmark a model from [ONNX Model Zoo](https://github.com/onnx/models) + +``` +djl-bench -e OnnxRuntime -u https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v1/resnet18v1.tar.gz -s 1,3,224,224 -n resnet18v1/resnet18v1 -c 10 +``` + ### Benchmark from ModelZoo (Only available in 0.13.0+) #### MXNet @@ -117,7 +136,6 @@ SSD object detection model: djl-bench -e PyTorch -c 2 -s 1,3,300,300 -u djl://ai.djl.pytorch/ssd/0.0.1/ssd_300_resnet50 ``` - ## Configuration of Benchmark script To start your benchmarking, we need to make sure we provide the following information. @@ -140,18 +158,20 @@ This will print out the possible arguments to pass in: ``` usage: djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS] - -c,--iteration Number of total iterations (per thread). - -d,--duration Duration of the test in minutes. - -e,--engine Choose an Engine for the benchmark. - -g,--gpus Number of GPUS to run multithreading inference. - -h,--help Print this help. - -l,--delay Delay of incremental threads. - -n,--model-name Specify model file name. - -o,--output-dir Directory for output logs. - -p,--model-path Model directory file path. - -s,--input-shapes Input data shapes for the model. - -t,--threads Number of inference threads. - -u,--model-url Model archive file URL. + -c,--iteration Number of total iterations. + -d,--duration Duration of the test in minutes. + -e,--engine Choose an Engine for the benchmark. + -g,--gpus Number of GPUS to run multithreading inference. + -h,--help Print this help. + -l,--delay Delay of incremental threads. + --model-arguments Specify model loading arguments. + --model-options Specify model loading options. + -n,--model-name Specify model file name. + -o,--output-dir Directory for output logs. + -p,--model-path Model directory file path. + -s,--input-shapes Input data shapes for the model. + -t,--threads Number of inference threads. + -u,--model-url Model archive file URL. ``` ### Step 1: Pick your deep engine @@ -165,6 +185,7 @@ By default, the above script will use MXNet as the default Engine, but you can a -e PaddlePaddle # PaddlePaddle -e OnnxRuntime # pytorch -e TFLite # TFLite +-e TensorRT # TensorRT -e DLR # Neo DLR -e XGBoost # XGBoost ```