Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation to API spec #6607

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4212392
Initial implementation to API spec
Dec 18, 2024
5effc29
add block_none
Jan 8, 2025
63a0d14
Make ImagenImage not subclass ImagenImageRepresentible
Jan 13, 2025
7319072
update to match API doc
Jan 16, 2025
0d43e2f
add tests and minor adjustments to support them
Jan 22, 2025
36a322a
fixes for comments (removing enums, fixing gradle.properties)
Jan 24, 2025
ad604f2
Add content blocked exception to match API spec
Jan 24, 2025
3532e18
minor fixes for comments
Jan 27, 2025
fb1da34
set image types to internal constructor only
Jan 27, 2025
a554fba
hide constructor for ImagenGenerationResponse
Jan 27, 2025
1156de9
minor rename to match API doc
Jan 27, 2025
0ee91ac
remove internal constructor on ImagenSafetySettings
Jan 28, 2025
d6a1688
rename imageModel and generateImages
Jan 28, 2025
4c786a3
add builder to ImagenGenerationConfig
Jan 28, 2025
9a4838c
delete ImagenImage
Jan 29, 2025
be2313d
Add documentation to imagen (#6616)
davidmotson Jan 30, 2025
e881958
Davidmotson.imagen java (#6618)
davidmotson Jan 30, 2025
d7296c3
rebase completed, one failing test to fix
Jan 30, 2025
4d44301
fixes for tests
Feb 3, 2025
fd5ceea
remove unintentionally included test
Feb 3, 2025
c100854
ktfmt
Feb 3, 2025
0f56452
Merge branch 'main' into davidmotson.imagen_support
davidmotson Feb 3, 2025
db0e4d7
Move ImagenPrompt and ImagenParameters into their request class
Feb 3, 2025
335379a
added release notes
Feb 3, 2025
81607a5
make a public preview annotation for Imagen (#6668)
davidmotson Feb 4, 2025
c02652d
fix serialization issue in ImagenParameters
Feb 6, 2025
db8550d
remove nullability from some internal types for safety
Feb 11, 2025
00612f7
hide gcs implementation
Feb 11, 2025
5732e30
remove gcs tests
Feb 11, 2025
3ac41de
Merge branch 'main' into davidmotson.imagen_support
Feb 11, 2025
29a317b
update api.txt for removing gcp
Feb 11, 2025
77a6d7d
format
Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firebase-vertexai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Unreleased
- [changed] Internal improvements to correctly handle empty model responses.
- [feature] Added support for generating images with Imagen models.


# 16.0.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.inject.Provider
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.GenerationConfig
import com.google.firebase.vertexai.type.ImagenGenerationConfig
import com.google.firebase.vertexai.type.ImagenSafetySettings
import com.google.firebase.vertexai.type.InvalidLocationException
import com.google.firebase.vertexai.type.RequestOptions
import com.google.firebase.vertexai.type.SafetySetting
Expand Down Expand Up @@ -79,6 +81,36 @@ internal constructor(
)
}

/**
* Instantiates a new [ImagenModel] given the provided parameters.
*
* @param modelName The name of the model to use, for example `"imagen-3.0-generate-001"`.
* @param generationConfig The configuration parameters to use for image generation.
* @param safetySettings The safety bounds the model will abide by during image generation.
* @param requestOptions Configuration options for sending requests to the backend.
* @return The initialized [ImagenModel] instance.
*/
@JvmOverloads
public fun imagenModel(
modelName: String,
generationConfig: ImagenGenerationConfig? = null,
safetySettings: ImagenSafetySettings? = null,
requestOptions: RequestOptions = RequestOptions(),
): ImagenModel {
if (location.trim().isEmpty() || location.contains("/")) {
throw InvalidLocationException(location)
}
return ImagenModel(
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
firebaseApp.options.apiKey,
generationConfig,
safetySettings,
requestOptions,
appCheckProvider.get(),
internalAuthProvider.get(),
)
}

public companion object {
/** The [FirebaseVertexAI] instance for the default [FirebaseApp] */
@JvmStatic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
package com.google.firebase.vertexai

import android.graphics.Bitmap
import android.util.Log
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.common.AppCheckHeaderProvider
import com.google.firebase.vertexai.common.CountTokensRequest
import com.google.firebase.vertexai.common.GenerateContentRequest
import com.google.firebase.vertexai.common.HeaderProvider
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.FinishReason
Expand All @@ -38,12 +37,9 @@ import com.google.firebase.vertexai.type.SerializationException
import com.google.firebase.vertexai.type.Tool
import com.google.firebase.vertexai.type.ToolConfig
import com.google.firebase.vertexai.type.content
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.tasks.await

/**
* Represents a multimodal model (like Gemini), capable of generating content based on various input
Expand All @@ -57,10 +53,8 @@ internal constructor(
private val tools: List<Tool>? = null,
private val toolConfig: ToolConfig? = null,
private val systemInstruction: Content? = null,
private val controller: APIController
private val controller: APIController,
) {

@JvmOverloads
internal constructor(
modelName: String,
apiKey: String,
Expand All @@ -84,42 +78,8 @@ internal constructor(
modelName,
requestOptions,
"gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}",
object : HeaderProvider {
override val timeout: Duration
get() = 10.seconds

override suspend fun generateHeaders(): Map<String, String> {
val headers = mutableMapOf<String, String>()
if (appCheckTokenProvider == null) {
Log.w(TAG, "AppCheck not registered, skipping")
} else {
val token = appCheckTokenProvider.getToken(false).await()

if (token.error != null) {
Log.w(TAG, "Error obtaining AppCheck token", token.error)
}
// The Firebase App Check backend can differentiate between apps without App Check, and
// wrongly configured apps by verifying the value of the token, so it always needs to be
// included.
headers["X-Firebase-AppCheck"] = token.token
}

if (internalAuthProvider == null) {
Log.w(TAG, "Auth not registered, skipping")
} else {
try {
val token = internalAuthProvider.getAccessToken(false).await()

headers["Authorization"] = "Firebase ${token.token!!}"
} catch (e: Exception) {
Log.w(TAG, "Error getting Auth token ", e)
}
}

return headers
}
}
)
AppCheckHeaderProvider(TAG, appCheckTokenProvider, internalAuthProvider),
),
)

/**
Expand Down Expand Up @@ -247,7 +207,7 @@ internal constructor(
generationConfig?.toInternal(),
tools?.map { it.toInternal() },
toolConfig?.toInternal(),
systemInstruction?.copy(role = "system")?.toInternal()
systemInstruction?.copy(role = "system")?.toInternal(),
)

private fun constructCountTokensRequest(vararg prompt: Content) =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai

import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.common.AppCheckHeaderProvider
import com.google.firebase.vertexai.common.ContentBlockedException
import com.google.firebase.vertexai.common.GenerateImageRequest
import com.google.firebase.vertexai.type.FirebaseVertexAIException
import com.google.firebase.vertexai.type.ImagenGCSImage
import com.google.firebase.vertexai.type.ImagenGenerationConfig
import com.google.firebase.vertexai.type.ImagenGenerationResponse
import com.google.firebase.vertexai.type.ImagenInlineImage
import com.google.firebase.vertexai.type.ImagenSafetySettings
import com.google.firebase.vertexai.type.RequestOptions

/**
* Represents a generative model (like Imagen), capable of generating images based on various input
* types.
*/
public class ImagenModel
internal constructor(
private val modelName: String,
private val generationConfig: ImagenGenerationConfig? = null,
private val safetySettings: ImagenSafetySettings? = null,
private val controller: APIController,
) {
@JvmOverloads
internal constructor(
modelName: String,
apiKey: String,
generationConfig: ImagenGenerationConfig? = null,
safetySettings: ImagenSafetySettings? = null,
requestOptions: RequestOptions = RequestOptions(),
appCheckTokenProvider: InteropAppCheckTokenProvider? = null,
internalAuthProvider: InternalAuthProvider? = null,
) : this(
modelName,
generationConfig,
safetySettings,
APIController(
apiKey,
modelName,
requestOptions,
"gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}",
AppCheckHeaderProvider(TAG, appCheckTokenProvider, internalAuthProvider),
),
)

/**
* Generates an image, returning the result directly to the caller.
*
* @param prompt The input(s) given to the model as a prompt.
*/
public suspend fun generateImages(prompt: String): ImagenGenerationResponse<ImagenInlineImage> =
try {
controller
.generateImage(constructRequest(prompt, null, generationConfig))
.validate()
.toPublicInline()
} catch (e: Throwable) {
throw FirebaseVertexAIException.from(e)
}

/**
* Generates an image, storing the result in Google Cloud Storage and returning a URL
*
* @param prompt The input(s) given to the model as a prompt.
* @param gcsUri Specifies where in Google Cloud Storage to store the image (for example, a
* specific bucket or folder).
*/
public suspend fun generateImages(
prompt: String,
gcsUri: String,
): ImagenGenerationResponse<ImagenGCSImage> =
try {
controller
.generateImage(constructRequest(prompt, gcsUri, generationConfig))
.validate()
.toPublicGCS()
} catch (e: Throwable) {
throw FirebaseVertexAIException.from(e)
}

private fun constructRequest(
prompt: String,
gcsUri: String?,
config: ImagenGenerationConfig?,
): GenerateImageRequest {
return GenerateImageRequest(
listOf(GenerateImageRequest.ImagenPrompt(prompt)),
GenerateImageRequest.ImagenParameters(
sampleCount = config?.numberOfImages ?: 1,
includeRaiReason = true,
addWatermark = generationConfig?.addWatermark,
personGeneration = safetySettings?.personFilterLevel?.internalVal,
negativePrompt = config?.negativePrompt,
safetySetting = safetySettings?.safetyFilterLevel?.internalVal,
storageUri = gcsUri,
aspectRatio = config?.aspectRatio?.internalVal,
imageOutputOptions = generationConfig?.imageFormat?.toInternal(),
),
)
}

internal companion object {
private val TAG = ImagenModel::class.java.simpleName
internal const val DEFAULT_FILTERED_ERROR =
"Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback."
}
}

private fun ImagenGenerationResponse.Internal.validate(): ImagenGenerationResponse.Internal {
if (predictions.none { it.mimeType != null }) {
throw ContentBlockedException(
message = predictions.first { it.raiFilteredReason != null }.raiFilteredReason
?: ImagenModel.DEFAULT_FILTERED_ERROR
)
}
return this
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.FinishReason
import com.google.firebase.vertexai.type.GRpcErrorResponse
import com.google.firebase.vertexai.type.GenerateContentResponse
import com.google.firebase.vertexai.type.ImagenGenerationResponse
import com.google.firebase.vertexai.type.RequestOptions
import com.google.firebase.vertexai.type.Response
import io.ktor.client.HttpClient
Expand Down Expand Up @@ -58,12 +59,15 @@ import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json

@OptIn(ExperimentalSerializationApi::class)
internal val JSON = Json {
ignoreUnknownKeys = true
prettyPrint = false
isLenient = true
explicitNulls = false
}

/**
Expand Down Expand Up @@ -122,6 +126,19 @@ internal constructor(
throw FirebaseCommonAIException.from(e)
}

suspend fun generateImage(request: GenerateImageRequest): ImagenGenerationResponse.Internal =
try {
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:predict") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body<ImagenGenerationResponse.Internal>()
} catch (e: Throwable) {
throw FirebaseCommonAIException.from(e)
}

fun generateContentStream(
request: GenerateContentRequest
): Flow<GenerateContentResponse.Internal> =
Expand Down Expand Up @@ -151,6 +168,7 @@ internal constructor(
when (request) {
is GenerateContentRequest -> setBody<GenerateContentRequest>(request)
is CountTokensRequest -> setBody<CountTokensRequest>(request)
is GenerateImageRequest -> setBody<GenerateImageRequest>(request)
}
contentType(ContentType.Application.Json)
header("x-goog-api-key", key)
Expand Down Expand Up @@ -258,6 +276,9 @@ private suspend fun validateResponse(response: HttpResponse) {
if (message.contains("quota")) {
throw QuotaExceededException(message)
}
if (message.contains("The prompt could not be submitted")) {
throw PromptBlockedException(message)
}
getServiceDisabledErrorDetailsOrNull(error)?.let {
val errorMessage =
if (it.metadata?.get("service") == "firebasevertexai.googleapis.com") {
Expand Down
Loading
Loading