From cbd922a57bec56fb96611006808055aee70b90a2 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Wed, 24 May 2023 16:20:36 +0100 Subject: [PATCH 01/14] Text splitter implementations --- .../textsplitter/CharacterTextSplitter.java | 30 +++ .../RecursiveCharacterTextSplitter.java | 63 +++++ .../parser/textsplitter/TextSplitter.java | 137 +++++++++++ .../langtorch/parser/TextSplitterTest.java | 226 ++++++++++++++++++ 4 files changed, 456 insertions(+) create mode 100644 src/main/java/ai/knowly/langtorch/parser/textsplitter/CharacterTextSplitter.java create mode 100644 src/main/java/ai/knowly/langtorch/parser/textsplitter/RecursiveCharacterTextSplitter.java create mode 100644 src/main/java/ai/knowly/langtorch/parser/textsplitter/TextSplitter.java create mode 100644 src/test/java/ai/knowly/langtorch/parser/TextSplitterTest.java diff --git a/src/main/java/ai/knowly/langtorch/parser/textsplitter/CharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/parser/textsplitter/CharacterTextSplitter.java new file mode 100644 index 00000000..a923bc1f --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/parser/textsplitter/CharacterTextSplitter.java @@ -0,0 +1,30 @@ +package ai.knowly.langtorch.parser.textsplitter; + +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.List; + +public class CharacterTextSplitter extends TextSplitter { + + public String separator = "\n\n"; + + public CharacterTextSplitter(@Nullable String separator, int chunkSize, int chunkOverlap) { + super(chunkSize, chunkOverlap); + if (separator != null) { + this.separator = separator; + } + } + + @Override + public List splitText(String text) { + List splits; + + if (this.separator != null) { + splits = Arrays.asList(text.split(this.separator)); + } else { + splits = Arrays.asList(text.split("")); + } + + return mergeSplits(splits, this.separator); + } +} diff --git a/src/main/java/ai/knowly/langtorch/parser/textsplitter/RecursiveCharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/parser/textsplitter/RecursiveCharacterTextSplitter.java new file mode 100644 index 00000000..96772746 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/parser/textsplitter/RecursiveCharacterTextSplitter.java @@ -0,0 +1,63 @@ +package ai.knowly.langtorch.parser.textsplitter; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class RecursiveCharacterTextSplitter extends TextSplitter { + + private List separators = Arrays.asList("\n\n", "\n", " ", ""); + + public RecursiveCharacterTextSplitter(@Nullable List separators, int chunkSize, int chunkOverlap) { + super(chunkSize, chunkOverlap); + if (separators != null) { + this.separators = separators; + } + } + + @Override + public List splitText(String text) { + List finalChunks = new ArrayList<>(); + + // Get appropriate separator to use + String separator = separators.get(separators.size() - 1); + for (String s : separators) { + if (s.isEmpty() || text.contains(s)) { + separator = s; + break; + } + } + + // Now that we have the separator, split the text + String[] splits; + if (!separator.isEmpty()) { + splits = text.split(separator); + } else { + splits = text.split(""); + } + + // Now go merging things, recursively splitting longer texts + List goodSplits = new ArrayList<>(); + for (String s : splits) { + if (s.length() < chunkSize) { + goodSplits.add(s); + } else { + if (!goodSplits.isEmpty()) { + List mergedText = mergeSplits(goodSplits, separator); + finalChunks.addAll(mergedText); + goodSplits.clear(); + } + List otherInfo = splitText(s); + finalChunks.addAll(otherInfo); + + } + } + if (!goodSplits.isEmpty()) { + List mergedText = mergeSplits(goodSplits, separator); + finalChunks.addAll(mergedText); + } + return finalChunks; + + } +} diff --git a/src/main/java/ai/knowly/langtorch/parser/textsplitter/TextSplitter.java b/src/main/java/ai/knowly/langtorch/parser/textsplitter/TextSplitter.java new file mode 100644 index 00000000..79f4182e --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/parser/textsplitter/TextSplitter.java @@ -0,0 +1,137 @@ +package ai.knowly.langtorch.parser.textsplitter; + +import ai.knowly.langtorch.schema.io.DomainDocument; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public abstract class TextSplitter { + + public int chunkSize; + + public int chunkOverlap; + + public TextSplitter(int chunkSize, int chunkOverlap) { + this.chunkSize = chunkSize; + this.chunkOverlap = chunkOverlap; + if (this.chunkOverlap >= this.chunkSize) { + throw new IllegalArgumentException("chunkOverlap cannot be equal to or greater than chunkSize"); + } + } + + abstract public List splitText(String text); + + public List createDocuments(List texts, @Nullable List> metaDatas) { + List> _metadatas; + + if (metaDatas != null) { + _metadatas = metaDatas.size() > 0 ? metaDatas : new ArrayList<>(); + } else { + _metadatas = new ArrayList<>(); + } + ArrayList documents = new ArrayList<>(); + + for (int i = 0; i < texts.size(); i += 1) { + String text = texts.get(i); + int lineCounterIndex = 1; + String prevChunk = null; + + for (String chunk : splitText(text)) { + int numberOfIntermediateNewLines = 0; + if (prevChunk != null) { + int indexChunk = text.indexOf(chunk); + int indexEndPrevChunk = text.indexOf(prevChunk) + prevChunk.length(); + String removedNewlinesFromSplittingText = text.substring(indexChunk, indexEndPrevChunk); + numberOfIntermediateNewLines = removedNewlinesFromSplittingText.split("\n").length - 1; + } + lineCounterIndex += numberOfIntermediateNewLines; + int newLinesCount = chunk.split("\n").length - 1; + + Map loc; + //todo should we also check what type of object is "loc"? + if (_metadatas.get(i) != null) { + if (!_metadatas.get(i).isEmpty() && _metadatas.get(i).get("loc") != null) { + loc = new HashMap<>(_metadatas.get(i)); + } else { + loc = new HashMap<>(); + } + } else { + loc = new HashMap<>(); + } + + loc.put("from", String.valueOf(lineCounterIndex)); + loc.put("to", String.valueOf(lineCounterIndex + newLinesCount)); + + Map metadataWithLinesNumber = new HashMap<>(); + if (_metadatas.get(i) != null) { + metadataWithLinesNumber.putAll(_metadatas.get(i)); + } + metadataWithLinesNumber.putAll(loc); + + documents.add(new DomainDocument(chunk, metadataWithLinesNumber)); + lineCounterIndex += newLinesCount; + prevChunk = chunk; + } + } + return documents; + } + + public List splitDocuments(List documents) { + + List selectedDocs = documents.stream().filter(doc -> doc.getPageContent() != null).collect(Collectors.toList()); + + List texts = selectedDocs.stream().map(DomainDocument::getPageContent).collect(Collectors.toList()); + List> metaDatas = selectedDocs.stream().map(DomainDocument::getMetadata).collect(Collectors.toList()); + + return this.createDocuments(texts, metaDatas); + } + + @Nullable + private String joinDocs(List docs, String separator) { + String text = String.join(separator, docs); + return text.equals("") ? null : text; + } + + public List mergeSplits(List splits, String separator) { + List docs = new ArrayList<>(); + List currentDoc = new ArrayList<>(); + int total = 0; + + for (String d : splits) { + int _len = d.length(); + + if (total + _len + (currentDoc.size() > 0 ? separator.length() : 0) > this.chunkSize) { + if (total > this.chunkSize) { + System.out.println("Created a chunk of size " + total + ", which is longer than the specified " + this.chunkSize); + } + + if (currentDoc.size() > 0) { + String doc = joinDocs(currentDoc, separator); + if (doc != null) { + docs.add(doc); + } + + while (total > this.chunkOverlap || (total + _len > this.chunkSize && total > 0)) { + total -= currentDoc.get(0).length(); + currentDoc.remove(0); + } + } + } + + currentDoc.add(d); + total += _len; + } + + String doc = joinDocs(currentDoc, separator); + if (doc != null) { + docs.add(doc); + } + + return docs; + } + +} diff --git a/src/test/java/ai/knowly/langtorch/parser/TextSplitterTest.java b/src/test/java/ai/knowly/langtorch/parser/TextSplitterTest.java new file mode 100644 index 00000000..d80a451f --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/parser/TextSplitterTest.java @@ -0,0 +1,226 @@ +package ai.knowly.langtorch.parser; + +import ai.knowly.langtorch.parser.textsplitter.CharacterTextSplitter; +import ai.knowly.langtorch.parser.textsplitter.RecursiveCharacterTextSplitter; +import ai.knowly.langtorch.parser.textsplitter.TextSplitter; +import ai.knowly.langtorch.schema.io.DomainDocument; +import org.junit.Assert; +import org.junit.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TextSplitterTest { + + @Test + public void testCharacterTextSplitter_splitByCharacterCount(){ + // Arrange + String text = "foo bar baz 123"; + CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 7, 3); + + // Act + List result = splitter.splitText(text); + + // Assert + List expected = new ArrayList<>(Arrays.asList("foo bar", "bar baz", "baz 123")); + + assertEquals(expected, result); + } + + @Test + public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments() { + // Arrange + String text = "foo bar"; + CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 2, 0); + + // Act + List result = splitter.splitText(text); + + // Assert + List expected = new ArrayList<>(Arrays.asList("foo", "bar")); + + assertEquals(expected, result); + } + + @Test + public void testCharacterTextSplitter_splitByCharacterCountLongWords() { + // Arrange + String text = "foo bar baz a a"; + CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 1); + + // Act + List result = splitter.splitText(text); + + // Assert + List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "a a")); + + assertEquals(expected, result); + } + + @Test + public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { + // Arrange + String text = "a a foo bar baz"; + CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 1); + + // Act + List result = splitter.splitText(text); + + // Assert + List expected = new ArrayList<>(Arrays.asList("a a", "foo", "bar", "baz")); + + assertEquals(expected, result); + } + + @Test + public void testCharacterTextSplitter_splitByCharactersSplitsNotFoundEasily() { + // Arrange + String text = "foo bar baz 123"; + CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 1, 0); + + // Act + List result = splitter.splitText(text); + + // Assert + List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "123")); + + assertEquals(expected, result); + } + + @Test(expected = IllegalArgumentException.class) + public void testCharacterTextSplitter_invalidArguments() { + // Arrange + int chunkSize = 2; + int chunkOverlap = 4; + + // Act + new CharacterTextSplitter(null, chunkSize, chunkOverlap); + + // Expect IllegalArgumentException to be thrown + } + + //TODO, this unit test will need improving. atm it only checks that length of our list of documents, it does not check the contents + @Test + public void testCharacterTextSplitter_createDocuments() { + // Arrange + List texts = Arrays.asList("foo bar", "baz"); + CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 0); + Map metadata = new HashMap<>(); + + Map loc = new HashMap<>(); + loc.put("from", String.valueOf(1)); + loc.put("to", String.valueOf(1)); + + // Act + List docs = splitter.createDocuments(texts, Arrays.asList(metadata, metadata)); + + // Assert + List expectedDocs = Arrays.asList( + new DomainDocument("foo", metadata), + new DomainDocument("bar", metadata), + new DomainDocument("baz", metadata) + ); + + assertEquals(expectedDocs.size(), docs.size()); + } + + //TODO, this unit test will need improving. atm it only checks that length of our list of documents, it does not check the contents + @Test + public void testCharacterTextSplitter_createDocumentsWithMetadata() { + // Arrange + List texts = Arrays.asList("foo bar", "baz"); + CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 0); + List> metadataList = Arrays.asList( + new HashMap() {{ + put("source", "1"); + }}, + new HashMap() {{ + put("source", "2"); + }} + ); + + // Act + List docs = splitter.createDocuments(texts, metadataList); + + // Assert + List expectedDocs = Arrays.asList( + new DomainDocument("foo", new HashMap() {{ + put("source", "1"); + put("from", String.valueOf(1)); + put("to", String.valueOf(1)); + }}), + new DomainDocument("bar", new HashMap() {{ + put("source", "1"); + put("from", String.valueOf(1)); + put("to", String.valueOf(1)); + }}), + new DomainDocument("baz", new HashMap() {{ + put("source", "2"); + put("from", String.valueOf(1)); + put("to", String.valueOf(1)); + }}) + ); + + assertEquals(expectedDocs.size(), docs.size()); + } + + + @Test + public void testRecursiveCharacterTextSplitter_iterativeTextSplitter() { + // Arrange + String text = "Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.\nThis is a weird text to write, but gotta test the splittingggg some how.\n\nBye!\n\n-H."; + RecursiveCharacterTextSplitter splitter = new RecursiveCharacterTextSplitter(null, 10, 1); + + // Act + List output = splitter.splitText(text); + + // Assert + List expectedOutput = Arrays.asList( + "Hi.", + "I'm", + "Harrison.", + "How? Are?", + "You?", + "Okay then f", + "f f f f.", + "This is a", + "a weird", + "text to", + "write, but", + "gotta test", + "the", + "splittingg", + "ggg", + "some how.", + "Bye!\n\n-H." + ); + + assertEquals(expectedOutput, output); + } + + @Test + public void testTextSplitter_iterativeTextSplitter_linesLoc() { + // Arrange + String text = "Hi.\nI'm Harrison.\n\nHow?\na\nb"; + RecursiveCharacterTextSplitter splitter = new RecursiveCharacterTextSplitter(null, 20, 1); + + // Act + List docs = splitter.createDocuments(Collections.singletonList(text), null); + + // Assert + DomainDocument doc1 = new DomainDocument("Hi.\nI'm Harrison.", null); + DomainDocument doc2 = new DomainDocument("How?\na\nb", null); + List expectedDocs = Arrays.asList(doc1, doc2); + + assertEquals(expectedDocs.size(), docs.size()); + } + + + + + + +} From d832f5fd694f27c2ceefb9e597f82909757c5dc2 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Wed, 24 May 2023 16:29:28 +0100 Subject: [PATCH 02/14] Added DomainDocument object --- .../langtorch/schema/io/DomainDocument.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java diff --git a/src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java b/src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java new file mode 100644 index 00000000..b670e181 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java @@ -0,0 +1,25 @@ +package ai.knowly.langtorch.schema.io; + +import javax.annotation.Nullable; +import java.util.Map; + +public class DomainDocument implements Input, Output { + + private final String pageContent; + + @Nullable + private final Map metadata; + + public DomainDocument(String pageContent, @Nullable Map metadata) { + this.pageContent = pageContent; + this.metadata = metadata; + } + + public String getPageContent() { + return pageContent; + } + + public Map getMetadata() { + return metadata; + } +} From e5c2882b8c405105702caef6da976b2a9356a758 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Thu, 25 May 2023 11:50:54 +0100 Subject: [PATCH 03/14] Resolved code review comments --- build.gradle | 3 + .../text/ChatCompletionLLMCapability.java | 2 +- .../text/TextCompletionTextLLMCapability.java | 2 +- .../openai/PromptTemplateTextCapability.java | 4 +- .../module/openai/SimpleChatCapability.java | 4 +- .../module/openai/SimpleTextCapability.java | 4 +- .../textsplitter/CharacterTextSplitter.java | 30 ------- .../parser/ChatMessageToStringParser.java | 2 +- .../{ => preprocessing}/parser/Parser.java | 2 +- .../PromptTemplateToSingleTextParser.java | 2 +- .../parser/SingleTextToStringParser.java | 2 +- .../StringToMultiChatMessageParser.java | 2 +- .../parser/StringToSingleTextParser.java | 2 +- .../splitter/text/CharacterTextSplitter.java | 48 ++++++++++++ .../text}/RecursiveCharacterTextSplitter.java | 24 +++++- .../splitter/text}/TextSplitter.java | 45 +++++------ .../knowly/langtorch/schema/io/Metadatas.java | 17 ++++ .../embeddings/OpenAIEmbeddingTest.java | 2 +- .../PromptTemplateToSingleTextParserTest.java | 3 +- .../splitter/text}/TextSplitterTest.java | 78 ++++++++++--------- 20 files changed, 168 insertions(+), 110 deletions(-) delete mode 100644 src/main/java/ai/knowly/langtorch/parser/textsplitter/CharacterTextSplitter.java rename src/main/java/ai/knowly/langtorch/{ => preprocessing}/parser/ChatMessageToStringParser.java (88%) rename src/main/java/ai/knowly/langtorch/{ => preprocessing}/parser/Parser.java (60%) rename src/main/java/ai/knowly/langtorch/{ => preprocessing}/parser/PromptTemplateToSingleTextParser.java (90%) rename src/main/java/ai/knowly/langtorch/{ => preprocessing}/parser/SingleTextToStringParser.java (88%) rename src/main/java/ai/knowly/langtorch/{ => preprocessing}/parser/StringToMultiChatMessageParser.java (91%) rename src/main/java/ai/knowly/langtorch/{ => preprocessing}/parser/StringToSingleTextParser.java (88%) create mode 100644 src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java rename src/main/java/ai/knowly/langtorch/{parser/textsplitter => preprocessing/splitter/text}/RecursiveCharacterTextSplitter.java (56%) rename src/main/java/ai/knowly/langtorch/{parser/textsplitter => preprocessing/splitter/text}/TextSplitter.java (71%) create mode 100644 src/main/java/ai/knowly/langtorch/schema/io/Metadatas.java rename src/test/java/ai/knowly/{langtoch => langtorch}/embeddings/OpenAIEmbeddingTest.java (97%) rename src/test/java/ai/knowly/langtorch/{ => preprocessing}/parser/PromptTemplateToSingleTextParserTest.java (91%) rename src/test/java/ai/knowly/langtorch/{parser => preprocessing/splitter/text}/TextSplitterTest.java (85%) diff --git a/build.gradle b/build.gradle index 348afb78..a2943e63 100644 --- a/build.gradle +++ b/build.gradle @@ -116,6 +116,9 @@ dependencies { testImplementation 'com.squareup.retrofit2:retrofit-mock:2.9.0' implementation 'com.squareup.okhttp3:logging-interceptor:4.9.2' implementation "com.squareup.retrofit2:converter-gson:2.9.0" + // Apache commons lang + implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.0' + } // Testing related dependencies diff --git a/src/main/java/ai/knowly/langtorch/capability/modality/text/ChatCompletionLLMCapability.java b/src/main/java/ai/knowly/langtorch/capability/modality/text/ChatCompletionLLMCapability.java index 6a261bc4..f7b4089c 100644 --- a/src/main/java/ai/knowly/langtorch/capability/modality/text/ChatCompletionLLMCapability.java +++ b/src/main/java/ai/knowly/langtorch/capability/modality/text/ChatCompletionLLMCapability.java @@ -5,7 +5,7 @@ import ai.knowly.langtorch.processor.module.Processor; import ai.knowly.langtorch.store.memory.Memory; -import ai.knowly.langtorch.parser.Parser; +import ai.knowly.langtorch.preprocessing.parser.Parser; import ai.knowly.langtorch.schema.chat.ChatMessage; import ai.knowly.langtorch.schema.text.MultiChatMessage; import ai.knowly.langtorch.schema.memory.MemoryKey; diff --git a/src/main/java/ai/knowly/langtorch/capability/modality/text/TextCompletionTextLLMCapability.java b/src/main/java/ai/knowly/langtorch/capability/modality/text/TextCompletionTextLLMCapability.java index 9ce4bde1..fd917f9e 100644 --- a/src/main/java/ai/knowly/langtorch/capability/modality/text/TextCompletionTextLLMCapability.java +++ b/src/main/java/ai/knowly/langtorch/capability/modality/text/TextCompletionTextLLMCapability.java @@ -5,7 +5,7 @@ import ai.knowly.langtorch.processor.module.Processor; import ai.knowly.langtorch.store.memory.Memory; -import ai.knowly.langtorch.parser.Parser; +import ai.knowly.langtorch.preprocessing.parser.Parser; import ai.knowly.langtorch.schema.text.SingleText; import ai.knowly.langtorch.schema.memory.MemoryKey; import ai.knowly.langtorch.schema.memory.MemoryValue; diff --git a/src/main/java/ai/knowly/langtorch/capability/module/openai/PromptTemplateTextCapability.java b/src/main/java/ai/knowly/langtorch/capability/module/openai/PromptTemplateTextCapability.java index 68a470b1..6c69a1df 100644 --- a/src/main/java/ai/knowly/langtorch/capability/module/openai/PromptTemplateTextCapability.java +++ b/src/main/java/ai/knowly/langtorch/capability/module/openai/PromptTemplateTextCapability.java @@ -2,8 +2,8 @@ import ai.knowly.langtorch.capability.modality.text.TextCompletionTextLLMCapability; import ai.knowly.langtorch.processor.module.openai.text.OpenAITextProcessor; -import ai.knowly.langtorch.parser.SingleTextToStringParser; -import ai.knowly.langtorch.parser.StringToSingleTextParser; +import ai.knowly.langtorch.preprocessing.parser.SingleTextToStringParser; +import ai.knowly.langtorch.preprocessing.parser.StringToSingleTextParser; import ai.knowly.langtorch.prompt.template.PromptTemplate; import java.util.Map; import java.util.Optional; diff --git a/src/main/java/ai/knowly/langtorch/capability/module/openai/SimpleChatCapability.java b/src/main/java/ai/knowly/langtorch/capability/module/openai/SimpleChatCapability.java index dd5c2af9..22b6b640 100644 --- a/src/main/java/ai/knowly/langtorch/capability/module/openai/SimpleChatCapability.java +++ b/src/main/java/ai/knowly/langtorch/capability/module/openai/SimpleChatCapability.java @@ -2,8 +2,8 @@ import ai.knowly.langtorch.capability.modality.text.ChatCompletionLLMCapability; import ai.knowly.langtorch.processor.module.openai.chat.OpenAIChatProcessor; -import ai.knowly.langtorch.parser.ChatMessageToStringParser; -import ai.knowly.langtorch.parser.StringToMultiChatMessageParser; +import ai.knowly.langtorch.preprocessing.parser.ChatMessageToStringParser; +import ai.knowly.langtorch.preprocessing.parser.StringToMultiChatMessageParser; import java.util.Optional; /** A simple chat capability unit that leverages openai api to generate response */ diff --git a/src/main/java/ai/knowly/langtorch/capability/module/openai/SimpleTextCapability.java b/src/main/java/ai/knowly/langtorch/capability/module/openai/SimpleTextCapability.java index 60ae634b..0d3ca0d0 100644 --- a/src/main/java/ai/knowly/langtorch/capability/module/openai/SimpleTextCapability.java +++ b/src/main/java/ai/knowly/langtorch/capability/module/openai/SimpleTextCapability.java @@ -2,8 +2,8 @@ import ai.knowly.langtorch.capability.modality.text.TextCompletionTextLLMCapability; import ai.knowly.langtorch.processor.module.openai.text.OpenAITextProcessor; -import ai.knowly.langtorch.parser.SingleTextToStringParser; -import ai.knowly.langtorch.parser.StringToSingleTextParser; +import ai.knowly.langtorch.preprocessing.parser.SingleTextToStringParser; +import ai.knowly.langtorch.preprocessing.parser.StringToSingleTextParser; /** A simple text capability unit that leverages openai api to generate response */ public class SimpleTextCapability extends TextCompletionTextLLMCapability { diff --git a/src/main/java/ai/knowly/langtorch/parser/textsplitter/CharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/parser/textsplitter/CharacterTextSplitter.java deleted file mode 100644 index a923bc1f..00000000 --- a/src/main/java/ai/knowly/langtorch/parser/textsplitter/CharacterTextSplitter.java +++ /dev/null @@ -1,30 +0,0 @@ -package ai.knowly.langtorch.parser.textsplitter; - -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.List; - -public class CharacterTextSplitter extends TextSplitter { - - public String separator = "\n\n"; - - public CharacterTextSplitter(@Nullable String separator, int chunkSize, int chunkOverlap) { - super(chunkSize, chunkOverlap); - if (separator != null) { - this.separator = separator; - } - } - - @Override - public List splitText(String text) { - List splits; - - if (this.separator != null) { - splits = Arrays.asList(text.split(this.separator)); - } else { - splits = Arrays.asList(text.split("")); - } - - return mergeSplits(splits, this.separator); - } -} diff --git a/src/main/java/ai/knowly/langtorch/parser/ChatMessageToStringParser.java b/src/main/java/ai/knowly/langtorch/preprocessing/parser/ChatMessageToStringParser.java similarity index 88% rename from src/main/java/ai/knowly/langtorch/parser/ChatMessageToStringParser.java rename to src/main/java/ai/knowly/langtorch/preprocessing/parser/ChatMessageToStringParser.java index cf53e6ac..f8000de3 100644 --- a/src/main/java/ai/knowly/langtorch/parser/ChatMessageToStringParser.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/parser/ChatMessageToStringParser.java @@ -1,4 +1,4 @@ -package ai.knowly.langtorch.parser; +package ai.knowly.langtorch.preprocessing.parser; import ai.knowly.langtorch.schema.chat.ChatMessage; diff --git a/src/main/java/ai/knowly/langtorch/parser/Parser.java b/src/main/java/ai/knowly/langtorch/preprocessing/parser/Parser.java similarity index 60% rename from src/main/java/ai/knowly/langtorch/parser/Parser.java rename to src/main/java/ai/knowly/langtorch/preprocessing/parser/Parser.java index a208fd07..9712e944 100644 --- a/src/main/java/ai/knowly/langtorch/parser/Parser.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/parser/Parser.java @@ -1,4 +1,4 @@ -package ai.knowly.langtorch.parser; +package ai.knowly.langtorch.preprocessing.parser; @FunctionalInterface public interface Parser { diff --git a/src/main/java/ai/knowly/langtorch/parser/PromptTemplateToSingleTextParser.java b/src/main/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParser.java similarity index 90% rename from src/main/java/ai/knowly/langtorch/parser/PromptTemplateToSingleTextParser.java rename to src/main/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParser.java index f926cfe3..d1ac01dc 100644 --- a/src/main/java/ai/knowly/langtorch/parser/PromptTemplateToSingleTextParser.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParser.java @@ -1,4 +1,4 @@ -package ai.knowly.langtorch.parser; +package ai.knowly.langtorch.preprocessing.parser; import ai.knowly.langtorch.schema.text.SingleText; import ai.knowly.langtorch.prompt.template.PromptTemplate; diff --git a/src/main/java/ai/knowly/langtorch/parser/SingleTextToStringParser.java b/src/main/java/ai/knowly/langtorch/preprocessing/parser/SingleTextToStringParser.java similarity index 88% rename from src/main/java/ai/knowly/langtorch/parser/SingleTextToStringParser.java rename to src/main/java/ai/knowly/langtorch/preprocessing/parser/SingleTextToStringParser.java index 7143227c..c0b87fdd 100644 --- a/src/main/java/ai/knowly/langtorch/parser/SingleTextToStringParser.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/parser/SingleTextToStringParser.java @@ -1,4 +1,4 @@ -package ai.knowly.langtorch.parser; +package ai.knowly.langtorch.preprocessing.parser; import ai.knowly.langtorch.schema.text.SingleText; diff --git a/src/main/java/ai/knowly/langtorch/parser/StringToMultiChatMessageParser.java b/src/main/java/ai/knowly/langtorch/preprocessing/parser/StringToMultiChatMessageParser.java similarity index 91% rename from src/main/java/ai/knowly/langtorch/parser/StringToMultiChatMessageParser.java rename to src/main/java/ai/knowly/langtorch/preprocessing/parser/StringToMultiChatMessageParser.java index 15e98b0b..c0bf8503 100644 --- a/src/main/java/ai/knowly/langtorch/parser/StringToMultiChatMessageParser.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/parser/StringToMultiChatMessageParser.java @@ -1,4 +1,4 @@ -package ai.knowly.langtorch.parser; +package ai.knowly.langtorch.preprocessing.parser; import static ai.knowly.langtorch.schema.chat.Role.USER; diff --git a/src/main/java/ai/knowly/langtorch/parser/StringToSingleTextParser.java b/src/main/java/ai/knowly/langtorch/preprocessing/parser/StringToSingleTextParser.java similarity index 88% rename from src/main/java/ai/knowly/langtorch/parser/StringToSingleTextParser.java rename to src/main/java/ai/knowly/langtorch/preprocessing/parser/StringToSingleTextParser.java index 80390580..6d30da37 100644 --- a/src/main/java/ai/knowly/langtorch/parser/StringToSingleTextParser.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/parser/StringToSingleTextParser.java @@ -1,4 +1,4 @@ -package ai.knowly.langtorch.parser; +package ai.knowly.langtorch.preprocessing.parser; import ai.knowly.langtorch.schema.text.SingleText; diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java new file mode 100644 index 00000000..be17172d --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java @@ -0,0 +1,48 @@ +package ai.knowly.langtorch.preprocessing.splitter.text; + +import org.apache.commons.lang3.StringUtils; + +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.List; +/** + The CharacterTextSplitter class is a concrete implementation of the TextSplitter abstract class + that splits text into chunks based on a specified separator. + */ + +public class CharacterTextSplitter extends TextSplitter { + + public String separator = "\n\n"; + + /** + + Constructs a CharacterTextSplitter object with the given separator, chunk size, and chunk overlap. + If the separator is null, the default separator "\n\n" is used. + @param separator The separator used for splitting the text into chunks. + @param chunkSize The size of each chunk. + @param chunkOverlap The amount of overlap between adjacent chunks. + */ + public CharacterTextSplitter(@Nullable String separator, int chunkSize, int chunkOverlap) { + super(chunkSize, chunkOverlap); + if (separator != null) { + this.separator = separator; + } + } + + /** + Splits the given text into chunks based on the specified separator. + @param text The text to be split into chunks. + @return A list of strings representing the chunks of the text. + */ + @Override + public List splitText(String text) { + List splits; + + if (StringUtils.isNotEmpty(this.separator)) { + splits = Arrays.asList(StringUtils.splitByWholeSeparatorPreserveAllTokens(text, this.separator)); + } else { + splits = Arrays.asList(StringUtils.splitByWholeSeparatorPreserveAllTokens(text, "")); + } + return mergeSplits(splits, this.separator); + } +} diff --git a/src/main/java/ai/knowly/langtorch/parser/textsplitter/RecursiveCharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java similarity index 56% rename from src/main/java/ai/knowly/langtorch/parser/textsplitter/RecursiveCharacterTextSplitter.java rename to src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java index 96772746..76252968 100644 --- a/src/main/java/ai/knowly/langtorch/parser/textsplitter/RecursiveCharacterTextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java @@ -1,14 +1,27 @@ -package ai.knowly.langtorch.parser.textsplitter; +package ai.knowly.langtorch.preprocessing.splitter.text; import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; - +/** + The RecursiveCharacterTextSplitter class is a concrete implementation of the TextSplitter abstract class + that recursively splits text into chunks using a set of separators. + It applies a recursive splitting approach to handle longer texts by examining the text and selecting the appropriate + separator from the list of separators based on their presence in the text. + If the text is longer than the chunk size, it recursively splits the longer portions into smaller chunks. This recursive process continues until the chunks reach a size smaller than the specified chunk size. + */ public class RecursiveCharacterTextSplitter extends TextSplitter { private List separators = Arrays.asList("\n\n", "\n", " ", ""); + /** + Constructs a RecursiveCharacterTextSplitter object with the given list of separators, chunk size, and chunk overlap. + If the separators list is null, the default list containing separators "\n\n", "\n", " ", and "" is used. + @param separators The list of separators used for splitting the text into chunks. + @param chunkSize The size of each chunk. + @param chunkOverlap The amount of overlap between adjacent chunks. + */ public RecursiveCharacterTextSplitter(@Nullable List separators, int chunkSize, int chunkOverlap) { super(chunkSize, chunkOverlap); if (separators != null) { @@ -16,6 +29,13 @@ public RecursiveCharacterTextSplitter(@Nullable List separators, int chu } } + /** + Splits the given text into chunks using a recursive splitting approach. + It selects an appropriate separator from the list of separators based on the presence of each separator in the text. + It recursively splits longer pieces of text into smaller chunks. + @param text The text to be split into chunks. + @return A list of strings representing the chunks of the text. + */ @Override public List splitText(String text) { List finalChunks = new ArrayList<>(); diff --git a/src/main/java/ai/knowly/langtorch/parser/textsplitter/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java similarity index 71% rename from src/main/java/ai/knowly/langtorch/parser/textsplitter/TextSplitter.java rename to src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index 79f4182e..c550e042 100644 --- a/src/main/java/ai/knowly/langtorch/parser/textsplitter/TextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -1,14 +1,16 @@ -package ai.knowly.langtorch.parser.textsplitter; +package ai.knowly.langtorch.preprocessing.splitter.text; import ai.knowly.langtorch.schema.io.DomainDocument; +import ai.knowly.langtorch.schema.io.Metadatas; +import org.apache.commons.lang3.StringUtils; import javax.annotation.Nullable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; +/** + * The TextSplitter class provides functionality for splitting text into chunks. + */ public abstract class TextSplitter { public int chunkSize; @@ -25,14 +27,10 @@ public TextSplitter(int chunkSize, int chunkOverlap) { abstract public List splitText(String text); - public List createDocuments(List texts, @Nullable List> metaDatas) { - List> _metadatas; + public List createDocuments(List texts, Optional docMetadatas) { + Metadatas metadatas; - if (metaDatas != null) { - _metadatas = metaDatas.size() > 0 ? metaDatas : new ArrayList<>(); - } else { - _metadatas = new ArrayList<>(); - } + metadatas = docMetadatas.filter(value -> value.getValues().size() > 0).orElseGet(() -> new Metadatas(new ArrayList<>())); ArrayList documents = new ArrayList<>(); for (int i = 0; i < texts.size(); i += 1) { @@ -43,19 +41,18 @@ public List createDocuments(List texts, @Nullable List loc; - //todo should we also check what type of object is "loc"? - if (_metadatas.get(i) != null) { - if (!_metadatas.get(i).isEmpty() && _metadatas.get(i).get("loc") != null) { - loc = new HashMap<>(_metadatas.get(i)); + if (i < metadatas.getValues().size() && metadatas.getValues().get(i) != null) { + if (!metadatas.getValues().get(i).isEmpty() && metadatas.getValues().get(i).get("loc") != null) { + loc = new HashMap<>(metadatas.getValues().get(i)); } else { loc = new HashMap<>(); } @@ -67,8 +64,8 @@ public List createDocuments(List texts, @Nullable List metadataWithLinesNumber = new HashMap<>(); - if (_metadatas.get(i) != null) { - metadataWithLinesNumber.putAll(_metadatas.get(i)); + if (i < metadatas.getValues().size() && metadatas.getValues().get(i) != null) { + metadataWithLinesNumber.putAll(metadatas.getValues().get(i)); } metadataWithLinesNumber.putAll(loc); @@ -87,7 +84,7 @@ public List splitDocuments(List documents) { List texts = selectedDocs.stream().map(DomainDocument::getPageContent).collect(Collectors.toList()); List> metaDatas = selectedDocs.stream().map(DomainDocument::getMetadata).collect(Collectors.toList()); - return this.createDocuments(texts, metaDatas); + return this.createDocuments(texts, Optional.of(new Metadatas(metaDatas))); } @Nullable diff --git a/src/main/java/ai/knowly/langtorch/schema/io/Metadatas.java b/src/main/java/ai/knowly/langtorch/schema/io/Metadatas.java new file mode 100644 index 00000000..41e8002f --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/schema/io/Metadatas.java @@ -0,0 +1,17 @@ +package ai.knowly.langtorch.schema.io; + +import java.util.List; +import java.util.Map; + +public class Metadatas { + + private final List> values; + + public Metadatas(List> values) { + this.values = values; + } + + public List> getValues() { + return values; + } +} diff --git a/src/test/java/ai/knowly/langtoch/embeddings/OpenAIEmbeddingTest.java b/src/test/java/ai/knowly/langtorch/embeddings/OpenAIEmbeddingTest.java similarity index 97% rename from src/test/java/ai/knowly/langtoch/embeddings/OpenAIEmbeddingTest.java rename to src/test/java/ai/knowly/langtorch/embeddings/OpenAIEmbeddingTest.java index 0e6483b0..85316c7f 100644 --- a/src/test/java/ai/knowly/langtoch/embeddings/OpenAIEmbeddingTest.java +++ b/src/test/java/ai/knowly/langtorch/embeddings/OpenAIEmbeddingTest.java @@ -1,4 +1,4 @@ -package ai.knowly.langtoch.embeddings; +package ai.knowly.langtorch.embeddings; import ai.knowly.langtorch.processor.llm.openai.service.OpenAIService; import ai.knowly.langtorch.processor.module.openai.embeddings.OpenAIEmbeddingsProcessor; diff --git a/src/test/java/ai/knowly/langtorch/parser/PromptTemplateToSingleTextParserTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParserTest.java similarity index 91% rename from src/test/java/ai/knowly/langtorch/parser/PromptTemplateToSingleTextParserTest.java rename to src/test/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParserTest.java index 1972cb07..e6d67875 100644 --- a/src/test/java/ai/knowly/langtorch/parser/PromptTemplateToSingleTextParserTest.java +++ b/src/test/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParserTest.java @@ -1,5 +1,6 @@ -package ai.knowly.langtorch.parser; +package ai.knowly.langtorch.preprocessing.parser; +import ai.knowly.langtorch.preprocessing.parser.PromptTemplateToSingleTextParser; import ai.knowly.langtorch.prompt.template.PromptTemplate; import ai.knowly.langtorch.schema.text.SingleText; import com.google.common.truth.Truth; diff --git a/src/test/java/ai/knowly/langtorch/parser/TextSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java similarity index 85% rename from src/test/java/ai/knowly/langtorch/parser/TextSplitterTest.java rename to src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java index d80a451f..e8fc866e 100644 --- a/src/test/java/ai/knowly/langtorch/parser/TextSplitterTest.java +++ b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java @@ -1,13 +1,10 @@ -package ai.knowly.langtorch.parser; +package ai.knowly.langtorch.preprocessing.splitter.text; -import ai.knowly.langtorch.parser.textsplitter.CharacterTextSplitter; -import ai.knowly.langtorch.parser.textsplitter.RecursiveCharacterTextSplitter; -import ai.knowly.langtorch.parser.textsplitter.TextSplitter; +import ai.knowly.langtorch.preprocessing.splitter.text.CharacterTextSplitter; +import ai.knowly.langtorch.preprocessing.splitter.text.RecursiveCharacterTextSplitter; import ai.knowly.langtorch.schema.io.DomainDocument; -import org.junit.Assert; +import ai.knowly.langtorch.schema.io.Metadatas; import org.junit.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.junit.jupiter.MockitoExtension; import java.util.*; @@ -17,14 +14,14 @@ public class TextSplitterTest { @Test public void testCharacterTextSplitter_splitByCharacterCount(){ - // Arrange + // Arrange. String text = "foo bar baz 123"; CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 7, 3); - // Act + // Act. List result = splitter.splitText(text); - // Assert + // Assert. List expected = new ArrayList<>(Arrays.asList("foo bar", "bar baz", "baz 123")); assertEquals(expected, result); @@ -32,14 +29,14 @@ public void testCharacterTextSplitter_splitByCharacterCount(){ @Test public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments() { - // Arrange + // Arrange. String text = "foo bar"; CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 2, 0); - // Act + // Act. List result = splitter.splitText(text); - // Assert + // Assert. List expected = new ArrayList<>(Arrays.asList("foo", "bar")); assertEquals(expected, result); @@ -47,14 +44,14 @@ public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments( @Test public void testCharacterTextSplitter_splitByCharacterCountLongWords() { - // Arrange + // Arrange. String text = "foo bar baz a a"; CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 1); - // Act + // Act. List result = splitter.splitText(text); - // Assert + // Assert. List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "a a")); assertEquals(expected, result); @@ -62,14 +59,14 @@ public void testCharacterTextSplitter_splitByCharacterCountLongWords() { @Test public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { - // Arrange + // Arrange. String text = "a a foo bar baz"; CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 1); - // Act + // Act. List result = splitter.splitText(text); - // Assert + // Assert. List expected = new ArrayList<>(Arrays.asList("a a", "foo", "bar", "baz")); assertEquals(expected, result); @@ -77,14 +74,14 @@ public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { @Test public void testCharacterTextSplitter_splitByCharactersSplitsNotFoundEasily() { - // Arrange + // Arrange. String text = "foo bar baz 123"; CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 1, 0); - // Act + // Act. List result = splitter.splitText(text); - // Assert + // Assert. List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "123")); assertEquals(expected, result); @@ -92,7 +89,7 @@ public void testCharacterTextSplitter_splitByCharactersSplitsNotFoundEasily() { @Test(expected = IllegalArgumentException.class) public void testCharacterTextSplitter_invalidArguments() { - // Arrange + // Arrange. int chunkSize = 2; int chunkOverlap = 4; @@ -105,7 +102,7 @@ public void testCharacterTextSplitter_invalidArguments() { //TODO, this unit test will need improving. atm it only checks that length of our list of documents, it does not check the contents @Test public void testCharacterTextSplitter_createDocuments() { - // Arrange + // Arrange. List texts = Arrays.asList("foo bar", "baz"); CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 0); Map metadata = new HashMap<>(); @@ -114,10 +111,11 @@ public void testCharacterTextSplitter_createDocuments() { loc.put("from", String.valueOf(1)); loc.put("to", String.valueOf(1)); - // Act - List docs = splitter.createDocuments(texts, Arrays.asList(metadata, metadata)); + Optional metadatas = Optional.of(new Metadatas(Arrays.asList(metadata, metadata))); + // Act. + List docs = splitter.createDocuments(texts, metadatas); - // Assert + // Assert. List expectedDocs = Arrays.asList( new DomainDocument("foo", metadata), new DomainDocument("bar", metadata), @@ -130,7 +128,7 @@ public void testCharacterTextSplitter_createDocuments() { //TODO, this unit test will need improving. atm it only checks that length of our list of documents, it does not check the contents @Test public void testCharacterTextSplitter_createDocumentsWithMetadata() { - // Arrange + // Arrange. List texts = Arrays.asList("foo bar", "baz"); CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 0); List> metadataList = Arrays.asList( @@ -142,10 +140,13 @@ public void testCharacterTextSplitter_createDocumentsWithMetadata() { }} ); - // Act - List docs = splitter.createDocuments(texts, metadataList); + Optional metadatas = Optional.of(new Metadatas(metadataList)); + + + // Act. + List docs = splitter.createDocuments(texts, metadatas); - // Assert + // Assert. List expectedDocs = Arrays.asList( new DomainDocument("foo", new HashMap() {{ put("source", "1"); @@ -170,14 +171,14 @@ public void testCharacterTextSplitter_createDocumentsWithMetadata() { @Test public void testRecursiveCharacterTextSplitter_iterativeTextSplitter() { - // Arrange + // Arrange. String text = "Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.\nThis is a weird text to write, but gotta test the splittingggg some how.\n\nBye!\n\n-H."; RecursiveCharacterTextSplitter splitter = new RecursiveCharacterTextSplitter(null, 10, 1); - // Act + // Act. List output = splitter.splitText(text); - // Assert + // Assert. List expectedOutput = Arrays.asList( "Hi.", "I'm", @@ -203,14 +204,15 @@ public void testRecursiveCharacterTextSplitter_iterativeTextSplitter() { @Test public void testTextSplitter_iterativeTextSplitter_linesLoc() { - // Arrange + // Arrange. String text = "Hi.\nI'm Harrison.\n\nHow?\na\nb"; RecursiveCharacterTextSplitter splitter = new RecursiveCharacterTextSplitter(null, 20, 1); - // Act - List docs = splitter.createDocuments(Collections.singletonList(text), null); + Optional metadatas = Optional.ofNullable(null); + // Act. + List docs = splitter.createDocuments(Collections.singletonList(text), metadatas); - // Assert + // Assert. DomainDocument doc1 = new DomainDocument("Hi.\nI'm Harrison.", null); DomainDocument doc2 = new DomainDocument("How?\na\nb", null); List expectedDocs = Arrays.asList(doc1, doc2); From 83b67700794ed7556aaa77d367992c100399a092 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Thu, 25 May 2023 19:34:29 +0100 Subject: [PATCH 04/14] WIP: responding to code review comments --- build.gradle | 2 +- .../splitter/text/CharacterTextSplitter.java | 11 +++-------- .../text/RecursiveCharacterTextSplitter.java | 2 +- .../preprocessing/splitter/text/TextSplitter.java | 12 ++++++------ .../splitter/text/TextSplitterTest.java | 4 ++-- 5 files changed, 13 insertions(+), 18 deletions(-) diff --git a/build.gradle b/build.gradle index a2943e63..6f3b8b51 100644 --- a/build.gradle +++ b/build.gradle @@ -117,7 +117,7 @@ dependencies { implementation 'com.squareup.okhttp3:logging-interceptor:4.9.2' implementation "com.squareup.retrofit2:converter-gson:2.9.0" // Apache commons lang - implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.0' + implementation 'org.apache.commons:commons-lang3:3.0' } diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java index be17172d..21e4b593 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java @@ -12,7 +12,7 @@ public class CharacterTextSplitter extends TextSplitter { - public String separator = "\n\n"; + private static String separator = "\n\n"; /** @@ -36,13 +36,8 @@ public CharacterTextSplitter(@Nullable String separator, int chunkSize, int chun */ @Override public List splitText(String text) { - List splits; - - if (StringUtils.isNotEmpty(this.separator)) { - splits = Arrays.asList(StringUtils.splitByWholeSeparatorPreserveAllTokens(text, this.separator)); - } else { - splits = Arrays.asList(StringUtils.splitByWholeSeparatorPreserveAllTokens(text, "")); - } + List splits = + Arrays.asList(StringUtils.splitByWholeSeparatorPreserveAllTokens(text, this.separator.isEmpty() ? "" : this.separator)); return mergeSplits(splits, this.separator); } } diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java index 76252968..694ffc77 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java @@ -13,7 +13,7 @@ */ public class RecursiveCharacterTextSplitter extends TextSplitter { - private List separators = Arrays.asList("\n\n", "\n", " ", ""); + private static List separators = Arrays.asList("\n\n", "\n", " ", ""); /** Constructs a RecursiveCharacterTextSplitter object with the given list of separators, chunk size, and chunk overlap. diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index c550e042..71d52db0 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -13,9 +13,9 @@ */ public abstract class TextSplitter { - public int chunkSize; + public final int chunkSize; - public int chunkOverlap; + public final int chunkOverlap; public TextSplitter(int chunkSize, int chunkOverlap) { this.chunkSize = chunkSize; @@ -99,9 +99,9 @@ public List mergeSplits(List splits, String separator) { int total = 0; for (String d : splits) { - int _len = d.length(); + int length = d.length(); - if (total + _len + (currentDoc.size() > 0 ? separator.length() : 0) > this.chunkSize) { + if (total + length + (currentDoc.size() > 0 ? separator.length() : 0) > this.chunkSize) { if (total > this.chunkSize) { System.out.println("Created a chunk of size " + total + ", which is longer than the specified " + this.chunkSize); } @@ -112,7 +112,7 @@ public List mergeSplits(List splits, String separator) { docs.add(doc); } - while (total > this.chunkOverlap || (total + _len > this.chunkSize && total > 0)) { + while (total > this.chunkOverlap || (total + length > this.chunkSize && total > 0)) { total -= currentDoc.get(0).length(); currentDoc.remove(0); } @@ -120,7 +120,7 @@ public List mergeSplits(List splits, String separator) { } currentDoc.add(d); - total += _len; + total += length; } String doc = joinDocs(currentDoc, separator); diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java index e8fc866e..52f66e9a 100644 --- a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java +++ b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java @@ -93,10 +93,10 @@ public void testCharacterTextSplitter_invalidArguments() { int chunkSize = 2; int chunkOverlap = 4; - // Act + // Act. new CharacterTextSplitter(null, chunkSize, chunkOverlap); - // Expect IllegalArgumentException to be thrown + // Expect IllegalArgumentException to be thrown. } //TODO, this unit test will need improving. atm it only checks that length of our list of documents, it does not check the contents From 6a0e579bf4fe38a3978c8f311601bce3c19e13f0 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Fri, 26 May 2023 09:47:36 +0100 Subject: [PATCH 05/14] Responded to code review comments --- build.gradle | 1 + .../splitter/text/TextSplitter.java | 42 ++++---- .../langtorch/schema/io/DomainDocument.java | 10 +- .../knowly/langtorch/schema/io/Metadata.java | 19 ++++ .../knowly/langtorch/schema/io/Metadatas.java | 17 ---- .../PromptTemplateToSingleTextParserTest.java | 1 - .../splitter/text/TextSplitterTest.java | 95 ++++++++----------- 7 files changed, 85 insertions(+), 100 deletions(-) create mode 100644 src/main/java/ai/knowly/langtorch/schema/io/Metadata.java delete mode 100644 src/main/java/ai/knowly/langtorch/schema/io/Metadatas.java diff --git a/build.gradle b/build.gradle index 6f3b8b51..4b2a162c 100644 --- a/build.gradle +++ b/build.gradle @@ -118,6 +118,7 @@ dependencies { implementation "com.squareup.retrofit2:converter-gson:2.9.0" // Apache commons lang implementation 'org.apache.commons:commons-lang3:3.0' + implementation 'org.apache.commons:commons-collections4:4.4' } diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index 71d52db0..589ad0bb 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -1,7 +1,8 @@ package ai.knowly.langtorch.preprocessing.splitter.text; import ai.knowly.langtorch.schema.io.DomainDocument; -import ai.knowly.langtorch.schema.io.Metadatas; +import ai.knowly.langtorch.schema.io.Metadata; +import org.apache.commons.collections4.map.MultiKeyMap; import org.apache.commons.lang3.StringUtils; import javax.annotation.Nullable; @@ -27,10 +28,11 @@ public TextSplitter(int chunkSize, int chunkOverlap) { abstract public List splitText(String text); - public List createDocuments(List texts, Optional docMetadatas) { - Metadatas metadatas; + public List createDocuments(List texts, Optional> docMetadatas) { + List metadatas = + (docMetadatas.isPresent() && docMetadatas.get().size() > 0) ? + docMetadatas.get() : Collections.nCopies(texts.size(), Metadata.createEmpty()); - metadatas = docMetadatas.filter(value -> value.getValues().size() > 0).orElseGet(() -> new Metadatas(new ArrayList<>())); ArrayList documents = new ArrayList<>(); for (int i = 0; i < texts.size(); i += 1) { @@ -49,27 +51,23 @@ public List createDocuments(List texts, Optional loc; - if (i < metadatas.getValues().size() && metadatas.getValues().get(i) != null) { - if (!metadatas.getValues().get(i).isEmpty() && metadatas.getValues().get(i).get("loc") != null) { - loc = new HashMap<>(metadatas.getValues().get(i)); - } else { - loc = new HashMap<>(); - } + MultiKeyMap loc; + if (metadatas.get(i).getValue().containsKey("loc")) { + loc = metadatas.get(i).getValue(); } else { - loc = new HashMap<>(); + loc = new MultiKeyMap<>(); } - loc.put("from", String.valueOf(lineCounterIndex)); - loc.put("to", String.valueOf(lineCounterIndex + newLinesCount)); + loc.put("loc", "from", String.valueOf(lineCounterIndex)); + loc.put("loc", "to", String.valueOf(lineCounterIndex + newLinesCount)); - Map metadataWithLinesNumber = new HashMap<>(); - if (i < metadatas.getValues().size() && metadatas.getValues().get(i) != null) { - metadataWithLinesNumber.putAll(metadatas.getValues().get(i)); + Metadata metadataWithLinesNumber = Metadata.createEmpty(); + if (metadatas.get(i) != null && !metadatas.get(i).getValue().isEmpty()) { + metadataWithLinesNumber.getValue().putAll(metadatas.get(i).getValue()); } - metadataWithLinesNumber.putAll(loc); + metadataWithLinesNumber.getValue().putAll(loc); - documents.add(new DomainDocument(chunk, metadataWithLinesNumber)); + documents.add(new DomainDocument(chunk, Optional.of(metadataWithLinesNumber))); lineCounterIndex += newLinesCount; prevChunk = chunk; } @@ -82,9 +80,11 @@ public List splitDocuments(List documents) { List selectedDocs = documents.stream().filter(doc -> doc.getPageContent() != null).collect(Collectors.toList()); List texts = selectedDocs.stream().map(DomainDocument::getPageContent).collect(Collectors.toList()); - List> metaDatas = selectedDocs.stream().map(DomainDocument::getMetadata).collect(Collectors.toList()); + List metadatas = + selectedDocs.stream().map(doc -> doc.getMetadata().isPresent() ? + doc.getMetadata().get() : Metadata.createEmpty()).collect(Collectors.toList()); - return this.createDocuments(texts, Optional.of(new Metadatas(metaDatas))); + return this.createDocuments(texts, Optional.of(metadatas)); } @Nullable diff --git a/src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java b/src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java index b670e181..aa2878f6 100644 --- a/src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java +++ b/src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java @@ -1,16 +1,14 @@ package ai.knowly.langtorch.schema.io; -import javax.annotation.Nullable; -import java.util.Map; +import java.util.Optional; public class DomainDocument implements Input, Output { private final String pageContent; - @Nullable - private final Map metadata; + private final Optional metadata; - public DomainDocument(String pageContent, @Nullable Map metadata) { + public DomainDocument(String pageContent, Optional metadata) { this.pageContent = pageContent; this.metadata = metadata; } @@ -19,7 +17,7 @@ public String getPageContent() { return pageContent; } - public Map getMetadata() { + public Optional getMetadata() { return metadata; } } diff --git a/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java b/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java new file mode 100644 index 00000000..6d17612a --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java @@ -0,0 +1,19 @@ +package ai.knowly.langtorch.schema.io; + +import org.apache.commons.collections4.map.MultiKeyMap; + +public class Metadata { + private final MultiKeyMap value; + + public Metadata(MultiKeyMap values) { + this.value = values; + } + + public MultiKeyMap getValue() { + return value; + } + + public static Metadata createEmpty(){ + return new Metadata(new MultiKeyMap<>()); + } +} diff --git a/src/main/java/ai/knowly/langtorch/schema/io/Metadatas.java b/src/main/java/ai/knowly/langtorch/schema/io/Metadatas.java deleted file mode 100644 index 41e8002f..00000000 --- a/src/main/java/ai/knowly/langtorch/schema/io/Metadatas.java +++ /dev/null @@ -1,17 +0,0 @@ -package ai.knowly.langtorch.schema.io; - -import java.util.List; -import java.util.Map; - -public class Metadatas { - - private final List> values; - - public Metadatas(List> values) { - this.values = values; - } - - public List> getValues() { - return values; - } -} diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParserTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParserTest.java index e6d67875..c6568248 100644 --- a/src/test/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParserTest.java +++ b/src/test/java/ai/knowly/langtorch/preprocessing/parser/PromptTemplateToSingleTextParserTest.java @@ -1,6 +1,5 @@ package ai.knowly.langtorch.preprocessing.parser; -import ai.knowly.langtorch.preprocessing.parser.PromptTemplateToSingleTextParser; import ai.knowly.langtorch.prompt.template.PromptTemplate; import ai.knowly.langtorch.schema.text.SingleText; import com.google.common.truth.Truth; diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java index 52f66e9a..522eab6e 100644 --- a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java +++ b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java @@ -1,19 +1,17 @@ package ai.knowly.langtorch.preprocessing.splitter.text; -import ai.knowly.langtorch.preprocessing.splitter.text.CharacterTextSplitter; -import ai.knowly.langtorch.preprocessing.splitter.text.RecursiveCharacterTextSplitter; import ai.knowly.langtorch.schema.io.DomainDocument; -import ai.knowly.langtorch.schema.io.Metadatas; +import ai.knowly.langtorch.schema.io.Metadata; +import com.google.common.truth.Truth; import org.junit.Test; import java.util.*; -import static org.junit.jupiter.api.Assertions.assertEquals; public class TextSplitterTest { @Test - public void testCharacterTextSplitter_splitByCharacterCount(){ + public void testCharacterTextSplitter_splitByCharacterCount() { // Arrange. String text = "foo bar baz 123"; CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 7, 3); @@ -22,9 +20,9 @@ public void testCharacterTextSplitter_splitByCharacterCount(){ List result = splitter.splitText(text); // Assert. - List expected = new ArrayList<>(Arrays.asList("foo bar", "bar baz", "baz 123")); + List expected = new ArrayList<>(Arrays.asList("foo bar", "bar baz", "baz 123")); - assertEquals(expected, result); + Truth.assertThat(Objects.equals(expected, result)); } @Test @@ -39,7 +37,7 @@ public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments( // Assert. List expected = new ArrayList<>(Arrays.asList("foo", "bar")); - assertEquals(expected, result); + Truth.assertThat(Objects.equals(expected, result)); } @Test @@ -54,7 +52,7 @@ public void testCharacterTextSplitter_splitByCharacterCountLongWords() { // Assert. List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "a a")); - assertEquals(expected, result); + Truth.assertThat(Objects.equals(expected, result)); } @Test @@ -69,7 +67,7 @@ public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { // Assert. List expected = new ArrayList<>(Arrays.asList("a a", "foo", "bar", "baz")); - assertEquals(expected, result); + Truth.assertThat(Objects.equals(expected, result)); } @Test @@ -84,7 +82,7 @@ public void testCharacterTextSplitter_splitByCharactersSplitsNotFoundEasily() { // Assert. List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "123")); - assertEquals(expected, result); + Truth.assertThat(Objects.equals(expected, result)); } @Test(expected = IllegalArgumentException.class) @@ -99,73 +97,62 @@ public void testCharacterTextSplitter_invalidArguments() { // Expect IllegalArgumentException to be thrown. } - //TODO, this unit test will need improving. atm it only checks that length of our list of documents, it does not check the contents @Test public void testCharacterTextSplitter_createDocuments() { // Arrange. List texts = Arrays.asList("foo bar", "baz"); CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 0); - Map metadata = new HashMap<>(); + Metadata metadata = Metadata.createEmpty(); - Map loc = new HashMap<>(); - loc.put("from", String.valueOf(1)); - loc.put("to", String.valueOf(1)); + List metadatas = Arrays.asList(metadata, metadata); - Optional metadatas = Optional.of(new Metadatas(Arrays.asList(metadata, metadata))); // Act. - List docs = splitter.createDocuments(texts, metadatas); + List docs = splitter.createDocuments(texts, Optional.of(metadatas)); // Assert. List expectedDocs = Arrays.asList( - new DomainDocument("foo", metadata), - new DomainDocument("bar", metadata), - new DomainDocument("baz", metadata) + new DomainDocument("foo", Optional.of(metadata)), + new DomainDocument("bar", Optional.of(metadata)), + new DomainDocument("baz", Optional.of(metadata)) ); - assertEquals(expectedDocs.size(), docs.size()); + Truth.assertThat(expectedDocs.size() == docs.size()); + for (int i = 0; i < docs.size(); i++) { + Truth.assertThat(Objects.equals(docs.get(i).getPageContent(), expectedDocs.get(i).getPageContent())); + } } - //TODO, this unit test will need improving. atm it only checks that length of our list of documents, it does not check the contents @Test public void testCharacterTextSplitter_createDocumentsWithMetadata() { // Arrange. List texts = Arrays.asList("foo bar", "baz"); CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 0); - List> metadataList = Arrays.asList( - new HashMap() {{ - put("source", "1"); - }}, - new HashMap() {{ - put("source", "2"); - }} - ); - Optional metadatas = Optional.of(new Metadatas(metadataList)); + Metadata metadata = Metadata.createEmpty(); + + metadata.getValue().put("source", "doc", "1"); + metadata.getValue().put("loc", "from", "1"); + metadata.getValue().put("loc", "to", "1"); + + List metadataList = Arrays.asList(metadata, metadata); + + Optional> metadatas = Optional.of(metadataList); // Act. List docs = splitter.createDocuments(texts, metadatas); // Assert. List expectedDocs = Arrays.asList( - new DomainDocument("foo", new HashMap() {{ - put("source", "1"); - put("from", String.valueOf(1)); - put("to", String.valueOf(1)); - }}), - new DomainDocument("bar", new HashMap() {{ - put("source", "1"); - put("from", String.valueOf(1)); - put("to", String.valueOf(1)); - }}), - new DomainDocument("baz", new HashMap() {{ - put("source", "2"); - put("from", String.valueOf(1)); - put("to", String.valueOf(1)); - }}) + new DomainDocument("foo", Optional.of(metadata)), + new DomainDocument("bar", Optional.of(metadata)), + new DomainDocument("baz", Optional.of(metadata)) ); - assertEquals(expectedDocs.size(), docs.size()); + Truth.assertThat(expectedDocs.size() == docs.size()); + for (int i = 0; i < docs.size(); i++) { + Truth.assertThat(Objects.equals(docs.get(i).getPageContent(), expectedDocs.get(i).getPageContent())); + } } @@ -199,7 +186,7 @@ public void testRecursiveCharacterTextSplitter_iterativeTextSplitter() { "Bye!\n\n-H." ); - assertEquals(expectedOutput, output); + Truth.assertThat(Objects.equals(expectedOutput, output)); } @Test @@ -208,7 +195,7 @@ public void testTextSplitter_iterativeTextSplitter_linesLoc() { String text = "Hi.\nI'm Harrison.\n\nHow?\na\nb"; RecursiveCharacterTextSplitter splitter = new RecursiveCharacterTextSplitter(null, 20, 1); - Optional metadatas = Optional.ofNullable(null); + Optional> metadatas = Optional.ofNullable(null); // Act. List docs = splitter.createDocuments(Collections.singletonList(text), metadatas); @@ -217,12 +204,10 @@ public void testTextSplitter_iterativeTextSplitter_linesLoc() { DomainDocument doc2 = new DomainDocument("How?\na\nb", null); List expectedDocs = Arrays.asList(doc1, doc2); - assertEquals(expectedDocs.size(), docs.size()); + Truth.assertThat(expectedDocs.size() == docs.size()); + Truth.assertThat(Objects.equals(expectedDocs.get(0).getPageContent(), docs.get(0).getPageContent())); + Truth.assertThat(Objects.equals(expectedDocs.get(1).getPageContent(), docs.get(1).getPageContent())); } - - - - } From a61c4744f838438fa938e7903d41e22bcf8ebc53 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Fri, 26 May 2023 20:36:41 +0100 Subject: [PATCH 06/14] Fixed sonar --- .../langtorch/preprocessing/splitter/text/TextSplitter.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index 589ad0bb..c9946051 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -2,6 +2,7 @@ import ai.knowly.langtorch.schema.io.DomainDocument; import ai.knowly.langtorch.schema.io.Metadata; +import org.apache.commons.collections4.keyvalue.MultiKey; import org.apache.commons.collections4.map.MultiKeyMap; import org.apache.commons.lang3.StringUtils; @@ -52,7 +53,8 @@ public List createDocuments(List texts, Optional loc; - if (metadatas.get(i).getValue().containsKey("loc")) { + //TODO: need to end to end test how metadata is being passed back and forth + if (metadatas.get(i).getValue().containsKey("loc", "")) { loc = metadatas.get(i).getValue(); } else { loc = new MultiKeyMap<>(); From bfcdd04b9f55658d786ebb188d334d3d0d035e83 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Sat, 27 May 2023 07:31:17 +0100 Subject: [PATCH 07/14] WIP --- ...er.java => RecursiveWordTextSplitter.java} | 10 +++---- .../splitter/text/TextSplitter.java | 27 +++++++++++-------- ...terTextSplitter.java => WordSplitter.java} | 4 +-- .../splitter/text/TextSplitterTest.java | 22 +++++++-------- 4 files changed, 34 insertions(+), 29 deletions(-) rename src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/{RecursiveCharacterTextSplitter.java => RecursiveWordTextSplitter.java} (86%) rename src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/{CharacterTextSplitter.java => WordSplitter.java} (90%) diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveWordTextSplitter.java similarity index 86% rename from src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java rename to src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveWordTextSplitter.java index 694ffc77..115bfaf4 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveCharacterTextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveWordTextSplitter.java @@ -5,24 +5,24 @@ import java.util.Arrays; import java.util.List; /** - The RecursiveCharacterTextSplitter class is a concrete implementation of the TextSplitter abstract class + The RecursiveWordTextSplitter class is a concrete implementation of the TextSplitter abstract class that recursively splits text into chunks using a set of separators. It applies a recursive splitting approach to handle longer texts by examining the text and selecting the appropriate separator from the list of separators based on their presence in the text. If the text is longer than the chunk size, it recursively splits the longer portions into smaller chunks. This recursive process continues until the chunks reach a size smaller than the specified chunk size. */ -public class RecursiveCharacterTextSplitter extends TextSplitter { +public class RecursiveWordTextSplitter extends TextSplitter { private static List separators = Arrays.asList("\n\n", "\n", " ", ""); /** - Constructs a RecursiveCharacterTextSplitter object with the given list of separators, chunk size, and chunk overlap. + Constructs a RecursiveWordTextSplitter object with the given list of separators, chunk size, and chunk overlap. If the separators list is null, the default list containing separators "\n\n", "\n", " ", and "" is used. @param separators The list of separators used for splitting the text into chunks. @param chunkSize The size of each chunk. @param chunkOverlap The amount of overlap between adjacent chunks. */ - public RecursiveCharacterTextSplitter(@Nullable List separators, int chunkSize, int chunkOverlap) { + public RecursiveWordTextSplitter(@Nullable List separators, int chunkSize, int chunkOverlap) { super(chunkSize, chunkOverlap); if (separators != null) { this.separators = separators; @@ -60,7 +60,7 @@ public List splitText(String text) { // Now go merging things, recursively splitting longer texts List goodSplits = new ArrayList<>(); for (String s : splits) { - if (s.length() < chunkSize) { + if (s.length() < wordCount) { goodSplits.add(s); } else { if (!goodSplits.isEmpty()) { diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index c9946051..b7600fbc 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -2,7 +2,6 @@ import ai.knowly.langtorch.schema.io.DomainDocument; import ai.knowly.langtorch.schema.io.Metadata; -import org.apache.commons.collections4.keyvalue.MultiKey; import org.apache.commons.collections4.map.MultiKeyMap; import org.apache.commons.lang3.StringUtils; @@ -15,14 +14,20 @@ */ public abstract class TextSplitter { - public final int chunkSize; + /** + * The amount of words inside one chunk + */ + public final int wordCount; - public final int chunkOverlap; + /** + * amount of words from previous chunk, it will be empty for the first chunk + */ + public final int wordOverlap; - public TextSplitter(int chunkSize, int chunkOverlap) { - this.chunkSize = chunkSize; - this.chunkOverlap = chunkOverlap; - if (this.chunkOverlap >= this.chunkSize) { + public TextSplitter(int wordCount, int wordOverlap) { + this.wordCount = wordCount; + this.wordOverlap = wordOverlap; + if (this.wordOverlap >= this.wordCount) { throw new IllegalArgumentException("chunkOverlap cannot be equal to or greater than chunkSize"); } } @@ -103,9 +108,9 @@ public List mergeSplits(List splits, String separator) { for (String d : splits) { int length = d.length(); - if (total + length + (currentDoc.size() > 0 ? separator.length() : 0) > this.chunkSize) { - if (total > this.chunkSize) { - System.out.println("Created a chunk of size " + total + ", which is longer than the specified " + this.chunkSize); + if (total + length + (currentDoc.size() > 0 ? separator.length() : 0) > this.wordCount) { + if (total > this.wordCount) { + System.out.println("Created a chunk of size " + total + ", which is longer than the specified " + this.wordCount); } if (currentDoc.size() > 0) { @@ -114,7 +119,7 @@ public List mergeSplits(List splits, String separator) { docs.add(doc); } - while (total > this.chunkOverlap || (total + length > this.chunkSize && total > 0)) { + while (total > this.wordOverlap || (total + length > this.wordCount && total > 0)) { total -= currentDoc.get(0).length(); currentDoc.remove(0); } diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java similarity index 90% rename from src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java rename to src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java index 21e4b593..13c0e6fc 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/CharacterTextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java @@ -10,7 +10,7 @@ that splits text into chunks based on a specified separator. */ -public class CharacterTextSplitter extends TextSplitter { +public class WordSplitter extends TextSplitter { private static String separator = "\n\n"; @@ -22,7 +22,7 @@ public class CharacterTextSplitter extends TextSplitter { @param chunkSize The size of each chunk. @param chunkOverlap The amount of overlap between adjacent chunks. */ - public CharacterTextSplitter(@Nullable String separator, int chunkSize, int chunkOverlap) { + public WordSplitter(@Nullable String separator, int chunkSize, int chunkOverlap) { super(chunkSize, chunkOverlap); if (separator != null) { this.separator = separator; diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java index 522eab6e..5c186178 100644 --- a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java +++ b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java @@ -14,7 +14,7 @@ public class TextSplitterTest { public void testCharacterTextSplitter_splitByCharacterCount() { // Arrange. String text = "foo bar baz 123"; - CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 7, 3); + WordSplitter splitter = new WordSplitter(" ", 7, 3); // Act. List result = splitter.splitText(text); @@ -28,8 +28,8 @@ public void testCharacterTextSplitter_splitByCharacterCount() { @Test public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments() { // Arrange. - String text = "foo bar"; - CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 2, 0); + String text = "i,love,langtorchisveryxxxxxxxxx"; + WordSplitter splitter = new WordSplitter(" ", 2, 0); // Act. List result = splitter.splitText(text); @@ -44,7 +44,7 @@ public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments( public void testCharacterTextSplitter_splitByCharacterCountLongWords() { // Arrange. String text = "foo bar baz a a"; - CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 1); + WordSplitter splitter = new WordSplitter(" ", 3, 1); // Act. List result = splitter.splitText(text); @@ -59,7 +59,7 @@ public void testCharacterTextSplitter_splitByCharacterCountLongWords() { public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { // Arrange. String text = "a a foo bar baz"; - CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 1); + WordSplitter splitter = new WordSplitter(" ", 3, 1); // Act. List result = splitter.splitText(text); @@ -74,7 +74,7 @@ public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { public void testCharacterTextSplitter_splitByCharactersSplitsNotFoundEasily() { // Arrange. String text = "foo bar baz 123"; - CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 1, 0); + WordSplitter splitter = new WordSplitter(" ", 1, 0); // Act. List result = splitter.splitText(text); @@ -92,7 +92,7 @@ public void testCharacterTextSplitter_invalidArguments() { int chunkOverlap = 4; // Act. - new CharacterTextSplitter(null, chunkSize, chunkOverlap); + new WordSplitter(null, chunkSize, chunkOverlap); // Expect IllegalArgumentException to be thrown. } @@ -101,7 +101,7 @@ public void testCharacterTextSplitter_invalidArguments() { public void testCharacterTextSplitter_createDocuments() { // Arrange. List texts = Arrays.asList("foo bar", "baz"); - CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 0); + WordSplitter splitter = new WordSplitter(" ", 3, 0); Metadata metadata = Metadata.createEmpty(); List metadatas = Arrays.asList(metadata, metadata); @@ -126,7 +126,7 @@ public void testCharacterTextSplitter_createDocuments() { public void testCharacterTextSplitter_createDocumentsWithMetadata() { // Arrange. List texts = Arrays.asList("foo bar", "baz"); - CharacterTextSplitter splitter = new CharacterTextSplitter(" ", 3, 0); + WordSplitter splitter = new WordSplitter(" ", 3, 0); Metadata metadata = Metadata.createEmpty(); @@ -160,7 +160,7 @@ public void testCharacterTextSplitter_createDocumentsWithMetadata() { public void testRecursiveCharacterTextSplitter_iterativeTextSplitter() { // Arrange. String text = "Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.\nThis is a weird text to write, but gotta test the splittingggg some how.\n\nBye!\n\n-H."; - RecursiveCharacterTextSplitter splitter = new RecursiveCharacterTextSplitter(null, 10, 1); + RecursiveWordTextSplitter splitter = new RecursiveWordTextSplitter(null, 10, 1); // Act. List output = splitter.splitText(text); @@ -193,7 +193,7 @@ public void testRecursiveCharacterTextSplitter_iterativeTextSplitter() { public void testTextSplitter_iterativeTextSplitter_linesLoc() { // Arrange. String text = "Hi.\nI'm Harrison.\n\nHow?\na\nb"; - RecursiveCharacterTextSplitter splitter = new RecursiveCharacterTextSplitter(null, 20, 1); + RecursiveWordTextSplitter splitter = new RecursiveWordTextSplitter(null, 20, 1); Optional> metadatas = Optional.ofNullable(null); // Act. From 3e6632e43c1b528824cd5b5944341545e41c23fc Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Sat, 27 May 2023 10:21:32 +0100 Subject: [PATCH 08/14] Updated Unit tests --- .../text/RecursiveWordTextSplitter.java | 83 ------ .../splitter/text/WordSplitter.java | 11 +- .../knowly/langtorch/schema/io/Metadata.java | 20 ++ .../splitter/text/TextSplitterTest.java | 213 ---------------- .../splitter/text/WordSplitterTest.java | 238 ++++++++++++++++++ 5 files changed, 264 insertions(+), 301 deletions(-) delete mode 100644 src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveWordTextSplitter.java delete mode 100644 src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java create mode 100644 src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveWordTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveWordTextSplitter.java deleted file mode 100644 index 115bfaf4..00000000 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/RecursiveWordTextSplitter.java +++ /dev/null @@ -1,83 +0,0 @@ -package ai.knowly.langtorch.preprocessing.splitter.text; - -import javax.annotation.Nullable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -/** - The RecursiveWordTextSplitter class is a concrete implementation of the TextSplitter abstract class - that recursively splits text into chunks using a set of separators. - It applies a recursive splitting approach to handle longer texts by examining the text and selecting the appropriate - separator from the list of separators based on their presence in the text. - If the text is longer than the chunk size, it recursively splits the longer portions into smaller chunks. This recursive process continues until the chunks reach a size smaller than the specified chunk size. - */ -public class RecursiveWordTextSplitter extends TextSplitter { - - private static List separators = Arrays.asList("\n\n", "\n", " ", ""); - - /** - Constructs a RecursiveWordTextSplitter object with the given list of separators, chunk size, and chunk overlap. - If the separators list is null, the default list containing separators "\n\n", "\n", " ", and "" is used. - @param separators The list of separators used for splitting the text into chunks. - @param chunkSize The size of each chunk. - @param chunkOverlap The amount of overlap between adjacent chunks. - */ - public RecursiveWordTextSplitter(@Nullable List separators, int chunkSize, int chunkOverlap) { - super(chunkSize, chunkOverlap); - if (separators != null) { - this.separators = separators; - } - } - - /** - Splits the given text into chunks using a recursive splitting approach. - It selects an appropriate separator from the list of separators based on the presence of each separator in the text. - It recursively splits longer pieces of text into smaller chunks. - @param text The text to be split into chunks. - @return A list of strings representing the chunks of the text. - */ - @Override - public List splitText(String text) { - List finalChunks = new ArrayList<>(); - - // Get appropriate separator to use - String separator = separators.get(separators.size() - 1); - for (String s : separators) { - if (s.isEmpty() || text.contains(s)) { - separator = s; - break; - } - } - - // Now that we have the separator, split the text - String[] splits; - if (!separator.isEmpty()) { - splits = text.split(separator); - } else { - splits = text.split(""); - } - - // Now go merging things, recursively splitting longer texts - List goodSplits = new ArrayList<>(); - for (String s : splits) { - if (s.length() < wordCount) { - goodSplits.add(s); - } else { - if (!goodSplits.isEmpty()) { - List mergedText = mergeSplits(goodSplits, separator); - finalChunks.addAll(mergedText); - goodSplits.clear(); - } - List otherInfo = splitText(s); - finalChunks.addAll(otherInfo); - - } - } - if (!goodSplits.isEmpty()) { - List mergedText = mergeSplits(goodSplits, separator); - finalChunks.addAll(mergedText); - } - return finalChunks; - - } -} diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java index 13c0e6fc..ca813987 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java @@ -12,18 +12,18 @@ public class WordSplitter extends TextSplitter { - private static String separator = "\n\n"; + private String separator = "\n\n"; /** Constructs a CharacterTextSplitter object with the given separator, chunk size, and chunk overlap. If the separator is null, the default separator "\n\n" is used. @param separator The separator used for splitting the text into chunks. - @param chunkSize The size of each chunk. - @param chunkOverlap The amount of overlap between adjacent chunks. + @param wordCount The size of each chunk. + @param wordOverlap The amount of overlap between adjacent chunks. */ - public WordSplitter(@Nullable String separator, int chunkSize, int chunkOverlap) { - super(chunkSize, chunkOverlap); + public WordSplitter(@Nullable String separator, int wordCount, int wordOverlap) { + super(wordCount, wordOverlap); if (separator != null) { this.separator = separator; } @@ -36,6 +36,7 @@ public WordSplitter(@Nullable String separator, int chunkSize, int chunkOverlap) */ @Override public List splitText(String text) { + List splits = Arrays.asList(StringUtils.splitByWholeSeparatorPreserveAllTokens(text, this.separator.isEmpty() ? "" : this.separator)); return mergeSplits(splits, this.separator); diff --git a/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java b/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java index 6d17612a..0af0a37e 100644 --- a/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java +++ b/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java @@ -2,6 +2,8 @@ import org.apache.commons.collections4.map.MultiKeyMap; +import java.util.Objects; + public class Metadata { private final MultiKeyMap value; @@ -16,4 +18,22 @@ public MultiKeyMap getValue() { public static Metadata createEmpty(){ return new Metadata(new MultiKeyMap<>()); } + + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Metadata other = (Metadata) obj; + return Objects.equals(value, other.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } } diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java deleted file mode 100644 index 5c186178..00000000 --- a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitterTest.java +++ /dev/null @@ -1,213 +0,0 @@ -package ai.knowly.langtorch.preprocessing.splitter.text; - -import ai.knowly.langtorch.schema.io.DomainDocument; -import ai.knowly.langtorch.schema.io.Metadata; -import com.google.common.truth.Truth; -import org.junit.Test; - -import java.util.*; - - -public class TextSplitterTest { - - @Test - public void testCharacterTextSplitter_splitByCharacterCount() { - // Arrange. - String text = "foo bar baz 123"; - WordSplitter splitter = new WordSplitter(" ", 7, 3); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("foo bar", "bar baz", "baz 123")); - - Truth.assertThat(Objects.equals(expected, result)); - } - - @Test - public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments() { - // Arrange. - String text = "i,love,langtorchisveryxxxxxxxxx"; - WordSplitter splitter = new WordSplitter(" ", 2, 0); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("foo", "bar")); - - Truth.assertThat(Objects.equals(expected, result)); - } - - @Test - public void testCharacterTextSplitter_splitByCharacterCountLongWords() { - // Arrange. - String text = "foo bar baz a a"; - WordSplitter splitter = new WordSplitter(" ", 3, 1); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "a a")); - - Truth.assertThat(Objects.equals(expected, result)); - } - - @Test - public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { - // Arrange. - String text = "a a foo bar baz"; - WordSplitter splitter = new WordSplitter(" ", 3, 1); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("a a", "foo", "bar", "baz")); - - Truth.assertThat(Objects.equals(expected, result)); - } - - @Test - public void testCharacterTextSplitter_splitByCharactersSplitsNotFoundEasily() { - // Arrange. - String text = "foo bar baz 123"; - WordSplitter splitter = new WordSplitter(" ", 1, 0); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "123")); - - Truth.assertThat(Objects.equals(expected, result)); - } - - @Test(expected = IllegalArgumentException.class) - public void testCharacterTextSplitter_invalidArguments() { - // Arrange. - int chunkSize = 2; - int chunkOverlap = 4; - - // Act. - new WordSplitter(null, chunkSize, chunkOverlap); - - // Expect IllegalArgumentException to be thrown. - } - - @Test - public void testCharacterTextSplitter_createDocuments() { - // Arrange. - List texts = Arrays.asList("foo bar", "baz"); - WordSplitter splitter = new WordSplitter(" ", 3, 0); - Metadata metadata = Metadata.createEmpty(); - - List metadatas = Arrays.asList(metadata, metadata); - - // Act. - List docs = splitter.createDocuments(texts, Optional.of(metadatas)); - - // Assert. - List expectedDocs = Arrays.asList( - new DomainDocument("foo", Optional.of(metadata)), - new DomainDocument("bar", Optional.of(metadata)), - new DomainDocument("baz", Optional.of(metadata)) - ); - - Truth.assertThat(expectedDocs.size() == docs.size()); - for (int i = 0; i < docs.size(); i++) { - Truth.assertThat(Objects.equals(docs.get(i).getPageContent(), expectedDocs.get(i).getPageContent())); - } - } - - @Test - public void testCharacterTextSplitter_createDocumentsWithMetadata() { - // Arrange. - List texts = Arrays.asList("foo bar", "baz"); - WordSplitter splitter = new WordSplitter(" ", 3, 0); - - - Metadata metadata = Metadata.createEmpty(); - - metadata.getValue().put("source", "doc", "1"); - metadata.getValue().put("loc", "from", "1"); - metadata.getValue().put("loc", "to", "1"); - - List metadataList = Arrays.asList(metadata, metadata); - - Optional> metadatas = Optional.of(metadataList); - - // Act. - List docs = splitter.createDocuments(texts, metadatas); - - // Assert. - List expectedDocs = Arrays.asList( - new DomainDocument("foo", Optional.of(metadata)), - new DomainDocument("bar", Optional.of(metadata)), - new DomainDocument("baz", Optional.of(metadata)) - ); - - Truth.assertThat(expectedDocs.size() == docs.size()); - for (int i = 0; i < docs.size(); i++) { - Truth.assertThat(Objects.equals(docs.get(i).getPageContent(), expectedDocs.get(i).getPageContent())); - } - } - - - @Test - public void testRecursiveCharacterTextSplitter_iterativeTextSplitter() { - // Arrange. - String text = "Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.\nThis is a weird text to write, but gotta test the splittingggg some how.\n\nBye!\n\n-H."; - RecursiveWordTextSplitter splitter = new RecursiveWordTextSplitter(null, 10, 1); - - // Act. - List output = splitter.splitText(text); - - // Assert. - List expectedOutput = Arrays.asList( - "Hi.", - "I'm", - "Harrison.", - "How? Are?", - "You?", - "Okay then f", - "f f f f.", - "This is a", - "a weird", - "text to", - "write, but", - "gotta test", - "the", - "splittingg", - "ggg", - "some how.", - "Bye!\n\n-H." - ); - - Truth.assertThat(Objects.equals(expectedOutput, output)); - } - - @Test - public void testTextSplitter_iterativeTextSplitter_linesLoc() { - // Arrange. - String text = "Hi.\nI'm Harrison.\n\nHow?\na\nb"; - RecursiveWordTextSplitter splitter = new RecursiveWordTextSplitter(null, 20, 1); - - Optional> metadatas = Optional.ofNullable(null); - // Act. - List docs = splitter.createDocuments(Collections.singletonList(text), metadatas); - - // Assert. - DomainDocument doc1 = new DomainDocument("Hi.\nI'm Harrison.", null); - DomainDocument doc2 = new DomainDocument("How?\na\nb", null); - List expectedDocs = Arrays.asList(doc1, doc2); - - Truth.assertThat(expectedDocs.size() == docs.size()); - Truth.assertThat(Objects.equals(expectedDocs.get(0).getPageContent(), docs.get(0).getPageContent())); - Truth.assertThat(Objects.equals(expectedDocs.get(1).getPageContent(), docs.get(1).getPageContent())); - } - - -} diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java new file mode 100644 index 00000000..78b9b252 --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java @@ -0,0 +1,238 @@ +package ai.knowly.langtorch.preprocessing.splitter.text; + +import ai.knowly.langtorch.schema.io.DomainDocument; +import ai.knowly.langtorch.schema.io.Metadata; +import com.google.common.truth.Truth; +import org.junit.Test; + +import java.util.*; + + +public class WordSplitterTest { + + @Test + public void testWordSplitter_realWorldText() { + WordSplitter splitter = new WordSplitter(null, 1000, 100); + + List result = splitter.splitText(sampleText()); + List expectedResult = sampleTextExpectedSplit(); + + for (int i = 0; i < result.size(); i++) { + Truth.assertThat(result.get(i)).isEqualTo(expectedResult.get(i)); + } + } + @Test + public void testWordSplitter_splitByWordCount() { + // Arrange. + String text = "foo bar baz 123"; + WordSplitter splitter = new WordSplitter(" ", 7, 3); + + // Act. + List result = splitter.splitText(text); + + // Assert. + List expected = new ArrayList<>(Arrays.asList("foo bar", "bar baz", "baz 123")); + + Truth.assertThat(result.size()).isEqualTo(expected.size()); + for (int i = 0; i < expected.size(); i++) { + Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); + } + } + + @Test + public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments() { + // Arrange. + String text = "foo bar"; + WordSplitter splitter = new WordSplitter(" ", 2, 0); + + // Act. + List result = splitter.splitText(text); + + // Assert. + List expected = new ArrayList<>(Arrays.asList("foo", "bar")); + + for (int i = 0; i < expected.size(); i++) { + Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); + } + } + + @Test + public void testCharacterTextSplitter_splitByCharacterCountLongWords() { + // Arrange. + String text = "foo bar baz a a"; + WordSplitter splitter = new WordSplitter(" ", 3, 1); + + // Act. + List result = splitter.splitText(text); + + // Assert. + List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "a a")); + + for (int i = 0; i < expected.size(); i++) { + Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); + } + } + + @Test + public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { + // Arrange. + String text = "a a foo bar baz"; + WordSplitter splitter = new WordSplitter(" ", 3, 1); + + // Act. + List result = splitter.splitText(text); + + // Assert. + List expected = new ArrayList<>(Arrays.asList("a a", "foo", "bar", "baz")); + + for (int i = 0; i < expected.size(); i++) { + Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); + } + } + + @Test + public void testCharacterTextSplitter_splitByCharactersSplitsNotFoundEasily() { + // Arrange. + String text = "foo bar baz 123"; + WordSplitter splitter = new WordSplitter(" ", 1, 0); + + // Act. + List result = splitter.splitText(text); + + // Assert. + List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "123")); + + for (int i = 0; i < expected.size(); i++) { + Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); + } + } + + @Test(expected = IllegalArgumentException.class) + public void testCharacterTextSplitter_invalidArguments() { + // Arrange. + int chunkSize = 2; + int chunkOverlap = 4; + + // Act. + new WordSplitter(null, chunkSize, chunkOverlap); + + // Expect IllegalArgumentException to be thrown. + } + + @Test + public void testWordSplitter_createDocuments() { + // Arrange. + List texts = Arrays.asList("foo bar", "baz"); + WordSplitter splitter = new WordSplitter(" ", 3, 0); + Metadata metadata = Metadata.createEmpty(); + + List metadatas = Arrays.asList(metadata, metadata); + + // Act. + List docs = splitter.createDocuments(texts, Optional.of(metadatas)); + + // Assert. + List expectedDocs = Arrays.asList( + new DomainDocument("foo", Optional.of(metadata)), + new DomainDocument("bar", Optional.of(metadata)), + new DomainDocument("baz", Optional.of(metadata)) + ); + + Truth.assertThat(expectedDocs.size() == docs.size()); + for (int i = 0; i < docs.size(); i++) { + Truth.assertThat(docs.get(i).getPageContent()).isEqualTo(expectedDocs.get(i).getPageContent()); + } + } + + @Test + public void testWordSplitter_createDocumentsWithMetadata() { + // Arrange. + List texts = Arrays.asList("foo bar", "baz"); + WordSplitter splitter = new WordSplitter(" ", 3, 0); + + + Metadata metadata = Metadata.createEmpty(); + + metadata.getValue().put("source", "doc", "1"); + metadata.getValue().put("loc", "from", "1"); + metadata.getValue().put("loc", "to", "1"); + + List metadataList = Arrays.asList(metadata, metadata); + + Optional> metadatas = Optional.of(metadataList); + + // Act. + List docs = splitter.createDocuments(texts, metadatas); + + // Assert. + List expectedDocs = Arrays.asList( + new DomainDocument("foo", Optional.of(metadata)), + new DomainDocument("bar", Optional.of(metadata)), + new DomainDocument("baz", Optional.of(metadata)) + ); + + Truth.assertThat(docs.size()).isEqualTo(expectedDocs.size()); + for (int i = 0; i < docs.size(); i++) { + Truth.assertThat(docs.get(i).getPageContent()).isEqualTo(expectedDocs.get(i).getPageContent()); + Truth.assertThat(docs.get(i).getMetadata()).isEqualTo(expectedDocs.get(i).getMetadata()); + } + } + + + + private List sampleTextExpectedSplit() { + return Arrays.asList( + "Langtorch one pager\n" + + "Langtorch is a Java framework that assists you in developing large language model applications. It is designed with reusability, composability and Fluent style in mind. It can aid you in developing workflows or pipelines that include large language models.\n" + + "\n" + + "Processor\n" + + "In Langtorch, we introduce the concept of a processor. A processor is a container for the smallest computational unit in Langtorch. The response produced by the processor can either originate from a large language model, such as OpenAI's GPT model(retrieved by rest api), or it could be a deterministic Java function.", + "A processor is an interface that includes two functions: run() and runAsync(). Anything that implements these two functions can be considered a processor. For instance, a processor could be something that sends an HTTP request to OpenAI to invoke its GPT model and generate a response. It could also be a calculator function, where the input is 1+1, and the output is 2.\n" + + "Using this approach, we can conveniently add a processor, such as the soon-to-be-publicly-available Google PALM 2 API. At the same time, when we chain different processors together, we can leverage this to avoid some of the shortcomings of large language models (LLMs). For instance, when we want to implement a chatbot, if a user asks a mathematical question, we can parse this question using the LLM's capabilities into an input for our calculator to get an accurate answer, rather than letting the LLM come to a conclusion directly.", + "Note: The processor is the smallest computational unit in Langtorch, so a processor is generally only allowed to process a single task. For example, it could have the ability to handle text completion, chat completion, or generate images based on a prompt. If the requirements are complex, such as first generating a company's slogan through text completion, and then generating an image based on the slogan, this should be accomplished by chaining different processors together, rather than completing everything within a single processor.\n" + + "\n" + + "Capability\n" + + "As previously mentioned, the processor is the smallest container of a computational unit, and often it is not sufficient to handle all situations. We need to enhance the processor!\n" + + "Here we introduce the concept of Capability. If the processor is likened to an internal combustion steam engine, then a Capability could be a steam train based on the steam engine, or a electricity generator based on the steam engine.", + "Imagine that you are implementing a chatbot. If the processor is based on OpenAI's API, sending every user's input to the OpenAI GPT-4 model and returning its response, what would the user experience be like?\n" + + "\n" + + "The reason is that the chatbot does not incorporate chat history. Therefore, in capability, we can add memory (a simple implementation of memory is to put all conversation records into the request sent to OpenAI).\n" + + "\n" + + "Workflow(chaining Capabilities)\n" + + "To make the combination of capabilities easier, we introduce the concept of a Node Adapter, and we refer to capabilities nodes composition as a Capability graph.\n" + + "However, the capability graph can only be a Directed Acyclic Graph (DAG), i.e., there are no cycles allowed.\n" + + "\n" + + "The Node Adapter is primarily used for validation and optimization of the Capability graph. It wraps the capability and also includes some information about the Capability graph, such as the current node's globally unique ID, what the next nodes are, and so on." + ); + + } + + private String sampleText() { + return "Langtorch one pager\n" + + "Langtorch is a Java framework that assists you in developing large language model applications. It is designed with reusability, composability and Fluent style in mind. It can aid you in developing workflows or pipelines that include large language models.\n" + + "\n" + + "Processor\n" + + "In Langtorch, we introduce the concept of a processor. A processor is a container for the smallest computational unit in Langtorch. The response produced by the processor can either originate from a large language model, such as OpenAI's GPT model(retrieved by rest api), or it could be a deterministic Java function.\n" + + "\n" + + "A processor is an interface that includes two functions: run() and runAsync(). Anything that implements these two functions can be considered a processor. For instance, a processor could be something that sends an HTTP request to OpenAI to invoke its GPT model and generate a response. It could also be a calculator function, where the input is 1+1, and the output is 2.\n" + + "Using this approach, we can conveniently add a processor, such as the soon-to-be-publicly-available Google PALM 2 API. At the same time, when we chain different processors together, we can leverage this to avoid some of the shortcomings of large language models (LLMs). For instance, when we want to implement a chatbot, if a user asks a mathematical question, we can parse this question using the LLM's capabilities into an input for our calculator to get an accurate answer, rather than letting the LLM come to a conclusion directly.\n" + + "\n" + + "Note: The processor is the smallest computational unit in Langtorch, so a processor is generally only allowed to process a single task. For example, it could have the ability to handle text completion, chat completion, or generate images based on a prompt. If the requirements are complex, such as first generating a company's slogan through text completion, and then generating an image based on the slogan, this should be accomplished by chaining different processors together, rather than completing everything within a single processor.\n" + + "\n" + + "Capability\n" + + "As previously mentioned, the processor is the smallest container of a computational unit, and often it is not sufficient to handle all situations. We need to enhance the processor!\n" + + "Here we introduce the concept of Capability. If the processor is likened to an internal combustion steam engine, then a Capability could be a steam train based on the steam engine, or a electricity generator based on the steam engine.\n" + + "\n" + + "Imagine that you are implementing a chatbot. If the processor is based on OpenAI's API, sending every user's input to the OpenAI GPT-4 model and returning its response, what would the user experience be like?\n" + + "\n" + + "The reason is that the chatbot does not incorporate chat history. Therefore, in capability, we can add memory (a simple implementation of memory is to put all conversation records into the request sent to OpenAI).\n" + + "\n" + + "Workflow(chaining Capabilities)\n" + + "To make the combination of capabilities easier, we introduce the concept of a Node Adapter, and we refer to capabilities nodes composition as a Capability graph.\n" + + "However, the capability graph can only be a Directed Acyclic Graph (DAG), i.e., there are no cycles allowed.\n" + + "\n" + + "The Node Adapter is primarily used for validation and optimization of the Capability graph. It wraps the capability and also includes some information about the Capability graph, such as the current node's globally unique ID, what the next nodes are, and so on."; + } + + +} From db24a954cc94bd3815122e041f798c55564b4117 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Sun, 28 May 2023 10:22:23 +0100 Subject: [PATCH 09/14] Responded to code review comments --- .../splitter/text/TextSplitter.java | 125 +++++++++--------- .../knowly/langtorch/schema/io/Metadata.java | 11 +- .../splitter/text/WordSplitterTest.java | 30 ++--- 3 files changed, 87 insertions(+), 79 deletions(-) diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index b7600fbc..f43e1d76 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -2,10 +2,10 @@ import ai.knowly.langtorch.schema.io.DomainDocument; import ai.knowly.langtorch.schema.io.Metadata; +import com.google.common.flogger.FluentLogger; import org.apache.commons.collections4.map.MultiKeyMap; import org.apache.commons.lang3.StringUtils; -import javax.annotation.Nullable; import java.util.*; import java.util.stream.Collectors; @@ -14,6 +14,8 @@ */ public abstract class TextSplitter { + private static final FluentLogger logger = FluentLogger.forEnclosingClass(); + /** * The amount of words inside one chunk */ @@ -24,7 +26,7 @@ public abstract class TextSplitter { */ public final int wordOverlap; - public TextSplitter(int wordCount, int wordOverlap) { + protected TextSplitter(int wordCount, int wordOverlap) { this.wordCount = wordCount; this.wordOverlap = wordOverlap; if (this.wordOverlap >= this.wordCount) { @@ -34,70 +36,76 @@ public TextSplitter(int wordCount, int wordOverlap) { abstract public List splitText(String text); - public List createDocuments(List texts, Optional> docMetadatas) { - List metadatas = - (docMetadatas.isPresent() && docMetadatas.get().size() > 0) ? - docMetadatas.get() : Collections.nCopies(texts.size(), Metadata.createEmpty()); + public List createDocumentsSplitFromSingle(DomainDocument document) { + String text = document.getPageContent(); + Metadata metadata = document.getMetadata().isPresent() ? document.getMetadata().get() : Metadata.create(); - ArrayList documents = new ArrayList<>(); + ArrayList docsToReturn = new ArrayList<>(); - for (int i = 0; i < texts.size(); i += 1) { - String text = texts.get(i); - int lineCounterIndex = 1; - String prevChunk = null; - - for (String chunk : splitText(text)) { - int numberOfIntermediateNewLines = 0; - if (prevChunk != null) { - int indexChunk = StringUtils.indexOf(text, chunk); - int indexEndPrevChunk = StringUtils.indexOf(text, prevChunk) + prevChunk.length(); - String removedNewlinesFromSplittingText = StringUtils.substring(text, indexChunk, indexEndPrevChunk); - numberOfIntermediateNewLines = StringUtils.countMatches(removedNewlinesFromSplittingText, "\n"); - } - lineCounterIndex += numberOfIntermediateNewLines; - int newLinesCount = StringUtils.countMatches(chunk, "\n"); - - MultiKeyMap loc; - //TODO: need to end to end test how metadata is being passed back and forth - if (metadatas.get(i).getValue().containsKey("loc", "")) { - loc = metadatas.get(i).getValue(); - } else { - loc = new MultiKeyMap<>(); - } + addDocumentFromWordChunk(docsToReturn, text, metadata); + return docsToReturn; + } - loc.put("loc", "from", String.valueOf(lineCounterIndex)); - loc.put("loc", "to", String.valueOf(lineCounterIndex + newLinesCount)); + public List createDocumentsSplitFromList(List documents) { - Metadata metadataWithLinesNumber = Metadata.createEmpty(); - if (metadatas.get(i) != null && !metadatas.get(i).getValue().isEmpty()) { - metadataWithLinesNumber.getValue().putAll(metadatas.get(i).getValue()); - } - metadataWithLinesNumber.getValue().putAll(loc); + ArrayList docsToReturn = new ArrayList<>(); - documents.add(new DomainDocument(chunk, Optional.of(metadataWithLinesNumber))); - lineCounterIndex += newLinesCount; - prevChunk = chunk; - } + for (DomainDocument document : documents) { + String text = document.getPageContent(); + Metadata metadata = document.getMetadata().isPresent() ? document.getMetadata().get() : Metadata.create(); + int lineCounterIndex = 1; + + addDocumentFromWordChunk(docsToReturn, text, metadata); } - return documents; + return docsToReturn; } - public List splitDocuments(List documents) { + private void addDocumentFromWordChunk(ArrayList docsToReturn, String text, Metadata metadata) { + String prevChunk = null; + int lineCounterIndex = 1; + + for (String chunk : splitText(text)) { + int numberOfIntermediateNewLines = 0; + if (prevChunk != null) { + int indexChunk = StringUtils.indexOf(text, chunk); + int indexEndPrevChunk = StringUtils.indexOf(text, prevChunk) + prevChunk.length(); + String removedNewlinesFromSplittingText = StringUtils.substring(text, indexChunk, indexEndPrevChunk); + numberOfIntermediateNewLines = StringUtils.countMatches(removedNewlinesFromSplittingText, "\n"); + } + lineCounterIndex += numberOfIntermediateNewLines; + int newLinesCount = StringUtils.countMatches(chunk, "\n"); + + MultiKeyMap loc; + //TODO: need to end to end test how metadata is being passed back and forth + if (metadata.getValue().containsKey("loc", "")) { + loc = metadata.getValue(); + } else { + loc = new MultiKeyMap<>(); + } + + loc.put("loc", "from", String.valueOf(lineCounterIndex)); + loc.put("loc", "to", String.valueOf(lineCounterIndex + newLinesCount)); - List selectedDocs = documents.stream().filter(doc -> doc.getPageContent() != null).collect(Collectors.toList()); + Metadata metadataWithLinesNumber = Metadata.create(); + if (!metadata.getValue().isEmpty()) { + metadataWithLinesNumber.getValue().putAll(metadata.getValue()); + } + metadataWithLinesNumber.getValue().putAll(loc); - List texts = selectedDocs.stream().map(DomainDocument::getPageContent).collect(Collectors.toList()); - List metadatas = - selectedDocs.stream().map(doc -> doc.getMetadata().isPresent() ? - doc.getMetadata().get() : Metadata.createEmpty()).collect(Collectors.toList()); + docsToReturn.add(new DomainDocument(chunk, Optional.of(metadataWithLinesNumber))); + lineCounterIndex += newLinesCount; + prevChunk = chunk; + } + } - return this.createDocuments(texts, Optional.of(metadatas)); + public List splitWordsFromDocuments(List documents) { + List selectedDocs = documents.stream().filter(doc -> doc.getPageContent() != null).collect(Collectors.toList()); + return this.createDocumentsSplitFromList(selectedDocs); } - @Nullable - private String joinDocs(List docs, String separator) { + private Optional joinDocs(List docs, String separator) { String text = String.join(separator, docs); - return text.equals("") ? null : text; + return text.equals("") ? Optional.empty() : Optional.of(text); } public List mergeSplits(List splits, String separator) { @@ -110,14 +118,11 @@ public List mergeSplits(List splits, String separator) { if (total + length + (currentDoc.size() > 0 ? separator.length() : 0) > this.wordCount) { if (total > this.wordCount) { - System.out.println("Created a chunk of size " + total + ", which is longer than the specified " + this.wordCount); + logger.atInfo().log("Created a chunk of size " + total + ", which is longer than the specified " + this.wordCount); } if (currentDoc.size() > 0) { - String doc = joinDocs(currentDoc, separator); - if (doc != null) { - docs.add(doc); - } + joinDocs(currentDoc, separator).ifPresent(docs::add); while (total > this.wordOverlap || (total + length > this.wordCount && total > 0)) { total -= currentDoc.get(0).length(); @@ -130,11 +135,7 @@ public List mergeSplits(List splits, String separator) { total += length; } - String doc = joinDocs(currentDoc, separator); - if (doc != null) { - docs.add(doc); - } - + joinDocs(currentDoc, separator).ifPresent(docs::add); return docs; } diff --git a/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java b/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java index 0af0a37e..ed8ac0ae 100644 --- a/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java +++ b/src/main/java/ai/knowly/langtorch/schema/io/Metadata.java @@ -1,5 +1,6 @@ package ai.knowly.langtorch.schema.io; +import org.apache.commons.collections4.keyvalue.MultiKey; import org.apache.commons.collections4.map.MultiKeyMap; import java.util.Objects; @@ -15,10 +16,18 @@ public MultiKeyMap getValue() { return value; } - public static Metadata createEmpty(){ + public static Metadata create(){ return new Metadata(new MultiKeyMap<>()); } + public Metadata set(MultiKey key, String value) { + this.value.put(key, value); + return this; + } + + public static Metadata copyOf(MultiKeyMap values) { + return new Metadata(values); + } @Override public boolean equals(Object obj) { diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java index 78b9b252..83c64304 100644 --- a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java +++ b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java @@ -6,6 +6,7 @@ import org.junit.Test; import java.util.*; +import java.util.stream.Collectors; public class WordSplitterTest { @@ -21,6 +22,7 @@ public void testWordSplitter_realWorldText() { Truth.assertThat(result.get(i)).isEqualTo(expectedResult.get(i)); } } + @Test public void testWordSplitter_splitByWordCount() { // Arrange. @@ -122,20 +124,19 @@ public void testCharacterTextSplitter_invalidArguments() { @Test public void testWordSplitter_createDocuments() { // Arrange. - List texts = Arrays.asList("foo bar", "baz"); WordSplitter splitter = new WordSplitter(" ", 3, 0); - Metadata metadata = Metadata.createEmpty(); - - List metadatas = Arrays.asList(metadata, metadata); + List docsToSplit = + Arrays.asList("foo bar", "baz").stream() + .map(text -> new DomainDocument(text, Optional.of(Metadata.create()))).collect(Collectors.toList()); // Act. - List docs = splitter.createDocuments(texts, Optional.of(metadatas)); + List docs = splitter.createDocumentsSplitFromList(docsToSplit); // Assert. List expectedDocs = Arrays.asList( - new DomainDocument("foo", Optional.of(metadata)), - new DomainDocument("bar", Optional.of(metadata)), - new DomainDocument("baz", Optional.of(metadata)) + new DomainDocument("foo", Optional.of(Metadata.create())), + new DomainDocument("bar", Optional.of(Metadata.create())), + new DomainDocument("baz", Optional.of(Metadata.create())) ); Truth.assertThat(expectedDocs.size() == docs.size()); @@ -147,22 +148,20 @@ public void testWordSplitter_createDocuments() { @Test public void testWordSplitter_createDocumentsWithMetadata() { // Arrange. - List texts = Arrays.asList("foo bar", "baz"); WordSplitter splitter = new WordSplitter(" ", 3, 0); - - Metadata metadata = Metadata.createEmpty(); + Metadata metadata = Metadata.create(); metadata.getValue().put("source", "doc", "1"); metadata.getValue().put("loc", "from", "1"); metadata.getValue().put("loc", "to", "1"); - List metadataList = Arrays.asList(metadata, metadata); - - Optional> metadatas = Optional.of(metadataList); + List docsToSplit = + Arrays.asList("foo bar", "baz").stream() + .map(text -> new DomainDocument(text, Optional.of(metadata))).collect(Collectors.toList()); // Act. - List docs = splitter.createDocuments(texts, metadatas); + List docs = splitter.createDocumentsSplitFromList(docsToSplit); // Assert. List expectedDocs = Arrays.asList( @@ -179,7 +178,6 @@ public void testWordSplitter_createDocumentsWithMetadata() { } - private List sampleTextExpectedSplit() { return Arrays.asList( "Langtorch one pager\n" + From 70c9dba62825f56464708396c4195de60df6efb6 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Mon, 29 May 2023 08:10:19 +0100 Subject: [PATCH 10/14] Fixed sonar error --- .../preprocessing/splitter/text/TextSplitter.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index f43e1d76..1b538699 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -38,8 +38,7 @@ protected TextSplitter(int wordCount, int wordOverlap) { public List createDocumentsSplitFromSingle(DomainDocument document) { String text = document.getPageContent(); - Metadata metadata = document.getMetadata().isPresent() ? document.getMetadata().get() : Metadata.create(); - + Metadata metadata = document.getMetadata().orElse(Metadata.create()); ArrayList docsToReturn = new ArrayList<>(); addDocumentFromWordChunk(docsToReturn, text, metadata); @@ -52,15 +51,16 @@ public List createDocumentsSplitFromList(List do for (DomainDocument document : documents) { String text = document.getPageContent(); - Metadata metadata = document.getMetadata().isPresent() ? document.getMetadata().get() : Metadata.create(); - int lineCounterIndex = 1; - + Metadata metadata = document.getMetadata().orElse(Metadata.create()); addDocumentFromWordChunk(docsToReturn, text, metadata); } return docsToReturn; } - private void addDocumentFromWordChunk(ArrayList docsToReturn, String text, Metadata metadata) { + private void addDocumentFromWordChunk( + ArrayList docsToReturn, + String text, + Metadata metadata) { String prevChunk = null; int lineCounterIndex = 1; From 3cb82dcf6ed34380e764c03072521578e350d868 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Mon, 29 May 2023 08:55:39 +0100 Subject: [PATCH 11/14] fixing code smells --- .../langtorch/preprocessing/splitter/text/TextSplitter.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index 1b538699..c77400c6 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -34,7 +34,7 @@ protected TextSplitter(int wordCount, int wordOverlap) { } } - abstract public List splitText(String text); + public abstract List splitText(String text); public List createDocumentsSplitFromSingle(DomainDocument document) { String text = document.getPageContent(); @@ -76,7 +76,6 @@ private void addDocumentFromWordChunk( int newLinesCount = StringUtils.countMatches(chunk, "\n"); MultiKeyMap loc; - //TODO: need to end to end test how metadata is being passed back and forth if (metadata.getValue().containsKey("loc", "")) { loc = metadata.getValue(); } else { From a77bb2c579d9d9072ae073cbabd99a7dee037ae1 Mon Sep 17 00:00:00 2001 From: Weizhi Li Date: Mon, 29 May 2023 17:24:41 -0700 Subject: [PATCH 12/14] Adding extra candidate for TextSplitter --- .../splitter/text/BaseTextSplitter.java | 7 ++ .../splitter/text/SplitterOption.java | 9 ++ .../splitter/text/word/WordLevelSplitter.java | 65 ++++++++++++++ .../text/word/WordLevelSplitterOption.java | 28 ++++++ .../text/word/WordLevelSplitterTest.java | 87 +++++++++++++++++++ 5 files changed, 196 insertions(+) create mode 100644 src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/BaseTextSplitter.java create mode 100644 src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/SplitterOption.java create mode 100644 src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitter.java create mode 100644 src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterOption.java create mode 100644 src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterTest.java diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/BaseTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/BaseTextSplitter.java new file mode 100644 index 00000000..740513dc --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/BaseTextSplitter.java @@ -0,0 +1,7 @@ +package ai.knowly.langtorch.preprocessing.splitter.text; + +import java.util.List; + +public interface BaseTextSplitter { + List splitText(S option); +} diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/SplitterOption.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/SplitterOption.java new file mode 100644 index 00000000..97f56397 --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/SplitterOption.java @@ -0,0 +1,9 @@ +package ai.knowly.langtorch.preprocessing.splitter.text; + +public abstract class SplitterOption { + String text; + + protected SplitterOption(String text) { + this.text = text; + } +} diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitter.java new file mode 100644 index 00000000..2617508d --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitter.java @@ -0,0 +1,65 @@ +package ai.knowly.langtorch.preprocessing.splitter.text.word; + +import ai.knowly.langtorch.preprocessing.splitter.text.BaseTextSplitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; +import java.util.List; + +/** Splits text into chunks of words. */ +public class WordLevelSplitter implements BaseTextSplitter { + + public static WordLevelSplitter create() { + return new WordLevelSplitter(); + } + + @Override + public List splitText(WordLevelSplitterOption option) { + int maxLengthPerChunk = option.getMaxLengthPerChunk(); + String text = option.getText(); + + Builder chunks = ImmutableList.builder(); + + // Validate the maxLengthPerChunk + if (maxLengthPerChunk < 1) { + throw new IllegalArgumentException("maxLengthPerChunk should be greater than 0"); + } + + String[] words = text.split("\\s+"); + int minLengthOfWord = words[0].length(); + + for (String word : words) { + minLengthOfWord = Math.min(minLengthOfWord, word.length()); + } + + if (maxLengthPerChunk < minLengthOfWord) { + throw new IllegalArgumentException( + "maxLengthPerChunk is smaller than the smallest word in the string"); + } + + StringBuilder chunk = new StringBuilder(); + int wordsLength = words.length; + + for (int i = 0; i < wordsLength; i++) { + String word = words[i]; + boolean isLastWord = i == wordsLength - 1; + if ((chunk.length() + word.length() + (isLastWord ? 0 : 1)) + <= maxLengthPerChunk) { // '+1' accounts for spaces, except for the last word + chunk.append(word); + if (!isLastWord) { + chunk.append(" "); + } + } else { + chunks.add(chunk.toString().trim()); + chunk = new StringBuilder(); + chunk.append(word).append(" "); + } + } + + // Add remaining chunk if any + if (chunk.length() > 0) { + chunks.add(chunk.toString().trim()); + } + + return chunks.build(); + } +} diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterOption.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterOption.java new file mode 100644 index 00000000..d464427a --- /dev/null +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterOption.java @@ -0,0 +1,28 @@ +package ai.knowly.langtorch.preprocessing.splitter.text.word; + +import ai.knowly.langtorch.preprocessing.splitter.text.SplitterOption; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; + +/** Options for {@link WordLevelSplitter}. */ +@EqualsAndHashCode(callSuper = true) +@Data +@Builder(toBuilder = true, setterPrefix = "set") +public class WordLevelSplitterOption extends SplitterOption { + // Unprocessed text. + private final String text; + + // The max length of a chunk. + private final int maxLengthPerChunk; + + private WordLevelSplitterOption(String text, int maxLengthPerChunk) { + super(text); + this.text = text; + this.maxLengthPerChunk = maxLengthPerChunk; + } + + public static WordLevelSplitterOption of(String text, int totalLengthOfChunk) { + return new WordLevelSplitterOption(text, totalLengthOfChunk); + } +} diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterTest.java new file mode 100644 index 00000000..df27fa2c --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterTest.java @@ -0,0 +1,87 @@ +package ai.knowly.langtorch.preprocessing.splitter.text.word; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import org.junit.jupiter.api.Test; + +class WordLevelSplitterTest { + @Test + void testSplitText_NormalUsage() { + // Arrange. + WordLevelSplitterOption option = + WordLevelSplitterOption.builder() + .setText("Hello world, this is a test.") + .setMaxLengthPerChunk(10) + .build(); + + // Act. + List result = WordLevelSplitter.create().splitText(option); + + // Assert. + assertThat(result).containsExactly("Hello", "world,", "this is a", "test.").inOrder(); + } + + @Test + void testSplitText_SingleWord() { + // Arrange. + WordLevelSplitterOption option = + WordLevelSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(10).build(); + + // Act. + List result = WordLevelSplitter.create().splitText(option); + + // Assert. + assertThat(result).containsExactly("Hello"); + } + + @Test + void testSplitText_SingleChar() { + // Arrange. + WordLevelSplitterOption option = + WordLevelSplitterOption.builder().setText("H").setMaxLengthPerChunk(1).build(); + + // Act. + List result = WordLevelSplitter.create().splitText(option); + + // Assert. + assertThat(result).containsExactly("H"); + } + + void testSplitText_MaxLengthSmallerThanWordLength() { + // Arrange. + WordLevelSplitterOption option = + WordLevelSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(3).build(); + + // Act. + // Assert. + assertThrows( + IllegalArgumentException.class, () -> WordLevelSplitter.create().splitText(option)); + } + + @Test + void testSplitText_NegativeMaxLength() { + // Arrange. + WordLevelSplitterOption option = + WordLevelSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(-5).build(); + + // Act. + // Assert. + assertThrows( + IllegalArgumentException.class, () -> WordLevelSplitter.create().splitText(option)); + } + + @Test + void testSplitText_EmptyString() { + // Arrange. + WordLevelSplitterOption option = + WordLevelSplitterOption.builder().setText("").setMaxLengthPerChunk(10).build(); + + // Act. + List result = WordLevelSplitter.create().splitText(option); + + // Assert. + assertThat(result).isEmpty(); + } +} From 866daf0210ec775f1943a857368e1e316ec54557 Mon Sep 17 00:00:00 2001 From: ayoola adedeji Date: Tue, 30 May 2023 07:15:02 +0100 Subject: [PATCH 13/14] removed old implementation --- .../splitter/text/TextSplitter.java | 141 ----------- .../splitter/text/WordSplitter.java | 44 ---- .../splitter/text/WordSplitterTest.java | 236 ------------------ 3 files changed, 421 deletions(-) delete mode 100644 src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java delete mode 100644 src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java delete mode 100644 src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java deleted file mode 100644 index c77400c6..00000000 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java +++ /dev/null @@ -1,141 +0,0 @@ -package ai.knowly.langtorch.preprocessing.splitter.text; - -import ai.knowly.langtorch.schema.io.DomainDocument; -import ai.knowly.langtorch.schema.io.Metadata; -import com.google.common.flogger.FluentLogger; -import org.apache.commons.collections4.map.MultiKeyMap; -import org.apache.commons.lang3.StringUtils; - -import java.util.*; -import java.util.stream.Collectors; - -/** - * The TextSplitter class provides functionality for splitting text into chunks. - */ -public abstract class TextSplitter { - - private static final FluentLogger logger = FluentLogger.forEnclosingClass(); - - /** - * The amount of words inside one chunk - */ - public final int wordCount; - - /** - * amount of words from previous chunk, it will be empty for the first chunk - */ - public final int wordOverlap; - - protected TextSplitter(int wordCount, int wordOverlap) { - this.wordCount = wordCount; - this.wordOverlap = wordOverlap; - if (this.wordOverlap >= this.wordCount) { - throw new IllegalArgumentException("chunkOverlap cannot be equal to or greater than chunkSize"); - } - } - - public abstract List splitText(String text); - - public List createDocumentsSplitFromSingle(DomainDocument document) { - String text = document.getPageContent(); - Metadata metadata = document.getMetadata().orElse(Metadata.create()); - ArrayList docsToReturn = new ArrayList<>(); - - addDocumentFromWordChunk(docsToReturn, text, metadata); - return docsToReturn; - } - - public List createDocumentsSplitFromList(List documents) { - - ArrayList docsToReturn = new ArrayList<>(); - - for (DomainDocument document : documents) { - String text = document.getPageContent(); - Metadata metadata = document.getMetadata().orElse(Metadata.create()); - addDocumentFromWordChunk(docsToReturn, text, metadata); - } - return docsToReturn; - } - - private void addDocumentFromWordChunk( - ArrayList docsToReturn, - String text, - Metadata metadata) { - String prevChunk = null; - int lineCounterIndex = 1; - - for (String chunk : splitText(text)) { - int numberOfIntermediateNewLines = 0; - if (prevChunk != null) { - int indexChunk = StringUtils.indexOf(text, chunk); - int indexEndPrevChunk = StringUtils.indexOf(text, prevChunk) + prevChunk.length(); - String removedNewlinesFromSplittingText = StringUtils.substring(text, indexChunk, indexEndPrevChunk); - numberOfIntermediateNewLines = StringUtils.countMatches(removedNewlinesFromSplittingText, "\n"); - } - lineCounterIndex += numberOfIntermediateNewLines; - int newLinesCount = StringUtils.countMatches(chunk, "\n"); - - MultiKeyMap loc; - if (metadata.getValue().containsKey("loc", "")) { - loc = metadata.getValue(); - } else { - loc = new MultiKeyMap<>(); - } - - loc.put("loc", "from", String.valueOf(lineCounterIndex)); - loc.put("loc", "to", String.valueOf(lineCounterIndex + newLinesCount)); - - Metadata metadataWithLinesNumber = Metadata.create(); - if (!metadata.getValue().isEmpty()) { - metadataWithLinesNumber.getValue().putAll(metadata.getValue()); - } - metadataWithLinesNumber.getValue().putAll(loc); - - docsToReturn.add(new DomainDocument(chunk, Optional.of(metadataWithLinesNumber))); - lineCounterIndex += newLinesCount; - prevChunk = chunk; - } - } - - public List splitWordsFromDocuments(List documents) { - List selectedDocs = documents.stream().filter(doc -> doc.getPageContent() != null).collect(Collectors.toList()); - return this.createDocumentsSplitFromList(selectedDocs); - } - - private Optional joinDocs(List docs, String separator) { - String text = String.join(separator, docs); - return text.equals("") ? Optional.empty() : Optional.of(text); - } - - public List mergeSplits(List splits, String separator) { - List docs = new ArrayList<>(); - List currentDoc = new ArrayList<>(); - int total = 0; - - for (String d : splits) { - int length = d.length(); - - if (total + length + (currentDoc.size() > 0 ? separator.length() : 0) > this.wordCount) { - if (total > this.wordCount) { - logger.atInfo().log("Created a chunk of size " + total + ", which is longer than the specified " + this.wordCount); - } - - if (currentDoc.size() > 0) { - joinDocs(currentDoc, separator).ifPresent(docs::add); - - while (total > this.wordOverlap || (total + length > this.wordCount && total > 0)) { - total -= currentDoc.get(0).length(); - currentDoc.remove(0); - } - } - } - - currentDoc.add(d); - total += length; - } - - joinDocs(currentDoc, separator).ifPresent(docs::add); - return docs; - } - -} diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java deleted file mode 100644 index ca813987..00000000 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitter.java +++ /dev/null @@ -1,44 +0,0 @@ -package ai.knowly.langtorch.preprocessing.splitter.text; - -import org.apache.commons.lang3.StringUtils; - -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.List; -/** - The CharacterTextSplitter class is a concrete implementation of the TextSplitter abstract class - that splits text into chunks based on a specified separator. - */ - -public class WordSplitter extends TextSplitter { - - private String separator = "\n\n"; - - /** - - Constructs a CharacterTextSplitter object with the given separator, chunk size, and chunk overlap. - If the separator is null, the default separator "\n\n" is used. - @param separator The separator used for splitting the text into chunks. - @param wordCount The size of each chunk. - @param wordOverlap The amount of overlap between adjacent chunks. - */ - public WordSplitter(@Nullable String separator, int wordCount, int wordOverlap) { - super(wordCount, wordOverlap); - if (separator != null) { - this.separator = separator; - } - } - - /** - Splits the given text into chunks based on the specified separator. - @param text The text to be split into chunks. - @return A list of strings representing the chunks of the text. - */ - @Override - public List splitText(String text) { - - List splits = - Arrays.asList(StringUtils.splitByWholeSeparatorPreserveAllTokens(text, this.separator.isEmpty() ? "" : this.separator)); - return mergeSplits(splits, this.separator); - } -} diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java deleted file mode 100644 index 83c64304..00000000 --- a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/WordSplitterTest.java +++ /dev/null @@ -1,236 +0,0 @@ -package ai.knowly.langtorch.preprocessing.splitter.text; - -import ai.knowly.langtorch.schema.io.DomainDocument; -import ai.knowly.langtorch.schema.io.Metadata; -import com.google.common.truth.Truth; -import org.junit.Test; - -import java.util.*; -import java.util.stream.Collectors; - - -public class WordSplitterTest { - - @Test - public void testWordSplitter_realWorldText() { - WordSplitter splitter = new WordSplitter(null, 1000, 100); - - List result = splitter.splitText(sampleText()); - List expectedResult = sampleTextExpectedSplit(); - - for (int i = 0; i < result.size(); i++) { - Truth.assertThat(result.get(i)).isEqualTo(expectedResult.get(i)); - } - } - - @Test - public void testWordSplitter_splitByWordCount() { - // Arrange. - String text = "foo bar baz 123"; - WordSplitter splitter = new WordSplitter(" ", 7, 3); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("foo bar", "bar baz", "baz 123")); - - Truth.assertThat(result.size()).isEqualTo(expected.size()); - for (int i = 0; i < expected.size(); i++) { - Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); - } - } - - @Test - public void testCharacterTextSplitter_splitByCharacterCountWithNoEmptyDocuments() { - // Arrange. - String text = "foo bar"; - WordSplitter splitter = new WordSplitter(" ", 2, 0); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("foo", "bar")); - - for (int i = 0; i < expected.size(); i++) { - Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); - } - } - - @Test - public void testCharacterTextSplitter_splitByCharacterCountLongWords() { - // Arrange. - String text = "foo bar baz a a"; - WordSplitter splitter = new WordSplitter(" ", 3, 1); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "a a")); - - for (int i = 0; i < expected.size(); i++) { - Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); - } - } - - @Test - public void testCharacterTextSplitter_splitByCharacterCountShorterWordsFirst() { - // Arrange. - String text = "a a foo bar baz"; - WordSplitter splitter = new WordSplitter(" ", 3, 1); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("a a", "foo", "bar", "baz")); - - for (int i = 0; i < expected.size(); i++) { - Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); - } - } - - @Test - public void testCharacterTextSplitter_splitByCharactersSplitsNotFoundEasily() { - // Arrange. - String text = "foo bar baz 123"; - WordSplitter splitter = new WordSplitter(" ", 1, 0); - - // Act. - List result = splitter.splitText(text); - - // Assert. - List expected = new ArrayList<>(Arrays.asList("foo", "bar", "baz", "123")); - - for (int i = 0; i < expected.size(); i++) { - Truth.assertThat(result.get(i)).isEqualTo(expected.get(i)); - } - } - - @Test(expected = IllegalArgumentException.class) - public void testCharacterTextSplitter_invalidArguments() { - // Arrange. - int chunkSize = 2; - int chunkOverlap = 4; - - // Act. - new WordSplitter(null, chunkSize, chunkOverlap); - - // Expect IllegalArgumentException to be thrown. - } - - @Test - public void testWordSplitter_createDocuments() { - // Arrange. - WordSplitter splitter = new WordSplitter(" ", 3, 0); - List docsToSplit = - Arrays.asList("foo bar", "baz").stream() - .map(text -> new DomainDocument(text, Optional.of(Metadata.create()))).collect(Collectors.toList()); - - // Act. - List docs = splitter.createDocumentsSplitFromList(docsToSplit); - - // Assert. - List expectedDocs = Arrays.asList( - new DomainDocument("foo", Optional.of(Metadata.create())), - new DomainDocument("bar", Optional.of(Metadata.create())), - new DomainDocument("baz", Optional.of(Metadata.create())) - ); - - Truth.assertThat(expectedDocs.size() == docs.size()); - for (int i = 0; i < docs.size(); i++) { - Truth.assertThat(docs.get(i).getPageContent()).isEqualTo(expectedDocs.get(i).getPageContent()); - } - } - - @Test - public void testWordSplitter_createDocumentsWithMetadata() { - // Arrange. - WordSplitter splitter = new WordSplitter(" ", 3, 0); - - Metadata metadata = Metadata.create(); - - metadata.getValue().put("source", "doc", "1"); - metadata.getValue().put("loc", "from", "1"); - metadata.getValue().put("loc", "to", "1"); - - List docsToSplit = - Arrays.asList("foo bar", "baz").stream() - .map(text -> new DomainDocument(text, Optional.of(metadata))).collect(Collectors.toList()); - - // Act. - List docs = splitter.createDocumentsSplitFromList(docsToSplit); - - // Assert. - List expectedDocs = Arrays.asList( - new DomainDocument("foo", Optional.of(metadata)), - new DomainDocument("bar", Optional.of(metadata)), - new DomainDocument("baz", Optional.of(metadata)) - ); - - Truth.assertThat(docs.size()).isEqualTo(expectedDocs.size()); - for (int i = 0; i < docs.size(); i++) { - Truth.assertThat(docs.get(i).getPageContent()).isEqualTo(expectedDocs.get(i).getPageContent()); - Truth.assertThat(docs.get(i).getMetadata()).isEqualTo(expectedDocs.get(i).getMetadata()); - } - } - - - private List sampleTextExpectedSplit() { - return Arrays.asList( - "Langtorch one pager\n" + - "Langtorch is a Java framework that assists you in developing large language model applications. It is designed with reusability, composability and Fluent style in mind. It can aid you in developing workflows or pipelines that include large language models.\n" + - "\n" + - "Processor\n" + - "In Langtorch, we introduce the concept of a processor. A processor is a container for the smallest computational unit in Langtorch. The response produced by the processor can either originate from a large language model, such as OpenAI's GPT model(retrieved by rest api), or it could be a deterministic Java function.", - "A processor is an interface that includes two functions: run() and runAsync(). Anything that implements these two functions can be considered a processor. For instance, a processor could be something that sends an HTTP request to OpenAI to invoke its GPT model and generate a response. It could also be a calculator function, where the input is 1+1, and the output is 2.\n" + - "Using this approach, we can conveniently add a processor, such as the soon-to-be-publicly-available Google PALM 2 API. At the same time, when we chain different processors together, we can leverage this to avoid some of the shortcomings of large language models (LLMs). For instance, when we want to implement a chatbot, if a user asks a mathematical question, we can parse this question using the LLM's capabilities into an input for our calculator to get an accurate answer, rather than letting the LLM come to a conclusion directly.", - "Note: The processor is the smallest computational unit in Langtorch, so a processor is generally only allowed to process a single task. For example, it could have the ability to handle text completion, chat completion, or generate images based on a prompt. If the requirements are complex, such as first generating a company's slogan through text completion, and then generating an image based on the slogan, this should be accomplished by chaining different processors together, rather than completing everything within a single processor.\n" + - "\n" + - "Capability\n" + - "As previously mentioned, the processor is the smallest container of a computational unit, and often it is not sufficient to handle all situations. We need to enhance the processor!\n" + - "Here we introduce the concept of Capability. If the processor is likened to an internal combustion steam engine, then a Capability could be a steam train based on the steam engine, or a electricity generator based on the steam engine.", - "Imagine that you are implementing a chatbot. If the processor is based on OpenAI's API, sending every user's input to the OpenAI GPT-4 model and returning its response, what would the user experience be like?\n" + - "\n" + - "The reason is that the chatbot does not incorporate chat history. Therefore, in capability, we can add memory (a simple implementation of memory is to put all conversation records into the request sent to OpenAI).\n" + - "\n" + - "Workflow(chaining Capabilities)\n" + - "To make the combination of capabilities easier, we introduce the concept of a Node Adapter, and we refer to capabilities nodes composition as a Capability graph.\n" + - "However, the capability graph can only be a Directed Acyclic Graph (DAG), i.e., there are no cycles allowed.\n" + - "\n" + - "The Node Adapter is primarily used for validation and optimization of the Capability graph. It wraps the capability and also includes some information about the Capability graph, such as the current node's globally unique ID, what the next nodes are, and so on." - ); - - } - - private String sampleText() { - return "Langtorch one pager\n" + - "Langtorch is a Java framework that assists you in developing large language model applications. It is designed with reusability, composability and Fluent style in mind. It can aid you in developing workflows or pipelines that include large language models.\n" + - "\n" + - "Processor\n" + - "In Langtorch, we introduce the concept of a processor. A processor is a container for the smallest computational unit in Langtorch. The response produced by the processor can either originate from a large language model, such as OpenAI's GPT model(retrieved by rest api), or it could be a deterministic Java function.\n" + - "\n" + - "A processor is an interface that includes two functions: run() and runAsync(). Anything that implements these two functions can be considered a processor. For instance, a processor could be something that sends an HTTP request to OpenAI to invoke its GPT model and generate a response. It could also be a calculator function, where the input is 1+1, and the output is 2.\n" + - "Using this approach, we can conveniently add a processor, such as the soon-to-be-publicly-available Google PALM 2 API. At the same time, when we chain different processors together, we can leverage this to avoid some of the shortcomings of large language models (LLMs). For instance, when we want to implement a chatbot, if a user asks a mathematical question, we can parse this question using the LLM's capabilities into an input for our calculator to get an accurate answer, rather than letting the LLM come to a conclusion directly.\n" + - "\n" + - "Note: The processor is the smallest computational unit in Langtorch, so a processor is generally only allowed to process a single task. For example, it could have the ability to handle text completion, chat completion, or generate images based on a prompt. If the requirements are complex, such as first generating a company's slogan through text completion, and then generating an image based on the slogan, this should be accomplished by chaining different processors together, rather than completing everything within a single processor.\n" + - "\n" + - "Capability\n" + - "As previously mentioned, the processor is the smallest container of a computational unit, and often it is not sufficient to handle all situations. We need to enhance the processor!\n" + - "Here we introduce the concept of Capability. If the processor is likened to an internal combustion steam engine, then a Capability could be a steam train based on the steam engine, or a electricity generator based on the steam engine.\n" + - "\n" + - "Imagine that you are implementing a chatbot. If the processor is based on OpenAI's API, sending every user's input to the OpenAI GPT-4 model and returning its response, what would the user experience be like?\n" + - "\n" + - "The reason is that the chatbot does not incorporate chat history. Therefore, in capability, we can add memory (a simple implementation of memory is to put all conversation records into the request sent to OpenAI).\n" + - "\n" + - "Workflow(chaining Capabilities)\n" + - "To make the combination of capabilities easier, we introduce the concept of a Node Adapter, and we refer to capabilities nodes composition as a Capability graph.\n" + - "However, the capability graph can only be a Directed Acyclic Graph (DAG), i.e., there are no cycles allowed.\n" + - "\n" + - "The Node Adapter is primarily used for validation and optimization of the Capability graph. It wraps the capability and also includes some information about the Capability graph, such as the current node's globally unique ID, what the next nodes are, and so on."; - } - - -} From c4a52d3be46ef46745f08efef24f31edb56b49ea Mon Sep 17 00:00:00 2001 From: Weizhi Li Date: Mon, 29 May 2023 23:35:37 -0700 Subject: [PATCH 14/14] Adding text splitter --- ...aseTextSplitter.java => TextSplitter.java} | 2 +- ...rdLevelSplitter.java => WordSplitter.java} | 10 +-- ...terOption.java => WordSplitterOption.java} | 10 +-- .../text/word/WordLevelSplitterTest.java | 87 ------------------- .../splitter/text/word/WordSplitterTest.java | 85 ++++++++++++++++++ 5 files changed, 96 insertions(+), 98 deletions(-) rename src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/{BaseTextSplitter.java => TextSplitter.java} (65%) rename src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/{WordLevelSplitter.java => WordSplitter.java} (84%) rename src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/{WordLevelSplitterOption.java => WordSplitterOption.java} (61%) delete mode 100644 src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterTest.java create mode 100644 src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitterTest.java diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/BaseTextSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java similarity index 65% rename from src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/BaseTextSplitter.java rename to src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java index 740513dc..a039023b 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/BaseTextSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/TextSplitter.java @@ -2,6 +2,6 @@ import java.util.List; -public interface BaseTextSplitter { +public interface TextSplitter { List splitText(S option); } diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitter.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitter.java similarity index 84% rename from src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitter.java rename to src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitter.java index 2617508d..5e381093 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitter.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitter.java @@ -1,19 +1,19 @@ package ai.knowly.langtorch.preprocessing.splitter.text.word; -import ai.knowly.langtorch.preprocessing.splitter.text.BaseTextSplitter; +import ai.knowly.langtorch.preprocessing.splitter.text.TextSplitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import java.util.List; /** Splits text into chunks of words. */ -public class WordLevelSplitter implements BaseTextSplitter { +public class WordSplitter implements TextSplitter { - public static WordLevelSplitter create() { - return new WordLevelSplitter(); + public static WordSplitter create() { + return new WordSplitter(); } @Override - public List splitText(WordLevelSplitterOption option) { + public List splitText(WordSplitterOption option) { int maxLengthPerChunk = option.getMaxLengthPerChunk(); String text = option.getText(); diff --git a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterOption.java b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitterOption.java similarity index 61% rename from src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterOption.java rename to src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitterOption.java index d464427a..9cf47ff3 100644 --- a/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterOption.java +++ b/src/main/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitterOption.java @@ -5,24 +5,24 @@ import lombok.Data; import lombok.EqualsAndHashCode; -/** Options for {@link WordLevelSplitter}. */ +/** Options for {@link WordSplitter}. */ @EqualsAndHashCode(callSuper = true) @Data @Builder(toBuilder = true, setterPrefix = "set") -public class WordLevelSplitterOption extends SplitterOption { +public class WordSplitterOption extends SplitterOption { // Unprocessed text. private final String text; // The max length of a chunk. private final int maxLengthPerChunk; - private WordLevelSplitterOption(String text, int maxLengthPerChunk) { + private WordSplitterOption(String text, int maxLengthPerChunk) { super(text); this.text = text; this.maxLengthPerChunk = maxLengthPerChunk; } - public static WordLevelSplitterOption of(String text, int totalLengthOfChunk) { - return new WordLevelSplitterOption(text, totalLengthOfChunk); + public static WordSplitterOption of(String text, int totalLengthOfChunk) { + return new WordSplitterOption(text, totalLengthOfChunk); } } diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterTest.java deleted file mode 100644 index df27fa2c..00000000 --- a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordLevelSplitterTest.java +++ /dev/null @@ -1,87 +0,0 @@ -package ai.knowly.langtorch.preprocessing.splitter.text.word; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import java.util.List; -import org.junit.jupiter.api.Test; - -class WordLevelSplitterTest { - @Test - void testSplitText_NormalUsage() { - // Arrange. - WordLevelSplitterOption option = - WordLevelSplitterOption.builder() - .setText("Hello world, this is a test.") - .setMaxLengthPerChunk(10) - .build(); - - // Act. - List result = WordLevelSplitter.create().splitText(option); - - // Assert. - assertThat(result).containsExactly("Hello", "world,", "this is a", "test.").inOrder(); - } - - @Test - void testSplitText_SingleWord() { - // Arrange. - WordLevelSplitterOption option = - WordLevelSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(10).build(); - - // Act. - List result = WordLevelSplitter.create().splitText(option); - - // Assert. - assertThat(result).containsExactly("Hello"); - } - - @Test - void testSplitText_SingleChar() { - // Arrange. - WordLevelSplitterOption option = - WordLevelSplitterOption.builder().setText("H").setMaxLengthPerChunk(1).build(); - - // Act. - List result = WordLevelSplitter.create().splitText(option); - - // Assert. - assertThat(result).containsExactly("H"); - } - - void testSplitText_MaxLengthSmallerThanWordLength() { - // Arrange. - WordLevelSplitterOption option = - WordLevelSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(3).build(); - - // Act. - // Assert. - assertThrows( - IllegalArgumentException.class, () -> WordLevelSplitter.create().splitText(option)); - } - - @Test - void testSplitText_NegativeMaxLength() { - // Arrange. - WordLevelSplitterOption option = - WordLevelSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(-5).build(); - - // Act. - // Assert. - assertThrows( - IllegalArgumentException.class, () -> WordLevelSplitter.create().splitText(option)); - } - - @Test - void testSplitText_EmptyString() { - // Arrange. - WordLevelSplitterOption option = - WordLevelSplitterOption.builder().setText("").setMaxLengthPerChunk(10).build(); - - // Act. - List result = WordLevelSplitter.create().splitText(option); - - // Assert. - assertThat(result).isEmpty(); - } -} diff --git a/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitterTest.java b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitterTest.java new file mode 100644 index 00000000..72fa5f94 --- /dev/null +++ b/src/test/java/ai/knowly/langtorch/preprocessing/splitter/text/word/WordSplitterTest.java @@ -0,0 +1,85 @@ +package ai.knowly.langtorch.preprocessing.splitter.text.word; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import org.junit.jupiter.api.Test; + +class WordSplitterTest { + @Test + void testSplitText_NormalUsage() { + // Arrange. + WordSplitterOption option = + WordSplitterOption.builder() + .setText("Hello world, this is a test.") + .setMaxLengthPerChunk(10) + .build(); + + // Act. + List result = WordSplitter.create().splitText(option); + + // Assert. + assertThat(result).containsExactly("Hello", "world,", "this is a", "test.").inOrder(); + } + + @Test + void testSplitText_SingleWord() { + // Arrange. + WordSplitterOption option = + WordSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(10).build(); + + // Act. + List result = WordSplitter.create().splitText(option); + + // Assert. + assertThat(result).containsExactly("Hello"); + } + + @Test + void testSplitText_SingleChar() { + // Arrange. + WordSplitterOption option = + WordSplitterOption.builder().setText("H").setMaxLengthPerChunk(1).build(); + + // Act. + List result = WordSplitter.create().splitText(option); + + // Assert. + assertThat(result).containsExactly("H"); + } + + void testSplitText_MaxLengthSmallerThanWordLength() { + // Arrange. + WordSplitterOption option = + WordSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(3).build(); + + // Act. + // Assert. + assertThrows(IllegalArgumentException.class, () -> WordSplitter.create().splitText(option)); + } + + @Test + void testSplitText_NegativeMaxLength() { + // Arrange. + WordSplitterOption option = + WordSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(-5).build(); + + // Act. + // Assert. + assertThrows(IllegalArgumentException.class, () -> WordSplitter.create().splitText(option)); + } + + @Test + void testSplitText_EmptyString() { + // Arrange. + WordSplitterOption option = + WordSplitterOption.builder().setText("").setMaxLengthPerChunk(10).build(); + + // Act. + List result = WordSplitter.create().splitText(option); + + // Assert. + assertThat(result).isEmpty(); + } +}