Skip to content

Commit

Permalink
feat: Add reasoningEffort parameter to OpenAI API and Chat Options
Browse files Browse the repository at this point in the history
This commit introduces the `reasoningEffort` parameter to the OpenAI API integration, allowing control over the reasoning effort used by models like `o1-mini`.

Changes:
- Adds `reasoningEffort` field to `OpenAiApi.ChatCompletionRequest`.
- Adds `reasoningEffort` field and builder method to `OpenAiChatOptions`.

Signed-off-by: Alexandros Pappas <apappascs@gmail.com>
  • Loading branch information
apappascs committed Feb 4, 2025
1 parent 54463e6 commit ccb37fe
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ public class OpenAiChatOptions implements FunctionCallingOptions {
* Developer-defined tags and values used for filtering completions in the <a href="https://platform.openai.com/chat-completions">dashboard</a>.
*/
private @JsonProperty("metadata") Map<String, String> metadata;

/**
* Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high.
* Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.
* Optional. Defaults to medium.
* Only for 'o1' models.
*/
private @JsonProperty("reasoning_effort") String reasoningEffort;

/**
* OpenAI Tool Function Callbacks to register with the ChatModel.
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
Expand Down Expand Up @@ -256,6 +265,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
.toolContext(fromOptions.getToolContext())
.store(fromOptions.getStore())
.metadata(fromOptions.getMetadata())
.reasoningEffort(fromOptions.getReasoningEffort())
.build();
}

Expand Down Expand Up @@ -520,6 +530,14 @@ public void setMetadata(Map<String, String> metadata) {
this.metadata = metadata;
}

public String getReasoningEffort() {
return this.reasoningEffort;
}

public void setReasoningEffort(String reasoningEffort) {
this.reasoningEffort = reasoningEffort;
}

@Override
public OpenAiChatOptions copy() {
return OpenAiChatOptions.fromOptions(this);
Expand All @@ -532,7 +550,7 @@ public int hashCode() {
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders,
this.proxyToolCalls, this.toolContext, this.outputModalities, this.outputAudio, this.store,
this.metadata);
this.metadata, this.reasoningEffort);
}

@Override
Expand Down Expand Up @@ -563,7 +581,8 @@ public boolean equals(Object o) {
&& Objects.equals(this.proxyToolCalls, other.proxyToolCalls)
&& Objects.equals(this.outputModalities, other.outputModalities)
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
&& Objects.equals(this.metadata, other.metadata);
&& Objects.equals(this.metadata, other.metadata)
&& Objects.equals(this.reasoningEffort, other.reasoningEffort);
}

@Override
Expand Down Expand Up @@ -740,6 +759,11 @@ public Builder metadata(Map<String, String> metadata) {
return this;
}

public Builder reasoningEffort(String reasoningEffort) {
this.options.reasoningEffort = reasoningEffort;
return this;
}

public OpenAiChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
* @author Mariusz Bernacki
* @author Thomas Vitale
* @author David Frizelle
* @author Alexandros Pappas
*/
public class OpenAiApi {

Expand Down Expand Up @@ -804,6 +805,7 @@ public record ChatCompletionRequest(// @formatter:off
@JsonProperty("messages") List<ChatCompletionMessage> messages,
@JsonProperty("model") String model,
@JsonProperty("store") Boolean store,
@JsonProperty("reasoning_effort") String reasoningEffort,
@JsonProperty("metadata") Map<String, String> metadata,
@JsonProperty("frequency_penalty") Double frequencyPenalty,
@JsonProperty("logit_bias") Map<String, Integer> logitBias,
Expand Down Expand Up @@ -836,7 +838,7 @@ public record ChatCompletionRequest(// @formatter:off
* @param temperature What sampling temperature to use, between 0 and 1.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, false, null, temperature, null,
null, null, null, null);
}
Expand All @@ -849,7 +851,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
* @param audio Parameters for audio output. Required when audio output is requested with outputModalities: ["audio"].
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, AudioParameters audio, boolean stream) {
this(messages, model, null, null, null, null, null, null,
this(messages, model, null, null, null, null, null, null, null,
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
null, null, null, stream, null, null, null,
null, null, null, null);
Expand All @@ -865,7 +867,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
* as they become available, with the stream terminated by a data: [DONE] message.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
this(messages, model, null, null, null, null, null, null, null, null, null,
this(messages, model, null, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, temperature, null,
null, null, null, null);
}
Expand All @@ -881,7 +883,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
List<FunctionTool> tools, Object toolChoice) {
this(messages, model, null, null, null, null, null, null, null, null, null,
this(messages, model, null, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, false, null, 0.8, null,
tools, toolChoice, null, null);
}
Expand All @@ -894,7 +896,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
* as they become available, with the stream terminated by a data: [DONE] message.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
this(messages, null, null, null, null, null, null, null, null, null, null,
this(messages, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, null, null,
null, null, null, null);
}
Expand All @@ -906,7 +908,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean strea
* @return A new {@link ChatCompletionRequest} with the specified stream options.
*/
public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
return new ChatCompletionRequest(this.messages, this.model, this.store, this.reasoningEffort, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
this.tools, this.toolChoice, this.parallelToolCalls, this.user);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
/**
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
*/
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiApiIT {
Expand All @@ -66,6 +67,25 @@ void chatCompletionStream() {
assertThat(response.collectList().block()).isNotNull();
}

@Test
void validateReasoningTokens() {
ChatCompletionMessage userMessage = new ChatCompletionMessage(
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null,
"low", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false,
null, null, null, null, null, null, null);
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

assertThat(response).isNotNull();
assertThat(response.getBody()).isNotNull();

OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = response.getBody()
.usage()
.completionTokenDetails();
assertThat(completionTokenDetails).isNotNull();
assertThat(completionTokenDetails.reasoningTokens()).isPositive();
}

@Test
void embeddings() {
ResponseEntity<EmbeddingList<Embedding>> response = this.openAiApi
Expand Down

0 comments on commit ccb37fe

Please sign in to comment.