Skip to content

Commit

Permalink
Create translator options (#2145)
Browse files Browse the repository at this point in the history
* Create translator options

This creates a new tool for TranslatorOptions based on the characteristic
function ((InputClass, OutputClass) -> Translator). Unlike the
TranslatorFactory, it doesn't require arguments or a model. So, the
TranslatorFactory is mainly used as part of model loading where the
TranslatorOptions can be more broad.

In particular, it connects to both Dataset and Translator. Both of them gain a
new function that will return a TranslatorOptions (matchingTranslatorOptions for
the Dataset and getExpansions for the translator). Right now, I made both of
these optional with a default value of null so that users don't have to define
them, it is just a possibility.

To connect the TranslatorOptions and TranslatorFactory, I created a class
ExpansionTranslatorFactory. This is used for the common abstraction where we
have a base translator and multiple "expansions" for it that come before it in
pre-processing or after it in post-processing. With a way to construct the base
translator with the options, it works as a TranslatorFactory. With a starting
Translator, it can become the TranslatorOptions.

Lastly and probably worth calling out, the ExpansionTranslatorFactory creates a
map of the options. This allows the supported types and the functions to be
defined in only one place instead of two. In a followup PR, I will make a
version of this that creates a BaseImageTranslatorFactory matching the
BaseImageTranslator and will handle all of the repeated code for all of the
image translator factories.

* minor javadoc updates

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
zachgk and frankfliu authored Nov 18, 2022
1 parent 28209d6 commit 7ceb7cd
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import ai.djl.translate.TranslatorOptions;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -62,6 +63,12 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) {
return new Classifications(classes, probabilitiesNd, topK);
}

/** {@inheritDoc} */
@Override
public TranslatorOptions getExpansions() {
return new ImageClassificationTranslatorFactory().withTranslator(this);
}

