Skip to content

Commit

Permalink
rebase completed, one failing test to fix
Browse files Browse the repository at this point in the history
  • Loading branch information
David Motsonashvili committed Jan 30, 2025
1 parent e881958 commit d7296c3
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.common.CountTokensRequest
import com.google.firebase.vertexai.common.GenerateContentRequest
import com.google.firebase.vertexai.common.HeaderProvider
import com.google.firebase.vertexai.internal.util.AppCheckHeaderProvider
import com.google.firebase.vertexai.internal.util.toInternal
import com.google.firebase.vertexai.internal.util.toPublic
import com.google.firebase.vertexai.common.AppCheckHeaderProvider
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.FinishReason
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@ import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.common.ContentBlockedException
import com.google.firebase.vertexai.internal.GenerateImageRequest
import com.google.firebase.vertexai.internal.GenerateImageResponse
import com.google.firebase.vertexai.internal.ImagenParameters
import com.google.firebase.vertexai.internal.ImagenPromptInstance
import com.google.firebase.vertexai.internal.util.AppCheckHeaderProvider
import com.google.firebase.vertexai.internal.util.toInternal
import com.google.firebase.vertexai.internal.util.toPublicGCS
import com.google.firebase.vertexai.internal.util.toPublicInline
import com.google.firebase.vertexai.common.AppCheckHeaderProvider
import com.google.firebase.vertexai.common.GenerateImageRequest
import com.google.firebase.vertexai.common.ImagenParameters
import com.google.firebase.vertexai.common.ImagenPromptInstance
import com.google.firebase.vertexai.type.FirebaseVertexAIException
import com.google.firebase.vertexai.type.ImagenGCSImage
import com.google.firebase.vertexai.type.ImagenGenerationConfig
Expand Down Expand Up @@ -133,7 +129,7 @@ internal constructor(
}
}

private fun GenerateImageResponse.validate(): GenerateImageResponse {
private fun ImagenGenerationResponse.Internal.validate(): ImagenGenerationResponse.Internal {
if (predictions.none { it.mimeType != null }) {
throw ContentBlockedException(
message = predictions.first { it.raiFilteredReason != null }.raiFilteredReason
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.FinishReason
import com.google.firebase.vertexai.type.GRpcErrorResponse
import com.google.firebase.vertexai.type.GenerateContentResponse
import com.google.firebase.vertexai.internal.GenerateImageRequest
import com.google.firebase.vertexai.internal.GenerateImageResponse
import com.google.firebase.vertexai.type.ImagenGenerationResponse
import com.google.firebase.vertexai.type.RequestOptions
import com.google.firebase.vertexai.type.Response
import io.ktor.client.HttpClient
Expand Down Expand Up @@ -59,6 +58,7 @@ import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.tasks.await
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.Json

Expand Down Expand Up @@ -124,15 +124,15 @@ internal constructor(
throw FirebaseCommonAIException.from(e)
}

suspend fun generateImage(request: GenerateImageRequest): GenerateImageResponse =
suspend fun generateImage(request: GenerateImageRequest): ImagenGenerationResponse.Internal =
try {
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:predict") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body<GenerateImageResponse>()
.body<ImagenGenerationResponse.Internal>()
} catch (e: Throwable) {
throw FirebaseCommonAIException.from(e)
}
Expand Down Expand Up @@ -315,4 +315,4 @@ private fun GenerateContentResponse.Internal.validate() = apply {
?.mapNotNull { it.finishReason }
?.firstOrNull { it != FinishReason.Internal.STOP }
?.let { throw ResponseStoppedException(this) }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai.common

import android.util.Log
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import kotlinx.coroutines.tasks.await
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

internal class AppCheckHeaderProvider(
private val logTag: String,
private val appCheckTokenProvider: InteropAppCheckTokenProvider? = null,
private val internalAuthProvider: InternalAuthProvider? = null,
) : HeaderProvider {
override val timeout: Duration
get() = 10.seconds

override suspend fun generateHeaders(): Map<String, String> {
val headers = mutableMapOf<String, String>()
if (appCheckTokenProvider == null) {
Log.w(logTag, "AppCheck not registered, skipping")
} else {
val token = appCheckTokenProvider.getToken(false).await()

if (token.error != null) {
Log.w(logTag, "Error obtaining AppCheck token", token.error)
}
// The Firebase App Check backend can differentiate between apps without App Check, and
// wrongly configured apps by verifying the value of the token, so it always needs to be
// included.
headers["X-Firebase-AppCheck"] = token.token
}

if (internalAuthProvider == null) {
Log.w(logTag, "Auth not registered, skipping")
} else {
try {
val token = internalAuthProvider.getAccessToken(false).await()

headers["Authorization"] = "Firebase ${token.token!!}"
} catch (e: Exception) {
Log.w(logTag, "Error getting Auth token ", e)
}
}

return headers
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.firebase.vertexai.common
import com.google.firebase.vertexai.common.util.fullModelName
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.GenerationConfig
import com.google.firebase.vertexai.type.ImagenImageFormat
import com.google.firebase.vertexai.type.SafetySetting
import com.google.firebase.vertexai.type.Tool
import com.google.firebase.vertexai.type.ToolConfig
Expand Down Expand Up @@ -65,3 +66,24 @@ internal data class CountTokensRequest(
)
}
}

@Serializable
internal data class GenerateImageRequest(
val instances: List<ImagenPromptInstance>,
val parameters: ImagenParameters,
) : Request {}

@Serializable internal data class ImagenPromptInstance(val prompt: String)

@Serializable
internal data class ImagenParameters(
val sampleCount: Int = 1,
val includeRaiReason: Boolean = true,
val storageUri: String?,
val negativePrompt: String?,
val aspectRatio: String?,
val safetySetting: String?,
val personGeneration: String?,
val addWatermark: Boolean?,
val imageOutputOptions: ImagenImageFormat.Internal?,
)

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public class ImagenGenerationConfig(
/**
* Builder for creating a [ImagenGenerationConfig].
*
* This is mainly intended for Java interop. For Kotlin, use [imagenGenerationConfig] for a
* more idiomatic experience.
* This is mainly intended for Java interop. For Kotlin, use [imagenGenerationConfig] for a more
* idiomatic experience.
*
* @property negativePrompt See [ImagenGenerationConfig.negativePrompt].
* @property numberOfImages See [ImagenGenerationConfig.numberOfImages].
Expand Down
Loading

0 comments on commit d7296c3

Please sign in to comment.