Skip to content

Commit

Permalink
Add tfhub url support
Browse files Browse the repository at this point in the history
Fixes deepjavalibrary#1122

Change-Id: Ia1a2fafc502cb07878ed23dea66f1914b8b3159a
  • Loading branch information
frankfliu committed Sep 18, 2021
1 parent 6defabb commit adc647e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
18 changes: 18 additions & 0 deletions api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
Expand Down Expand Up @@ -59,6 +60,23 @@ public Repository newInstance(String name, URI uri) {
throw new IllegalArgumentException("Malformed URL: " + uri, e);
}

if ("tfhub.dev".equals(uri.getHost().toLowerCase(Locale.ROOT))) {
// Handle tfhub case
String path = uri.getPath();
if (path.endsWith("/")) {
path = path.substring(0, path.length() - 1);
}
path = "/tfhub-modules" + path + ".tar.gz";
try {
uri = new URI("https", null, "storage.googleapis.com", -1, path, null, null);
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Failed to append query string: " + uri, e);
}
String[] tokens = path.split("/");
String modelName = tokens[tokens.length - 2];
return new SimpleUrlRepository(name, uri, modelName);
}

Path path = parseFilePath(uri);
String fileName = path.toFile().getName();
if (FilenameUtils.isArchiveFile(fileName)) {
Expand Down
28 changes: 28 additions & 0 deletions api/src/test/java/ai/djl/repository/TfhubRepositoryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.repository;

import org.testng.Assert;
import org.testng.annotations.Test;

public class TfhubRepositoryTest {

@Test
public void testResource() {
Repository repo =
Repository.newInstance(
"tfhub",
"https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1/");
Assert.assertEquals(repo.getResources().size(), 1);
}
}
4 changes: 2 additions & 2 deletions extensions/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ they have different CUDA version to support. Please check the individual Engine
Here is a few sample benchmark script for you to refer. You can also skip this and directly follow
the 4-step instructions for your own model.

Benchmark on a Tensorflow model from http url with all-ones NDArray input for 10 times:
Benchmark on a Tensorflow model from [tfhub](https://tfhub.dev/) url with all-zeros NDArray input for 10 times:

```
djl-bench -e TensorFlow -u https://storage.googleapis.com/tfhub-modules/tensorflow/resnet_50/classification/1.tar.gz -c 10 -s 1,224,224,3
djl-bench -e TensorFlow -u https://tfhub.dev/tensorflow/resnet_50/classification/1 -c 10 -s 1,224,224,3
```

Similarly, this is for PyTorch
Expand Down

0 comments on commit adc647e

Please sign in to comment.