/**
* Creates a builder to build a {@code ImageClassificationTranslator}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import ai.djl.modality.cv.translator.wrapper.FileTranslator;
import ai.djl.modality.cv.translator.wrapper.InputStreamTranslator;
import ai.djl.modality.cv.translator.wrapper.UrlTranslator;
import ai.djl.translate.ExpansionTranslatorFactory;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
Expand All @@ -28,47 +29,39 @@
import java.lang.reflect.Type;
import java.net.URL;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

/** A {@link TranslatorFactory} that creates an {@link ImageClassificationTranslator}. */
public class ImageClassificationTranslatorFactory implements TranslatorFactory {
public class ImageClassificationTranslatorFactory
extends ExpansionTranslatorFactory<Image, Classifications> {

private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<>();
private static final Map<
Pair<Type, Type>,
Function<Translator<Image, Classifications>, Translator<?, ?>>>
EXPANSIONS = new ConcurrentHashMap<>();

static {
SUPPORTED_TYPES.add(new Pair<>(Image.class, Classifications.class));
SUPPORTED_TYPES.add(new Pair<>(Path.class, Classifications.class));
SUPPORTED_TYPES.add(new Pair<>(URL.class, Classifications.class));
SUPPORTED_TYPES.add(new Pair<>(InputStream.class, Classifications.class));
SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
EXPANSIONS.put(new Pair<>(Image.class, Classifications.class), t -> t);
EXPANSIONS.put(new Pair<>(Path.class, Classifications.class), FileTranslator::new);
EXPANSIONS.put(new Pair<>(URL.class, Classifications.class), UrlTranslator::new);
EXPANSIONS.put(
new Pair<>(InputStream.class, Classifications.class), InputStreamTranslator::new);
EXPANSIONS.put(new Pair<>(Input.class, Output.class), ImageServingTranslator::new);
}

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return SUPPORTED_TYPES;
protected Translator<Image, Classifications> buildBaseTranslator(
Model model, Map<String, ?> arguments) {
return ImageClassificationTranslator.builder(arguments).build();
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
ImageClassificationTranslator translator =
ImageClassificationTranslator.builder(arguments).build();
if (input == Image.class && output == Classifications.class) {
return (Translator<I, O>) translator;
} else if (input == Path.class && output == Classifications.class) {
return (Translator<I, O>) new FileTranslator<>(translator);
} else if (input == URL.class && output == Classifications.class) {
return (Translator<I, O>) new UrlTranslator<>(translator);
} else if (input == InputStream.class && output == Classifications.class) {
return (Translator<I, O>) new InputStreamTranslator<>(translator);
} else if (input == Input.class && output == Output.class) {
return (Translator<I, O>) new ImageServingTranslator(translator);
}
throw new IllegalArgumentException("Unsupported input/output types.");
protected Map<Pair<Type, Type>, Function<Translator<Image, Classifications>, Translator<?, ?>>>
getExpansions() {
return EXPANSIONS;
}
}
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/training/dataset/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.ndarray.NDManager;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorOptions;
import ai.djl.util.Progress;

import java.io.IOException;
Expand Down Expand Up @@ -71,6 +72,16 @@ default void prepare() throws IOException, TranslateException {
*/
void prepare(Progress progress) throws IOException, TranslateException;

/**
* Returns {@link TranslatorOptions} that match the pre-processing and post-processing of this
* dataset.
*
* @return matching translators or null if none defined
*/
default TranslatorOptions matchingTranslatorOptions() {
return null;
}

/** An enum that indicates the mode - training, test or validation. */
enum Usage {
TRAIN,
Expand Down
105 changes: 105 additions & 0 deletions api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.translate;

import ai.djl.Model;
import ai.djl.util.Pair;

import java.lang.reflect.Type;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

/**
* A {@link TranslatorFactory} based on a {@link Translator} and it's {@link TranslatorOptions}.
*
* @param <IbaseT> the input type for the base translator
* @param <ObaseT> the output type for the base translator
*/
@SuppressWarnings("PMD.GenericsNaming")
public abstract class ExpansionTranslatorFactory<IbaseT, ObaseT> implements TranslatorFactory {

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return getExpansions().keySet();
}

/** {@inheritDoc} */
@Override
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
Translator<IbaseT, ObaseT> baseTranslator = buildBaseTranslator(model, arguments);
return newInstance(input, output, baseTranslator);
}

@SuppressWarnings("unchecked")
<I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Translator<IbaseT, ObaseT> translator) {
Function<Translator<IbaseT, ObaseT>, Translator<?, ?>> expansion =
getExpansions().get(new Pair<>(input, output));
if (expansion == null) {
throw new IllegalArgumentException("Unsupported expansion input/output types.");
}
return (Translator<I, O>) expansion.apply(translator);
}

/**
* Creates a set of {@link TranslatorOptions} based on the expansions of a given translator.
*
* @param translator the translator to expand
* @return the {@link TranslatorOptions}
*/
public ExpandedTranslatorOptions withTranslator(Translator<IbaseT, ObaseT> translator) {
return new ExpandedTranslatorOptions(translator);
}

/**
* Builds the base translator that can be expanded.
*
* @param model the {@link Model} that uses the {@link Translator}
* @param arguments the configurations for a new {@code Translator} instance
* @return a base translator that can be expanded to form the factory options
*/
protected abstract Translator<IbaseT, ObaseT> buildBaseTranslator(
Model model, Map<String, ?> arguments);

/**
* Returns the possible expansions of this factory.
*
* @return the possible expansions of this factory
*/
protected abstract Map<Pair<Type, Type>, Function<Translator<IbaseT, ObaseT>, Translator<?, ?>>>
getExpansions();

final class ExpandedTranslatorOptions implements TranslatorOptions {

private Translator<IbaseT, ObaseT> translator;

private ExpandedTranslatorOptions(Translator<IbaseT, ObaseT> translator) {
this.translator = translator;
}

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getOptions() {
return getSupportedTypes();
}

/** {@inheritDoc} */
@Override
public <I, O> Translator<I, O> option(Class<I> input, Class<O> output) {
return newInstance(input, output, translator);
}
}
}
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/translate/Translator.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,13 @@ default Batchifier getBatchifier() {
*/
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
default void prepare(TranslatorContext ctx) throws Exception {}

/**
* Return possible {@link TranslatorOptions} that can be built using this {@link Translator}.
*
* @return possible options or null if not defined
*/
default TranslatorOptions getExpansions() {
return null;
}
}
54 changes: 54 additions & 0 deletions api/src/main/java/ai/djl/translate/TranslatorOptions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.translate;

import ai.djl.util.Pair;

import java.lang.reflect.Type;
import java.util.Set;

/** A set of possible options for {@link Translator}s with different input and output types. */
public interface TranslatorOptions {

/**
* Returns the supported wrap types.
*
* @return the supported wrap types
* @see #option(Class, Class)
*/
Set<Pair<Type, Type>> getOptions();

/**
* Returns if the input/output is a supported wrap type.
*
* @param input the input class
* @param output the output class
* @return {@code true} if the input/output type is supported
* @see #option(Class, Class)
*/
default boolean isSupported(Class<?> input, Class<?> output) {
return getOptions().contains(new Pair<Type, Type>(input, output));
}

/**
* Returns the {@link Translator} option with the matching input and output type.
*
* @param <I> the input data type
* @param <O> the output data type
* @param input the input class
* @param output the output class
* @return a new instance of the {@code Translator} class
* @throws TranslateException if failed to create Translator instance
*/
<I, O> Translator<I, O> option(Class<I> input, Class<O> output) throws TranslateException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
package ai.djl.basicdataset.cv.classification;

import ai.djl.basicdataset.cv.ImageDataset;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
Expand All @@ -23,7 +21,7 @@
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Pipeline;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorOptions;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -62,12 +60,9 @@ public Record get(NDManager manager, long index) throws IOException {
return new Record(data, label);
}

/**
* Returns the {@link ImageClassificationTranslator} matching the format of this dataset.
*
* @return the {@link ImageClassificationTranslator} matching the format of this dataset
*/
public Translator<Image, Classifications> makeTranslator() {
/** {@inheritDoc} */
@Override
public TranslatorOptions matchingTranslatorOptions() {
Pipeline pipeline = new Pipeline();

// Resize the image if the image size is fixed
Expand All @@ -81,7 +76,8 @@ public Translator<Image, Classifications> makeTranslator() {
return ImageClassificationTranslator.builder()
.optSynset(getClasses())
.setPipeline(pipeline)
.build();
.build()
.getExpansions();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ public static ZooModel<Image, Classifications> train(
EasyTrain.fit(trainer, 35, trainDataset, validateDataset);
}

Translator<Image, Classifications> translator = dataset.makeTranslator();
Translator<Image, Classifications> translator =
dataset.matchingTranslatorOptions().option(Image.class, Classifications.class);
return new ZooModel<>(model, translator);
}

Expand Down
Loading

0 comments on commit 7ceb7cd

Please sign in to comment.