Skip to content

Commit

Permalink
Fixes #3634: Updated ML procs for Azure OpenAI services
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed Dec 1, 2023
1 parent e79aae4 commit 5de2761
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 53 deletions.
64 changes: 63 additions & 1 deletion docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,60 @@

NOTE: You need to acquire an https://platform.openai.com/account/api-keys[OpenAI API key^] to use these procedures. Using them will incur costs on your OpenAI account. You can set the api key globally by defining the `apoc.openai.key` configuration in `apoc.conf`



All the following procedures can have the following APOC config, i.e. in `apoc.conf` or via docker env variable
.Apoc configuration
|===
|key | description | default
| apoc.ml.openai.type | "AZURE" or "OPENAI", indicates whether the API is Azure or not | "OPENAI"
| apoc.ml.openai.url | the OpenAI endpoint base url | https://api.openai.com/v1
(or empty string if `apoc.ml.openai.type=AZURE`)
| apoc.ml.azure.api.version | in case of `apoc.ml.openai.type=AZURE`, indicates the `api-version` to be passed after the `?api-version=` url
|===


Moreover, they can have the following configuration keys, as the last parameter.
If present, they take precedence over the analogous APOC configs.

.Common configuration parameter

|===
| key | description
| apiType | analogous to `apoc.ml.openai.type` APOC config
| endpoint | analogous to `apoc.ml.openai.url` APOC config
| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config
|===


Therefore, we can use the following procedures with the Open AI Services provided by Azure,
pointing to the correct endpoints https://learn.microsoft.com/it-it/azure/ai-services/openai/reference[as explained in the documentation].

That is, if we want to call an endpoint like https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/embeddings?api-version=my-api-version` for example,
by passing as a configuration parameter:
```
{endpoint: "https://my-resource.openai.azure.com/openai/deployments/my-deployment-id",
apiVersion: my-api-version,
apiType: 'AZURE'
}
```

The `/embeddings` portion will be added under-the-hood.
Similarly, if we use the `apoc.ml.openai.completion`, if we want to call an endpoint like `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/completions?api-version=my-api-version` for example,
we can write the same configuration parameter as above,
where the `/completions` portion will be added.

While using the `apoc.ml.openai.chat`, with the same configuration, the url portion `/chat/completions` will be added

Or else, we can write this `apoc.conf`:
```
apoc.ml.openai.url=https://my-resource.openai.azure.com/openai/deployments/my-deployment-id
apoc.ml.azure.api.version=my-api-version
apoc.ml.openai.type=AZURE
```



== Generate Embeddings API

This procedure `apoc.ml.openai.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector.
Expand All @@ -30,7 +84,15 @@ CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, {}) yield index, text, emb
|name | description
| texts | List of text strings
| apiKey | OpenAI API key
| configuration | optional map for entries like model and other request parameters
| configuration | optional map for entries like model and other request parameters.

We can also pass a custom `endpoint: <MyAndPointKey>` entry (it takes precedence over the `apoc.ml.openai.url` config).
The `<MyAndPointKey>` can be the complete andpoint (e.g. using Azure: `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/chat/completions?api-version=my-api-version`),
or with a `%s` (e.g. using Azure: `https://my-resource.openai.azure.com/openai/deployments/my-deployment-id/%s?api-version=my-api-version`) which will eventually be replaced with `embeddings`, `chat/completion` and `completion`
by using respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.chat` and `apoc.ml.openai.completion`.

Or an `authType: `AUTH_TYPE`, which can be `authType: "BEARER"` (default config.), to pass the apiKey via the header as an `Authorization: Bearer $apiKey`,
or `authType: "API_KEY"` to pass the apiKey as an `api-key: $apiKey` header entry.
|===


Expand Down
3 changes: 3 additions & 0 deletions extended/src/main/java/apoc/ExtendedApocConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ public class ExtendedApocConfig extends LifecycleAdapter
public static final String APOC_UUID_ENABLED_DB = "apoc.uuid.enabled.%s";
public static final String APOC_UUID_FORMAT = "apoc.uuid.format";
public static final String APOC_OPENAI_KEY = "apoc.openai.key";
public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url";
public static final String APOC_ML_OPENAI_TYPE = "apoc.ml.openai.type";
public static final String APOC_ML_OPENAI_AZURE_VERSION = "apoc.ml.azure.api.version";
public static final String APOC_AWS_KEY_ID = "apoc.aws.key.id";
public static final String APOC_AWS_SECRET_KEY = "apoc.aws.secret.key";
public enum UuidFormatType { hex, base64 }
Expand Down
69 changes: 59 additions & 10 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,41 @@
import apoc.Extended;
import apoc.util.JsonUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import apoc.result.MapResult;

import com.fasterxml.jackson.databind.ObjectMapper;

import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL;
import static apoc.ExtendedApocConfig.APOC_OPENAI_KEY;
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION;
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_TYPE;


