diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java new file mode 100644 index 00000000000..1ef29f29873 --- /dev/null +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java @@ -0,0 +1,146 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset.nlp; + +import ai.djl.Application; +import ai.djl.basicdataset.BasicDatasets; +import ai.djl.modality.nlp.embedding.EmbeddingException; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.Artifact; +import ai.djl.repository.MRL; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.dataset.Record; +import ai.djl.util.Progress; +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +/** + * The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal + * (WSJ) collection of 98,732 stories for syntactic annotation (see here for details). + */ +public class PennTreebankText extends TextDataset { + + private static final String VERSION = "1.0"; + private static final String ARTIFACT_ID = "penntreebank-unlabeled-processed"; + + /** + * Creates a new instance of {@link PennTreebankText} with the given necessary configurations. + * + * @param builder a builder with the necessary configurations + */ + PennTreebankText(Builder builder) { + super(builder); + this.usage = builder.usage; + mrl = builder.getMrl(); + } + + /** + * Creates a builder to build a {@link PennTreebankText}. + * + * @return a new {@link PennTreebankText.Builder} object + */ + public static Builder builder() { + return new Builder(); + } + + /** {@inheritDoc} */ + @Override + public Record get(NDManager manager, long index) throws IOException { + NDList data = new NDList(); + NDList labels = null; + data.add(sourceTextData.getEmbedding(manager, index)); + return new Record(data, labels); + } + + /** {@inheritDoc} */ + @Override + protected long availableSize() { + return sourceTextData.getSize(); + } + + /** + * Prepares the dataset for use with tracked progress. + * + * @param progress the progress tracker + * @throws IOException for various exceptions depending on the dataset + */ + @Override + public void prepare(Progress progress) throws IOException, EmbeddingException { + if (prepared) { + return; + } + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); + Artifact.Item item; + switch (usage) { + case TRAIN: + item = artifact.getFiles().get("train"); + break; + case TEST: + item = artifact.getFiles().get("test"); + break; + case VALIDATION: + item = artifact.getFiles().get("valid"); + break; + default: + throw new UnsupportedOperationException("Unsupported usage type."); + } + Path path = mrl.getRepository().getFile(item, "").toAbsolutePath(); + List lineArray = new ArrayList<>(); + try (BufferedReader reader = Files.newBufferedReader(path)) { + String row; + while ((row = reader.readLine()) != null) { + lineArray.add(row); + } + } + preprocess(lineArray, true); + prepared = true; + } + + /** A builder to construct a {@link PennTreebankText} . */ + public static class Builder extends TextDataset.Builder { + + /** Constructs a new builder. */ + public Builder() { + repository = BasicDatasets.REPOSITORY; + groupId = BasicDatasets.GROUP_ID; + artifactId = ARTIFACT_ID; + usage = Dataset.Usage.TRAIN; + } + + /** + * Builds a new {@link PennTreebankText} object. + * + * @return the new {@link PennTreebankText} object + */ + public PennTreebankText build() { + return new PennTreebankText(this); + } + + MRL getMrl() { + return repository.dataset(Application.NLP.ANY, groupId, artifactId, VERSION); + } + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + } +} diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java new file mode 100644 index 00000000000..24d7c42bc6d --- /dev/null +++ b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset; + +import ai.djl.basicdataset.nlp.PennTreebankText; +import ai.djl.basicdataset.utils.TextData.Configuration; +import ai.djl.ndarray.NDManager; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.dataset.Record; +import ai.djl.translate.TranslateException; +import java.io.IOException; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class PennTreebankTextTest { + + private static final int EMBEDDING_SIZE = 15; + + @Test + public void testPennTreebankText() throws IOException, TranslateException { + for (Dataset.Usage usage : + new Dataset.Usage[] { + Dataset.Usage.TRAIN, Dataset.Usage.VALIDATION, Dataset.Usage.TEST + }) { + try (NDManager manager = NDManager.newBaseManager()) { + PennTreebankText dataset = + PennTreebankText.builder() + .setSourceConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setTargetConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setSampling(32, true) + .optLimit(100) + .optUsage(usage) + .build(); + dataset.prepare(); + Record record = dataset.get(manager, 0); + Assert.assertEquals(record.getData().get(0).getShape().get(1), 15); + Assert.assertNull(record.getLabels()); + } + } + } +} diff --git a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json new file mode 100644 index 00000000000..8071911c026 --- /dev/null +++ b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json @@ -0,0 +1,40 @@ +{ + "metadataVersion": "0.2", + "resourceType": "dataset", + "application": "nlp", + "groupId": "ai.djl.basicdataset", + "artifactId": "penntreebank-unlabeled-processed", + "name": "penntreebank-unlabeled-processed", + "description": "The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal (WSJ) collection of 98,732 stories for syntactic annotation.", + "website": "https://catalog.ldc.upenn.edu/docs/LDC95T7/cl93.html", + "licenses": { + "license": { + "name": "LDC User Agreement for Non-Members", + "url": "https://catalog.ldc.upenn.edu/license/ldc-non-members-agreement.pdf" + } + }, + "artifacts": [ + { + "version": "1.0", + "snapshot": false, + "name": "penntreebank-unlabeled-processed", + "files": { + "train":{ + "uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt", + "sha1Hash": "f9ffb014fa33bd5730e5029697ad245184f3a678", + "size": 5101618 + }, + "test":{ + "uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt", + "sha1Hash": "5c15c548b42d80bce9332b788514e6635fb0226e", + "size": 449945 + }, + "valid":{ + "uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt", + "sha1Hash": "d9f5fed6afa5e1b82cd1e3e5f5040f6852940228", + "size": 399782 + } + } + } + ] +} \ No newline at end of file