Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Add streaming tests #225

Merged
merged 1 commit into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ OpenAiApi api = retrofit.create(OpenAiApi.class);
OpenAiService service = new OpenAiService(api);
```


### Streaming thread shutdown
If you want to shut down your process immediately after streaming responses, call `OpenAiService.shutdown()`.
This is not necessary for non-streaming calls.

## Running the example project
All the [example](example/src/main/java/example/OpenAiApiExample.java) project requires is your OpenAI api token
Expand Down
2 changes: 1 addition & 1 deletion client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ apply plugin: "com.vanniktech.maven.publish"
dependencies {
api project(":api")
api 'com.squareup.retrofit2:retrofit:2.9.0'
implementation 'com.squareup.retrofit2:adapter-rxjava2:2.9.0'
api 'com.squareup.retrofit2:adapter-rxjava2:2.9.0'
implementation 'com.squareup.retrofit2:converter-jackson:2.9.0'

testImplementation(platform('org.junit:junit-bom:5.8.2'))
Expand Down
26 changes: 26 additions & 0 deletions example/src/main/java/example/OpenAiApiExample.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package example;

import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.service.OpenAiService;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.image.CreateImageRequest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

class OpenAiApiExample {
public static void main(String... args) {
String token = System.getenv("OPENAI_TOKEN");
Expand All @@ -26,5 +33,24 @@ public static void main(String... args) {

System.out.println("\nImage is located at:");
System.out.println(service.createImage(request).getData().get(0).getUrl());

System.out.println("Streaming chat completion...");
final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a dog and will speak as such.");
messages.add(systemMessage);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo")
.messages(messages)
.n(1)
.maxTokens(50)
.logitBias(new HashMap<>())
.build();

service.streamChatCompletion(chatCompletionRequest)
.doOnError(Throwable::printStackTrace)
.blockingForEach(System.out::println);

service.shutdownExecutor();
}
}
79 changes: 0 additions & 79 deletions example/src/main/java/example/OpenAiApiStreamExample.java

This file was deleted.

3 changes: 2 additions & 1 deletion service/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ dependencies {
implementation 'com.squareup.retrofit2:converter-jackson:2.9.0'

testImplementation(platform('org.junit:junit-bom:5.8.2'))
testImplementation('org.junit.jupiter:junit-jupiter')
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'com.squareup.retrofit2:retrofit-mock:2.9.0'
}

compileJava {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import okhttp3.Response;

import java.io.IOException;
import java.util.Objects;

/**
* OkHttp Interceptor that adds an authorization token header
Expand All @@ -14,6 +15,7 @@ public class AuthenticationInterceptor implements Interceptor {
private final String token;

AuthenticationInterceptor(String token) {
Objects.requireNonNull(token, "OpenAI token required");
this.token = token;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class OpenAiService {

private static final String BASE_URL = "https://api.openai.com/";
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(10);
private static final ObjectMapper errorMapper = defaultObjectMapper();
private static final ObjectMapper mapper = defaultObjectMapper();

private final OpenAiApi api;
private final ExecutorService executorService;
Expand All @@ -72,24 +72,34 @@ public OpenAiService(final String token) {
* @param timeout http read timeout, Duration.ZERO means no timeout
*/
public OpenAiService(final String token, final Duration timeout) {
this(defaultClient(token, timeout));
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);

this.api = retrofit.create(OpenAiApi.class);
this.executorService = client.dispatcher().executorService();
}

/**
* Creates a new OpenAiService that wraps OpenAiApi
* Creates a new OpenAiService that wraps OpenAiApi.
* Use this if you need more customization, but use OpenAiService(api, executorService) if you use streaming and
* want to shut down instantly
*
* @param client OkHttpClient to be used for api calls
* @param api OpenAiApi instance to use for all methods
*/
public OpenAiService(OkHttpClient client){
this(buildApi(client), client.dispatcher().executorService());
public OpenAiService(final OpenAiApi api) {
this.api = api;
this.executorService = null;
}

/**
* Creates a new OpenAiService that wraps OpenAiApi.
* The ExecutoryService must be the one you get from the client you created the api with
* otherwise shutdownExecutor() won't work. Use this if you need more customization.
* The ExecutorService must be the one you get from the client you created the api with
* otherwise shutdownExecutor() won't work.
* <p>
* Use this if you need more customization.
*
* @param api OpenAiApi instance to use for all methods
* @param api OpenAiApi instance to use for all methods
* @param executorService the ExecutorService from client.dispatcher().executorService()
*/
public OpenAiService(final OpenAiApi api, final ExecutorService executorService) {
Expand All @@ -109,37 +119,21 @@ public CompletionResult createCompletion(CompletionRequest request) {
return execute(api.createCompletion(request));
}

public Flowable<byte[]> streamCompletionBytes(CompletionRequest request) {
request.setStream(true);

return stream(api.createCompletionStream(request), true).map(sse -> {
return sse.toBytes();
});
}

public Flowable<CompletionChunk> streamCompletion(CompletionRequest request) {
request.setStream(true);
return stream(api.createCompletionStream(request), CompletionChunk.class);
}
request.setStream(true);

return stream(api.createCompletionStream(request), CompletionChunk.class);
}

public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
return execute(api.createChatCompletion(request));
}

public Flowable<byte[]> streamChatCompletionBytes(ChatCompletionRequest request) {
request.setStream(true);
public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
request.setStream(true);

return stream(api.createChatCompletionStream(request), true).map(sse -> {
return sse.toBytes();
});
}

public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
request.setStream(true);

return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
}
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
}

public EditResult createEdit(EditRequest request) {
return execute(api.createEdit(request));
Expand Down Expand Up @@ -271,7 +265,7 @@ public static <T> T execute(Single<T> apiCall) {
}
String errorBody = e.response().errorBody().string();

OpenAiError error = errorMapper.readValue(errorBody, OpenAiError.class);
OpenAiError error = mapper.readValue(errorBody, OpenAiError.class);
throw new OpenAiHttpException(error, e, e.code());
} catch (IOException ex) {
// couldn't parse OpenAI error
Expand All @@ -283,52 +277,50 @@ public static <T> T execute(Single<T> apiCall) {
/**
* Calls the Open AI api and returns a Flowable of SSE for streaming
* omitting the last message.
*
*
* @param apiCall The api call
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall) {
return stream(apiCall, false);
}
return stream(apiCall, false);
}

/**
* Calls the Open AI api and returns a Flowable of SSE for streaming.
*
* @param apiCall The api call
*
* @param apiCall The api call
* @param emitDone If true the last message ([DONE]) is emitted
*/
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
return Flowable.create(emitter -> {
apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone));
}, BackpressureStrategy.BUFFER);
}
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
return Flowable.create(emitter -> apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone)), BackpressureStrategy.BUFFER);
}

/**
* Calls the Open AI api and returns a Flowable of type T for streaming
* omitting the last message.
*
*
* @param apiCall The api call
* @param cl Class of type T to return
* @param cl Class of type T to return
*/
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
return stream(apiCall).map(sse -> {
return errorMapper.readValue(sse.getData(), cl);
});
}
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
return stream(apiCall).map(sse -> mapper.readValue(sse.getData(), cl));
}

/**
* Shuts down the OkHttp ExecutorService.
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
* is to shutdown after an idle timeout of 60s.
* Call this method to shutdown the ExecutorService immediately.
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
* is to shut down after an idle timeout of 60s.
* Call this method to shut down the ExecutorService immediately.
*/
public void shutdownExecutor(){
public void shutdownExecutor() {
Objects.requireNonNull(this.executorService, "executorService must be set in order to shut down");
this.executorService.shutdown();
}

public static OpenAiApi buildApi(OkHttpClient client) {
public static OpenAiApi buildApi(String token, Duration timeout) {
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient(token, timeout);
Retrofit retrofit = defaultRetrofit(client, mapper);

return retrofit.create(OpenAiApi.class);
}

Expand All @@ -341,8 +333,6 @@ public static ObjectMapper defaultObjectMapper() {
}

public static OkHttpClient defaultClient(String token, Duration timeout) {
Objects.requireNonNull(token, "OpenAI token required");

return new OkHttpClient.Builder()
.addInterceptor(new AuthenticationInterceptor(token))
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))
Expand Down
Loading