Skip to content

Commit

Permalink
Merge branch 'text-splitter' of https://github.com/AyoTheDev/langtorch
Browse files Browse the repository at this point in the history
…into AyoTheDev-text-splitter
  • Loading branch information
li2109 committed May 30, 2023
2 parents a32e372 + c4a52d3 commit 99505de
Show file tree
Hide file tree
Showing 21 changed files with 285 additions and 64 deletions.
4 changes: 4 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ dependencies {
implementation 'org.apache.commons:commons-csv:1.10.0'


// Apache commons lang
implementation 'org.apache.commons:commons-lang3:3.0'
implementation 'org.apache.commons:commons-collections4:4.4'

}

// Testing related dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;

import ai.knowly.langtorch.parser.Parser;
import ai.knowly.langtorch.preprocessing.parser.Parser;
import ai.knowly.langtorch.processor.module.Processor;
import ai.knowly.langtorch.schema.chat.ChatMessage;
import ai.knowly.langtorch.schema.text.MultiChatMessage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;

import ai.knowly.langtorch.parser.Parser;
import ai.knowly.langtorch.preprocessing.parser.Parser;
import ai.knowly.langtorch.processor.module.Processor;
import ai.knowly.langtorch.schema.text.SingleText;
import com.google.common.util.concurrent.FluentFuture;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import ai.knowly.langtorch.capability.modality.text.TextCompletionTextLLMCapability;
import ai.knowly.langtorch.processor.module.openai.text.OpenAITextProcessor;
import ai.knowly.langtorch.parser.SingleTextToStringParser;
import ai.knowly.langtorch.parser.StringToSingleTextParser;
import ai.knowly.langtorch.preprocessing.parser.SingleTextToStringParser;
import ai.knowly.langtorch.preprocessing.parser.StringToSingleTextParser;
import ai.knowly.langtorch.prompt.template.PromptTemplate;
import java.util.Map;
import java.util.Optional;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package ai.knowly.langtorch.capability.module.openai;

import ai.knowly.langtorch.capability.modality.text.ChatCompletionLLMCapability;
import ai.knowly.langtorch.parser.ChatMessageToStringParser;
import ai.knowly.langtorch.parser.Parser;
import ai.knowly.langtorch.parser.StringToMultiChatMessageParser;
import ai.knowly.langtorch.preprocessing.parser.ChatMessageToStringParser;
import ai.knowly.langtorch.preprocessing.parser.Parser;
import ai.knowly.langtorch.preprocessing.parser.StringToMultiChatMessageParser;
import ai.knowly.langtorch.processor.module.openai.chat.OpenAIChatProcessor;
import ai.knowly.langtorch.schema.chat.ChatMessage;
import ai.knowly.langtorch.schema.text.MultiChatMessage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import ai.knowly.langtorch.capability.modality.text.TextCompletionTextLLMCapability;
import ai.knowly.langtorch.processor.module.openai.text.OpenAITextProcessor;
import ai.knowly.langtorch.parser.SingleTextToStringParser;
import ai.knowly.langtorch.parser.StringToSingleTextParser;
import ai.knowly.langtorch.preprocessing.parser.SingleTextToStringParser;
import ai.knowly.langtorch.preprocessing.parser.StringToSingleTextParser;

/** A simple text capability unit that leverages openai api to generate response */
public class SimpleTextCapability extends TextCompletionTextLLMCapability<String, String> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.knowly.langtorch.parser;
package ai.knowly.langtorch.preprocessing.parser;

import ai.knowly.langtorch.schema.chat.ChatMessage;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.knowly.langtorch.parser;
package ai.knowly.langtorch.preprocessing.parser;

@FunctionalInterface
public interface Parser<T, R> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.knowly.langtorch.parser;
package ai.knowly.langtorch.preprocessing.parser;

import ai.knowly.langtorch.schema.text.SingleText;
import ai.knowly.langtorch.prompt.template.PromptTemplate;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.knowly.langtorch.parser;
package ai.knowly.langtorch.preprocessing.parser;

import ai.knowly.langtorch.schema.text.SingleText;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.knowly.langtorch.parser;
package ai.knowly.langtorch.preprocessing.parser;

import ai.knowly.langtorch.schema.chat.UserMessage;
import ai.knowly.langtorch.schema.text.MultiChatMessage;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.knowly.langtorch.parser;
package ai.knowly.langtorch.preprocessing.parser;

import ai.knowly.langtorch.schema.text.SingleText;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package ai.knowly.langtorch.preprocessing.splitter.text;

