diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/AbstractImageFolder.java b/basicdataset/src/main/java/ai/djl/basicdataset/AbstractImageFolder.java index 0e5c1c1c400..1129a019f3a 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/AbstractImageFolder.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/AbstractImageFolder.java @@ -14,13 +14,8 @@ import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; import ai.djl.repository.Repository; import ai.djl.repository.Resource; -import ai.djl.training.dataset.RandomAccessDataset; -import ai.djl.training.dataset.Record; import ai.djl.translate.TranslateException; import ai.djl.util.Pair; import ai.djl.util.PairList; @@ -38,14 +33,13 @@ import org.slf4j.LoggerFactory; /** A dataset for loading image files stored in a folder structure. */ -public abstract class AbstractImageFolder extends RandomAccessDataset { +public abstract class AbstractImageFolder extends ImageClassificationDataset { private static final Logger logger = LoggerFactory.getLogger(AbstractImageFolder.class); private static final Set EXT = new HashSet<>(Arrays.asList(".jpg", ".jpeg", ".png", ".bmp", ".wbmp", ".gif")); - protected Image.Flag flag; protected List synset; protected PairList items; protected Resource resource; @@ -55,7 +49,6 @@ public abstract class AbstractImageFolder extends RandomAccessDataset { protected AbstractImageFolder(ImageFolderBuilder builder) { super(builder); - this.flag = builder.flag; this.maxDepth = builder.maxDepth; this.synset = new ArrayList<>(); this.items = new PairList<>(); @@ -64,14 +57,18 @@ protected AbstractImageFolder(ImageFolderBuilder builder) { /** {@inheritDoc} */ @Override - public Record get(NDManager manager, long index) throws IOException { + protected Image getImage(long index) throws IOException { + ImageFactory imageFactory = ImageFactory.getInstance(); Pair item = items.get(Math.toIntExact(index)); - Path imagePath = getImagePath(item.getKey()); - NDArray array = ImageFactory.getInstance().fromFile(imagePath).toNDArray(manager, flag); - NDList d = new NDList(array); - NDList l = new NDList(manager.create(item.getValue())); - return new Record(d, l); + return imageFactory.fromFile(imagePath); + } + + /** {@inheritDoc} */ + @Override + protected long getClassNumber(long index) { + Pair item = items.get(Math.toIntExact(index)); + return item.getValue(); } /** {@inheritDoc} */ @@ -140,7 +137,6 @@ public abstract static class ImageFolderBuilder> extends BaseBuilder { Repository repository; - Image.Flag flag; int maxDepth; protected ImageFolderBuilder() { @@ -148,17 +144,6 @@ protected ImageFolderBuilder() { maxDepth = 1; } - /** - * Sets the optional color mode flag. - * - * @param flag the color mode flag - * @return this builder - */ - public T optFlag(Image.Flag flag) { - this.flag = flag; - return self(); - } - /** * Sets the repository containing the image folder. * diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/ImageClassificationDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/ImageClassificationDataset.java new file mode 100644 index 00000000000..4eca7605383 --- /dev/null +++ b/basicdataset/src/main/java/ai/djl/basicdataset/ImageClassificationDataset.java @@ -0,0 +1,92 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset; + +import ai.djl.modality.cv.Image; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.training.dataset.RandomAccessDataset; +import ai.djl.training.dataset.Record; +import java.io.IOException; + +/** + * A helper to create {@link ai.djl.training.dataset.Dataset}s for {@link + * ai.djl.Application.CV#IMAGE_CLASSIFICATION}. + */ +public abstract class ImageClassificationDataset extends RandomAccessDataset { + + private Image.Flag flag; + + /** + * Creates a new instance of {@link RandomAccessDataset} with the given necessary + * configurations. + * + * @param builder a builder with the necessary configurations + */ + public ImageClassificationDataset(BaseBuilder builder) { + super(builder); + this.flag = builder.flag; + } + + /** + * Returns the image at the given index in the dataset. + * + * @param index the index (if the dataset is a list of data items) + * @return the image + * @throws IOException if the image could not be loaded + */ + protected abstract Image getImage(long index) throws IOException; + + /** + * Returns the class of the data item at the given index. + * + * @param index the index (if the dataset is a list of data items) + * @return the class number or the index into the list of classes of the desired class name + */ + protected abstract long getClassNumber(long index); + + /** {@inheritDoc} */ + @Override + public Record get(NDManager manager, long index) throws IOException { + NDList data = new NDList(getImage(index).toNDArray(manager, flag)); + NDList label = new NDList(manager.create(getClassNumber(index))); + return new Record(data, label); + } + + /** + * Used to build an {@link ImageClassificationDataset}. + * + * @param the builder type + */ + @SuppressWarnings("rawtypes") + public abstract static class BaseBuilder> + extends RandomAccessDataset.BaseBuilder { + + Image.Flag flag; + + protected BaseBuilder() { + flag = Image.Flag.COLOR; + } + + /** + * Sets the optional color mode flag. + * + * @param flag the color mode flag + * @return this builder + */ + public T optFlag(Image.Flag flag) { + this.flag = flag; + return self(); + } + } +} diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/ImageFolderTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/ImageFolderTest.java index d28c37d50db..b67b6ce2349 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/ImageFolderTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/ImageFolderTest.java @@ -86,7 +86,7 @@ public void testImageFolder() throws IOException, TranslateException { catBatch.getData().singletonOrThrow(), NDImageUtils.toTensor(NDImageUtils.resize(cat, 100, 100)).expandDims(0)); Assert.assertEquals( - catBatch.getLabels().singletonOrThrow(), manager.create(new int[] {0})); + catBatch.getLabels().singletonOrThrow(), manager.create(new long[] {0})); catBatch.close(); Batch dogBatch = ds.next(); @@ -94,7 +94,7 @@ public void testImageFolder() throws IOException, TranslateException { dogBatch.getData().singletonOrThrow(), NDImageUtils.toTensor(NDImageUtils.resize(dog, 100, 100)).expandDims(0)); Assert.assertEquals( - dogBatch.getLabels().singletonOrThrow(), manager.create(new int[] {1})); + dogBatch.getLabels().singletonOrThrow(), manager.create(new long[] {1})); dogBatch.close(); Batch pikachuBatch = ds.next(); @@ -103,7 +103,8 @@ public void testImageFolder() throws IOException, TranslateException { NDImageUtils.toTensor(NDImageUtils.resize(pikachu, 100, 100)) .expandDims(0)); Assert.assertEquals( - pikachuBatch.getLabels().singletonOrThrow(), manager.create(new int[] {2})); + pikachuBatch.getLabels().singletonOrThrow(), + manager.create(new long[] {2})); pikachuBatch.close(); } } diff --git a/djl-easy/src/main/java/ai/djl/easy/cv/ImageClassification.java b/djl-easy/src/main/java/ai/djl/easy/cv/ImageClassification.java index 416e4009e23..4b6ae36dc55 100644 --- a/djl-easy/src/main/java/ai/djl/easy/cv/ImageClassification.java +++ b/djl-easy/src/main/java/ai/djl/easy/cv/ImageClassification.java @@ -93,6 +93,10 @@ public static ZooModel pretrained( /** * Trains the recommended image classification model on a custom dataset. * + *

In order to train on a custom dataset, you must create a custom {@link + * ai.djl.basicdataset.ImageClassificationDataset} to load your data and a {@link + * ImageClassificationDatasetFactory} to build it. + * * @param datasetFactory to build the datasets to train on and validate against * @param performance to determine the desired model tradeoffs * @return the model as a {@link ZooModel} with the {@link Translator} included diff --git a/djl-easy/src/main/java/ai/djl/easy/cv/ImageClassificationDatasetFactory.java b/djl-easy/src/main/java/ai/djl/easy/cv/ImageClassificationDatasetFactory.java index 0a0ff10f310..b2e5b6bbdf8 100644 --- a/djl-easy/src/main/java/ai/djl/easy/cv/ImageClassificationDatasetFactory.java +++ b/djl-easy/src/main/java/ai/djl/easy/cv/ImageClassificationDatasetFactory.java @@ -12,12 +12,18 @@ */ package ai.djl.easy.cv; +import ai.djl.basicdataset.ImageClassificationDataset; import ai.djl.easy.DatasetFactory; +import ai.djl.training.dataset.Dataset.Usage; import java.util.List; /** A {@link DatasetFactory} for {@link ImageClassification}. */ public interface ImageClassificationDatasetFactory extends DatasetFactory { + /** {@inheritDoc} */ + @Override + ImageClassificationDataset build(Usage usage); + /** * Returns the number of channels in the images in the dataset. *