@Extended
public class OpenAI {
enum ApiType { AZURE, OPENAI }

public static final String API_TYPE_CONF_KEY = "apiType";
public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";

@Context
public ApocConfig apocConfig;

public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url";

public static class EmbeddingResult {
public final long index;
public final String text;
Expand All @@ -43,25 +52,65 @@ public EmbeddingResult(long index, String text, List<Double> embedding) {
}

static Stream<Object> executeRequest(String apiKey, Map<String, Object> configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig) throws JsonProcessingException, MalformedURLException {
apiKey = apocConfig.getString(APOC_OPENAI_KEY, apiKey);
apiKey = (String) configuration.getOrDefault(APOC_OPENAI_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
if (apiKey == null || apiKey.isBlank())
throw new IllegalArgumentException("API Key must not be empty");
String endpoint = System.getProperty(APOC_ML_OPENAI_URL,"https://api.openai.com/v1/");
Map<String, Object> headers = Map.of(
"Content-Type", "application/json",
"Authorization", "Bearer " + apiKey


String apiTypeString = (String) configuration.getOrDefault(API_TYPE_CONF_KEY,
apocConfig.getString(APOC_ML_OPENAI_TYPE, ApiType.OPENAI.name())
);
ApiType apiType = ApiType.valueOf(apiTypeString);

String endpoint = (String) configuration.get(ENDPOINT_CONF_KEY);

String apiVersion;
Map<String, Object> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
switch (apiType) {
case AZURE -> {
endpoint = getEndpoint(endpoint, apocConfig, "");
apiVersion = "?api-version=" + configuration.getOrDefault(API_VERSION_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_AZURE_VERSION));
headers.put("api-key", apiKey);
}
default -> {
endpoint = getEndpoint(endpoint, apocConfig, "https://api.openai.com/v1");
apiVersion = "";
headers.put("Authorization", "Bearer " + apiKey);
}
}

var config = new HashMap<>(configuration);
// we remove these keys from config, since the json payload is calculated starting from the config map
Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY).forEach(config::remove);
config.putIfAbsent("model", model);
config.put(key, inputs);

String payload = new ObjectMapper().writeValueAsString(config);

var url = new URL(new URL(endpoint), path).toString();

// new URL(endpoint), path) can produce a wrong path, since endpoint can have for example embedding,
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
// therefore is better to join the not-empty path pieces
var url = Stream.of(endpoint, path, apiVersion)
.filter(StringUtils::isNotBlank)
.collect(Collectors.joining(File.separator));
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of());
}

private static String getEndpoint(String endpointConfMap, ApocConfig apocConfig, String defaultUrl) {
if (endpointConfMap != null) {
return endpointConfMap;
}

String apocConfUrl = apocConfig.getString(APOC_ML_OPENAI_URL, null);
if (apocConfUrl != null) {
return apocConfUrl;
}

return System.getProperty(APOC_ML_OPENAI_URL, defaultUrl);
}


@Procedure("apoc.ml.openai.embedding")
@Description("apoc.openai.embedding([texts], api_key, configuration) - returns the embeddings for a given text")
public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
Expand Down
90 changes: 90 additions & 0 deletions extended/src/test/java/apoc/ml/OpenAIAzureIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package apoc.ml;

import apoc.util.TestUtil;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.neo4j.test.rule.DbmsRule;
import org.neo4j.test.rule.ImpermanentDbmsRule;

import java.util.Map;
import java.util.stream.Stream;

import static apoc.ApocConfig.apocConfig;
import static apoc.ml.OpenAI.API_TYPE_CONF_KEY;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.OpenAITestUtils.getStringObjectMap;
import static apoc.util.TestUtil.testCall;
import static org.junit.Assume.assumeNotNull;

public class OpenAIAzureIT {
// In Azure, the endpoints can be different
private static String OPENAI_EMBEDDING_URL;
private static String OPENAI_CHAT_URL;
private static String OPENAI_COMPLETION_URL;

private static String OPENAI_AZURE_API_VERSION;

private static String OPENAI_KEY;

@ClassRule
public static DbmsRule db = new ImpermanentDbmsRule();

@BeforeClass
public static void setUp() throws Exception {
OPENAI_KEY = System.getenv("OPENAI_KEY");
// Azure OpenAI base URLs
OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL");
OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL");
OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL");

// Azure OpenAI query url (`<baseURL>/<type>/?api-version=<OPENAI_AZURE_API_VERSION>)
OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION");

apocConfig().setProperty("ajeje", "brazorf");

Stream.of(OPENAI_EMBEDDING_URL,
OPENAI_CHAT_URL,
OPENAI_COMPLETION_URL,
OPENAI_AZURE_API_VERSION,
OPENAI_KEY)
.forEach(key -> assumeNotNull("No " + key + " environment configured", key));

TestUtil.registerProcedure(db, OpenAI.class);
}

@Test
public void embedding() {
testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)",
getParams(OPENAI_EMBEDDING_URL),
OpenAITestUtils::extracted);
}


