diff --git a/README.md b/README.md index 1dc252b..b8be988 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ export OPENAI_API_KEY= - Java 1.8+ - [An OpenAI API Key](https://platform.openai.com/account/api-keys) - Windows/Linux/Mac +- Google API key ## Want to contribute? diff --git a/build.gradle b/build.gradle index 23e4f93..c54c6b0 100644 --- a/build.gradle +++ b/build.gradle @@ -1,8 +1,16 @@ plugins { id "com.vanniktech.maven.publish" version "0.19.0" apply false + id "nebula.lint" version "18.1.0" } allprojects { + apply plugin :"java" + apply plugin :"nebula.lint" + // https://www.baeldung.com/gradle-finding-unused-dependencies + gradleLint { + rules=['unused-dependency'] + } + plugins.withId("com.vanniktech.maven.publish") { mavenPublish { sonatypeHost = "S01" diff --git a/lib/build.gradle b/lib/build.gradle index 16be462..34f2c0d 100644 --- a/lib/build.gradle +++ b/lib/build.gradle @@ -9,14 +9,9 @@ dependencies { implementation 'com.theokanning.openai-gpt3-java:service:0.12.0' compileOnly 'org.projectlombok:lombok:1.18.24' annotationProcessor 'org.projectlombok:lombok:1.18.24' - // https://mvnrepository.com/artifact/commons-cli/commons-cli implementation 'commons-cli:commons-cli:1.4' - - implementation "org.deeplearning4j:deeplearning4j-core:1.0.0-M1" - implementation 'com.knuddels:jtokkit:0.4.0' - implementation "org.nd4j:nd4j-native-platform:1.0.0-M1" // https://mvnrepository.com/artifact/com.google.code.gson/gson implementation 'com.google.code.gson:gson:2.10' // https://mvnrepository.com/artifact/org.seleniumhq.selenium/selenium-chrome-driver @@ -26,22 +21,14 @@ dependencies { implementation 'io.github.bonigarcia:webdrivermanager:5.3.3' // https://mvnrepository.com/artifact/org.json/json implementation 'org.json:json:20230227' - // https://mvnrepository.com/artifact/com.google.apis/google-api-services-customsearch implementation 'com.google.apis:google-api-services-customsearch:v1-rev86-1.25.0' // https://mvnrepository.com/artifact/org.projectlombok/lombok compileOnly 'org.projectlombok:lombok:1.18.26' // https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-core implementation 'com.fasterxml.jackson.core:jackson-core:2.15.1' - - - - //Thanks for using https://jar-download.com - // https://mvnrepository.com/artifact/org.jsoup/jsoup implementation 'org.jsoup:jsoup:1.16.1' - - implementation 'com.fasterxml.jackson.core:jackson-databind:2.14.2' testImplementation(platform('org.junit:junit-bom:5.8.2')) testImplementation('org.junit.jupiter:junit-jupiter') diff --git a/lib/src/main/java/com/frazik/instructgpt/Agent.java b/lib/src/main/java/com/frazik/instructgpt/Agent.java index 82eceb1..6ef809e 100644 --- a/lib/src/main/java/com/frazik/instructgpt/Agent.java +++ b/lib/src/main/java/com/frazik/instructgpt/Agent.java @@ -37,7 +37,7 @@ public Agent(String name, String description, List goals, String model) this.name = name; this.description = description; this.goals = goals; - this.memory = new LocalMemory(new OpenAIEmbeddingProvider()); + this.memory = new LocalMemory(); this.tools = Arrays.asList(new Browser(), new GoogleSearch()); this.openAIModel = new OpenAIModel(model); } @@ -64,7 +64,7 @@ private List> getFullPrompt(String userInput) { prompt.add(currentTimePrompt.getPrompt()); // Retrieve relevant memory - List relevantMemory = memory.get(history.subListToString(Math.max(0, history.getSize() - 10), history.getSize()), 10); + List relevantMemory = memory.get(10); if (relevantMemory != null) { int tokenLimit = 2500; @@ -163,7 +163,7 @@ public Response chat(String message, boolean runTool) { .withRole("system") .formatted(0, this.stagingResponse, output, message) .build(); - this.memory.add(humanFeedbackPrompt.getContent(), null); + this.memory.add(humanFeedbackPrompt.getContent()); } else { Prompt noApprovePrompt = new Prompt.Builder("no_approve") .withRole("system") @@ -174,7 +174,7 @@ public Response chat(String message, boolean runTool) { .withRole("system") .formatted(0, this.stagingResponse, message) .build(); - this.memory.add(noApproveReplayPrompt.getContent(), null); + this.memory.add(noApproveReplayPrompt.getContent()); } this.stagingTool = null; this.stagingResponse = null; diff --git a/lib/src/main/java/com/frazik/instructgpt/memory/LocalMemory.java b/lib/src/main/java/com/frazik/instructgpt/memory/LocalMemory.java index d7fbdbb..ee4cab4 100644 --- a/lib/src/main/java/com/frazik/instructgpt/memory/LocalMemory.java +++ b/lib/src/main/java/com/frazik/instructgpt/memory/LocalMemory.java @@ -1,12 +1,8 @@ package com.frazik.instructgpt.memory; -import com.frazik.instructgpt.embedding.EmbeddingProvider; -import lombok.extern.slf4j.Slf4j; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; +import lombok.extern.slf4j.Slf4j;; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; @@ -14,54 +10,28 @@ public class LocalMemory extends Memory { private final List docs; - private INDArray embs; - private final EmbeddingProvider embeddingProvider; - public LocalMemory(EmbeddingProvider embeddingProvider) { + public LocalMemory() { super(); this.docs = new ArrayList<>(); - this.embs = null; - this.embeddingProvider = embeddingProvider; } @Override - public void add(String doc, String key) { - if (key == null) { - key = doc; - } - double[] embeddings = this.embeddingProvider.get(key); - INDArray emb = Nd4j.create(embeddings); - - if (this.embs == null) { - this.embs = Nd4j.expandDims(emb, 0); - } else { - this.embs = Nd4j.concat(0, this.embs, Nd4j.expandDims(emb, 0)); - } - this.docs.add(doc); + public void add(String doc) { + this.docs.add(0, doc); } @Override - public List get(String query, int k) { - if (this.embs == null) { - return new ArrayList<>(); - } - double[] embeddings = embeddingProvider.get(query); - INDArray scores; - try (INDArray emb = Nd4j.create(embeddings)) { - - scores = this.embs.mmul(emb); - } - int[] idxs = Nd4j.argMax(scores, 0).toIntVector(); - String[] results = new String[k]; - for (int i = 0; i < k; i++) { - results[i] = this.docs.get(idxs[i]); + public List get(int k) { + // get last k docs, or all docs if k > docs.size() + if (k > this.docs.size()) { + return this.docs; } - return Arrays.asList(results); + return this.docs.subList(0, k); } @Override public void clear() { this.docs.clear(); - this.embs = null; } } diff --git a/lib/src/main/java/com/frazik/instructgpt/memory/Memory.java b/lib/src/main/java/com/frazik/instructgpt/memory/Memory.java index 9a939c2..d556f57 100644 --- a/lib/src/main/java/com/frazik/instructgpt/memory/Memory.java +++ b/lib/src/main/java/com/frazik/instructgpt/memory/Memory.java @@ -1,7 +1,7 @@ package com.frazik.instructgpt.memory; import java.util.List; public abstract class Memory { - public abstract void add(String doc, String key); - public abstract List get(String query, int k); + public abstract void add(String doc); + public abstract List get(int k); public abstract void clear(); } diff --git a/lib/src/test/java/com/frazik/instructgpt/RegressionTests.java b/lib/src/test/java/com/frazik/instructgpt/RegressionTests.java index 4ada821..d01428e 100644 --- a/lib/src/test/java/com/frazik/instructgpt/RegressionTests.java +++ b/lib/src/test/java/com/frazik/instructgpt/RegressionTests.java @@ -28,19 +28,6 @@ void embeddingTest() { assertEquals(0.005179715, embeddings[embeddings.length - 2], 0.05); } - @Test - void memoryTest() { - EmbeddingProvider embeddingProvider = new OpenAIEmbeddingProvider(); - Memory memory = new LocalMemory(embeddingProvider); - - memory.add("Hello world1", null); - memory.add("Hello world2", null); - - List rez = memory.get("Hello world1", 1); - - assertEquals(Arrays.asList("Hello world1"), rez); - } - @Test void modelTest() { Model model = new OpenAIModel("gpt-3.5-turbo"); @@ -83,7 +70,7 @@ void basicChatResult() { Agent agent = new Agent(name, description, goals, "gpt-3.5-turbo"); Response resp = agent.chat(); - assertEquals(resp.getCommand(), "google_search"); + assertEquals("google_search", resp.getCommand()); } }