From 7ceb7cd3a95d89a4672eedd14cd6c53341948c9c Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Fri, 18 Nov 2022 13:09:57 -0800 Subject: [PATCH] Create translator options (#2145) * Create translator options This creates a new tool for TranslatorOptions based on the characteristic function ((InputClass, OutputClass) -> Translator). Unlike the TranslatorFactory, it doesn't require arguments or a model. So, the TranslatorFactory is mainly used as part of model loading where the TranslatorOptions can be more broad. In particular, it connects to both Dataset and Translator. Both of them gain a new function that will return a TranslatorOptions (matchingTranslatorOptions for the Dataset and getExpansions for the translator). Right now, I made both of these optional with a default value of null so that users don't have to define them, it is just a possibility. To connect the TranslatorOptions and TranslatorFactory, I created a class ExpansionTranslatorFactory. This is used for the common abstraction where we have a base translator and multiple "expansions" for it that come before it in pre-processing or after it in post-processing. With a way to construct the base translator with the options, it works as a TranslatorFactory. With a starting Translator, it can become the TranslatorOptions. Lastly and probably worth calling out, the ExpansionTranslatorFactory creates a map of the options. This allows the supported types and the functions to be defined in only one place instead of two. In a followup PR, I will make a version of this that creates a BaseImageTranslatorFactory matching the BaseImageTranslator and will handle all of the repeated code for all of the image translator factories. * minor javadoc updates Co-authored-by: Frank Liu --- .../ImageClassificationTranslator.java | 7 ++ .../ImageClassificationTranslatorFactory.java | 49 ++++---- .../java/ai/djl/training/dataset/Dataset.java | 11 ++ .../translate/ExpansionTranslatorFactory.java | 105 ++++++++++++++++++ .../java/ai/djl/translate/Translator.java | 9 ++ .../ai/djl/translate/TranslatorOptions.java | 54 +++++++++ .../ImageClassificationDataset.java | 16 +-- .../ai/djl/zero/cv/ImageClassification.java | 3 +- ...pImageClassificationTranslatorFactory.java | 44 ++------ 9 files changed, 225 insertions(+), 73 deletions(-) create mode 100644 api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java create mode 100644 api/src/main/java/ai/djl/translate/TranslatorOptions.java diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslator.java index 866ba830fef..32db1ddbabd 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslator.java @@ -18,6 +18,7 @@ import ai.djl.ndarray.NDList; import ai.djl.translate.ArgumentsUtil; import ai.djl.translate.TranslatorContext; +import ai.djl.translate.TranslatorOptions; import java.io.IOException; import java.util.List; @@ -62,6 +63,12 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) { return new Classifications(classes, probabilitiesNd, topK); } + /** {@inheritDoc} */ + @Override + public TranslatorOptions getExpansions() { + return new ImageClassificationTranslatorFactory().withTranslator(this); + } + /** * Creates a builder to build a {@code ImageClassificationTranslator}. * diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactory.java index 0cdcdea6434..9ee08ceb501 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactory.java @@ -20,6 +20,7 @@ import ai.djl.modality.cv.translator.wrapper.FileTranslator; import ai.djl.modality.cv.translator.wrapper.InputStreamTranslator; import ai.djl.modality.cv.translator.wrapper.UrlTranslator; +import ai.djl.translate.ExpansionTranslatorFactory; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; import ai.djl.util.Pair; @@ -28,47 +29,39 @@ import java.lang.reflect.Type; import java.net.URL; import java.nio.file.Path; -import java.util.HashSet; import java.util.Map; -import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; /** A {@link TranslatorFactory} that creates an {@link ImageClassificationTranslator}. */ -public class ImageClassificationTranslatorFactory implements TranslatorFactory { +public class ImageClassificationTranslatorFactory + extends ExpansionTranslatorFactory { - private static final Set> SUPPORTED_TYPES = new HashSet<>(); + private static final Map< + Pair, + Function, Translator>> + EXPANSIONS = new ConcurrentHashMap<>(); static { - SUPPORTED_TYPES.add(new Pair<>(Image.class, Classifications.class)); - SUPPORTED_TYPES.add(new Pair<>(Path.class, Classifications.class)); - SUPPORTED_TYPES.add(new Pair<>(URL.class, Classifications.class)); - SUPPORTED_TYPES.add(new Pair<>(InputStream.class, Classifications.class)); - SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); + EXPANSIONS.put(new Pair<>(Image.class, Classifications.class), t -> t); + EXPANSIONS.put(new Pair<>(Path.class, Classifications.class), FileTranslator::new); + EXPANSIONS.put(new Pair<>(URL.class, Classifications.class), UrlTranslator::new); + EXPANSIONS.put( + new Pair<>(InputStream.class, Classifications.class), InputStreamTranslator::new); + EXPANSIONS.put(new Pair<>(Input.class, Output.class), ImageServingTranslator::new); } /** {@inheritDoc} */ @Override - public Set> getSupportedTypes() { - return SUPPORTED_TYPES; + protected Translator buildBaseTranslator( + Model model, Map arguments) { + return ImageClassificationTranslator.builder(arguments).build(); } /** {@inheritDoc} */ @Override - @SuppressWarnings("unchecked") - public Translator newInstance( - Class input, Class output, Model model, Map arguments) { - ImageClassificationTranslator translator = - ImageClassificationTranslator.builder(arguments).build(); - if (input == Image.class && output == Classifications.class) { - return (Translator) translator; - } else if (input == Path.class && output == Classifications.class) { - return (Translator) new FileTranslator<>(translator); - } else if (input == URL.class && output == Classifications.class) { - return (Translator) new UrlTranslator<>(translator); - } else if (input == InputStream.class && output == Classifications.class) { - return (Translator) new InputStreamTranslator<>(translator); - } else if (input == Input.class && output == Output.class) { - return (Translator) new ImageServingTranslator(translator); - } - throw new IllegalArgumentException("Unsupported input/output types."); + protected Map, Function, Translator>> + getExpansions() { + return EXPANSIONS; } } diff --git a/api/src/main/java/ai/djl/training/dataset/Dataset.java b/api/src/main/java/ai/djl/training/dataset/Dataset.java index ba98da5ffb3..1d903d5879a 100644 --- a/api/src/main/java/ai/djl/training/dataset/Dataset.java +++ b/api/src/main/java/ai/djl/training/dataset/Dataset.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.translate.TranslateException; +import ai.djl.translate.TranslatorOptions; import ai.djl.util.Progress; import java.io.IOException; @@ -71,6 +72,16 @@ default void prepare() throws IOException, TranslateException { */ void prepare(Progress progress) throws IOException, TranslateException; + /** + * Returns {@link TranslatorOptions} that match the pre-processing and post-processing of this + * dataset. + * + * @return matching translators or null if none defined + */ + default TranslatorOptions matchingTranslatorOptions() { + return null; + } + /** An enum that indicates the mode - training, test or validation. */ enum Usage { TRAIN, diff --git a/api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java b/api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java new file mode 100644 index 00000000000..952f2901aec --- /dev/null +++ b/api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java @@ -0,0 +1,105 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.translate; + +import ai.djl.Model; +import ai.djl.util.Pair; + +import java.lang.reflect.Type; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +/** + * A {@link TranslatorFactory} based on a {@link Translator} and it's {@link TranslatorOptions}. + * + * @param the input type for the base translator + * @param the output type for the base translator + */ +@SuppressWarnings("PMD.GenericsNaming") +public abstract class ExpansionTranslatorFactory implements TranslatorFactory { + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return getExpansions().keySet(); + } + + /** {@inheritDoc} */ + @Override + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + Translator baseTranslator = buildBaseTranslator(model, arguments); + return newInstance(input, output, baseTranslator); + } + + @SuppressWarnings("unchecked") + Translator newInstance( + Class input, Class output, Translator translator) { + Function, Translator> expansion = + getExpansions().get(new Pair<>(input, output)); + if (expansion == null) { + throw new IllegalArgumentException("Unsupported expansion input/output types."); + } + return (Translator) expansion.apply(translator); + } + + /** + * Creates a set of {@link TranslatorOptions} based on the expansions of a given translator. + * + * @param translator the translator to expand + * @return the {@link TranslatorOptions} + */ + public ExpandedTranslatorOptions withTranslator(Translator translator) { + return new ExpandedTranslatorOptions(translator); + } + + /** + * Builds the base translator that can be expanded. + * + * @param model the {@link Model} that uses the {@link Translator} + * @param arguments the configurations for a new {@code Translator} instance + * @return a base translator that can be expanded to form the factory options + */ + protected abstract Translator buildBaseTranslator( + Model model, Map arguments); + + /** + * Returns the possible expansions of this factory. + * + * @return the possible expansions of this factory + */ + protected abstract Map, Function, Translator>> + getExpansions(); + + final class ExpandedTranslatorOptions implements TranslatorOptions { + + private Translator translator; + + private ExpandedTranslatorOptions(Translator translator) { + this.translator = translator; + } + + /** {@inheritDoc} */ + @Override + public Set> getOptions() { + return getSupportedTypes(); + } + + /** {@inheritDoc} */ + @Override + public Translator option(Class input, Class output) { + return newInstance(input, output, translator); + } + } +} diff --git a/api/src/main/java/ai/djl/translate/Translator.java b/api/src/main/java/ai/djl/translate/Translator.java index 34e47c85ec2..43022c4b46d 100644 --- a/api/src/main/java/ai/djl/translate/Translator.java +++ b/api/src/main/java/ai/djl/translate/Translator.java @@ -91,4 +91,13 @@ default Batchifier getBatchifier() { */ @SuppressWarnings("PMD.SignatureDeclareThrowsException") default void prepare(TranslatorContext ctx) throws Exception {} + + /** + * Return possible {@link TranslatorOptions} that can be built using this {@link Translator}. + * + * @return possible options or null if not defined + */ + default TranslatorOptions getExpansions() { + return null; + } } diff --git a/api/src/main/java/ai/djl/translate/TranslatorOptions.java b/api/src/main/java/ai/djl/translate/TranslatorOptions.java new file mode 100644 index 00000000000..b683b943a8f --- /dev/null +++ b/api/src/main/java/ai/djl/translate/TranslatorOptions.java @@ -0,0 +1,54 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.translate; + +import ai.djl.util.Pair; + +import java.lang.reflect.Type; +import java.util.Set; + +/** A set of possible options for {@link Translator}s with different input and output types. */ +public interface TranslatorOptions { + + /** + * Returns the supported wrap types. + * + * @return the supported wrap types + * @see #option(Class, Class) + */ + Set> getOptions(); + + /** + * Returns if the input/output is a supported wrap type. + * + * @param input the input class + * @param output the output class + * @return {@code true} if the input/output type is supported + * @see #option(Class, Class) + */ + default boolean isSupported(Class input, Class output) { + return getOptions().contains(new Pair(input, output)); + } + + /** + * Returns the {@link Translator} option with the matching input and output type. + * + * @param the input data type + * @param the output data type + * @param input the input class + * @param output the output class + * @return a new instance of the {@code Translator} class + * @throws TranslateException if failed to create Translator instance + */ + Translator option(Class input, Class output) throws TranslateException; +} diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageClassificationDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageClassificationDataset.java index 33c7c763c47..ce17d113a8a 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageClassificationDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/cv/classification/ImageClassificationDataset.java @@ -13,8 +13,6 @@ package ai.djl.basicdataset.cv.classification; import ai.djl.basicdataset.cv.ImageDataset; -import ai.djl.modality.Classifications; -import ai.djl.modality.cv.Image; import ai.djl.modality.cv.transform.Resize; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.modality.cv.translator.ImageClassificationTranslator; @@ -23,7 +21,7 @@ import ai.djl.training.dataset.RandomAccessDataset; import ai.djl.training.dataset.Record; import ai.djl.translate.Pipeline; -import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorOptions; import java.io.IOException; import java.util.List; @@ -62,12 +60,9 @@ public Record get(NDManager manager, long index) throws IOException { return new Record(data, label); } - /** - * Returns the {@link ImageClassificationTranslator} matching the format of this dataset. - * - * @return the {@link ImageClassificationTranslator} matching the format of this dataset - */ - public Translator makeTranslator() { + /** {@inheritDoc} */ + @Override + public TranslatorOptions matchingTranslatorOptions() { Pipeline pipeline = new Pipeline(); // Resize the image if the image size is fixed @@ -81,7 +76,8 @@ public Translator makeTranslator() { return ImageClassificationTranslator.builder() .optSynset(getClasses()) .setPipeline(pipeline) - .build(); + .build() + .getExpansions(); } /** diff --git a/djl-zero/src/main/java/ai/djl/zero/cv/ImageClassification.java b/djl-zero/src/main/java/ai/djl/zero/cv/ImageClassification.java index a887a1ea6f7..993396e134c 100644 --- a/djl-zero/src/main/java/ai/djl/zero/cv/ImageClassification.java +++ b/djl-zero/src/main/java/ai/djl/zero/cv/ImageClassification.java @@ -158,7 +158,8 @@ public static ZooModel train( EasyTrain.fit(trainer, 35, trainDataset, validateDataset); } - Translator translator = dataset.makeTranslator(); + Translator translator = + dataset.matchingTranslatorOptions().option(Image.class, Classifications.class); return new ZooModel<>(model, translator); } diff --git a/engines/paddlepaddle/paddlepaddle-model-zoo/src/main/java/ai/djl/paddlepaddle/zoo/cv/imageclassification/PpImageClassificationTranslatorFactory.java b/engines/paddlepaddle/paddlepaddle-model-zoo/src/main/java/ai/djl/paddlepaddle/zoo/cv/imageclassification/PpImageClassificationTranslatorFactory.java index 9ce71339a60..f95fb3996e7 100644 --- a/engines/paddlepaddle/paddlepaddle-model-zoo/src/main/java/ai/djl/paddlepaddle/zoo/cv/imageclassification/PpImageClassificationTranslatorFactory.java +++ b/engines/paddlepaddle/paddlepaddle-model-zoo/src/main/java/ai/djl/paddlepaddle/zoo/cv/imageclassification/PpImageClassificationTranslatorFactory.java @@ -14,24 +14,15 @@ import ai.djl.Model; import ai.djl.modality.Classifications; -import ai.djl.modality.Input; -import ai.djl.modality.Output; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.transform.Normalize; import ai.djl.modality.cv.transform.Resize; import ai.djl.modality.cv.transform.ToTensor; import ai.djl.modality.cv.translator.ImageClassificationTranslator; import ai.djl.modality.cv.translator.ImageClassificationTranslatorFactory; -import ai.djl.modality.cv.translator.ImageServingTranslator; -import ai.djl.modality.cv.translator.wrapper.FileTranslator; -import ai.djl.modality.cv.translator.wrapper.InputStreamTranslator; -import ai.djl.modality.cv.translator.wrapper.UrlTranslator; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; -import java.io.InputStream; -import java.net.URL; -import java.nio.file.Path; import java.util.Map; /** @@ -42,30 +33,15 @@ public class PpImageClassificationTranslatorFactory extends ImageClassificationT /** {@inheritDoc} */ @Override - @SuppressWarnings("unchecked") - public Translator newInstance( - Class input, Class output, Model model, Map arguments) { - ImageClassificationTranslator translator = - ImageClassificationTranslator.builder() - .addTransform(new Resize(128, 128)) - .addTransform(new ToTensor()) - .addTransform( - new Normalize( - new float[] {0.5f, 0.5f, 0.5f}, - new float[] {1.0f, 1.0f, 1.0f})) - .addTransform(nd -> nd.flip(0)) // RGB -> GBR - .build(); - if (input == Image.class && output == Classifications.class) { - return (Translator) translator; - } else if (input == Path.class && output == Classifications.class) { - return (Translator) new FileTranslator<>(translator); - } else if (input == URL.class && output == Classifications.class) { - return (Translator) new UrlTranslator<>(translator); - } else if (input == InputStream.class && output == Classifications.class) { - return (Translator) new InputStreamTranslator<>(translator); - } else if (input == Input.class && output == Output.class) { - return (Translator) new ImageServingTranslator(translator); - } - throw new IllegalArgumentException("Unsupported input/output types."); + protected Translator buildBaseTranslator( + Model model, Map arguments) { + return ImageClassificationTranslator.builder() + .addTransform(new Resize(128, 128)) + .addTransform(new ToTensor()) + .addTransform( + new Normalize( + new float[] {0.5f, 0.5f, 0.5f}, new float[] {1.0f, 1.0f, 1.0f})) + .addTransform(nd -> nd.flip(0)) // RGB -> GBR + .build(); } }