From 154ae6e2f8e4b3c253ade3ad3be60055eb8d038d Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Tue, 20 Jul 2021 15:05:41 -0700 Subject: [PATCH] Move version into MRL (#1114) Change-Id: Iafa45f59cf966d8134b2cc127d1243767513132b --- .../java/ai/djl/repository/JarRepository.java | 6 +- .../ai/djl/repository/LocalRepository.java | 9 +- api/src/main/java/ai/djl/repository/MRL.java | 153 ++++++++++++++++-- .../ai/djl/repository/RemoteRepository.java | 5 +- .../java/ai/djl/repository/Repository.java | 54 ++++++- .../main/java/ai/djl/repository/Resource.java | 152 ----------------- .../ai/djl/repository/SimpleRepository.java | 7 +- .../djl/repository/SimpleUrlRepository.java | 7 +- .../djl/repository/zoo/BaseModelLoader.java | 28 ++-- .../djl/repository/zoo/DefaultModelZoo.java | 2 +- .../ai/djl/repository/JarRepositoryTest.java | 2 +- .../djl/basicdataset/cv/BananaDetection.java | 16 +- .../ai/djl/basicdataset/cv/CocoDetection.java | 19 ++- .../ai/djl/basicdataset/cv/ImageDataset.java | 1 + .../djl/basicdataset/cv/PikachuDetection.java | 16 +- .../classification/AbstractImageFolder.java | 8 +- .../cv/classification/CaptchaDataset.java | 24 +-- .../cv/classification/Cifar10.java | 20 ++- .../cv/classification/FashionMnist.java | 23 +-- .../cv/classification/ImageFolder.java | 6 +- .../cv/classification/ImageNet.java | 4 +- .../basicdataset/cv/classification/Mnist.java | 23 +-- .../ai/djl/basicdataset/nlp/AmazonReview.java | 18 ++- .../nlp/CookingStackExchange.java | 17 +- .../basicdataset/nlp/StanfordMovieReview.java | 14 +- .../nlp/TatoebaEnglishFrenchDataset.java | 14 +- .../ai/djl/basicdataset/nlp/TextDataset.java | 4 +- .../tabular/AirfoilRandomAccess.java | 19 ++- .../tabular/AmesRandomAccess.java | 17 +- .../main/java/ai/djl/dlr/zoo/DlrModelZoo.java | 4 +- .../main/java/ai/djl/aws/s3/S3Repository.java | 7 +- .../java/ai/djl/aws/s3/S3RepositoryTest.java | 6 +- .../TextClassificationModelLoader.java | 9 +- .../ai/djl/hadoop/hdfs/HdfsRepository.java | 7 +- .../djl/hadoop/hdfs/HdfsRepositoryTest.java | 8 +- .../ai/djl/basicmodelzoo/BasicModelZoo.java | 12 +- .../java/ai/djl/mxnet/zoo/MxModelZoo.java | 86 +++++----- .../ai/djl/onnxruntime/zoo/OrtModelZoo.java | 5 +- .../ai/djl/paddlepaddle/zoo/PpModelZoo.java | 25 +-- .../java/ai/djl/pytorch/zoo/PtModelZoo.java | 21 +-- .../ai/djl/tensorflow/zoo/TfModelZoo.java | 12 +- .../ai/djl/tflite/zoo/TfLiteModelZoo.java | 4 +- 42 files changed, 481 insertions(+), 413 deletions(-) delete mode 100644 api/src/main/java/ai/djl/repository/Resource.java diff --git a/api/src/main/java/ai/djl/repository/JarRepository.java b/api/src/main/java/ai/djl/repository/JarRepository.java index 9e139dfaa5c..f4be865286d 100644 --- a/api/src/main/java/ai/djl/repository/JarRepository.java +++ b/api/src/main/java/ai/djl/repository/JarRepository.java @@ -76,7 +76,7 @@ public Metadata locate(MRL mrl) { /** {@inheritDoc} */ @Override - public Artifact resolve(MRL mrl, String version, Map filter) { + public Artifact resolve(MRL mrl, Map filter) { List artifacts = locate(mrl).getArtifacts(); if (artifacts.isEmpty()) { return null; @@ -89,7 +89,7 @@ public Artifact resolve(MRL mrl, String version, Map filter) { public List getResources() { Metadata m = getMetadata(); if (m != null && !m.getArtifacts().isEmpty()) { - MRL mrl = MRL.undefined(m.getGroupId(), m.getArtifactId()); + MRL mrl = MRL.undefined(this, m.getGroupId(), m.getArtifactId()); return Collections.singletonList(mrl); } return Collections.emptyList(); @@ -128,7 +128,7 @@ private synchronized Metadata getMetadata() { metadata.setArtifactId(artifactId); metadata.setArtifacts(Collections.singletonList(artifact)); String hash = md5hash(uri.toString()); - MRL mrl = MRL.model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); + MRL mrl = model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); metadata.setRepositoryUri(mrl.toURI()); return metadata; diff --git a/api/src/main/java/ai/djl/repository/LocalRepository.java b/api/src/main/java/ai/djl/repository/LocalRepository.java index c7ab66a3b59..432453928f0 100644 --- a/api/src/main/java/ai/djl/repository/LocalRepository.java +++ b/api/src/main/java/ai/djl/repository/LocalRepository.java @@ -87,10 +87,9 @@ public Metadata locate(MRL mrl) throws IOException { /** {@inheritDoc} */ @Override - public Artifact resolve(MRL mrl, String version, Map filter) - throws IOException { + public Artifact resolve(MRL mrl, Map filter) throws IOException { Metadata metadata = locate(mrl); - VersionRange range = VersionRange.parse(version); + VersionRange range = VersionRange.parse(mrl.getVersion()); List artifacts = metadata.search(range, filter); if (artifacts.isEmpty()) { return null; @@ -117,9 +116,9 @@ public List getResources() { String groupId = metadata.getGroupId(); String artifactId = metadata.getArtifactId(); if ("dataset".equals(type)) { - list.add(MRL.dataset(application, groupId, artifactId)); + list.add(dataset(application, groupId, artifactId)); } else if ("model".equals(type)) { - list.add(MRL.model(application, groupId, artifactId)); + list.add(model(application, groupId, artifactId)); } } catch (IOException e) { logger.warn("Failed to read metadata.json", e); diff --git a/api/src/main/java/ai/djl/repository/MRL.java b/api/src/main/java/ai/djl/repository/MRL.java index 3c7e7795dd6..16aadb85c08 100644 --- a/api/src/main/java/ai/djl/repository/MRL.java +++ b/api/src/main/java/ai/djl/repository/MRL.java @@ -13,7 +13,13 @@ package ai.djl.repository; import ai.djl.Application; +import ai.djl.util.Progress; +import java.io.IOException; import java.net.URI; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * The {@code MRL} (Machine learning Resource Locator) is a pointer to a {@link Metadata} "resource" @@ -36,59 +42,89 @@ */ public final class MRL { + private static final Logger logger = LoggerFactory.getLogger(MRL.class); + private String type; private Application application; private String groupId; private String artifactId; + private String version; + private Repository repository; + private Metadata metadata; /** * Constructs an MRL. * + * @param repository the {@link Repository} * @param type the resource type * @param application the resource application * @param groupId the desired groupId * @param artifactId the desired artifactId + * @param version the resource version */ - private MRL(String type, Application application, String groupId, String artifactId) { + private MRL( + Repository repository, + String type, + Application application, + String groupId, + String artifactId, + String version) { + this.repository = repository; this.type = type; this.application = application; this.groupId = groupId; this.artifactId = artifactId; + this.version = version; } /** * Creates a model {@code MRL} with specified application. * + * @param repository the {@link Repository} * @param application the desired application * @param groupId the desired groupId * @param artifactId the desired artifactId + * @param version the resource version * @return a model {@code MRL} */ - public static MRL model(Application application, String groupId, String artifactId) { - return new MRL("model", application, groupId, artifactId); + public static MRL model( + Repository repository, + Application application, + String groupId, + String artifactId, + String version) { + return new MRL(repository, "model", application, groupId, artifactId, version); } /** * Creates a dataset {@code MRL} with specified application. * + * @param repository the {@link Repository} * @param application the desired application * @param groupId the desired groupId * @param artifactId the desired artifactId + * @param version the resource version * @return a dataset {@code MRL} */ - public static MRL dataset(Application application, String groupId, String artifactId) { - return new MRL("dataset", application, groupId, artifactId); + public static MRL dataset( + Repository repository, + Application application, + String groupId, + String artifactId, + String version) { + return new MRL(repository, "dataset", application, groupId, artifactId, version); } /** * Creates a dataset {@code MRL} with specified application. * + * @param repository the {@link Repository} * @param groupId the desired groupId * @param artifactId the desired artifactId * @return a dataset {@code MRL} */ - public static MRL undefined(String groupId, String artifactId) { - return new MRL("", Application.UNDEFINED, groupId, artifactId); + public static MRL undefined(Repository repository, String groupId, String artifactId) { + return new MRL(repository, "", Application.UNDEFINED, groupId, artifactId, null); } /** @@ -112,9 +148,18 @@ public URI toURI() { } /** - * Returns the resource application. + * Returns the repository. + * + * @return the repository + */ + public Repository getRepository() { + return repository; + } + + /** + * Returns the application. * - * @return the resource application + * @return the application */ public Application getApplication() { return application; @@ -138,6 +183,96 @@ public String getArtifactId() { return artifactId; } + /** + * Returns the version. + * + * @return the version + */ + public String getVersion() { + return version; + } + + /** + * Returns the default artifact. + * + * @return the default artifact + * @throws IOException for various exceptions depending on the specific dataset + */ + public Artifact getDefaultArtifact() throws IOException { + return repository.resolve(this, null); + } + + /** + * Returns the first artifact that matches a given criteria. + * + * @param criteria the criteria to match against + * @return the first artifact that matches the criteria. Null will be returned if no artifact + * matches + * @throws IOException for errors while loading the model + */ + public Artifact match(Map criteria) throws IOException { + List list = search(criteria); + if (list.isEmpty()) { + return null; + } + return list.get(0); + } + + /** + * Returns a list of artifacts in this resource. + * + * @return a list of artifacts in this resource + * @throws IOException for errors while loading the model + */ + public List listArtifacts() throws IOException { + return getMetadata().getArtifacts(); + } + + /** + * Prepares the artifact for use. + * + * @param artifact the artifact to prepare + * @throws IOException if it failed to prepare + */ + public void prepare(Artifact artifact) throws IOException { + prepare(artifact, null); + } + + /** + * Prepares the artifact for use with progress tracking. + * + * @param artifact the artifact to prepare + * @param progress the progress tracker + * @throws IOException if it failed to prepare + */ + public void prepare(Artifact artifact, Progress progress) throws IOException { + if (artifact != null) { + logger.debug("Preparing artifact: {}, {}", repository.getName(), artifact); + repository.prepare(artifact, progress); + } + } + + /** + * Returns all the artifacts that match a given criteria. + * + * @param criteria the criteria to match against + * @return all the artifacts that match a given criteria + * @throws IOException for errors while loading the model + */ + private List search(Map criteria) throws IOException { + return getMetadata().search(VersionRange.parse(version), criteria); + } + + private Metadata getMetadata() throws IOException { + if (metadata == null) { + metadata = repository.locate(this); + if (metadata == null) { + throw new IOException(this + " resource not found."); + } + } + return metadata; + } + /** {@inheritDoc} */ @Override public String toString() { diff --git a/api/src/main/java/ai/djl/repository/RemoteRepository.java b/api/src/main/java/ai/djl/repository/RemoteRepository.java index 8ae28de168e..6f78c292221 100644 --- a/api/src/main/java/ai/djl/repository/RemoteRepository.java +++ b/api/src/main/java/ai/djl/repository/RemoteRepository.java @@ -111,10 +111,9 @@ public Metadata locate(MRL mrl) throws IOException { /** {@inheritDoc} */ @Override - public Artifact resolve(MRL mrl, String version, Map filter) - throws IOException { + public Artifact resolve(MRL mrl, Map filter) throws IOException { Metadata metadata = locate(mrl); - VersionRange range = VersionRange.parse(version); + VersionRange range = VersionRange.parse(mrl.getVersion()); List artifacts = metadata.search(range, filter); if (artifacts.isEmpty()) { return null; diff --git a/api/src/main/java/ai/djl/repository/Repository.java b/api/src/main/java/ai/djl/repository/Repository.java index e2fcf482ae6..c3e9398ac7e 100644 --- a/api/src/main/java/ai/djl/repository/Repository.java +++ b/api/src/main/java/ai/djl/repository/Repository.java @@ -98,6 +98,57 @@ static void registerRepositoryFactory(RepositoryFactory factory) { RepositoryFactoryImpl.registerRepositoryFactory(factory); } + /** + * Creates a model {@code MRL} with specified application. + * + * @param application the desired application + * @param groupId the desired groupId + * @param artifactId the desired artifactId + * @return a model {@code MRL} + */ + default MRL model(Application application, String groupId, String artifactId) { + return model(application, groupId, artifactId, null); + } + + /** + * Creates a model {@code MRL} with specified application. + * + * @param application the desired application + * @param groupId the desired groupId + * @param artifactId the desired artifactId + * @param version the resource version + * @return a model {@code MRL} + */ + default MRL model(Application application, String groupId, String artifactId, String version) { + return MRL.model(this, application, groupId, artifactId, version); + } + + /** + * Creates a dataset {@code MRL} with specified application. + * + * @param application the desired application + * @param groupId the desired groupId + * @param artifactId the desired artifactId + * @return a dataset {@code MRL} + */ + default MRL dataset(Application application, String groupId, String artifactId) { + return dataset(application, groupId, artifactId, null); + } + + /** + * Creates a dataset {@code MRL} with specified application. + * + * @param application the desired application + * @param groupId the desired groupId + * @param artifactId the desired artifactId + * @param version the resource version + * @return a dataset {@code MRL} + */ + default MRL dataset( + Application application, String groupId, String artifactId, String version) { + return MRL.dataset(this, application, groupId, artifactId, version); + } + /** * Returns whether the repository is remote repository. * @@ -132,12 +183,11 @@ static void registerRepositoryFactory(RepositoryFactory factory) { * Returns the artifact matching a mrl, version, and property filter. * * @param mrl the mrl to match the artifact against - * @param version the version of the artifact * @param filter the property filter * @return the matched artifact * @throws IOException if it failed to load the artifact */ - Artifact resolve(MRL mrl, String version, Map filter) throws IOException; + Artifact resolve(MRL mrl, Map filter) throws IOException; /** * Returns an {@link InputStream} for an item in a repository. diff --git a/api/src/main/java/ai/djl/repository/Resource.java b/api/src/main/java/ai/djl/repository/Resource.java deleted file mode 100644 index 8d4ce51107e..00000000000 --- a/api/src/main/java/ai/djl/repository/Resource.java +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.repository; - -import ai.djl.util.Progress; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A class represents a resource in a {@link Repository}. */ -public class Resource { - - private static final Logger logger = LoggerFactory.getLogger(Resource.class); - - private Repository repository; - private MRL mrl; - private String version; - private Metadata metadata; - - /** - * Constructs a {@code Resource} instance. - * - * @param repository the {@link Repository} - * @param mrl the resource locator - * @param version the version of the resource - */ - public Resource(Repository repository, MRL mrl, String version) { - this.repository = repository; - this.mrl = mrl; - this.version = version; - } - - /** - * Returns the {@link Repository} of the resource. - * - * @return the {@link Repository} of the resource - */ - public Repository getRepository() { - return repository; - } - - /** - * Returns the {@link MRL} of the resource. - * - * @return the {@link MRL} of the resource - */ - public MRL getMrl() { - return mrl; - } - - /** - * Returns the version of the resource. - * - * @return the version of the resource - */ - public String getVersion() { - return version; - } - - /** - * Returns the default artifact. - * - * @return the default artifact - * @throws IOException for various exceptions depending on the specific dataset - */ - public Artifact getDefaultArtifact() throws IOException { - return repository.resolve(mrl, version, null); - } - - /** - * Returns the first artifact that matches a given criteria. - * - * @param criteria the criteria to match against - * @return the first artifact that matches the criteria. Null will be returned if no artifact - * matches - * @throws IOException for errors while loading the model - */ - public Artifact match(Map criteria) throws IOException { - List list = search(criteria); - if (list.isEmpty()) { - return null; - } - return list.get(0); - } - - /** - * Returns a list of artifacts in this resource. - * - * @return a list of artifacts in this resource - * @throws IOException for errors while loading the model - */ - public List listArtifacts() throws IOException { - return getMetadata().getArtifacts(); - } - - /** - * Prepares the artifact for use. - * - * @param artifact the artifact to prepare - * @throws IOException if it failed to prepare - */ - public void prepare(Artifact artifact) throws IOException { - prepare(artifact, null); - } - - /** - * Prepares the artifact for use with progress tracking. - * - * @param artifact the artifact to prepare - * @param progress the progress tracker - * @throws IOException if it failed to prepare - */ - public void prepare(Artifact artifact, Progress progress) throws IOException { - if (artifact != null) { - logger.debug("Preparing artifact: {}, {}", repository.getName(), artifact); - repository.prepare(artifact, progress); - } - } - - /** - * Returns all the artifacts that match a given criteria. - * - * @param criteria the criteria to match against - * @return all the artifacts that match a given criteria - * @throws IOException for errors while loading the model - */ - private List search(Map criteria) throws IOException { - return getMetadata().search(VersionRange.parse(version), criteria); - } - - private Metadata getMetadata() throws IOException { - if (metadata == null) { - metadata = repository.locate(mrl); - if (metadata == null) { - throw new IOException("MRL: " + mrl + " resource not found."); - } - } - return metadata; - } -} diff --git a/api/src/main/java/ai/djl/repository/SimpleRepository.java b/api/src/main/java/ai/djl/repository/SimpleRepository.java index 9b3c5b497d5..1c66566734f 100644 --- a/api/src/main/java/ai/djl/repository/SimpleRepository.java +++ b/api/src/main/java/ai/djl/repository/SimpleRepository.java @@ -92,8 +92,7 @@ public Metadata locate(MRL mrl) throws IOException { /** {@inheritDoc} */ @Override - public Artifact resolve(MRL mrl, String version, Map filter) - throws IOException { + public Artifact resolve(MRL mrl, Map filter) throws IOException { List artifacts = locate(mrl).getArtifacts(); if (artifacts.isEmpty()) { return null; @@ -156,7 +155,7 @@ public List getResources() { return Collections.emptyList(); } - MRL mrl = MRL.undefined(DefaultModelZoo.GROUP_ID, artifactId); + MRL mrl = MRL.undefined(this, DefaultModelZoo.GROUP_ID, artifactId); return Collections.singletonList(mrl); } @@ -190,7 +189,7 @@ private synchronized Metadata getMetadata() throws IOException { artifact.setName(modelName); String hash = md5hash(uri + "artifact_id=" + artifactId + "&model_name=" + modelName); - MRL mrl = MRL.model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); + MRL mrl = model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); metadata.setRepositoryUri(mrl.toURI()); } else { if (Files.isDirectory(path)) { diff --git a/api/src/main/java/ai/djl/repository/SimpleUrlRepository.java b/api/src/main/java/ai/djl/repository/SimpleUrlRepository.java index 90c9c884828..1b381913279 100644 --- a/api/src/main/java/ai/djl/repository/SimpleUrlRepository.java +++ b/api/src/main/java/ai/djl/repository/SimpleUrlRepository.java @@ -77,8 +77,7 @@ public Metadata locate(MRL mrl) throws IOException { /** {@inheritDoc} */ @Override - public Artifact resolve(MRL mrl, String version, Map filter) - throws IOException { + public Artifact resolve(MRL mrl, Map filter) throws IOException { List artifacts = locate(mrl).getArtifacts(); if (artifacts.isEmpty()) { return null; @@ -92,7 +91,7 @@ public List getResources() { try { Metadata m = getMetadata(); if (m != null && !m.getArtifacts().isEmpty()) { - MRL mrl = MRL.undefined(m.getGroupId(), m.getArtifactId()); + MRL mrl = MRL.undefined(this, m.getGroupId(), m.getArtifactId()); return Collections.singletonList(mrl); } } catch (IOException e) { @@ -133,7 +132,7 @@ private synchronized Metadata getMetadata() throws IOException { metadata.setArtifactId(artifactId); metadata.setArtifacts(Collections.singletonList(artifact)); String hash = md5hash(uri.toString()); - MRL mrl = MRL.model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); + MRL mrl = model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); metadata.setRepositoryUri(mrl.toURI()); return metadata; } diff --git a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java index b726293e37c..2fb3d7835fb 100644 --- a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java +++ b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java @@ -21,8 +21,6 @@ import ai.djl.nn.BlockFactory; import ai.djl.repository.Artifact; import ai.djl.repository.MRL; -import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.translate.DefaultTranslatorFactory; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; @@ -44,19 +42,17 @@ public class BaseModelLoader implements ModelLoader { protected ModelZoo modelZoo; - protected Resource resource; + protected MRL mrl; protected TranslatorFactory defaultFactory; /** * Constructs a {@link ModelLoader} given the repository, mrl, and version. * - * @param repository the repository to load the model from * @param mrl the mrl of the model to load - * @param version the version of the model to load * @param modelZoo the modelZoo type that is being used to get supported engine types */ - public BaseModelLoader(Repository repository, MRL mrl, String version, ModelZoo modelZoo) { - this.resource = new Resource(repository, mrl, version); + public BaseModelLoader(MRL mrl, ModelZoo modelZoo) { + this.mrl = mrl; this.modelZoo = modelZoo; defaultFactory = new DefaultTranslatorFactory(); } @@ -64,13 +60,13 @@ public BaseModelLoader(Repository repository, MRL mrl, String version, ModelZoo /** {@inheritDoc} */ @Override public String getArtifactId() { - return resource.getMrl().getArtifactId(); + return mrl.getArtifactId(); } /** {@inheritDoc} */ @Override public Application getApplication() { - return resource.getMrl().getApplication(); + return mrl.getApplication(); } /** {@inheritDoc} */ @@ -78,7 +74,7 @@ public Application getApplication() { @SuppressWarnings("unchecked") public ZooModel loadModel(Criteria criteria) throws IOException, ModelNotFoundException, MalformedModelException { - Artifact artifact = resource.match(criteria.getFilters()); + Artifact artifact = mrl.match(criteria.getFilters()); if (artifact == null) { throw new ModelNotFoundException("No matching filter found"); } @@ -98,13 +94,13 @@ public ZooModel loadModel(Criteria criteria) } } - resource.prepare(artifact, progress); + mrl.prepare(artifact, progress); if (progress != null) { progress.reset("Loading", 2); progress.update(1); } - Path modelPath = resource.getRepository().getResourceDirectory(artifact); + Path modelPath = mrl.getRepository().getResourceDirectory(artifact); loadServingProperties(modelPath, arguments); Application application = criteria.getApplication(); @@ -169,8 +165,8 @@ public ZooModel loadModel(Criteria criteria) /** {@inheritDoc} */ @Override public List listModels() throws IOException { - List list = resource.listArtifacts(); - String version = resource.getVersion(); + List list = mrl.listArtifacts(); + String version = mrl.getVersion(); return list.stream() .filter(a -> version == null || version.equals(a.getVersion())) .collect(Collectors.toList()); @@ -202,9 +198,9 @@ protected Model createModel( @Override public String toString() { StringBuilder sb = new StringBuilder(200); - sb.append(resource.getMrl().getGroupId()) + sb.append(mrl.getGroupId()) .append(':') - .append(resource.getMrl().getArtifactId()) + .append(mrl.getArtifactId()) .append(' ') .append(getApplication()) .append(" [\n"); diff --git a/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java b/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java index a3d517d9f07..bfe15b1c128 100644 --- a/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java +++ b/api/src/main/java/ai/djl/repository/zoo/DefaultModelZoo.java @@ -44,7 +44,7 @@ public DefaultModelZoo(String locations) { logger.debug("Scanning models in repo: {}, {}", repo.getClass(), url); List mrls = repo.getResources(); for (MRL mrl : mrls) { - modelLoaders.add(new BaseModelLoader(repo, mrl, null, null)); + modelLoaders.add(new BaseModelLoader(mrl, null)); } } else { logger.warn("Model location is empty."); diff --git a/api/src/test/java/ai/djl/repository/JarRepositoryTest.java b/api/src/test/java/ai/djl/repository/JarRepositoryTest.java index 110ac944534..871f1ed34ab 100644 --- a/api/src/test/java/ai/djl/repository/JarRepositoryTest.java +++ b/api/src/test/java/ai/djl/repository/JarRepositoryTest.java @@ -43,7 +43,7 @@ public void testResource() throws IOException { List list = repo.getResources(); Assert.assertEquals(list.size(), 1); - Artifact artifact = repo.resolve(list.get(0), null, null); + Artifact artifact = repo.resolve(list.get(0), null); repo.prepare(artifact); Assert.assertEquals(1, artifact.getFiles().size()); } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/BananaDetection.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/BananaDetection.java index f09a8dd72f7..3cef51cae8c 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/BananaDetection.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/BananaDetection.java @@ -22,7 +22,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.translate.Pipeline; import ai.djl.translate.TranslateException; @@ -55,7 +54,7 @@ public class BananaDetection extends ObjectDetectionDataset { private final List imagePaths; private final PairList labels; - private final Resource resource; + private final MRL mrl; private boolean prepared; /** @@ -67,10 +66,9 @@ public class BananaDetection extends ObjectDetectionDataset { public BananaDetection(Builder builder) { super(builder); usage = builder.usage; + mrl = builder.getMrl(); imagePaths = new ArrayList<>(); labels = new PairList<>(); - MRL mrl = MRL.dataset(Application.CV.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, VERSION); } /** @@ -100,10 +98,10 @@ public void prepare(Progress progress) throws IOException, TranslateException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); - Path root = resource.getRepository().getResourceDirectory(artifact); + Path root = mrl.getRepository().getResourceDirectory(artifact); Path usagePath; switch (usage) { case TRAIN: @@ -235,5 +233,9 @@ public BananaDetection build() { } return new BananaDetection(this); } + + MRL getMrl() { + return repository.dataset(Application.CV.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java index b2e12c473db..c70b261264b 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/CocoDetection.java @@ -12,7 +12,7 @@ */ package ai.djl.basicdataset.cv; -import ai.djl.Application.CV; +import ai.djl.Application; import ai.djl.basicdataset.BasicDatasets; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; @@ -22,7 +22,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.translate.Pipeline; import ai.djl.util.PairList; import ai.djl.util.Progress; @@ -40,21 +39,21 @@ public class CocoDetection extends ObjectDetectionDataset { private static final String ARTIFACT_ID = "coco"; + private static final String VERSION = "1.0"; private Usage usage; private List imagePaths; private List> labels; - private Resource resource; + private MRL mrl; private boolean prepared; CocoDetection(Builder builder) { super(builder); usage = builder.usage; + mrl = builder.getMrl(); imagePaths = new ArrayList<>(); labels = new ArrayList<>(); - MRL mrl = MRL.dataset(CV.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, "1.0"); } /** @@ -78,9 +77,9 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); - Path root = resource.getRepository().getResourceDirectory(artifact); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); + Path root = mrl.getRepository().getResourceDirectory(artifact); Path jsonFile; switch (usage) { @@ -233,5 +232,9 @@ public CocoDetection build() { } return new CocoDetection(this); } + + MRL getMrl() { + return repository.dataset(Application.CV.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/ImageDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/ImageDataset.java index 10c3411ab7d..97f25fb5d01 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/ImageDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/ImageDataset.java @@ -26,6 +26,7 @@ * image. */ public abstract class ImageDataset extends RandomAccessDataset { + protected Image.Flag flag; /** diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java index fa0f9f087db..fb5d676bf04 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/PikachuDetection.java @@ -22,7 +22,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.translate.Pipeline; import ai.djl.util.JsonUtils; import ai.djl.util.PairList; @@ -50,16 +49,15 @@ public class PikachuDetection extends ObjectDetectionDataset { private List imagePaths; private PairList labels; - private Resource resource; + private MRL mrl; private boolean prepared; protected PikachuDetection(Builder builder) { super(builder); usage = builder.usage; + mrl = builder.getMrl(); imagePaths = new ArrayList<>(); labels = new PairList<>(); - MRL mrl = MRL.dataset(CV.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, VERSION); } /** @@ -78,10 +76,10 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); - Path root = resource.getRepository().getResourceDirectory(artifact); + Path root = mrl.getRepository().getResourceDirectory(artifact); Path usagePath; switch (usage) { case TRAIN: @@ -224,5 +222,9 @@ public PikachuDetection build() { } return new PikachuDetection(this); } + + MRL getMrl() { + return repository.dataset(CV.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java index 71ab711a4d1..fe8ccad10ca 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/AbstractImageFolder.java @@ -14,8 +14,9 @@ import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; +import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; +import ai.djl.repository.zoo.DefaultModelZoo; import ai.djl.translate.TranslateException; import ai.djl.util.Pair; import ai.djl.util.PairList; @@ -43,7 +44,7 @@ public abstract class AbstractImageFolder extends ImageClassificationDataset { protected List synset; protected PairList items; - protected Resource resource; + protected MRL mrl; protected boolean prepared; private int maxDepth; @@ -57,7 +58,8 @@ protected AbstractImageFolder(ImageFolderBuilder builder) { this.imageHeight = builder.imageHeight; this.synset = new ArrayList<>(); this.items = new PairList<>(); - this.resource = new Resource(builder.repository, null, "1.0"); + String path = builder.repository.getBaseUri().toString(); + mrl = MRL.undefined(builder.repository, DefaultModelZoo.GROUP_ID, path); } /** {@inheritDoc} */ diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/CaptchaDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/CaptchaDataset.java index 75ee274b992..a6782324844 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/CaptchaDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/CaptchaDataset.java @@ -23,7 +23,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.training.dataset.Dataset; import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.training.dataset.Record; @@ -43,19 +42,20 @@ */ public class CaptchaDataset extends RandomAccessDataset { + private static final String ARTIFACT_ID = "captcha"; + private static final String VERSION = "1.1"; + public static final int IMAGE_WIDTH = 160; public static final int IMAGE_HEIGHT = 60; public static final int CAPTCHA_LENGTH = 6; public static final int CAPTCHA_OPTIONS = 11; - private static final String ARTIFACT_ID = "captcha"; - private Usage usage; private List items; private Artifact.Item dataItem; private String pathPrefix; - private Resource resource; + private MRL mrl; private boolean prepared; /** @@ -66,8 +66,7 @@ public class CaptchaDataset extends RandomAccessDataset { public CaptchaDataset(Builder builder) { super(builder); this.usage = builder.usage; - MRL mrl = MRL.dataset(CV.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, "1.1"); + mrl = builder.getMrl(); } /** @@ -83,8 +82,7 @@ public static CaptchaDataset.Builder builder() { @Override public Record get(NDManager manager, long index) throws IOException { String item = items.get(Math.toIntExact(index)); - Path imagePath = - resource.getRepository().getFile(dataItem, pathPrefix + '/' + item + ".jpeg"); + Path imagePath = mrl.getRepository().getFile(dataItem, pathPrefix + '/' + item + ".jpeg"); NDArray imageArray = ImageFactory.getInstance() .fromFile(imagePath) @@ -118,14 +116,14 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); dataItem = artifact.getFiles().get("data"); pathPrefix = getUsagePath(); items = new ArrayList<>(); for (String filenameWithExtension : - resource.getRepository().listDirectory(dataItem, pathPrefix)) { + mrl.getRepository().listDirectory(dataItem, pathPrefix)) { String captchaFilename = filenameWithExtension.substring(0, filenameWithExtension.lastIndexOf('.')); items.add(captchaFilename); @@ -227,5 +225,9 @@ public Builder optUsage(Usage usage) { public CaptchaDataset build() { return new CaptchaDataset(this); } + + MRL getMrl() { + return repository.dataset(CV.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java index f2909650eb3..21d7a5bc0f6 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Cifar10.java @@ -23,7 +23,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.training.dataset.ArrayDataset; import ai.djl.translate.Pipeline; import ai.djl.util.Progress; @@ -39,20 +38,22 @@ */ public final class Cifar10 extends ArrayDataset { + private static final String ARTIFACT_ID = "cifar10"; + private static final String VERSION = "1.0"; + public static final int IMAGE_WIDTH = 32; public static final int IMAGE_HEIGHT = 32; public static final float[] NORMALIZE_MEAN = {0.4914f, 0.4822f, 0.4465f}; public static final float[] NORMALIZE_STD = {0.2023f, 0.1994f, 0.2010f}; - private static final String ARTIFACT_ID = "cifar10"; // 3072 = 32 * 32 * 3, i.e. one image size, +1 here is label private static final int DATA_AND_LABEL_SIZE = IMAGE_HEIGHT * IMAGE_WIDTH * 3 + 1; private NDManager manager; private Usage usage; - private Resource resource; + private MRL mrl; private boolean prepared; Cifar10(Builder builder) { @@ -60,8 +61,7 @@ public final class Cifar10 extends ArrayDataset { this.manager = builder.manager; this.manager.setName("cifar10"); this.usage = builder.usage; - MRL mrl = MRL.dataset(CV.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, "1.0"); + mrl = builder.getMrl(); } /** @@ -80,8 +80,8 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); Map map = artifact.getFiles(); Artifact.Item item; @@ -117,7 +117,7 @@ public void prepare(Progress progress) throws IOException { } private NDArray readData(Artifact.Item item) throws IOException { - try (InputStream is = resource.getRepository().openStream(item, null)) { + try (InputStream is = mrl.getRepository().openStream(item, null)) { byte[] buf = Utils.toByteArray(is); int length = buf.length / DATA_AND_LABEL_SIZE; try (NDArray array = @@ -223,5 +223,9 @@ public Builder optUsage(Usage usage) { public Cifar10 build() { return new Cifar10(this); } + + MRL getMrl() { + return repository.dataset(CV.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java index 785eed16ded..abbad625b92 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/FashionMnist.java @@ -23,7 +23,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.training.dataset.ArrayDataset; import ai.djl.translate.Pipeline; import ai.djl.util.Progress; @@ -40,16 +39,17 @@ */ public final class FashionMnist extends ArrayDataset { + private static final String ARTIFACT_ID = "fashmnist"; + private static final String VERSION = "1.0"; + public static final int IMAGE_WIDTH = 28; public static final int IMAGE_HEIGHT = 28; public static final int NUM_CLASSES = 10; - private static final String ARTIFACT_ID = "fashmnist"; - private final NDManager manager; private final Usage usage; - private Resource resource; + private MRL mrl; private boolean prepared; /** @@ -62,8 +62,7 @@ private FashionMnist(FashionMnist.Builder builder) { this.manager = builder.manager; this.manager.setName("fashionmnist"); this.usage = builder.usage; - MRL mrl = MRL.dataset(CV.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, "1.0"); + mrl = builder.getMrl(); } /** @@ -82,8 +81,8 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); Map map = artifact.getFiles(); Artifact.Item imageItem; @@ -108,7 +107,7 @@ public void prepare(Progress progress) throws IOException { } private NDArray readData(Artifact.Item item, long length) throws IOException { - try (InputStream is = resource.getRepository().openStream(item, null)) { + try (InputStream is = mrl.getRepository().openStream(item, null)) { if (is.skip(16) != 16) { throw new AssertionError("Failed skip data."); } @@ -124,7 +123,7 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { } private NDArray readLabel(Artifact.Item item) throws IOException { - try (InputStream is = resource.getRepository().openStream(item, null)) { + try (InputStream is = mrl.getRepository().openStream(item, null)) { if (is.skip(8) != 8) { throw new AssertionError("Failed skip data."); } @@ -234,5 +233,9 @@ public FashionMnist build() { } return new FashionMnist(this); } + + MRL getMrl() { + return repository.dataset(CV.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java index 14fb448bc51..fa2251477d9 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageFolder.java @@ -63,9 +63,9 @@ protected Path getImagePath(String key) { @Override public void prepare(Progress progress) throws IOException { if (!prepared) { - resource.prepare(null, progress); + mrl.prepare(null, progress); loadSynset(); - Path root = Paths.get(resource.getRepository().getBaseUri()); + Path root = Paths.get(mrl.getRepository().getBaseUri()); if (progress != null) { progress.reset("Preparing", 2); progress.start(0); @@ -79,7 +79,7 @@ public void prepare(Progress progress) throws IOException { } private void loadSynset() { - File root = new File(resource.getRepository().getBaseUri()); + File root = new File(mrl.getRepository().getBaseUri()); File[] dir = root.listFiles(f -> f.isDirectory() && !f.getName().startsWith(".")); if (dir == null || dir.length == 0) { throw new IllegalArgumentException(root + " not found or didn't have any file in it"); diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageNet.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageNet.java index 10035c342e1..ef828033ad4 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageNet.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageNet.java @@ -42,7 +42,7 @@ public class ImageNet extends AbstractImageFolder { ImageNet(Builder builder) { super(builder); String usagePath = getUsagePath(builder.usage); - root = Paths.get(resource.getRepository().getBaseUri()).resolve(usagePath); + root = Paths.get(mrl.getRepository().getBaseUri()).resolve(usagePath); } /** @@ -85,7 +85,7 @@ public String[] getClassFull() { @Override public void prepare(Progress progress) throws IOException { if (!prepared) { - resource.prepare(null, progress); + mrl.prepare(null, progress); if (progress != null) { progress.reset("Preparing", 2); diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java index ac15f16157f..257fd6bc44d 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/Mnist.java @@ -23,7 +23,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.training.dataset.ArrayDataset; import ai.djl.translate.Pipeline; import ai.djl.util.Progress; @@ -39,16 +38,17 @@ */ public final class Mnist extends ArrayDataset { + private static final String ARTIFACT_ID = "mnist"; + private static final String VERSION = "1.0"; + public static final int IMAGE_WIDTH = 28; public static final int IMAGE_HEIGHT = 28; public static final int NUM_CLASSES = 10; - private static final String ARTIFACT_ID = "mnist"; - private NDManager manager; private Usage usage; - private Resource resource; + private MRL mrl; private boolean prepared; private Mnist(Builder builder) { @@ -56,8 +56,7 @@ private Mnist(Builder builder) { this.manager = builder.manager; this.manager.setName("mnist"); this.usage = builder.usage; - MRL mrl = MRL.dataset(CV.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, "1.0"); + mrl = builder.getMrl(); } /** @@ -76,8 +75,8 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); Map map = artifact.getFiles(); Artifact.Item imageItem; @@ -101,7 +100,7 @@ public void prepare(Progress progress) throws IOException { } private NDArray readData(Artifact.Item item, long length) throws IOException { - try (InputStream is = resource.getRepository().openStream(item, null)) { + try (InputStream is = mrl.getRepository().openStream(item, null)) { if (is.skip(16) != 16) { throw new AssertionError("Failed skip data."); } @@ -115,7 +114,7 @@ private NDArray readData(Artifact.Item item, long length) throws IOException { } private NDArray readLabel(Artifact.Item item) throws IOException { - try (InputStream is = resource.getRepository().openStream(item, null)) { + try (InputStream is = mrl.getRepository().openStream(item, null)) { if (is.skip(8) != 8) { throw new AssertionError("Failed skip data."); } @@ -223,5 +222,9 @@ public Builder optUsage(Usage usage) { public Mnist build() { return new Mnist(this); } + + MRL getMrl() { + return repository.dataset(CV.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/AmazonReview.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/AmazonReview.java index f16c6361f0c..07ca83598fa 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/AmazonReview.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/AmazonReview.java @@ -12,13 +12,12 @@ */ package ai.djl.basicdataset.nlp; -import ai.djl.Application; +import ai.djl.Application.NLP; import ai.djl.basicdataset.BasicDatasets; import ai.djl.basicdataset.tabular.CsvDataset; import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.util.Progress; import java.io.IOException; import java.nio.file.Path; @@ -35,7 +34,7 @@ public class AmazonReview extends CsvDataset { private static final String VERSION = "1.0"; private static final String ARTIFACT_ID = "amazon_reviews"; - private Resource resource; + private MRL mrl; private String datasetName; private boolean prepared; @@ -46,8 +45,7 @@ public class AmazonReview extends CsvDataset { */ protected AmazonReview(Builder builder) { super(builder); - MRL mrl = MRL.dataset(Application.NLP.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, VERSION); + mrl = builder.getMrl(); datasetName = builder.datasetName; } @@ -60,10 +58,10 @@ public void prepare(Progress progress) throws IOException { Map filter = new ConcurrentHashMap<>(); filter.put("dataset", datasetName); - Artifact artifact = resource.match(filter); - resource.prepare(artifact, progress); + Artifact artifact = mrl.match(filter); + mrl.prepare(artifact, progress); - Path dir = resource.getRepository().getResourceDirectory(artifact); + Path dir = mrl.getRepository().getResourceDirectory(artifact); Path csvFile = dir.resolve(artifact.getFiles().values().iterator().next().getName()); csvUrl = csvFile.toUri().toURL(); super.prepare(progress); @@ -163,5 +161,9 @@ public AmazonReview build() { } return new AmazonReview(this); } + + MRL getMrl() { + return repository.dataset(NLP.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/CookingStackExchange.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/CookingStackExchange.java index b49d8942349..ef394706bde 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/CookingStackExchange.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/CookingStackExchange.java @@ -19,7 +19,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.training.dataset.Batch; import ai.djl.training.dataset.Dataset; import ai.djl.util.Progress; @@ -33,17 +32,17 @@ public class CookingStackExchange implements RawDataset { private static final String ARTIFACT_ID = "cooking_stackexchange"; + private static final String VERSION = "1.0"; private Dataset.Usage usage; private Path root; - private Resource resource; + private MRL mrl; private boolean prepared; CookingStackExchange(Builder builder) { this.usage = builder.usage; - MRL mrl = MRL.dataset(NLP.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, "1.0"); + mrl = builder.getMrl(); } /** {@inheritDoc} */ @@ -66,8 +65,8 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); Artifact.Item item; switch (usage) { @@ -81,7 +80,7 @@ public void prepare(Progress progress) throws IOException { default: throw new IOException("Only training and testing dataset supported."); } - root = resource.getRepository().getFile(item, "").toAbsolutePath(); + root = mrl.getRepository().getFile(item, "").toAbsolutePath(); prepared = true; } @@ -168,5 +167,9 @@ public Builder optUsage(Dataset.Usage usage) { public CookingStackExchange build() { return new CookingStackExchange(this); } + + MRL getMrl() { + return repository.dataset(NLP.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordMovieReview.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordMovieReview.java index 4d092cff40c..273d76b4534 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordMovieReview.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/StanfordMovieReview.java @@ -19,7 +19,6 @@ import ai.djl.ndarray.types.DataType; import ai.djl.repository.Artifact; import ai.djl.repository.MRL; -import ai.djl.repository.Resource; import ai.djl.training.dataset.Record; import ai.djl.util.Progress; import java.io.File; @@ -56,8 +55,7 @@ public class StanfordMovieReview extends TextDataset { protected StanfordMovieReview(Builder builder) { super(builder); this.usage = builder.usage; - MRL mrl = MRL.dataset(NLP.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, VERSION); + mrl = builder.getMrl(); } /** @@ -75,9 +73,9 @@ public void prepare(Progress progress) throws IOException, EmbeddingException { if (prepared) { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); - Path cacheDir = resource.getRepository().getCacheDirectory(); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); + Path cacheDir = mrl.getRepository().getCacheDirectory(); URI resourceUri = artifact.getResourceUri(); Path root = cacheDir.resolve(resourceUri.getPath()).resolve("aclImdb").resolve("aclImdb"); @@ -167,5 +165,9 @@ protected Builder self() { public StanfordMovieReview build() { return new StanfordMovieReview(this); } + + MRL getMrl() { + return repository.dataset(NLP.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java index dbc96114969..e59f4bbf177 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java @@ -18,7 +18,6 @@ import ai.djl.ndarray.NDManager; import ai.djl.repository.Artifact; import ai.djl.repository.MRL; -import ai.djl.repository.Resource; import ai.djl.training.dataset.Record; import ai.djl.util.Progress; import java.io.BufferedReader; @@ -46,8 +45,7 @@ public class TatoebaEnglishFrenchDataset extends TextDataset { protected TatoebaEnglishFrenchDataset(Builder builder) { super(builder); this.usage = builder.usage; - MRL mrl = MRL.dataset(NLP.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, VERSION); + mrl = builder.getMrl(); } /** @@ -66,9 +64,9 @@ public void prepare(Progress progress) throws IOException, EmbeddingException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); - Path root = resource.getRepository().getResourceDirectory(artifact); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); + Path root = mrl.getRepository().getResourceDirectory(artifact); Path usagePath; switch (usage) { @@ -139,5 +137,9 @@ public Builder self() { public TatoebaEnglishFrenchDataset build() { return new TatoebaEnglishFrenchDataset(this); } + + MRL getMrl() { + return repository.dataset(NLP.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TextDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TextDataset.java index a97f459a3d9..196900befb4 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TextDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TextDataset.java @@ -22,8 +22,8 @@ import ai.djl.modality.nlp.embedding.TextEmbedding; import ai.djl.modality.nlp.embedding.TrainableWordEmbedding; import ai.djl.ndarray.NDManager; +import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.training.dataset.RandomAccessDataset; import java.util.ArrayList; import java.util.Comparator; @@ -46,7 +46,7 @@ public abstract class TextDataset extends RandomAccessDataset { protected NDManager manager; protected Usage usage; - protected Resource resource; + protected MRL mrl; protected boolean prepared; protected List samples; diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AirfoilRandomAccess.java b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AirfoilRandomAccess.java index 482255d9a97..796e1ac2709 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AirfoilRandomAccess.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AirfoilRandomAccess.java @@ -12,7 +12,7 @@ */ package ai.djl.basicdataset.tabular; -import ai.djl.Application; +import ai.djl.Application.Tabular; import ai.djl.basicdataset.BasicDatasets; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; @@ -20,7 +20,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.util.Progress; import java.io.IOException; import java.nio.ByteBuffer; @@ -41,12 +40,13 @@ public final class AirfoilRandomAccess extends CsvDataset { private static final String ARTIFACT_ID = "airfoil"; + private static final String VERSION = "1.0"; private static final String[] COLUMNS = { "freq", "aoa", "chordlen", "freestreamvel", "ssdt", "ssoundpres" }; - private Resource resource; + private MRL mrl; private Usage usage; private boolean prepared; @@ -61,9 +61,8 @@ public final class AirfoilRandomAccess extends CsvDataset { */ AirfoilRandomAccess(Builder builder) { super(builder); - MRL mrl = MRL.dataset(Application.Tabular.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, "1.0"); usage = builder.usage; + mrl = builder.getMrl(); normalize = builder.normalize; } @@ -74,10 +73,10 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact); - Path root = resource.getRepository().getResourceDirectory(artifact); + Path root = mrl.getRepository().getResourceDirectory(artifact); Path csvFile; switch (usage) { case TRAIN: @@ -282,5 +281,9 @@ public AirfoilRandomAccess build() { } return new AirfoilRandomAccess(this); } + + MRL getMrl() { + return repository.dataset(Tabular.ANY, groupId, artifactId, VERSION); + } } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AmesRandomAccess.java b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AmesRandomAccess.java index 43b6d76bd76..ded11c35856 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AmesRandomAccess.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/AmesRandomAccess.java @@ -17,7 +17,6 @@ import ai.djl.repository.Artifact; import ai.djl.repository.MRL; import ai.djl.repository.Repository; -import ai.djl.repository.Resource; import ai.djl.util.JsonUtils; import ai.djl.util.Progress; import java.io.IOException; @@ -46,16 +45,16 @@ public class AmesRandomAccess extends CsvDataset { private static final String ARTIFACT_ID = "ames"; + private static final String VERSION = "1.0"; private Usage usage; - private Resource resource; + private MRL mrl; private boolean prepared; AmesRandomAccess(Builder builder) { super(builder); usage = builder.usage; - MRL mrl = MRL.dataset(Tabular.ANY, builder.groupId, builder.artifactId); - resource = new Resource(builder.repository, mrl, "1.0"); + mrl = builder.getMrl(); } /** {@inheritDoc} */ @@ -65,10 +64,10 @@ public void prepare(Progress progress) throws IOException { return; } - Artifact artifact = resource.getDefaultArtifact(); - resource.prepare(artifact, progress); + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); - Path dir = resource.getRepository().getResourceDirectory(artifact); + Path dir = mrl.getRepository().getResourceDirectory(artifact); Path root = dir.resolve("house-prices-advanced-regression-techniques"); Path csvFile; switch (usage) { @@ -240,6 +239,10 @@ private void parseFeatures() { } } } + + MRL getMrl() { + return repository.dataset(Tabular.ANY, groupId, artifactId, VERSION); + } } private static final class AmesFeatures { diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/zoo/DlrModelZoo.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/zoo/DlrModelZoo.java index d4f4f25760c..21366936943 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/zoo/DlrModelZoo.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/zoo/DlrModelZoo.java @@ -35,8 +35,8 @@ public class DlrModelZoo implements ModelZoo { private static final List MODEL_LOADERS = new ArrayList<>(); static { - MRL resnet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, resnet, "0.0.1", ZOO)); + MRL resnet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(resnet, ZOO)); } /** {@inheritDoc} */ diff --git a/extensions/aws-ai/src/main/java/ai/djl/aws/s3/S3Repository.java b/extensions/aws-ai/src/main/java/ai/djl/aws/s3/S3Repository.java index 644ea606753..3d623bd1d8c 100644 --- a/extensions/aws-ai/src/main/java/ai/djl/aws/s3/S3Repository.java +++ b/extensions/aws-ai/src/main/java/ai/djl/aws/s3/S3Repository.java @@ -102,8 +102,7 @@ public Metadata locate(MRL mrl) throws IOException { /** {@inheritDoc} */ @Override - public Artifact resolve(MRL mrl, String version, Map filter) - throws IOException { + public Artifact resolve(MRL mrl, Map filter) throws IOException { Metadata m = locate(mrl); if (m == null) { return null; @@ -133,7 +132,7 @@ public List getResources() { try { Metadata m = getMetadata(); if (m != null && !m.getArtifacts().isEmpty()) { - MRL mrl = MRL.model(Application.UNDEFINED, m.getGroupId(), m.getArtifactId()); + MRL mrl = model(Application.UNDEFINED, m.getGroupId(), m.getArtifactId()); return Collections.singletonList(mrl); } } catch (IOException e) { @@ -157,7 +156,7 @@ private synchronized Metadata getMetadata() throws IOException { metadata = new Metadata.MatchAllMetadata(); String hash = md5hash("s3://" + bucket + '/' + prefix); - MRL mrl = MRL.model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); + MRL mrl = model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); metadata.setRepositoryUri(mrl.toURI()); metadata.setArtifactId(artifactId); metadata.setArtifacts(Collections.singletonList(artifact)); diff --git a/extensions/aws-ai/src/test/java/ai/djl/aws/s3/S3RepositoryTest.java b/extensions/aws-ai/src/test/java/ai/djl/aws/s3/S3RepositoryTest.java index 04828bd15a4..6d08baf3693 100644 --- a/extensions/aws-ai/src/test/java/ai/djl/aws/s3/S3RepositoryTest.java +++ b/extensions/aws-ai/src/test/java/ai/djl/aws/s3/S3RepositoryTest.java @@ -107,12 +107,12 @@ public void testS3Repository() throws IOException { Assert.assertEquals( url.toString(), "mlrepo/model/cv/image_classification/ai/djl/mxnet/mlp/0.0.1/"); - MRL mrl = MRL.model(Application.UNDEFINED, "ai.djl.localmodelzoo", "mlp"); - Artifact artifact = repository.resolve(mrl, "0.0.1", null); + MRL mrl = repository.model(Application.UNDEFINED, "ai.djl.localmodelzoo", "mlp"); + Artifact artifact = repository.resolve(mrl, null); Assert.assertNotNull(artifact); repository = Repository.newInstance("s3", "s3://djl-ai/non-exists"); - artifact = repository.resolve(mrl, "0.0.1", null); + artifact = repository.resolve(mrl, null); Assert.assertNull(artifact); } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java index e2b527f04b3..95aa87cb9ce 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java @@ -18,7 +18,6 @@ import ai.djl.fasttext.FtModel; import ai.djl.fasttext.zoo.FtModelZoo; import ai.djl.repository.Artifact; -import ai.djl.repository.MRL; import ai.djl.repository.Repository; import ai.djl.repository.zoo.BaseModelLoader; import ai.djl.repository.zoo.Criteria; @@ -42,20 +41,20 @@ public class TextClassificationModelLoader extends BaseModelLoader { * @param repository the repository to load the model from */ public TextClassificationModelLoader(Repository repository) { - super(repository, MRL.model(APPLICATION, GROUP_ID, ARTIFACT_ID), VERSION, new FtModelZoo()); + super(repository.model(APPLICATION, GROUP_ID, ARTIFACT_ID, VERSION), new FtModelZoo()); } /** {@inheritDoc} */ @Override public ZooModel loadModel(Criteria criteria) throws ModelNotFoundException, IOException, MalformedModelException { - Artifact artifact = resource.match(criteria.getFilters()); + Artifact artifact = mrl.match(criteria.getFilters()); if (artifact == null) { throw new ModelNotFoundException("No matching filter found"); } Progress progress = criteria.getProgress(); - resource.prepare(artifact, progress); + mrl.prepare(artifact, progress); if (progress != null) { progress.reset("Loading", 2); progress.update(1); @@ -65,7 +64,7 @@ public ZooModel loadModel(Criteria criteria) modelName = artifact.getName(); } Model model = new FtModel(modelName); - Path modelPath = resource.getRepository().getResourceDirectory(artifact); + Path modelPath = mrl.getRepository().getResourceDirectory(artifact); model.load(modelPath); return new ZooModel<>(model, null); } diff --git a/extensions/hadoop/src/main/java/ai/djl/hadoop/hdfs/HdfsRepository.java b/extensions/hadoop/src/main/java/ai/djl/hadoop/hdfs/HdfsRepository.java index a1cade88e92..cf3c9f295a9 100644 --- a/extensions/hadoop/src/main/java/ai/djl/hadoop/hdfs/HdfsRepository.java +++ b/extensions/hadoop/src/main/java/ai/djl/hadoop/hdfs/HdfsRepository.java @@ -98,8 +98,7 @@ public Metadata locate(MRL mrl) throws IOException { /** {@inheritDoc} */ @Override - public Artifact resolve(MRL mrl, String version, Map filter) - throws IOException { + public Artifact resolve(MRL mrl, Map filter) throws IOException { Metadata m = locate(mrl); if (m == null) { return null; @@ -129,7 +128,7 @@ public List getResources() { try { Metadata m = getMetadata(); if (m != null && !m.getArtifacts().isEmpty()) { - MRL mrl = MRL.model(Application.UNDEFINED, m.getGroupId(), m.getArtifactId()); + MRL mrl = model(Application.UNDEFINED, m.getGroupId(), m.getArtifactId()); return Collections.singletonList(mrl); } } catch (IOException e) { @@ -152,7 +151,7 @@ private synchronized Metadata getMetadata() throws IOException { metadata = new Metadata.MatchAllMetadata(); String hash = md5hash(uri.resolve(prefix).toString()); - MRL mrl = MRL.model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); + MRL mrl = model(Application.UNDEFINED, DefaultModelZoo.GROUP_ID, hash); metadata.setRepositoryUri(mrl.toURI()); metadata.setArtifactId(artifactId); metadata.setArtifacts(Collections.singletonList(artifact)); diff --git a/extensions/hadoop/src/test/java/ai/djl/hadoop/hdfs/HdfsRepositoryTest.java b/extensions/hadoop/src/test/java/ai/djl/hadoop/hdfs/HdfsRepositoryTest.java index 380549edc82..14046e67c96 100644 --- a/extensions/hadoop/src/test/java/ai/djl/hadoop/hdfs/HdfsRepositoryTest.java +++ b/extensions/hadoop/src/test/java/ai/djl/hadoop/hdfs/HdfsRepositoryTest.java @@ -87,7 +87,7 @@ public void testZipFile() throws IOException { List list = repo.getResources(); Assert.assertFalse(list.isEmpty()); - Artifact artifact = repo.resolve(list.get(0), "1.0", null); + Artifact artifact = repo.resolve(list.get(0), null); repo.prepare(artifact); } @@ -98,7 +98,7 @@ public void testDir() throws IOException { List list = repo.getResources(); Assert.assertFalse(list.isEmpty()); - Artifact artifact = repo.resolve(list.get(0), "1.0", null); + Artifact artifact = repo.resolve(list.get(0), null); repo.prepare(artifact); Assert.assertTrue(repo.isRemote()); @@ -119,8 +119,8 @@ public void testAccessDeny() throws IOException { List list = repo.getResources(); Assert.assertTrue(list.isEmpty()); - MRL mrl = MRL.model(Application.UNDEFINED, "ai.djl.localmodelzoo", "mlp"); - Artifact artifact = repo.resolve(mrl, "0.0.1", null); + MRL mrl = repo.model(Application.UNDEFINED, "ai.djl.localmodelzoo", "mlp"); + Artifact artifact = repo.resolve(mrl, null); Assert.assertNull(artifact); } diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java index 0cb559205bb..70a5fba4e84 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/BasicModelZoo.java @@ -34,14 +34,14 @@ public class BasicModelZoo implements ModelZoo { private static final List MODEL_LOADERS = new ArrayList<>(); static { - MRL mlp = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mlp"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, mlp, "0.0.3", ZOO)); + MRL mlp = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mlp", "0.0.3"); + MODEL_LOADERS.add(new BaseModelLoader(mlp, ZOO)); - MRL resnet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, resnet, "0.0.2", ZOO)); + MRL resnet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.2"); + MODEL_LOADERS.add(new BaseModelLoader(resnet, ZOO)); - MRL ssd = MRL.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, ssd, "0.0.2", ZOO)); + MRL ssd = REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.2"); + MODEL_LOADERS.add(new BaseModelLoader(ssd, ZOO)); } /** {@inheritDoc} */ diff --git a/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/MxModelZoo.java b/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/MxModelZoo.java index 7f58ab8c7e2..21dc015eb22 100644 --- a/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/MxModelZoo.java +++ b/mxnet/mxnet-model-zoo/src/main/java/ai/djl/mxnet/zoo/MxModelZoo.java @@ -39,68 +39,70 @@ public class MxModelZoo implements ModelZoo { private static final List MODEL_LOADERS = new ArrayList<>(); static { - MRL ssd = MRL.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, ssd, "0.0.1", ZOO)); + MRL ssd = REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(ssd, ZOO)); - MRL yolo = MRL.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, yolo, "0.0.1", ZOO)); + MRL yolo = REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(yolo, ZOO)); - MRL alexnet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "alexnet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, alexnet, "0.0.1", ZOO)); + MRL alexnet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "alexnet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(alexnet, ZOO)); - MRL darknet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "darknet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, darknet, "0.0.1", ZOO)); + MRL darknet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "darknet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(darknet, ZOO)); - MRL densenet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "densenet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, densenet, "0.0.1", ZOO)); + MRL densenet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "densenet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(densenet, ZOO)); - MRL googlenet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "googlenet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, googlenet, "0.0.1", ZOO)); + MRL googlenet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "googlenet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(googlenet, ZOO)); - MRL inceptionv3 = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "inceptionv3"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, inceptionv3, "0.0.1", ZOO)); + MRL inceptionv3 = + REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "inceptionv3", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(inceptionv3, ZOO)); - MRL mlp = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mlp"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, mlp, "0.0.1", ZOO)); + MRL mlp = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mlp", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(mlp, ZOO)); - MRL mobilenet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mobilenet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, mobilenet, "0.0.1", ZOO)); + MRL mobilenet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mobilenet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(mobilenet, ZOO)); - MRL resnest = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnest"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, resnest, "0.0.1", ZOO)); + MRL resnest = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnest", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(resnest, ZOO)); - MRL resnet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, resnet, "0.0.1", ZOO)); + MRL resnet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(resnet, ZOO)); - MRL senet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "senet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, senet, "0.0.1", ZOO)); + MRL senet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "senet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(senet, ZOO)); - MRL seresnext = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "se_resnext"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, seresnext, "0.0.1", ZOO)); + MRL seresnext = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "se_resnext", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(seresnext, ZOO)); - MRL squeezenet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "squeezenet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, squeezenet, "0.0.1", ZOO)); + MRL squeezenet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "squeezenet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(squeezenet, ZOO)); - MRL vgg = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "vgg"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, vgg, "0.0.1", ZOO)); + MRL vgg = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "vgg", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(vgg, ZOO)); - MRL xception = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "xception"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, xception, "0.0.1", ZOO)); + MRL xception = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "xception", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(xception, ZOO)); - MRL simplePose = MRL.model(CV.POSE_ESTIMATION, GROUP_ID, "simple_pose"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, simplePose, "0.0.1", ZOO)); + MRL simplePose = REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "simple_pose", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(simplePose, ZOO)); - MRL maskrcnn = MRL.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "mask_rcnn"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, maskrcnn, "0.0.1", ZOO)); + MRL maskrcnn = REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "mask_rcnn", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(maskrcnn, ZOO)); - MRL actionRecognition = MRL.model(CV.ACTION_RECOGNITION, GROUP_ID, "action_recognition"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, actionRecognition, "0.0.1", ZOO)); + MRL actionRecognition = + REPOSITORY.model(CV.ACTION_RECOGNITION, GROUP_ID, "action_recognition", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(actionRecognition, ZOO)); - MRL bertQa = MRL.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, bertQa, "0.0.1", ZOO)); + MRL bertQa = REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(bertQa, ZOO)); - MRL glove = MRL.model(NLP.WORD_EMBEDDING, GROUP_ID, "glove"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, glove, "0.0.2", ZOO)); + MRL glove = REPOSITORY.model(NLP.WORD_EMBEDDING, GROUP_ID, "glove", "0.0.2"); + MODEL_LOADERS.add(new BaseModelLoader(glove, ZOO)); } /** {@inheritDoc} */ diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java index 109284832f4..a5aff0cc714 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java @@ -35,8 +35,9 @@ public class OrtModelZoo implements ModelZoo { private static final List MODEL_LOADERS = new ArrayList<>(); static { - MRL irisFlower = MRL.model(Tabular.SOFTMAX_REGRESSION, GROUP_ID, "iris_flowers"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, irisFlower, "0.0.1", ZOO)); + MRL irisFlower = + REPOSITORY.model(Tabular.SOFTMAX_REGRESSION, GROUP_ID, "iris_flowers", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(irisFlower, ZOO)); } /** {@inheritDoc} */ diff --git a/paddlepaddle/paddlepaddle-model-zoo/src/main/java/ai/djl/paddlepaddle/zoo/PpModelZoo.java b/paddlepaddle/paddlepaddle-model-zoo/src/main/java/ai/djl/paddlepaddle/zoo/PpModelZoo.java index f329112a60d..ffbcdf37b2a 100644 --- a/paddlepaddle/paddlepaddle-model-zoo/src/main/java/ai/djl/paddlepaddle/zoo/PpModelZoo.java +++ b/paddlepaddle/paddlepaddle-model-zoo/src/main/java/ai/djl/paddlepaddle/zoo/PpModelZoo.java @@ -35,20 +35,25 @@ public class PpModelZoo implements ModelZoo { private static final List MODEL_LOADERS = new ArrayList<>(); static { - MRL maskDetection = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mask_classification"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, maskDetection, "0.0.1", ZOO)); + MRL maskDetection = + REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mask_classification", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(maskDetection, ZOO)); - MRL wordRotation = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "word_rotation"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, wordRotation, "0.0.1", ZOO)); + MRL wordRotation = + REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "word_rotation", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(wordRotation, ZOO)); - MRL faceDetection = MRL.model(CV.OBJECT_DETECTION, GROUP_ID, "face_detection"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, faceDetection, "0.0.1", ZOO)); + MRL faceDetection = + REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "face_detection", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(faceDetection, ZOO)); - MRL wordDetection = MRL.model(CV.OBJECT_DETECTION, GROUP_ID, "word_detection"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, wordDetection, "0.0.1", ZOO)); + MRL wordDetection = + REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "word_detection", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(wordDetection, ZOO)); - MRL wordRecognition = MRL.model(CV.WORD_RECOGNITION, GROUP_ID, "word_recognition"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, wordRecognition, "0.0.1", ZOO)); + MRL wordRecognition = + REPOSITORY.model(CV.WORD_RECOGNITION, GROUP_ID, "word_recognition", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(wordRecognition, ZOO)); } /** {@inheritDoc} */ diff --git a/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index 41c91c15c05..3f55c8ba6a1 100644 --- a/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -39,20 +39,21 @@ public class PtModelZoo implements ModelZoo { private static final List MODEL_LOADERS = new ArrayList<>(); static { - MRL resnet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, resnet, "0.0.1", ZOO)); + MRL resnet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(resnet, ZOO)); - MRL ssd = MRL.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, ssd, "0.0.1", ZOO)); + MRL ssd = REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(ssd, ZOO)); - MRL bertQa = MRL.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, bertQa, "0.0.1", ZOO)); + MRL bertQa = REPOSITORY.model(NLP.QUESTION_ANSWER, GROUP_ID, "bertqa", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(bertQa, ZOO)); - MRL sentimentAnalysis = MRL.model(NLP.SENTIMENT_ANALYSIS, GROUP_ID, "distilbert"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, sentimentAnalysis, "0.0.1", ZOO)); + MRL sentimentAnalysis = + REPOSITORY.model(NLP.SENTIMENT_ANALYSIS, GROUP_ID, "distilbert", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(sentimentAnalysis, ZOO)); - MRL bigGan = MRL.model(CV.IMAGE_GENERATION, GROUP_ID, "biggan-deep"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, bigGan, "0.0.1", ZOO)); + MRL bigGan = REPOSITORY.model(CV.IMAGE_GENERATION, GROUP_ID, "biggan-deep", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(bigGan, ZOO)); } /** {@inheritDoc} */ diff --git a/tensorflow/tensorflow-model-zoo/src/main/java/ai/djl/tensorflow/zoo/TfModelZoo.java b/tensorflow/tensorflow-model-zoo/src/main/java/ai/djl/tensorflow/zoo/TfModelZoo.java index 2b104af7b4d..e3afad60af9 100644 --- a/tensorflow/tensorflow-model-zoo/src/main/java/ai/djl/tensorflow/zoo/TfModelZoo.java +++ b/tensorflow/tensorflow-model-zoo/src/main/java/ai/djl/tensorflow/zoo/TfModelZoo.java @@ -36,14 +36,14 @@ public class TfModelZoo implements ModelZoo { private static final List MODEL_LOADERS = new ArrayList<>(); static { - MRL resnet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, resnet, "0.0.1", ZOO)); + MRL resnet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(resnet, ZOO)); - MRL mobilenet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mobilenet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, mobilenet, "0.0.1", ZOO)); + MRL mobilenet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mobilenet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(mobilenet, ZOO)); - MRL ssd = MRL.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, ssd, "0.0.1", ZOO)); + MRL ssd = REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(ssd, ZOO)); } /** {@inheritDoc} */ diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/zoo/TfLiteModelZoo.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/zoo/TfLiteModelZoo.java index 7a956fa5689..9f3f7e638ba 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/zoo/TfLiteModelZoo.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/zoo/TfLiteModelZoo.java @@ -35,8 +35,8 @@ public class TfLiteModelZoo implements ModelZoo { private static final List MODEL_LOADERS = new ArrayList<>(); static { - MRL mobilenet = MRL.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mobilenet"); - MODEL_LOADERS.add(new BaseModelLoader(REPOSITORY, mobilenet, "0.0.1", ZOO)); + MRL mobilenet = REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "mobilenet", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(mobilenet, ZOO)); } /** {@inheritDoc} */