Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Support for DynamodbOnlineStoreConfig endpoint_url parameter #2485

Merged
80 changes: 57 additions & 23 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

from pydantic import StrictStr
from pydantic.typing import Literal
from pydantic.typing import Literal, Union

from feast import Entity, FeatureView, utils
from feast.infra.infra_object import DYNAMODB_INFRA_OBJECT_CLASS_TYPE, InfraObject
Expand Down Expand Up @@ -50,17 +50,20 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel):
type: Literal["dynamodb"] = "dynamodb"
"""Online store type selector"""

batch_size: int = 40
"""Number of items to retrieve in a DynamoDB BatchGetItem call."""

endpoint_url: Union[str, None] = None
"""DynamoDB local development endpoint Url, i.e. http://localhost:8000"""

region: StrictStr
"""AWS Region Name"""

table_name_template: StrictStr = "{project}.{table_name}"
"""DynamoDB table name template"""

sort_response: bool = True
"""Whether or not to sort BatchGetItem response."""

batch_size: int = 40
"""Number of items to retrieve in a DynamoDB BatchGetItem call."""
table_name_template: StrictStr = "{project}.{table_name}"
"""DynamoDB table name template"""


class DynamoDBOnlineStore(OnlineStore):
Expand Down Expand Up @@ -95,8 +98,12 @@ def update(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
dynamodb_client = self._get_dynamodb_client(online_config.region)
dynamodb_resource = self._get_dynamodb_resource(online_config.region)
dynamodb_client = self._get_dynamodb_client(
online_config.region, online_config.endpoint_url
)
dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)

for table_instance in tables_to_keep:
try:
Expand Down Expand Up @@ -141,7 +148,9 @@ def teardown(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
dynamodb_resource = self._get_dynamodb_resource(online_config.region)
dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)

for table in tables:
_delete_table_idempotent(
Expand Down Expand Up @@ -175,7 +184,9 @@ def online_write_batch(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
dynamodb_resource = self._get_dynamodb_resource(online_config.region)
dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)

table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
Expand Down Expand Up @@ -217,7 +228,9 @@ def online_read(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
dynamodb_resource = self._get_dynamodb_resource(online_config.region)
dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)
Expand Down Expand Up @@ -260,14 +273,16 @@ def online_read(
result.extend(batch_size_nones)
return result

def _get_dynamodb_client(self, region: str):
def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None):
if self._dynamodb_client is None:
self._dynamodb_client = _initialize_dynamodb_client(region)
self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url)
return self._dynamodb_client

def _get_dynamodb_resource(self, region: str):
def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None):
if self._dynamodb_resource is None:
self._dynamodb_resource = _initialize_dynamodb_resource(region)
self._dynamodb_resource = _initialize_dynamodb_resource(
region, endpoint_url
)
return self._dynamodb_resource

def _sort_dynamodb_response(self, responses: list, order: list):
Expand All @@ -285,12 +300,12 @@ def _sort_dynamodb_response(self, responses: list, order: list):
return table_responses_ordered


def _initialize_dynamodb_client(region: str):
return boto3.client("dynamodb", region_name=region)
def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
return boto3.client("dynamodb", region_name=region, endpoint_url=endpoint_url)


def _initialize_dynamodb_resource(region: str):
return boto3.resource("dynamodb", region_name=region)
def _initialize_dynamodb_resource(region: str, endpoint_url: Optional[str] = None):
return boto3.resource("dynamodb", region_name=region, endpoint_url=endpoint_url)


