Skip to content

Commit

Permalink
Merge pull request #37 from frazik-main/24-fix-localmemory-funcionality
Browse files Browse the repository at this point in the history
24 fix localmemory funcionality
  • Loading branch information
MatKollar authored Jul 19, 2023
2 parents 3070188 + 2aad892 commit a8500d3
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 72 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export OPENAI_API_KEY=<your-openai-api-key>
- Java 1.8+
- [An OpenAI API Key](https://platform.openai.com/account/api-keys)
- Windows/Linux/Mac
- Google API key

## Want to contribute?

Expand Down
8 changes: 8 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
plugins {
id "com.vanniktech.maven.publish" version "0.19.0" apply false
id "nebula.lint" version "18.1.0"
}

allprojects {
apply plugin :"java"
apply plugin :"nebula.lint"
// https://www.baeldung.com/gradle-finding-unused-dependencies
gradleLint {
rules=['unused-dependency']
}

plugins.withId("com.vanniktech.maven.publish") {
mavenPublish {
sonatypeHost = "S01"
Expand Down
13 changes: 0 additions & 13 deletions lib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@ dependencies {
implementation 'com.theokanning.openai-gpt3-java:service:0.12.0'
compileOnly 'org.projectlombok:lombok:1.18.24'
annotationProcessor 'org.projectlombok:lombok:1.18.24'

// https://mvnrepository.com/artifact/commons-cli/commons-cli
implementation 'commons-cli:commons-cli:1.4'

implementation "org.deeplearning4j:deeplearning4j-core:1.0.0-M1"

implementation 'com.knuddels:jtokkit:0.4.0'
implementation "org.nd4j:nd4j-native-platform:1.0.0-M1"
// https://mvnrepository.com/artifact/com.google.code.gson/gson
implementation 'com.google.code.gson:gson:2.10'
// https://mvnrepository.com/artifact/org.seleniumhq.selenium/selenium-chrome-driver
Expand All @@ -26,22 +21,14 @@ dependencies {
implementation 'io.github.bonigarcia:webdrivermanager:5.3.3'
// https://mvnrepository.com/artifact/org.json/json
implementation 'org.json:json:20230227'

// https://mvnrepository.com/artifact/com.google.apis/google-api-services-customsearch
implementation 'com.google.apis:google-api-services-customsearch:v1-rev86-1.25.0'
// https://mvnrepository.com/artifact/org.projectlombok/lombok
compileOnly 'org.projectlombok:lombok:1.18.26'
// https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-core
implementation 'com.fasterxml.jackson.core:jackson-core:2.15.1'



//Thanks for using https://jar-download.com

// https://mvnrepository.com/artifact/org.jsoup/jsoup
implementation 'org.jsoup:jsoup:1.16.1'


implementation 'com.fasterxml.jackson.core:jackson-databind:2.14.2'
testImplementation(platform('org.junit:junit-bom:5.8.2'))
testImplementation('org.junit.jupiter:junit-jupiter')
Expand Down
8 changes: 4 additions & 4 deletions lib/src/main/java/com/frazik/instructgpt/Agent.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public Agent(String name, String description, List<String> goals, String model)
this.name = name;
this.description = description;
this.goals = goals;
this.memory = new LocalMemory(new OpenAIEmbeddingProvider());
this.memory = new LocalMemory();
this.tools = Arrays.asList(new Browser(), new GoogleSearch());
this.openAIModel = new OpenAIModel(model);
}
Expand All @@ -64,7 +64,7 @@ private List<Map<String, String>> getFullPrompt(String userInput) {
prompt.add(currentTimePrompt.getPrompt());

// Retrieve relevant memory
List<String> relevantMemory = memory.get(history.subListToString(Math.max(0, history.getSize() - 10), history.getSize()), 10);
List<String> relevantMemory = memory.get(10);

if (relevantMemory != null) {
int tokenLimit = 2500;
Expand Down Expand Up @@ -163,7 +163,7 @@ public Response chat(String message, boolean runTool) {
.withRole("system")
.formatted(0, this.stagingResponse, output, message)
.build();
this.memory.add(humanFeedbackPrompt.getContent(), null);
this.memory.add(humanFeedbackPrompt.getContent());
} else {
Prompt noApprovePrompt = new Prompt.Builder("no_approve")
.withRole("system")
Expand All @@ -174,7 +174,7 @@ public Response chat(String message, boolean runTool) {
.withRole("system")
.formatted(0, this.stagingResponse, message)
.build();
this.memory.add(noApproveReplayPrompt.getContent(), null);
this.memory.add(noApproveReplayPrompt.getContent());
}
this.stagingTool = null;
this.stagingResponse = null;
Expand Down
48 changes: 9 additions & 39 deletions lib/src/main/java/com/frazik/instructgpt/memory/LocalMemory.java
Original file line number Diff line number Diff line change
@@ -1,67 +1,37 @@
package com.frazik.instructgpt.memory;

import com.frazik.instructgpt.embedding.EmbeddingProvider;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import lombok.extern.slf4j.Slf4j;;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


@Slf4j
public class LocalMemory extends Memory {

private final List<String> docs;
private INDArray embs;
private final EmbeddingProvider embeddingProvider;

public LocalMemory(EmbeddingProvider embeddingProvider) {
public LocalMemory() {
super();
this.docs = new ArrayList<>();
this.embs = null;
this.embeddingProvider = embeddingProvider;
}

@Override
public void add(String doc, String key) {
if (key == null) {
key = doc;
}
double[] embeddings = this.embeddingProvider.get(key);
INDArray emb = Nd4j.create(embeddings);

if (this.embs == null) {
this.embs = Nd4j.expandDims(emb, 0);
} else {
this.embs = Nd4j.concat(0, this.embs, Nd4j.expandDims(emb, 0));
}
this.docs.add(doc);
public void add(String doc) {
this.docs.add(0, doc);
}

@Override
public List<String> get(String query, int k) {
if (this.embs == null) {
return new ArrayList<>();
}
double[] embeddings = embeddingProvider.get(query);
INDArray scores;
try (INDArray emb = Nd4j.create(embeddings)) {

scores = this.embs.mmul(emb);
}
int[] idxs = Nd4j.argMax(scores, 0).toIntVector();
String[] results = new String[k];
for (int i = 0; i < k; i++) {
results[i] = this.docs.get(idxs[i]);
public List<String> get(int k) {
// get last k docs, or all docs if k > docs.size()
if (k > this.docs.size()) {
return this.docs;
}
return Arrays.asList(results);
return this.docs.subList(0, k);
}
@Override
public void clear() {
this.docs.clear();
this.embs = null;
}

}
4 changes: 2 additions & 2 deletions lib/src/main/java/com/frazik/instructgpt/memory/Memory.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.frazik.instructgpt.memory;
import java.util.List;
public abstract class Memory {
public abstract void add(String doc, String key);
public abstract List<String> get(String query, int k);
public abstract void add(String doc);
public abstract List<String> get(int k);
public abstract void clear();
}
15 changes: 1 addition & 14 deletions lib/src/test/java/com/frazik/instructgpt/RegressionTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,6 @@ void embeddingTest() {
assertEquals(0.005179715, embeddings[embeddings.length - 2], 0.05);
}

@Test
void memoryTest() {
EmbeddingProvider embeddingProvider = new OpenAIEmbeddingProvider();
Memory memory = new LocalMemory(embeddingProvider);

memory.add("Hello world1", null);
memory.add("Hello world2", null);

List<String> rez = memory.get("Hello world1", 1);

assertEquals(Arrays.asList("Hello world1"), rez);
}

@Test
void modelTest() {
Model model = new OpenAIModel("gpt-3.5-turbo");
Expand Down Expand Up @@ -83,7 +70,7 @@ void basicChatResult() {

Agent agent = new Agent(name, description, goals, "gpt-3.5-turbo");
Response resp = agent.chat();
assertEquals(resp.getCommand(), "google_search");
assertEquals("google_search", resp.getCommand());
}

}

0 comments on commit a8500d3

Please sign in to comment.