@Test
public void completion() {
testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)",
getParams(OPENAI_COMPLETION_URL), OpenAITestUtils::extracted1);
}

@Test
public void chatCompletion() {
testCall(db, """
CALL apoc.ml.openai.chat([
{role:"system", content:"Only answer with a single word"},
{role:"user", content:"What planet do humans live on?"}
], $apiKey, $conf)
""", getParams(OPENAI_CHAT_URL),
(row) -> getStringObjectMap(row, "gpt-35-turbo"));
}

private static Map<String, Object> getParams(String endpoint) {
return Map.of("apiKey", OPENAI_KEY,
"conf", Map.of(ENDPOINT_CONF_KEY, endpoint,
API_TYPE_CONF_KEY, OpenAI.ApiType.AZURE.name(),
API_VERSION_CONF_KEY, OPENAI_AZURE_API_VERSION
)
);
}
}
48 changes: 7 additions & 41 deletions extended/src/test/java/apoc/ml/OpenAIIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.List;
import java.util.Map;

import static apoc.ml.OpenAITestUtils.getStringObjectMap;
import static apoc.util.TestUtil.testCall;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand All @@ -34,34 +35,15 @@ public void setUp() throws Exception {

@Test
public void getEmbedding() {
testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey)", Map.of("apiKey",openaiKey),(row) -> {
System.out.println("row = " + row);
assertEquals(0L, row.get("index"));
assertEquals("Some Text", row.get("text"));
var embedding = (List<Double>) row.get("embedding");
assertEquals(1536, embedding.size());
assertEquals(true, embedding.stream().allMatch(d -> d instanceof Double));
});
testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey)", Map.of("apiKey",openaiKey),
OpenAITestUtils::extracted);
}

@Test
public void completion() {
testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey)",
Map.of("apiKey",openaiKey),(row) -> {
System.out.println("row = " + row);
var result = (Map<String,Object>)row.get("value");
assertEquals(true, result.get("created") instanceof Number);
assertEquals(true, result.containsKey("choices"));
var finishReason = (String)((List<Map>) result.get("choices")).get(0).get("finish_reason");
assertEquals(true, finishReason.matches("stop|length"));
String text = (String) ((List<Map>) result.get("choices")).get(0).get("text");
assertEquals(true, text != null && !text.isBlank());
assertEquals(true, text.toLowerCase().contains("blue"));
assertEquals(true, result.containsKey("usage"));
assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number);
assertEquals("text-davinci-003", result.get("model"));
assertEquals("text_completion", result.get("object"));
});
Map.of("apiKey",openaiKey),
OpenAITestUtils::extracted1);
}

@Test
Expand All @@ -71,24 +53,8 @@ public void chatCompletion() {
{role:"system", content:"Only answer with a single word"},
{role:"user", content:"What planet do humans live on?"}
], $apiKey)
""", Map.of("apiKey",openaiKey), (row) -> {
System.out.println("row = " + row);
var result = (Map<String,Object>)row.get("value");
assertEquals(true, result.get("created") instanceof Number);
assertEquals(true, result.containsKey("choices"));

Map message = ((List<Map<String,Map>>) result.get("choices")).get(0).get("message");
assertEquals("assistant", message.get("role"));
// assertEquals("stop", message.get("finish_reason"));
String text = (String) message.get("content");
assertEquals(true, text != null && !text.isBlank());


assertEquals(true, result.containsKey("usage"));
assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number);
assertTrue(result.get("model").toString().startsWith("gpt-3.5-turbo"));
assertEquals("chat.completion", result.get("object"));
});
""", Map.of("apiKey",openaiKey),
(row) -> getStringObjectMap(row, "gpt-3.5-turbo"));

/*
{
Expand Down
3 changes: 2 additions & 1 deletion extended/src/test/java/apoc/ml/OpenAITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import static apoc.ApocConfig.APOC_IMPORT_FILE_ENABLED;
import static apoc.ApocConfig.apocConfig;
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL;
import static apoc.util.TestUtil.getUrlFileName;
import static apoc.util.TestUtil.testCall;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand All @@ -32,7 +33,7 @@ public void setUp() throws Exception {
// openaiKey = System.getenv("OPENAI_KEY");
// Assume.assumeNotNull("No OPENAI_KEY environment configured", openaiKey);
var path = Paths.get(getUrlFileName("embeddings").toURI()).getParent().toUri();
System.setProperty(OpenAI.APOC_ML_OPENAI_URL, path.toString());
System.setProperty(APOC_ML_OPENAI_URL, path.toString());
apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true);
TestUtil.registerProcedure(db, OpenAI.class);
}
Expand Down
Loading

0 comments on commit 5de2761

Please sign in to comment.