Skip to content

Commit

Permalink
feat: LVM - Added support for Videos from GCS uri for multimodal embe…
Browse files Browse the repository at this point in the history
…ddings

PiperOrigin-RevId: 606757856
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Feb 13, 2024
1 parent e35ab64 commit f3bd3bf
Show file tree
Hide file tree
Showing 5 changed files with 411 additions and 9 deletions.
11 changes: 11 additions & 0 deletions tests/system/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def _load_image_from_gcs(
return vision_models.Image.load_from_file(gcs_uri)


def _load_video_from_gcs(
gcs_uri: str = "gs://cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4",
) -> vision_models.Video:
return vision_models.Video.load_from_file(gcs_uri)


class VisionModelTestSuite(e2e_base.TestEndToEnd):
"""System tests for vision models."""

Expand Down Expand Up @@ -98,13 +104,18 @@ def test_multi_modal_embedding_model_with_gcs_uri(self):
"multimodalembedding@001"
)
image = _load_image_from_gcs()
video = _load_video_from_gcs()
video_segment_config = vision_models.VideoSegmentConfig()
embeddings = model.get_embeddings(
image=image,
video=video,
# Optional:
contextual_text="this is a car",
video_segment_config=video_segment_config,
)
# The service is expected to return the embeddings of size 1408
assert len(embeddings.image_embedding) == 1408
assert len(embeddings.video_embeddings[0].embedding) == 1408
assert len(embeddings.text_embedding) == 1408

def test_image_generation_model_generate_images(self):
Expand Down
212 changes: 212 additions & 0 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ def generate_image_from_gcs_uri(
return ga_vision_models.Image.load_from_file(gcs_uri)


def generate_video_from_gcs_uri(
gcs_uri: str = "gs://cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4",
) -> ga_vision_models.Video:
return ga_vision_models.Video.load_from_file(gcs_uri)


@pytest.mark.usefixtures("google_auth_mock")
class TestImageGenerationModels:
"""Unit tests for the image generation models."""
Expand Down Expand Up @@ -888,6 +894,212 @@ def test_image_embedding_model_with_gcs_uri(self):
assert embedding_response.image_embedding == test_embeddings
assert embedding_response.text_embedding == test_embeddings

def test_video_embedding_model_with_only_video(self):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/multimodalembedding@001",
retry=base._DEFAULT_RETRY,
)

test_video_embeddings = [
ga_vision_models.VideoEmbedding(
start_offset_sec=0,
end_offset_sec=7,
embedding=[0, 7],
)
]

gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(
{
"videoEmbeddings": [
{
"startOffsetSec": test_video_embeddings[0].start_offset_sec,
"endOffsetSec": test_video_embeddings[0].end_offset_sec,
"embedding": test_video_embeddings[0].embedding,
}
]
}
)

video = generate_video_from_gcs_uri()

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embedding_response = model.get_embeddings(video=video)

assert (
embedding_response.video_embeddings[0].embedding
== test_video_embeddings[0].embedding
)
assert (
embedding_response.video_embeddings[0].start_offset_sec
== test_video_embeddings[0].start_offset_sec
)
assert (
embedding_response.video_embeddings[0].end_offset_sec
== test_video_embeddings[0].end_offset_sec
)
assert not embedding_response.text_embedding
assert not embedding_response.image_embedding

def test_video_embedding_model_with_video_and_text(self):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/multimodalembedding@001",
retry=base._DEFAULT_RETRY,
)

test_text_embedding = [0, 0]
test_video_embeddings = [
ga_vision_models.VideoEmbedding(
start_offset_sec=0,
end_offset_sec=7,
embedding=test_text_embedding,
)
]
gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(
{
"textEmbedding": test_text_embedding,
"videoEmbeddings": [
{
"startOffsetSec": test_video_embeddings[0].start_offset_sec,
"endOffsetSec": test_video_embeddings[0].end_offset_sec,
"embedding": test_video_embeddings[0].embedding,
}
],
}
)

video = generate_video_from_gcs_uri()

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embedding_response = model.get_embeddings(
video=video, contextual_text="hello world"
)

assert (
embedding_response.video_embeddings[0].embedding
== test_video_embeddings[0].embedding
)
assert (
embedding_response.video_embeddings[0].start_offset_sec
== test_video_embeddings[0].start_offset_sec
)
assert (
embedding_response.video_embeddings[0].end_offset_sec
== test_video_embeddings[0].end_offset_sec
)
assert embedding_response.text_embedding == test_text_embedding
assert not embedding_response.image_embedding

def test_multimodal_embedding_model_with_image_video_and_text(self):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/multimodalembedding@001",
retry=base._DEFAULT_RETRY,
)

test_embedding = [0, 0]
test_video_embeddings = [
ga_vision_models.VideoEmbedding(
start_offset_sec=0,
end_offset_sec=7,
embedding=test_embedding,
)
]
gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(
{
"textEmbedding": test_embedding,
"imageEmbedding": test_embedding,
"videoEmbeddings": [
{
"startOffsetSec": test_video_embeddings[0].start_offset_sec,
"endOffsetSec": test_video_embeddings[0].end_offset_sec,
"embedding": test_video_embeddings[0].embedding,
}
],
}
)

image = generate_image_from_file()
video = generate_video_from_gcs_uri()

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
embedding_response = model.get_embeddings(
video=video, image=image, contextual_text="hello world"
)

assert (
embedding_response.video_embeddings[0].embedding
== test_video_embeddings[0].embedding
)
assert (
embedding_response.video_embeddings[0].start_offset_sec
== test_video_embeddings[0].start_offset_sec
)
assert (
embedding_response.video_embeddings[0].end_offset_sec
== test_video_embeddings[0].end_offset_sec
)
assert embedding_response.text_embedding == test_embedding
assert embedding_response.image_embedding == test_embedding


@pytest.mark.usefixtures("google_auth_mock")
class ImageTextModelTests:
Expand Down
6 changes: 6 additions & 0 deletions vertexai/preview/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
GeneratedImage,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
Video,
VideoEmbedding,
VideoSegmentConfig,
)

__all__ = [
Expand All @@ -36,4 +39,7 @@
"GeneratedImage",
"MultiModalEmbeddingModel",
"MultiModalEmbeddingResponse",
"Video",
"VideoEmbedding",
"VideoSegmentConfig",
]
6 changes: 6 additions & 0 deletions vertexai/vision_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
ImageTextModel,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
Video,
VideoEmbedding,
VideoSegmentConfig,
)

__all__ = [
Expand All @@ -30,4 +33,7 @@
"ImageTextModel",
"MultiModalEmbeddingModel",
"MultiModalEmbeddingResponse",
"Video",
"VideoEmbedding",
"VideoSegmentConfig",
]
Loading

0 comments on commit f3bd3bf

Please sign in to comment.