Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add interfaces for batch materialization engine #2901

Merged
merged 11 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions sdk/python/feast/infra/materialization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .batch_materialization_engine import (
BatchMaterializationEngine,
MaterializationJob,
MaterializationTask,
)
from .local_engine import LocalMaterializationEngine

__all__ = [
"MaterializationJob",
"MaterializationTask",
"BatchMaterializationEngine",
"LocalMaterializationEngine",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import dataclasses
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable, List, Optional, Union

from tqdm import tqdm

from feast.batch_feature_view import BatchFeatureView
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.online_stores.online_store import OnlineStore
from feast.repo_config import RepoConfig
from feast.stream_feature_view import StreamFeatureView


@dataclasses.dataclass
class MaterializationTask:
project: str
feature_view: Union[BatchFeatureView, StreamFeatureView]
start_time: datetime
end_time: datetime
tqdm_builder: Callable[[int], tqdm]


class MaterializationJob(ABC):
task: MaterializationTask

@abstractmethod
def status(self) -> str:
...

@abstractmethod
def should_be_retried(self) -> bool:
...

@abstractmethod
def job_id(self) -> str:
...

@abstractmethod
def url(self) -> Optional[str]:
...


class BatchMaterializationEngine(ABC):
def __init__(
self,
*,
repo_config: RepoConfig,
offline_store: OfflineStore,
online_store: OnlineStore,
**kwargs,
):
self.repo_config = repo_config
self.offline_store = offline_store
self.online_store = online_store

@abstractmethod
def materialize(
self, registry, tasks: List[MaterializationTask]
) -> List[MaterializationJob]:
...
302 changes: 302 additions & 0 deletions sdk/python/feast/infra/materialization/local_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
from datetime import datetime
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import dask.dataframe as dd
import pandas as pd
import pyarrow as pa
from tqdm import tqdm

from feast import (
BatchFeatureView,
Entity,
FeatureView,
RepoConfig,
StreamFeatureView,
ValueType,
)
from feast.feature_view import DUMMY_ENTITY_ID
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel
from feast.type_map import python_values_to_proto_values

from .batch_materialization_engine import (
BatchMaterializationEngine,
MaterializationJob,
MaterializationTask,
)

DEFAULT_BATCH_SIZE = 10_000


class LocalMaterializationEngineConfig(FeastConfigBaseModel):
"""Batch Materialization Engine config for local in-process engine"""

type: Literal["local"] = "local"
""" Type selector"""


class LocalMaterializationJob(MaterializationJob):
def __init__(self, job_id: str) -> None:
super().__init__()
self._job_id: str = job_id

def status(self) -> str:
return "success"

def should_be_retried(self) -> bool:
return False

def job_id(self) -> str:
return self.job_id()

def url(self) -> Optional[str]:
return None


class LocalMaterializationEngine(BatchMaterializationEngine):
def __init__(
self,
*,
repo_config: RepoConfig,
offline_store: OfflineStore,
online_store: OnlineStore,
**kwargs,
):
super().__init__(
repo_config=repo_config,
offline_store=offline_store,
online_store=online_store,
**kwargs,
)

def materialize(
self, registry, tasks: List[MaterializationTask]
) -> List[MaterializationJob]:
return [
self.materialize_one(
registry,
task.feature_view,
task.start_time,
task.end_time,
task.project,
task.tqdm_builder,
)
for task in tasks
]

def materialize_one(
self,
registry,
feature_view: Union[BatchFeatureView, StreamFeatureView],
start_date: datetime,
end_date: datetime,
project: str,
tqdm_builder: Callable[[int], tqdm],
):
entities = []
for entity_name in feature_view.entities:
entities.append(registry.get_entity(entity_name, project))

(
join_key_columns,
feature_name_columns,
timestamp_field,
created_timestamp_column,
) = _get_column_names(feature_view, entities)

offline_job = self.offline_store.pull_latest_from_table_or_query(
config=self.repo_config,
data_source=feature_view.batch_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

table = offline_job.to_arrow()

if feature_view.batch_source.field_mapping is not None:
table = _run_field_mapping(table, feature_view.batch_source.field_mapping)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

with tqdm_builder(table.num_rows) as pbar:
for batch in table.to_batches(DEFAULT_BATCH_SIZE):
rows_to_write = _convert_arrow_to_proto(
batch, feature_view, join_key_to_value_type
)
self.online_store.online_write_batch(
self.repo_config,
feature_view,
rows_to_write,
lambda x: pbar.update(x),
)
job_id = f"{feature_view.name}-{start_date}-{end_date}"
return LocalMaterializationJob(job_id=job_id)


def _get_column_names(
feature_view: FeatureView, entities: List[Entity]
) -> Tuple[List[str], List[str], str, Optional[str]]:
"""
If a field mapping exists, run it in reverse on the join keys,
feature names, event timestamp column, and created timestamp column
to get the names of the relevant columns in the offline feature store table.

Returns:
Tuple containing the list of reverse-mapped join_keys,
reverse-mapped feature names, reverse-mapped event timestamp column,
and reverse-mapped created timestamp column that will be passed into
the query to the offline store.
"""
# if we have mapped fields, use the original field names in the call to the offline store
timestamp_field = feature_view.batch_source.timestamp_field
feature_names = [feature.name for feature in feature_view.features]
created_timestamp_column = feature_view.batch_source.created_timestamp_column
join_keys = [
entity.join_key for entity in entities if entity.join_key != DUMMY_ENTITY_ID
]
if feature_view.batch_source.field_mapping is not None:
reverse_field_mapping = {
v: k for k, v in feature_view.batch_source.field_mapping.items()
}
timestamp_field = (
reverse_field_mapping[timestamp_field]
if timestamp_field in reverse_field_mapping.keys()
else timestamp_field
)
created_timestamp_column = (
reverse_field_mapping[created_timestamp_column]
if created_timestamp_column
and created_timestamp_column in reverse_field_mapping.keys()
else created_timestamp_column
)
join_keys = [
reverse_field_mapping[col] if col in reverse_field_mapping.keys() else col
for col in join_keys
]
feature_names = [
reverse_field_mapping[col] if col in reverse_field_mapping.keys() else col
for col in feature_names
]

# We need to exclude join keys and timestamp columns from the list of features, after they are mapped to
# their final column names via the `field_mapping` field of the source.
feature_names = [
name
for name in feature_names
if name not in join_keys
and name != timestamp_field
and name != created_timestamp_column
]
return (
join_keys,
feature_names,
timestamp_field,
created_timestamp_column,
)


def _run_field_mapping(table: pa.Table, field_mapping: Dict[str, str],) -> pa.Table:
# run field mapping in the forward direction
cols = table.column_names
mapped_cols = [
field_mapping[col] if col in field_mapping.keys() else col for col in cols
]
table = table.rename_columns(mapped_cols)
return table


def _run_dask_field_mapping(
table: dd.DataFrame, field_mapping: Dict[str, str],
):
if field_mapping:
# run field mapping in the forward direction
table = table.rename(columns=field_mapping)
table = table.persist()

return table


def _coerce_datetime(ts):
"""
Depending on underlying time resolution, arrow to_pydict() sometimes returns pd
timestamp type (for nanosecond resolution), and sometimes you get standard python datetime
(for microsecond resolution).
While pd timestamp class is a subclass of python datetime, it doesn't always behave the
same way. We convert it to normal datetime so that consumers downstream don't have to deal
with these quirks.
"""
if isinstance(ts, pd.Timestamp):
return ts.to_pydatetime()
else:
return ts


def _convert_arrow_to_proto(
table: Union[pa.Table, pa.RecordBatch],
feature_view: FeatureView,
join_keys: Dict[str, ValueType],
) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]:
# Avoid ChunkedArrays which guarentees `zero_copy_only` availiable.
if isinstance(table, pa.Table):
table = table.to_batches()[0]

columns = [
(field.name, field.dtype.to_value_type()) for field in feature_view.features
] + list(join_keys.items())

proto_values_by_column = {
column: python_values_to_proto_values(
table.column(column).to_numpy(zero_copy_only=False), value_type
)
for column, value_type in columns
}

entity_keys = [
EntityKeyProto(
join_keys=join_keys,
entity_values=[proto_values_by_column[k][idx] for k in join_keys],
)
for idx in range(table.num_rows)
]

# Serialize the features per row
feature_dict = {
feature.name: proto_values_by_column[feature.name]
for feature in feature_view.features
}
features = [dict(zip(feature_dict, vars)) for vars in zip(*feature_dict.values())]

# Convert event_timestamps
event_timestamps = [
_coerce_datetime(val)
for val in pd.to_datetime(
table.column(feature_view.batch_source.timestamp_field).to_numpy(
zero_copy_only=False
)
)
]

# Convert created_timestamps if they exist
if feature_view.batch_source.created_timestamp_column:
created_timestamps = [
_coerce_datetime(val)
for val in pd.to_datetime(
table.column(
feature_view.batch_source.created_timestamp_column
).to_numpy(zero_copy_only=False)
)
]
else:
created_timestamps = [None] * table.num_rows

return list(zip(entity_keys, features, event_timestamps, created_timestamps))
Loading