public abstract class SplitterOption {
String text;

protected SplitterOption(String text) {
this.text = text;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package ai.knowly.langtorch.preprocessing.splitter.text;

import java.util.List;

public interface TextSplitter<S extends SplitterOption> {
List<String> splitText(S option);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package ai.knowly.langtorch.preprocessing.splitter.text.word;

import ai.knowly.langtorch.preprocessing.splitter.text.TextSplitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.util.List;

/** Splits text into chunks of words. */
public class WordSplitter implements TextSplitter<WordSplitterOption> {

public static WordSplitter create() {
return new WordSplitter();
}

@Override
public List<String> splitText(WordSplitterOption option) {
int maxLengthPerChunk = option.getMaxLengthPerChunk();
String text = option.getText();

Builder<String> chunks = ImmutableList.builder();

// Validate the maxLengthPerChunk
if (maxLengthPerChunk < 1) {
throw new IllegalArgumentException("maxLengthPerChunk should be greater than 0");
}

String[] words = text.split("\\s+");
int minLengthOfWord = words[0].length();

for (String word : words) {
minLengthOfWord = Math.min(minLengthOfWord, word.length());
}

if (maxLengthPerChunk < minLengthOfWord) {
throw new IllegalArgumentException(
"maxLengthPerChunk is smaller than the smallest word in the string");
}

StringBuilder chunk = new StringBuilder();
int wordsLength = words.length;

for (int i = 0; i < wordsLength; i++) {
String word = words[i];
boolean isLastWord = i == wordsLength - 1;
if ((chunk.length() + word.length() + (isLastWord ? 0 : 1))
<= maxLengthPerChunk) { // '+1' accounts for spaces, except for the last word
chunk.append(word);
if (!isLastWord) {
chunk.append(" ");
}
} else {
chunks.add(chunk.toString().trim());
chunk = new StringBuilder();
chunk.append(word).append(" ");
}
}

// Add remaining chunk if any
if (chunk.length() > 0) {
chunks.add(chunk.toString().trim());
}

return chunks.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ai.knowly.langtorch.preprocessing.splitter.text.word;

import ai.knowly.langtorch.preprocessing.splitter.text.SplitterOption;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;

/** Options for {@link WordSplitter}. */
@EqualsAndHashCode(callSuper = true)
@Data
@Builder(toBuilder = true, setterPrefix = "set")
public class WordSplitterOption extends SplitterOption {
// Unprocessed text.
private final String text;

// The max length of a chunk.
private final int maxLengthPerChunk;

private WordSplitterOption(String text, int maxLengthPerChunk) {
super(text);
this.text = text;
this.maxLengthPerChunk = maxLengthPerChunk;
}

public static WordSplitterOption of(String text, int totalLengthOfChunk) {
return new WordSplitterOption(text, totalLengthOfChunk);
}
}
23 changes: 23 additions & 0 deletions src/main/java/ai/knowly/langtorch/schema/io/DomainDocument.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package ai.knowly.langtorch.schema.io;

import java.util.Optional;

public class DomainDocument implements Input, Output {

private final String pageContent;

private final Optional<Metadata> metadata;

public DomainDocument(String pageContent, Optional<Metadata> metadata) {
this.pageContent = pageContent;
this.metadata = metadata;
}

public String getPageContent() {
return pageContent;
}

public Optional<Metadata> getMetadata() {
return metadata;
}
}
48 changes: 48 additions & 0 deletions src/main/java/ai/knowly/langtorch/schema/io/Metadata.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package ai.knowly.langtorch.schema.io;

import org.apache.commons.collections4.keyvalue.MultiKey;
import org.apache.commons.collections4.map.MultiKeyMap;

import java.util.Objects;

public class Metadata {
private final MultiKeyMap<String, String> value;

public Metadata(MultiKeyMap<String, String> values) {
this.value = values;
}

public MultiKeyMap<String, String> getValue() {
return value;
}

public static Metadata create(){
return new Metadata(new MultiKeyMap<>());
}

public Metadata set(MultiKey<String> key, String value) {
this.value.put(key, value);
return this;
}

public static Metadata copyOf(MultiKeyMap<String, String> values) {
return new Metadata(values);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
Metadata other = (Metadata) obj;
return Objects.equals(value, other.value);
}

@Override
public int hashCode() {
return Objects.hash(value);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ai.knowly.langtorch.parser;
package ai.knowly.langtorch.preprocessing.parser;

import ai.knowly.langtorch.prompt.template.PromptTemplate;
import ai.knowly.langtorch.schema.text.SingleText;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package ai.knowly.langtorch.preprocessing.splitter.text.word;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.List;
import org.junit.jupiter.api.Test;

class WordSplitterTest {
@Test
void testSplitText_NormalUsage() {
// Arrange.
WordSplitterOption option =
WordSplitterOption.builder()
.setText("Hello world, this is a test.")
.setMaxLengthPerChunk(10)
.build();

// Act.
List<String> result = WordSplitter.create().splitText(option);

// Assert.
assertThat(result).containsExactly("Hello", "world,", "this is a", "test.").inOrder();
}

@Test
void testSplitText_SingleWord() {
// Arrange.
WordSplitterOption option =
WordSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(10).build();

// Act.
List<String> result = WordSplitter.create().splitText(option);

// Assert.
assertThat(result).containsExactly("Hello");
}

@Test
void testSplitText_SingleChar() {
// Arrange.
WordSplitterOption option =
WordSplitterOption.builder().setText("H").setMaxLengthPerChunk(1).build();

// Act.
List<String> result = WordSplitter.create().splitText(option);

// Assert.
assertThat(result).containsExactly("H");
}

void testSplitText_MaxLengthSmallerThanWordLength() {
// Arrange.
WordSplitterOption option =
WordSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(3).build();

// Act.
// Assert.
assertThrows(IllegalArgumentException.class, () -> WordSplitter.create().splitText(option));
}

@Test
void testSplitText_NegativeMaxLength() {
// Arrange.
WordSplitterOption option =
WordSplitterOption.builder().setText("Hello").setMaxLengthPerChunk(-5).build();

// Act.
// Assert.
assertThrows(IllegalArgumentException.class, () -> WordSplitter.create().splitText(option));
}

@Test
void testSplitText_EmptyString() {
// Arrange.
WordSplitterOption option =
WordSplitterOption.builder().setText("").setMaxLengthPerChunk(10).build();

// Act.
List<String> result = WordSplitter.create().splitText(option);

// Assert.
assertThat(result).isEmpty();
}
}
Loading

0 comments on commit 99505de

Please sign in to comment.