Skip to content

Commit

Permalink
Add ImageClassificationDataset for clearer requirements (#517)
Browse files Browse the repository at this point in the history
* Add ImageClassificationDataset for clearer requirements

This dataset adds some help to creating image classification datasets,
especially for users less familiar with DJL. It creates simpler APIs to fulfill.

Change-Id: I659f5fd0ef02f23785a6bbc6ce7a1bf1128d6d64

* Address comments

Change-Id: I878d0b15fb81573f877f24c423877ca706713f43
  • Loading branch information
zachgk authored Jan 13, 2021
1 parent 4b51619 commit 5b59336
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
Expand All @@ -38,14 +33,13 @@
import org.slf4j.LoggerFactory;

/** A dataset for loading image files stored in a folder structure. */
public abstract class AbstractImageFolder extends RandomAccessDataset {
public abstract class AbstractImageFolder extends ImageClassificationDataset {

private static final Logger logger = LoggerFactory.getLogger(AbstractImageFolder.class);

private static final Set<String> EXT =
new HashSet<>(Arrays.asList(".jpg", ".jpeg", ".png", ".bmp", ".wbmp", ".gif"));

protected Image.Flag flag;
protected List<String> synset;
protected PairList<String, Integer> items;
protected Resource resource;
Expand All @@ -55,7 +49,6 @@ public abstract class AbstractImageFolder extends RandomAccessDataset {

protected AbstractImageFolder(ImageFolderBuilder<?> builder) {
super(builder);
this.flag = builder.flag;
this.maxDepth = builder.maxDepth;
this.synset = new ArrayList<>();
this.items = new PairList<>();
Expand All @@ -64,14 +57,18 @@ protected AbstractImageFolder(ImageFolderBuilder<?> builder) {

/** {@inheritDoc} */
@Override
public Record get(NDManager manager, long index) throws IOException {
protected Image getImage(long index) throws IOException {
ImageFactory imageFactory = ImageFactory.getInstance();
Pair<String, Integer> item = items.get(Math.toIntExact(index));

Path imagePath = getImagePath(item.getKey());
NDArray array = ImageFactory.getInstance().fromFile(imagePath).toNDArray(manager, flag);
NDList d = new NDList(array);
NDList l = new NDList(manager.create(item.getValue()));
return new Record(d, l);
return imageFactory.fromFile(imagePath);
}

/** {@inheritDoc} */
@Override
protected long getClassNumber(long index) {
Pair<String, Integer> item = items.get(Math.toIntExact(index));
return item.getValue();
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -140,25 +137,13 @@ public abstract static class ImageFolderBuilder<T extends ImageFolderBuilder<T>>
extends BaseBuilder<T> {

Repository repository;
Image.Flag flag;
int maxDepth;

protected ImageFolderBuilder() {
flag = Image.Flag.COLOR;
maxDepth = 1;
}

/**
* Sets the optional color mode flag.
*
* @param flag the color mode flag
* @return this builder
*/
public T optFlag(Image.Flag flag) {
this.flag = flag;
return self();
}

/**
* Sets the repository containing the image folder.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.basicdataset;

import ai.djl.modality.cv.Image;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import java.io.IOException;

/**
* A helper to create {@link ai.djl.training.dataset.Dataset}s for {@link
* ai.djl.Application.CV#IMAGE_CLASSIFICATION}.
*/
public abstract class ImageClassificationDataset extends RandomAccessDataset {

private Image.Flag flag;

/**
* Creates a new instance of {@link RandomAccessDataset} with the given necessary
* configurations.
*
* @param builder a builder with the necessary configurations
*/
public ImageClassificationDataset(BaseBuilder<?> builder) {
super(builder);
this.flag = builder.flag;
}

/**
* Returns the image at the given index in the dataset.
*
* @param index the index (if the dataset is a list of data items)
* @return the image
* @throws IOException if the image could not be loaded
*/
protected abstract Image getImage(long index) throws IOException;

/**
* Returns the class of the data item at the given index.
*
* @param index the index (if the dataset is a list of data items)
* @return the class number or the index into the list of classes of the desired class name
*/
protected abstract long getClassNumber(long index);

/** {@inheritDoc} */
@Override
public Record get(NDManager manager, long index) throws IOException {
NDList data = new NDList(getImage(index).toNDArray(manager, flag));
NDList label = new NDList(manager.create(getClassNumber(index)));
return new Record(data, label);
}

/**
* Used to build an {@link ImageClassificationDataset}.
*
* @param <T> the builder type
*/
@SuppressWarnings("rawtypes")
public abstract static class BaseBuilder<T extends BaseBuilder<T>>
extends RandomAccessDataset.BaseBuilder<T> {

Image.Flag flag;

protected BaseBuilder() {
flag = Image.Flag.COLOR;
}

/**
* Sets the optional color mode flag.
*
* @param flag the color mode flag
* @return this builder
*/
public T optFlag(Image.Flag flag) {
this.flag = flag;
return self();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ public void testImageFolder() throws IOException, TranslateException {
catBatch.getData().singletonOrThrow(),
NDImageUtils.toTensor(NDImageUtils.resize(cat, 100, 100)).expandDims(0));
Assert.assertEquals(
catBatch.getLabels().singletonOrThrow(), manager.create(new int[] {0}));
catBatch.getLabels().singletonOrThrow(), manager.create(new long[] {0}));
catBatch.close();

Batch dogBatch = ds.next();
Assertions.assertAlmostEquals(
dogBatch.getData().singletonOrThrow(),
NDImageUtils.toTensor(NDImageUtils.resize(dog, 100, 100)).expandDims(0));
Assert.assertEquals(
dogBatch.getLabels().singletonOrThrow(), manager.create(new int[] {1}));
dogBatch.getLabels().singletonOrThrow(), manager.create(new long[] {1}));
dogBatch.close();

Batch pikachuBatch = ds.next();
Expand All @@ -103,7 +103,8 @@ public void testImageFolder() throws IOException, TranslateException {
NDImageUtils.toTensor(NDImageUtils.resize(pikachu, 100, 100))
.expandDims(0));
Assert.assertEquals(
pikachuBatch.getLabels().singletonOrThrow(), manager.create(new int[] {2}));
pikachuBatch.getLabels().singletonOrThrow(),
manager.create(new long[] {2}));
pikachuBatch.close();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ public static <I> ZooModel<I, Classifications> pretrained(
/**
* Trains the recommended image classification model on a custom dataset.
*
* <p>In order to train on a custom dataset, you must create a custom {@link
* ai.djl.basicdataset.ImageClassificationDataset} to load your data and a {@link
* ImageClassificationDatasetFactory} to build it.
*
* @param datasetFactory to build the datasets to train on and validate against
* @param performance to determine the desired model tradeoffs
* @return the model as a {@link ZooModel} with the {@link Translator} included
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
*/
package ai.djl.easy.cv;

import ai.djl.basicdataset.ImageClassificationDataset;
import ai.djl.easy.DatasetFactory;
import ai.djl.training.dataset.Dataset.Usage;
import java.util.List;

/** A {@link DatasetFactory} for {@link ImageClassification}. */
public interface ImageClassificationDatasetFactory extends DatasetFactory {

/** {@inheritDoc} */
@Override
ImageClassificationDataset build(Usage usage);

/**
* Returns the number of channels in the images in the dataset.
*
Expand Down

0 comments on commit 5b59336

Please sign in to comment.