Skip to content

Commit

Permalink
Merge pull request #12 from vyfor/dev
Browse files Browse the repository at this point in the history
feat: Allow specifying default values in client configuration
  • Loading branch information
vyfor authored Nov 10, 2024
2 parents 47f069d + 788b29d commit 196dfd4
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 70 deletions.
103 changes: 58 additions & 45 deletions src/commonMain/kotlin/io/github/vyfor/groqkt/GroqClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import io.github.vyfor.groqkt.api.chat.ChatCompletionRequest
import io.github.vyfor.groqkt.api.chat.StreamingChatCompletion
import io.github.vyfor.groqkt.api.model.Model
import io.github.vyfor.groqkt.api.model.Models
import io.github.vyfor.groqkt.util.applyDefaults
import io.github.vyfor.groqkt.util.parse
import io.github.vyfor.groqkt.util.parseHeaders
import io.github.vyfor.groqkt.util.validate
Expand Down Expand Up @@ -67,6 +68,7 @@ class GroqClient(
contentType(ContentType.Application.Json)
setBody(
ChatCompletionRequest.Builder()
.applyDefaults(config.defaults?.chatCompletion)
.apply {
block()
stream = false
Expand Down Expand Up @@ -119,6 +121,7 @@ class GroqClient(
contentType(ContentType.Application.Json)
setBody(
ChatCompletionRequest.Builder()
.applyDefaults(config.defaults?.chatCompletion)
.apply {
block()
stream = true
Expand Down Expand Up @@ -165,7 +168,7 @@ class GroqClient(
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
append("model", data.model.id)
append("model", data.model!!.id)
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
Expand All @@ -186,21 +189,25 @@ class GroqClient(
.submitFormWithBinaryData(
AudioTranslationRequest.ENDPOINT,
formData {
AudioTranslationRequest.Builder().apply(block).build().let { data ->
append(
"file",
data.file,
Headers.build {
append(
HttpHeaders.ContentDisposition,
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
append("model", data.model.id)
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
}
AudioTranslationRequest.Builder()
.applyDefaults(config.defaults?.audioTranslation)
.apply(block)
.build()
.let { data ->
append(
"file",
data.file,
Headers.build {
append(
HttpHeaders.ContentDisposition,
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
append("model", data.model!!.id)
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
}
},
) {
contentType(ContentType.MultiPart.FormData)
Expand Down Expand Up @@ -229,7 +236,7 @@ class GroqClient(
},
)
}
append("model", data.model.id)
append("model", data.model!!.id)
data.language?.let { append("language", it) }
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
Expand Down Expand Up @@ -260,34 +267,40 @@ class GroqClient(
.submitFormWithBinaryData(
AudioTranscriptionRequest.ENDPOINT,
formData {
AudioTranscriptionRequest.Builder().apply(block).build().let { data ->
data.file?.let {
append(
"file",
it,
Headers.build {
append(
HttpHeaders.ContentDisposition,
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
}
append("model", data.model.id)
data.url?.let { append("url", it) }
data.language?.let { append("language", it) }
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
data.timestampGranularities?.let {
append(
"timestamp_granularities",
config.json.encodeToString(
buildJsonArray { it.forEach { enum -> add(JsonPrimitive(enum.value)) } }
.toString(),
),
)
}
}
AudioTranscriptionRequest.Builder()
.applyDefaults(config.defaults?.audioTranscription)
.apply(block)
.build()
.let { data ->
data.file?.let {
append(
"file",
it,
Headers.build {
append(
HttpHeaders.ContentDisposition,
"filename=\"${data.filename.encodeURLPathPart()}\"")
},
)
}
append("model", data.model!!.id)
data.url?.let { append("url", it) }
data.language?.let { append("language", it) }
data.prompt?.let { append("prompt", it) }
data.responseFormat?.let { append("response_format", it.name) }
data.temperature?.let { append("temperature", it.toString()) }
data.timestampGranularities?.let {
append(
"timestamp_granularities",
config.json.encodeToString(
buildJsonArray {
it.forEach { enum -> add(JsonPrimitive(enum.value)) }
}
.toString(),
),
)
}
}
},
) {
contentType(ContentType.MultiPart.FormData)
Expand Down
64 changes: 64 additions & 0 deletions src/commonMain/kotlin/io/github/vyfor/groqkt/GroqConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
package io.github.vyfor.groqkt

import io.github.vyfor.groqkt.GroqClient.Companion.BASE_URL
import io.github.vyfor.groqkt.api.audio.transcription.AudioTranscriptionRequest
import io.github.vyfor.groqkt.api.audio.translation.AudioTranslationRequest
import io.github.vyfor.groqkt.api.chat.ChatCompletionRequest
import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.*
Expand All @@ -19,10 +22,27 @@ import kotlinx.serialization.json.JsonNamingStrategy
*
* @property json The JSON serializer.
* @property client The HTTP client.
* @property defaults The default values for the [GroqClient]. These values are applied to every
* request made using a DSL function.
*/
data class GroqConfig(
val json: Json,
val client: HttpClient,
val defaults: GroqDefaults?,
)

/**
* Default values for use with the [GroqClient]. These values are applied to every request made
* using a DSL function.
*
* @property chatCompletion The default values for [ChatCompletionRequest].
* @property audioTranslation The default values for [AudioTranslationRequest].
* @property audioTranscription The default values for [AudioTranscriptionRequest].
*/
data class GroqDefaults(
val chatCompletion: (ChatCompletionRequest.Builder.() -> Unit)? = null,
val audioTranslation: (AudioTranslationRequest.Builder.() -> Unit)? = null,
val audioTranscription: (AudioTranscriptionRequest.Builder.() -> Unit)? = null,
)

/**
Expand All @@ -42,6 +62,7 @@ class GroqConfigBuilder(
namingStrategy = JsonNamingStrategy.SnakeCase
classDiscriminatorMode = ClassDiscriminatorMode.NONE
}
private var defaults: GroqDefaults? = null
var client: HttpClient = HttpClient {
install(ContentNegotiation) { json(json) }

Expand Down Expand Up @@ -74,9 +95,52 @@ class GroqConfigBuilder(
}
}

/**
* Sets the default values for the [GroqClient]. These values are applied to every request made
* using a DSL function.
*
* @param block The default values for the [GroqClient].
*/
fun defaults(block: GroqDefaultsBuilder.() -> Unit) {
defaults = GroqDefaultsBuilder().apply(block).build()
}

internal fun build(): GroqConfig =
GroqConfig(
json,
client,
defaults,
)
}

/**
* Groq defaults builder class.
*
* @property chatCompletion The default values for [ChatCompletionRequest].
* @property audioTranslation The default values for [AudioTranslationRequest].
* @property audioTranscription The default values for [AudioTranscriptionRequest].
*/
class GroqDefaultsBuilder {
private var chatCompletion: (ChatCompletionRequest.Builder.() -> Unit)? = null
private var audioTranslation: (AudioTranslationRequest.Builder.() -> Unit)? = null
private var audioTranscription: (AudioTranscriptionRequest.Builder.() -> Unit)? = null

fun chatCompletion(block: ChatCompletionRequest.Builder.() -> Unit) {
chatCompletion = block
}

fun audioTranslation(block: AudioTranslationRequest.Builder.() -> Unit) {
audioTranslation = block
}

fun audioTranscription(block: AudioTranscriptionRequest.Builder.() -> Unit) {
audioTranscription = block
}

internal fun build(): GroqDefaults =
GroqDefaults(
chatCompletion,
audioTranslation,
audioTranscription,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ data class AudioTranscriptionRequest(
val file: ByteArray? = null,
val url: String? = null,
val language: String? = null,
val model: GroqModel,
val model: GroqModel?,
val prompt: String? = null,
val responseFormat: AudioResponseFormat? = null,
val temperature: Double? = null,
Expand All @@ -49,6 +49,7 @@ data class AudioTranscriptionRequest(
var filename: String = "audio.mp3"

init {
require(model != null) { "model must be set" }
require(file != null || url != null) { "either file or url must be set" }
}

Expand Down Expand Up @@ -130,7 +131,7 @@ data class AudioTranscriptionRequest(
file,
url,
language,
requireNotNull(model) { "model must be set" },
model,
prompt,
responseFormat,
temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ import kotlinx.serialization.Serializable
@Serializable
data class AudioTranslationRequest(
val file: ByteArray,
val model: GroqModel,
val model: GroqModel?,
val prompt: String? = null,
val responseFormat: AudioResponseFormat? = null,
val temperature: Double? = null,
) {
var filename: String = "audio.mp3"

init {
require(model != null) { "model must be set" }
}

companion object {
const val ENDPOINT = "audio/translations"
}
Expand Down Expand Up @@ -103,7 +107,7 @@ data class AudioTranslationRequest(
fun build(): AudioTranslationRequest {
return AudioTranslationRequest(
requireNotNull(file) { "file must be set" },
requireNotNull(model) { "model must be set" },
model,
prompt,
responseFormat,
temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ data class ChatCompletionRequest(
/* val logprobs: Boolean? = null, */
val maxTokens: Int? = null,
val messages: List<CompletionMessage>,
val model: GroqModel,
val model: GroqModel?,
val n: Int? = null,
val parallelToolCalls: Boolean? = null,
var presencePenalty: Double? = null,
Expand All @@ -96,9 +96,10 @@ data class ChatCompletionRequest(
}

init {
require(n == null || n == 1) { "Currently only n = 1 is supported." }
require(model != null) { "model must be set" }
require(n == null || n == 1) { "currently only n = 1 is supported." }
require(streamOptions == null || stream == true) { "streamOptions must have stream = true." }
require(tools == null || tools.size <= 128) { "Currently only up to 128 tools are supported." }
require(tools == null || tools.size <= 128) { "currently only up to 128 tools are supported." }
require(messages.isNotEmpty()) { "messages must not be empty." }
presencePenalty = presencePenalty?.coerceIn(-2.0, 2.0)
temperature = temperature?.coerceIn(-2.0, 2.0)
Expand Down Expand Up @@ -210,7 +211,7 @@ data class ChatCompletionRequest(
functions,
maxTokens,
requireNotNull(messages) { "messages must be set" },
requireNotNull(model) { "model must be set" },
model,
n,
parallelToolCalls,
presencePenalty,
Expand Down Expand Up @@ -243,18 +244,13 @@ class ChatCompletionMessageBuilder {
fun image(image: String) {
messages.add(
CompletionMessage.User(
UserMessageType.Array(
imageContent =
Image(ImageObject(url = image)))))
UserMessageType.Array(imageContent = Image(ImageObject(url = image)))))
}

fun user(content: String?, image: String?, name: String? = null) {
messages.add(
CompletionMessage.User(
UserMessageType.Array(
Text(content),
Image(ImageObject(url = image))),
name))
UserMessageType.Array(Text(content), Image(ImageObject(url = image))), name))
}

fun assistant(
Expand Down Expand Up @@ -544,17 +540,10 @@ sealed class CompletionMessage(val role: String) {
fun text(content: String) = User(UserMessageType.Text(content))

fun image(image: String) =
User(
UserMessageType.Array(
imageContent =
Image(ImageObject(url = image))))
User(UserMessageType.Array(imageContent = Image(ImageObject(url = image))))

fun user(content: String?, image: String?, name: String? = null) =
User(
UserMessageType.Array(
Text(content),
Image(ImageObject(url = image))),
name)
User(UserMessageType.Array(Text(content), Image(ImageObject(url = image))), name)

fun assistant(
content: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@file:Suppress("unused")
@file:Suppress("unused", "NAME_SHADOWING")

package io.github.vyfor.groqkt.util

Expand All @@ -14,6 +14,9 @@ import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds

internal inline fun <reified T> T.applyDefaults(noinline defaults: (T.() -> Unit)?): T =
defaults?.let { defaults -> apply(defaults) } ?: this

internal suspend inline fun <reified T> HttpResponse.validate(): Result<T> =
if (status.isSuccess()) {
Result.success(body<T>())
Expand Down

0 comments on commit 196dfd4

Please sign in to comment.