-
Notifications
You must be signed in to change notification settings - Fork 495
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes #3634: Updated ML procs for Azure OpenAI services
- Loading branch information
Showing
7 changed files
with
271 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package apoc.ml; | ||
|
||
import apoc.util.TestUtil; | ||
import org.junit.BeforeClass; | ||
import org.junit.ClassRule; | ||
import org.junit.Test; | ||
import org.neo4j.test.rule.DbmsRule; | ||
import org.neo4j.test.rule.ImpermanentDbmsRule; | ||
|
||
import java.util.Map; | ||
import java.util.stream.Stream; | ||
|
||
import static apoc.ApocConfig.apocConfig; | ||
import static apoc.ml.OpenAI.API_TYPE_CONF_KEY; | ||
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; | ||
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; | ||
import static apoc.ml.OpenAITestUtils.getStringObjectMap; | ||
import static apoc.util.TestUtil.testCall; | ||
import static org.junit.Assume.assumeNotNull; | ||
|
||
public class OpenAIAzureIT { | ||
// In Azure, the endpoints can be different | ||
private static String OPENAI_EMBEDDING_URL; | ||
private static String OPENAI_CHAT_URL; | ||
private static String OPENAI_COMPLETION_URL; | ||
|
||
private static String OPENAI_AZURE_API_VERSION; | ||
|
||
private static String OPENAI_KEY; | ||
|
||
@ClassRule | ||
public static DbmsRule db = new ImpermanentDbmsRule(); | ||
|
||
@BeforeClass | ||
public static void setUp() throws Exception { | ||
OPENAI_KEY = System.getenv("OPENAI_KEY"); | ||
// Azure OpenAI base URLs | ||
OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL"); | ||
OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL"); | ||
OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL"); | ||
|
||
// Azure OpenAI query url (`<baseURL>/<type>/?api-version=<OPENAI_AZURE_API_VERSION>) | ||
OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION"); | ||
|
||
apocConfig().setProperty("ajeje", "brazorf"); | ||
|
||
Stream.of(OPENAI_EMBEDDING_URL, | ||
OPENAI_CHAT_URL, | ||
OPENAI_COMPLETION_URL, | ||
OPENAI_AZURE_API_VERSION, | ||
OPENAI_KEY) | ||
.forEach(key -> assumeNotNull("No " + key + " environment configured", key)); | ||
|
||
TestUtil.registerProcedure(db, OpenAI.class); | ||
} | ||
|
||
@Test | ||
public void embedding() { | ||
testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)", | ||
getParams(OPENAI_EMBEDDING_URL), | ||
OpenAITestUtils::extracted); | ||
} | ||
|
||
|
||
@Test | ||
public void completion() { | ||
testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)", | ||
getParams(OPENAI_COMPLETION_URL), OpenAITestUtils::extracted1); | ||
} | ||
|
||
@Test | ||
public void chatCompletion() { | ||
testCall(db, """ | ||
CALL apoc.ml.openai.chat([ | ||
{role:"system", content:"Only answer with a single word"}, | ||
{role:"user", content:"What planet do humans live on?"} | ||
], $apiKey, $conf) | ||
""", getParams(OPENAI_CHAT_URL), | ||
(row) -> getStringObjectMap(row, "gpt-35-turbo")); | ||
} | ||
|
||
private static Map<String, Object> getParams(String endpoint) { | ||
return Map.of("apiKey", OPENAI_KEY, | ||
"conf", Map.of(ENDPOINT_CONF_KEY, endpoint, | ||
API_TYPE_CONF_KEY, OpenAI.ApiType.AZURE.name(), | ||
API_VERSION_CONF_KEY, OPENAI_AZURE_API_VERSION | ||
) | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.