diff --git a/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactory.java index 6f9a6b2e709..9ee08ceb501 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/ImageClassificationTranslatorFactory.java @@ -51,12 +51,14 @@ public class ImageClassificationTranslatorFactory EXPANSIONS.put(new Pair<>(Input.class, Output.class), ImageServingTranslator::new); } + /** {@inheritDoc} */ @Override protected Translator buildBaseTranslator( Model model, Map arguments) { return ImageClassificationTranslator.builder(arguments).build(); } + /** {@inheritDoc} */ @Override protected Map, Function, Translator>> getExpansions() { diff --git a/api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java b/api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java index 4631adf982d..952f2901aec 100644 --- a/api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java +++ b/api/src/main/java/ai/djl/translate/ExpansionTranslatorFactory.java @@ -43,9 +43,8 @@ public Translator newInstance( return newInstance(input, output, baseTranslator); } - /** {@inheritDoc} */ @SuppressWarnings("unchecked") - public Translator newInstance( + Translator newInstance( Class input, Class output, Translator translator) { Function, Translator> expansion = getExpansions().get(new Pair<>(input, output)); @@ -84,17 +83,20 @@ protected abstract Translator buildBaseTranslator( getExpansions(); final class ExpandedTranslatorOptions implements TranslatorOptions { + private Translator translator; private ExpandedTranslatorOptions(Translator translator) { this.translator = translator; } + /** {@inheritDoc} */ @Override public Set> getOptions() { return getSupportedTypes(); } + /** {@inheritDoc} */ @Override public Translator option(Class input, Class output) { return newInstance(input, output, translator);