Skip to content

Commit

Permalink
chore(chat): update docs [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
aallam committed Sep 9, 2023
1 parent 688d9c8 commit f699272
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 86 deletions.
12 changes: 6 additions & 6 deletions guides/ChatFunctionCall.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,14 @@ val openAI = OpenAI(token)
Specify the model to use for the chat request.

```kotlin
val modelId = ModelId("gpt-3.5-turbo-0613")
val modelId = ModelId("gpt-3.5-turbo")
```

### Defining the Function

Define a dummy function `currentWeather` which the model might call. This function will return hardcoded weather information.

```kotlin
@Serializable
data class WeatherInfo(val location: String, val temperature: String, val unit: String, val forecast: List<String>)

/**
* Example dummy function hard coded to return the same weather
* In production, this could be your backend API or an external API
Expand All @@ -46,6 +43,9 @@ fun currentWeather(location: String, unit: String): String {
val weatherInfo = WeatherInfo(location, "72", unit, listOf("sunny", "windy"))
return Json.encodeToString(weatherInfo)
}

@Serializable
data class WeatherInfo(val location: String, val temperature: String, val unit: String, val forecast: List<String>)
```

### Defining Function Parameters
Expand Down Expand Up @@ -102,7 +102,7 @@ val request = chatCompletionRequest {
parameters = params
}
}
functionCall = FunctionMode.Auto
functionCall = FunctionMode.Named("currentWeather") // or FunctionMode.Auto
}
```

Expand Down Expand Up @@ -166,7 +166,7 @@ suspend fun main() {
val token = System.getenv("OPENAI_API_KEY")
val openAI = OpenAI(token)

val modelId = ModelId("gpt-3.5-turbo-0613")
val modelId = ModelId("gpt-3.5-turbo")
val chatMessages = mutableListOf(
ChatMessage(
role = ChatRole.User,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TestChatCompletions : TestOpenAI() {

@Test
fun chatCompletionsFunction() = test {
val modelId = ModelId("gpt-3.5-turbo-0613")
val modelId = ModelId("gpt-3.5-turbo")
val chatMessages = mutableListOf(
ChatMessage(
role = ChatRole.User,
Expand Down Expand Up @@ -94,7 +94,7 @@ class TestChatCompletions : TestOpenAI() {
}

val response = openAI.chatCompletion(request)
val message = response.choices.first().message ?: error("No chat response found!")
val message = response.choices.first().message
assertEquals("currentWeather", message.functionCall?.name)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import kotlinx.coroutines.runBlocking
fun main() = runBlocking {
val apiKey = System.getenv("OPENAI_API_KEY")
val token = requireNotNull(apiKey) { "OPENAI_API_KEY environment variable must be set." }
val openAI = OpenAI(token = token, logging = LoggingConfig(LogLevel.All))
val openAI = OpenAI(token = token, logging = LoggingConfig(LogLevel.None))

while (true) {
println("Select an option:")
Expand All @@ -21,8 +21,7 @@ fun main() = runBlocking {
println("7 - Whisper")
println("0 - Quit")

val option = readlnOrNull()?.toIntOrNull()
when (option) {
when (val option = readlnOrNull()?.toIntOrNull()) {
1 -> engines(openAI)
2 -> files(openAI)
3 -> moderations(openAI)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ package com.aallam.openai.sample.jvm
import com.aallam.openai.api.chat.*
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAI
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.*
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.*

/**
* This code snippet demonstrates the use of OpenAI's chat completion capabilities
* with a focus on integrating function calls into the chat conversation.
*/
suspend fun chatFunctionCall(openAI: OpenAI) {
// *** Chat Completion with Function Call *** //

println("\n> Create Chat Completion function call...")
val modelId = ModelId("gpt-3.5-turbo-0613")
val modelId = ModelId("gpt-3.5-turbo")
val chatMessages = mutableListOf(
ChatMessage(
role = ChatRole.User,
Expand Down Expand Up @@ -51,111 +53,125 @@ suspend fun chatFunctionCall(openAI: OpenAI) {
parameters = params
}
}
functionCall = FunctionMode.Auto
functionCall = FunctionMode.Named("currentWeather") // or FunctionMode.Auto
}

val response = openAI.chatCompletion(request)
val message = response.choices.first().message
chatMessages.append(message)
message.functionCall?.let { functionCall ->
val functionResponse = callFunction(functionCall)
updateChatMessages(chatMessages, message, functionCall, functionResponse)
val functionResponse = functionCall.execute()
chatMessages.append(functionCall, functionResponse)
val secondResponse = openAI.chatCompletion(
request = ChatCompletionRequest(
model = modelId,
messages = chatMessages,
)
request = ChatCompletionRequest(model = modelId, messages = chatMessages)
)
print(secondResponse)
}
print(secondResponse.choices.first().message.content.orEmpty())
} ?: print(message.content.orEmpty())

// *** Chat Completion Stream with Function Call *** //

println("\n> Create Chat Completion function call (stream)...")
val chunks = mutableListOf<ChatChunk>()
openAI.chatCompletions(request)
.onEach { chunks += it.choices.first() }
.onCompletion {
val chatMessage = chatMessageOf(chunks)
chatMessage.functionCall?.let {
val functionResponse = callFunction(it)
updateChatMessages(chatMessages, message, it, functionResponse)
}
}
.collect()
val chatMessage = openAI.chatCompletions(request)
.map { completion -> completion.choices.first() }
.fold(initial = ChatMessageAssembler()) { assembler, chunk -> assembler.merge(chunk) }
.build()

chatMessages.append(chatMessage)
chatMessage.functionCall?.let { functionCall ->
val functionResponse = functionCall.execute()
chatMessages.append(functionCall, functionResponse)
}

openAI.chatCompletions(
ChatCompletionRequest(
model = modelId,
messages = chatMessages,
)
)
openAI.chatCompletions(request = ChatCompletionRequest(model = modelId, messages = chatMessages))
.onEach { print(it.choices.first().delta.content.orEmpty()) }
.onCompletion { println() }
.collect()
}

@Serializable
data class WeatherInfo(val location: String, val temperature: String, val unit: String, val forecast: List<String>)
/**
* A map that associates function names with their corresponding functions.
*/
private val availableFunctions = mapOf("currentWeather" to ::callCurrentWeather)

/**
* Example dummy function hard coded to return the same weather
* In production, this could be your backend API or an external API
* Example dummy function for retrieving weather information based on location and temperature unit.
* In a production scenario, this function could be replaced with an actual backend or external API call.
*/
fun currentWeather(location: String, unit: String): String {
private fun callCurrentWeather(args: JsonObject): String {
val location = args.getValue("location").jsonPrimitive.content
val unit = args["unit"]?.jsonPrimitive?.content ?: "fahrenheit"
return currentWeather(location, unit)
}

/**
* Example dummy function for retrieving weather information based on location and temperature unit.
*/
private fun currentWeather(location: String, unit: String): String {
val weatherInfo = WeatherInfo(location, "72", unit, listOf("sunny", "windy"))
return Json.encodeToString(weatherInfo)
}

private fun callFunction(functionCall: FunctionCall): String {
val availableFunctions = mapOf("currentWeather" to ::currentWeather)
val functionToCall = availableFunctions[functionCall.name] ?: error("Function ${functionCall.name} not found")
val functionArgs = functionCall.argumentsAsJson()
/**
* Serializable data class to represent weather information.
*/
@Serializable
data class WeatherInfo(val location: String, val temperature: String, val unit: String, val forecast: List<String>)


return functionToCall(
functionArgs.getValue("location").jsonPrimitive.content,
functionArgs["unit"]?.jsonPrimitive?.content ?: "fahrenheit"
)
/**
* Executes a function call and returns its result.
*/
private fun FunctionCall.execute(): String {
val functionToCall = availableFunctions[name] ?: error("Function $name not found")
val functionArgs = argumentsAsJson()
return functionToCall(functionArgs)
}

private fun updateChatMessages(
chatMessages: MutableList<ChatMessage>,
message: ChatMessage,
functionCall: FunctionCall,
functionResponse: String
) {
chatMessages.add(
ChatMessage(
role = message.role,
content = message.content.orEmpty(), // required to not be empty in this case
functionCall = message.functionCall
)
)
chatMessages.add(
ChatMessage(role = ChatRole.Function, name = functionCall.name, content = functionResponse)
)
/**
* Appends a chat message to a list of chat messages.
*/
private fun MutableList<ChatMessage>.append(message: ChatMessage) {
add(ChatMessage(role = message.role, content = message.content.orEmpty(), functionCall = message.functionCall))
}

fun chatMessageOf(chunks: List<ChatChunk>): ChatMessage {
val funcName = StringBuilder()
val funcArgs = StringBuilder()
var role: ChatRole? = null
val content = StringBuilder()
/**
* Appends a function call and response to a list of chat messages.
*/
private fun MutableList<ChatMessage>.append(functionCall: FunctionCall, functionResponse: String) {
add(ChatMessage(role = ChatRole.Function, name = functionCall.name, content = functionResponse))
}

chunks.forEach { chunk ->
role = chunk.delta.role ?: role
chunk.delta.content?.let { content.append(it) }
/**
* A class to help assemble chat messages from chat chunks.
*/
class ChatMessageAssembler {
private val chatFuncName = StringBuilder()
private val chatFuncArgs = StringBuilder()
private val chatContent = StringBuilder()
private var chatRole: ChatRole? = null

/**
* Merges a chat chunk into the chat message being assembled.
*/
fun merge(chunk: ChatChunk): ChatMessageAssembler {
chatRole = chunk.delta.role ?: chatRole
chunk.delta.content?.let { chatContent.append(it) }
chunk.delta.functionCall?.let { call ->
funcName.append(call.name)
funcArgs.append(call.arguments)
call.nameOrNull?.let { chatFuncName.append(it) }
call.argumentsOrNull?.let { chatFuncArgs.append(it) }
}
return this
}

return chatMessage {
this.role = role
this.content = content.toString()
if (funcName.isNotEmpty() || funcArgs.isNotEmpty()) {
functionCall = FunctionCall(funcName.toString(), funcArgs.toString())
name = funcName.toString()
/**
* Builds and returns the assembled chat message.
*/
fun build(): ChatMessage = chatMessage {
this.role = chatRole
this.content = chatContent.toString()
if (chatFuncName.isNotEmpty() || chatFuncArgs.isNotEmpty()) {
this.functionCall = FunctionCall(chatFuncName.toString(), chatFuncArgs.toString())
this.name = chatFuncName.toString()
}
}
}

0 comments on commit f699272

Please sign in to comment.