# TODO(achals): This form of user-facing templating is experimental.
Expand Down Expand Up @@ -327,13 +342,20 @@ class DynamoDBTable(InfraObject):
Attributes:
name: The name of the table.
region: The region of the table.
endpoint_url: Local DynamoDB Endpoint Url.
_dynamodb_client: Boto3 DynamoDB client.
_dynamodb_resource: Boto3 DynamoDB resource.
"""

region: str
endpoint_url = None
_dynamodb_client = None
_dynamodb_resource = None

def __init__(self, name: str, region: str):
def __init__(self, name: str, region: str, endpoint_url: Optional[str] = None):
super().__init__(name)
self.region = region
self.endpoint_url = endpoint_url

def to_infra_object_proto(self) -> InfraObjectProto:
dynamodb_table_proto = self.to_proto()
Expand Down Expand Up @@ -362,8 +384,8 @@ def from_proto(dynamodb_table_proto: DynamoDBTableProto) -> Any:
)

def update(self):
dynamodb_client = _initialize_dynamodb_client(region=self.region)
dynamodb_resource = _initialize_dynamodb_resource(region=self.region)
dynamodb_client = self._get_dynamodb_client(self.region, self.endpoint_url)
dynamodb_resource = self._get_dynamodb_resource(self.region, self.endpoint_url)

try:
dynamodb_resource.create_table(
Expand All @@ -384,5 +406,17 @@ def update(self):
dynamodb_client.get_waiter("table_exists").wait(TableName=f"{self.name}")

def teardown(self):
dynamodb_resource = _initialize_dynamodb_resource(region=self.region)
dynamodb_resource = self._get_dynamodb_resource(self.region, self.endpoint_url)
_delete_table_idempotent(dynamodb_resource, self.name)

def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None):
if self._dynamodb_client is None:
self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url)
return self._dynamodb_client

def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None):
if self._dynamodb_resource is None:
self._dynamodb_resource = _initialize_dynamodb_resource(
region, endpoint_url
)
return self._dynamodb_resource
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from feast.infra.online_stores.dynamodb import (
DynamoDBOnlineStore,
DynamoDBOnlineStoreConfig,
DynamoDBTable,
)
from feast.repo_config import RepoConfig
from tests.utils.online_store_utils import (
Expand Down Expand Up @@ -38,6 +39,121 @@ def repo_config():
)


def test_online_store_config_default():
"""Test DynamoDBOnlineStoreConfig default parameters."""
aws_region = "us-west-2"
dynamodb_store_config = DynamoDBOnlineStoreConfig(region=aws_region)
assert dynamodb_store_config.type == "dynamodb"
assert dynamodb_store_config.batch_size == 40
assert dynamodb_store_config.endpoint_url is None
assert dynamodb_store_config.region == aws_region
assert dynamodb_store_config.sort_response is True
assert dynamodb_store_config.table_name_template == "{project}.{table_name}"


def test_dynamodb_table_default_params():
"""Test DynamoDBTable default parameters."""
tbl_name = "dynamodb-test"
aws_region = "us-west-2"
dynamodb_table = DynamoDBTable(tbl_name, aws_region)
assert dynamodb_table.name == tbl_name
assert dynamodb_table.region == aws_region
assert dynamodb_table.endpoint_url is None
assert dynamodb_table._dynamodb_client is None
assert dynamodb_table._dynamodb_resource is None


def test_online_store_config_custom_params():
"""Test DynamoDBOnlineStoreConfig custom parameters."""
aws_region = "us-west-2"
batch_size = 20
endpoint_url = "http://localhost:8000"
sort_response = False
table_name_template = "feast_test.dynamodb_table"
dynamodb_store_config = DynamoDBOnlineStoreConfig(
region=aws_region,
batch_size=batch_size,
endpoint_url=endpoint_url,
sort_response=sort_response,
table_name_template=table_name_template,
)
assert dynamodb_store_config.type == "dynamodb"
assert dynamodb_store_config.batch_size == batch_size
assert dynamodb_store_config.endpoint_url == endpoint_url
assert dynamodb_store_config.region == aws_region
assert dynamodb_store_config.sort_response == sort_response
assert dynamodb_store_config.table_name_template == table_name_template


def test_dynamodb_table_custom_params():
"""Test DynamoDBTable custom parameters."""
tbl_name = "dynamodb-test"
aws_region = "us-west-2"
endpoint_url = "http://localhost:8000"
dynamodb_table = DynamoDBTable(tbl_name, aws_region, endpoint_url)
assert dynamodb_table.name == tbl_name
assert dynamodb_table.region == aws_region
assert dynamodb_table.endpoint_url == endpoint_url
assert dynamodb_table._dynamodb_client is None
assert dynamodb_table._dynamodb_resource is None


def test_online_store_config_dynamodb_client():
"""Test DynamoDBOnlineStoreConfig configure DynamoDB client with endpoint_url."""
aws_region = "us-west-2"
endpoint_url = "http://localhost:8000"
dynamodb_store = DynamoDBOnlineStore()
dynamodb_store_config = DynamoDBOnlineStoreConfig(
region=aws_region, endpoint_url=endpoint_url
)
dynamodb_client = dynamodb_store._get_dynamodb_client(
dynamodb_store_config.region, dynamodb_store_config.endpoint_url
)
assert dynamodb_client.meta.region_name == aws_region
assert dynamodb_client.meta.endpoint_url == endpoint_url


def test_dynamodb_table_dynamodb_client():
"""Test DynamoDBTable configure DynamoDB client with endpoint_url."""
tbl_name = "dynamodb-test"
aws_region = "us-west-2"
endpoint_url = "http://localhost:8000"
dynamodb_table = DynamoDBTable(tbl_name, aws_region, endpoint_url)
dynamodb_client = dynamodb_table._get_dynamodb_client(
dynamodb_table.region, dynamodb_table.endpoint_url
)
assert dynamodb_client.meta.region_name == aws_region
assert dynamodb_client.meta.endpoint_url == endpoint_url


def test_online_store_config_dynamodb_resource():
"""Test DynamoDBOnlineStoreConfig configure DynamoDB Resource with endpoint_url."""
aws_region = "us-west-2"
endpoint_url = "http://localhost:8000"
dynamodb_store = DynamoDBOnlineStore()
dynamodb_store_config = DynamoDBOnlineStoreConfig(
region=aws_region, endpoint_url=endpoint_url
)
dynamodb_resource = dynamodb_store._get_dynamodb_resource(
dynamodb_store_config.region, dynamodb_store_config.endpoint_url
)
assert dynamodb_resource.meta.client.meta.region_name == aws_region
assert dynamodb_resource.meta.client.meta.endpoint_url == endpoint_url


def test_dynamodb_table_dynamodb_resource():
"""Test DynamoDBTable configure DynamoDB resource with endpoint_url."""
tbl_name = "dynamodb-test"
aws_region = "us-west-2"
endpoint_url = "http://localhost:8000"
dynamodb_table = DynamoDBTable(tbl_name, aws_region, endpoint_url)
dynamodb_resource = dynamodb_table._get_dynamodb_resource(
dynamodb_table.region, dynamodb_table.endpoint_url
)
assert dynamodb_resource.meta.client.meta.region_name == aws_region
assert dynamodb_resource.meta.client.meta.endpoint_url == endpoint_url


@mock_dynamodb2
@pytest.mark.parametrize("n_samples", [5, 50, 100])
def test_online_read(repo_config, n_samples):
Expand Down