Skip to content

Commit

Permalink
feat: LVM - Added support for Image Generation models
Browse files Browse the repository at this point in the history
Features:

* Generate images from text prompt (prompt, negative prompt, width, height, seed, guidance scale)
* Edit an existing images using a text prompt (and optional mask)
* Upscale image
* Show image (works in notebook environments)
* Save image (saved image also includes the image generation parameters)

Example usage:
```python
model = ImageGenerationModel.from_pretrained("imagegeneration@002")
images = model.generate_images(
    prompt="Astronaut riding a horse",
    # Optional:
    number_of_images=1,
    width=1024,
    height=768,
    seed=1,
    guidance_scale=15,
)
images[0].show()
images[0].save("image1.png")
```
PiperOrigin-RevId: 557736987
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 17, 2023
1 parent caee592 commit b3729c1
Show file tree
Hide file tree
Showing 4 changed files with 793 additions and 1 deletion.
74 changes: 74 additions & 0 deletions tests/system/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,77 @@ def test_multi_modal_embedding_model(self):
# The service is expected to return the embeddings of size 1408
assert len(embeddings.image_embedding) == 1408
assert len(embeddings.text_embedding) == 1408

def test_image_generation_model_generate_images(self):
"""Tests the image generation model generating images."""
model = vision_models.ImageGenerationModel.from_pretrained(
"imagegeneration@001"
)

width = 1024
height = 768
number_of_images = 4
seed = 1
guidance_scale = 15

prompt1 = "Astronaut riding a horse"
negative_prompt1 = "bad quality"
image_response = model.generate_images(
prompt=prompt1,
# Optional:
negative_prompt=negative_prompt1,
number_of_images=number_of_images,
width=width,
height=height,
seed=seed,
guidance_scale=guidance_scale,
)

assert len(image_response.images) == number_of_images
for idx, image in enumerate(image_response):
assert image._pil_image.size == (width, height)
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt1
assert image.generation_parameters["negative_prompt"] == negative_prompt1
assert image.generation_parameters["width"] == width
assert image.generation_parameters["height"] == height
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx

# Test saving and loading images
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir, "image.png")
image_response[0].save(location=image_path)
image1 = vision_models.GeneratedImage.load_from_file(image_path)
assert image1._pil_image.size == (width, height)
assert image1.generation_parameters
assert image1.generation_parameters["prompt"] == prompt1

# Preparing mask
mask_path = os.path.join(temp_dir, "mask.png")
mask_pil_image = PIL_Image.new(mode="RGB", size=(width, height))
mask_pil_image.save(mask_path, format="PNG")
mask_image = vision_models.Image.load_from_file(mask_path)

# Test generating image from base image
prompt2 = "Ancient book style"
image_response2 = model.edit_image(
prompt=prompt2,
# Optional:
number_of_images=number_of_images,
seed=seed,
guidance_scale=guidance_scale,
base_image=image1,
mask=mask_image,
)
assert len(image_response2.images) == number_of_images
for idx, image in enumerate(image_response2):
assert image._pil_image.size == (width, height)
assert image.generation_parameters
assert image.generation_parameters["prompt"] == prompt2
assert image.generation_parameters["seed"] == seed
assert image.generation_parameters["guidance_scale"] == guidance_scale
assert image.generation_parameters["index_of_image_in_batch"] == idx
assert "base_image_hash" in image.generation_parameters
assert "mask_hash" in image.generation_parameters
Loading

0 comments on commit b3729c1

Please sign in to comment.