-
Notifications
You must be signed in to change notification settings - Fork 654
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This creates a new tool for TranslatorOptions based on the characteristic function ((InputClass, OutputClass) -> Translator). Unlike the TranslatorFactory, it doesn't require arguments or a model. So, the TranslatorFactory is mainly used as part of model loading where the TranslatorOptions can be more broad. In particular, it connects to both Dataset and Translator. Both of them gain a new function that will return a TranslatorOptions (matchingTranslatorOptions for the Dataset and getExpansions for the translator). Right now, I made both of these optional with a default value of null so that users don't have to define them, it is just a possibility. To connect the TranslatorOptions and TranslatorFactory, I created a class ExpansionTranslatorFactory. This is used for the common abstraction where we have a base translator and multiple "expansions" for it that come before it in pre-processing or after it in post-processing. With a way to construct the base translator with the options, it works as a TranslatorFactory. With a starting Translator, it can become the TranslatorOptions. Lastly and probably worth calling out, the ExpansionTranslatorFactory creates a map of the options. This allows the supported types and the functions to be defined in only one place instead of two. In a followup PR, I will make a version of this that creates a BaseImageTranslatorFactory matching the BaseImageTranslator and will handle all of the repeated code for all of the image translator factories.
- Loading branch information
Showing
9 changed files
with
214 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 103 additions & 0 deletions
103
api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.translate; | ||
|
||
import ai.djl.Model; | ||
import ai.djl.util.Pair; | ||
|
||
import java.lang.reflect.Type; | ||
import java.util.Map; | ||
import java.util.Set; | ||
import java.util.function.Function; | ||
|
||
/** | ||
* A {@link TranslatorFactory} based on a {@link Translator} and it's {@link TranslatorOptions}. | ||
* | ||
* @param <IbaseT> the input type for the base translator | ||
* @param <ObaseT> the output type for the base translator | ||
*/ | ||
@SuppressWarnings("PMD.GenericsNaming") | ||
public abstract class ExpansionTranslatorFactory<IbaseT, ObaseT> implements TranslatorFactory { | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Set<Pair<Type, Type>> getSupportedTypes() { | ||
return getExpansions().keySet(); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public <I, O> Translator<I, O> newInstance( | ||
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) { | ||
Translator<IbaseT, ObaseT> baseTranslator = buildBaseTranslator(model, arguments); | ||
return newInstance(input, output, baseTranslator); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@SuppressWarnings("unchecked") | ||
public <I, O> Translator<I, O> newInstance( | ||
Class<I> input, Class<O> output, Translator<IbaseT, ObaseT> translator) { | ||
Function<Translator<IbaseT, ObaseT>, Translator<?, ?>> expansion = | ||
getExpansions().get(new Pair<>(input, output)); | ||
if (expansion == null) { | ||
throw new IllegalArgumentException("Unsupported expansion input/output types."); | ||
} | ||
return (Translator<I, O>) expansion.apply(translator); | ||
} | ||
|
||
/** | ||
* Creates a set of {@link TranslatorOptions} based on the expansions of a given translator. | ||
* | ||
* @param translator the translator to expand | ||
* @return the {@link TranslatorOptions} | ||
*/ | ||
public ExpandedTranslatorOptions withTranslator(Translator<IbaseT, ObaseT> translator) { | ||
return new ExpandedTranslatorOptions(translator); | ||
} | ||
|
||
/** | ||
* Builds the base translator that can be expanded. | ||
* | ||
* @param model the {@link Model} that uses the {@link Translator} | ||
* @param arguments the configurations for a new {@code Translator} instance | ||
* @return a base translator that can be expanded to form the factory options | ||
*/ | ||
protected abstract Translator<IbaseT, ObaseT> buildBaseTranslator( | ||
Model model, Map<String, ?> arguments); | ||
|
||
/** | ||
* Returns the possible expansions of this factory. | ||
* | ||
* @return the possible expansions of this factory | ||
*/ | ||
protected abstract Map<Pair<Type, Type>, Function<Translator<IbaseT, ObaseT>, Translator<?, ?>>> | ||
getExpansions(); | ||
|
||
final class ExpandedTranslatorOptions implements TranslatorOptions { | ||
private Translator<IbaseT, ObaseT> translator; | ||
|
||
private ExpandedTranslatorOptions(Translator<IbaseT, ObaseT> translator) { | ||
this.translator = translator; | ||
} | ||
|
||
@Override | ||
public Set<Pair<Type, Type>> getOptions() { | ||
return getSupportedTypes(); | ||
} | ||
|
||
@Override | ||
public <I, O> Translator<I, O> option(Class<I> input, Class<O> output) { | ||
return newInstance(input, output, translator); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.translate; | ||
|
||
import ai.djl.util.Pair; | ||
|
||
import java.lang.reflect.Type; | ||
import java.util.Set; | ||
|
||
/** A set of possible options for {@link Translator}s with different input and output types. */ | ||
public interface TranslatorOptions { | ||
|
||
/** | ||
* Returns the supported wrap types. | ||
* | ||
* @return the supported wrap types | ||
* @see #option(Class, Class) | ||
*/ | ||
Set<Pair<Type, Type>> getOptions(); | ||
|
||
/** | ||
* Returns if the input/output is a supported wrap type. | ||
* | ||
* @param input the input class | ||
* @param output the output class | ||
* @return {@code true} if the input/output type is supported | ||
* @see #option(Class, Class) | ||
*/ | ||
default boolean isSupported(Class<?> input, Class<?> output) { | ||
return getOptions().contains(new Pair<Type, Type>(input, output)); | ||
} | ||
|
||
/** | ||
* Returns the {@link Translator} option with the matching input and output type. | ||
* | ||
* @param <I> the input data type | ||
* @param <O> the output data type | ||
* @param input the input class | ||
* @param output the output class | ||
* @return a new instance of the {@code Translator} class | ||
* @throws TranslateException if failed to create Translator instance | ||
*/ | ||
<I, O> Translator<I, O> option(Class<I> input, Class<O> output) throws TranslateException; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters