Skip to content

Commit

Permalink
rename imageModel and generateImages
Browse files Browse the repository at this point in the history
  • Loading branch information
David Motsonashvili committed Jan 28, 2025
1 parent ff56aa0 commit 58d0894
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ internal constructor(
}

@JvmOverloads
public fun imageModel(
public fun imagenModel(
modelName: String,
generationConfig: ImagenGenerationConfig? = null,
safetySettings: ImagenSafetySettings? = null,
requestOptions: RequestOptions = RequestOptions(),
): ImageModel {
): ImagenModel {
if (location.trim().isEmpty() || location.contains("/")) {
throw InvalidLocationException(location)
}
return ImageModel(
return ImagenModel(
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
firebaseApp.options.apiKey,
generationConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import com.google.firebase.vertexai.type.ImagenInlineImage
import com.google.firebase.vertexai.type.ImagenSafetySettings
import com.google.firebase.vertexai.type.RequestOptions

public class ImageModel
public class ImagenModel
internal constructor(
private val modelName: String,
private val generationConfig: ImagenGenerationConfig? = null,
Expand Down Expand Up @@ -48,7 +48,7 @@ internal constructor(
),
)

public suspend fun generateImage(prompt: String): ImagenGenerationResponse<ImagenInlineImage> =
public suspend fun generateImages(prompt: String): ImagenGenerationResponse<ImagenInlineImage> =
try {
controller
.generateImage(constructRequest(prompt, null, generationConfig))
Expand All @@ -58,7 +58,7 @@ internal constructor(
throw FirebaseVertexAIException.from(e)
}

public suspend fun generateImage(
public suspend fun generateImages(
prompt: String,
gcsUri: String,
): ImagenGenerationResponse<ImagenGCSImage> =
Expand Down Expand Up @@ -93,7 +93,7 @@ internal constructor(
}

internal companion object {
private val TAG = ImageModel::class.java.simpleName
private val TAG = ImagenModel::class.java.simpleName
internal const val DEFAULT_FILTERED_ERROR =
"Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback."
}
Expand All @@ -103,7 +103,7 @@ private fun GenerateImageResponse.validate(): GenerateImageResponse {
if (predictions.none { it.mimeType != null }) {
throw ContentBlockedException(
message = predictions.first { it.raiFilteredReason != null }.raiFilteredReason
?: ImageModel.DEFAULT_FILTERED_ERROR
?: ImagenModel.DEFAULT_FILTERED_ERROR
)
}
return this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ internal constructor(message: String, cause: Throwable? = null) :
*
* @property response The full server response.
*/
// TODO(rlazo): Add secondary constructor to pass through the message?
public class PromptBlockedException
internal constructor(
public val response: GenerateContentResponse?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,15 @@ internal class UnarySnapshotTests {
fun `generateImages should throw when all images filtered`() =
goldenUnaryFile("unary-failure-generate-images-all-filtered.json") {
withTimeout(testTimeout) {
shouldThrow<ContentBlockedException> { imageModel.generateImage("prompt") }
shouldThrow<ContentBlockedException> { imagenModel.generateImages("prompt") }
}
}

@Test
fun `generateImages should return when some images are filtered -- gcs`() =
goldenUnaryFile("unary-failure-generate-images-gcs-some-filtered.json") {
withTimeout(testTimeout) {
imageModel.generateImage("prompt", "gcsBucket").images.isEmpty() shouldBe false
imagenModel.generateImages("prompt", "gcsBucket").images.isEmpty() shouldBe false
}
}

Expand All @@ -498,15 +498,15 @@ internal class UnarySnapshotTests {
HttpStatusCode.BadRequest,
) {
withTimeout(testTimeout) {
shouldThrow<PromptBlockedException> { imageModel.generateImage("prompt") }
shouldThrow<PromptBlockedException> { imagenModel.generateImages("prompt") }
}
}

@Test
fun `generateImages gcs should succeed`() =
goldenUnaryFile("unary-success-generate-images-gcs.json") {
withTimeout(testTimeout) {
imageModel.generateImage("prompt", "gcsBucket").images.isEmpty() shouldBe false
imagenModel.generateImages("prompt", "gcsBucket").images.isEmpty() shouldBe false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.google.firebase.vertexai.util

import com.google.firebase.vertexai.GenerativeModel
import com.google.firebase.vertexai.ImageModel
import com.google.firebase.vertexai.ImagenModel
import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.type.RequestOptions
import io.kotest.matchers.collections.shouldNotBeEmpty
Expand Down Expand Up @@ -61,7 +61,7 @@ internal suspend fun ByteChannel.send(bytes: ByteArray) {
internal data class CommonTestScope(
val channel: ByteChannel,
val model: GenerativeModel,
val imageModel: ImageModel,
val imagenModel: ImagenModel,
)

/** A test that runs under a [CommonTestScope]. */
Expand Down Expand Up @@ -109,8 +109,8 @@ internal fun commonTest(
null,
)
val model = GenerativeModel("cool-model-name", controller = apiController)
val imageModel = ImageModel("cooler-model-name", controller = apiController)
CommonTestScope(channel, model, imageModel).block()
val imagenModel = ImagenModel("cooler-model-name", controller = apiController)
CommonTestScope(channel, model, imagenModel).block()
}

/**
Expand Down

0 comments on commit 58d0894

Please sign in to comment.