Skip to content

Commit

Permalink
fix: Pgvector patch (#4108)
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoXuAI authored Apr 17, 2024
1 parent 0fb2351 commit ad45bb4
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 28 deletions.
19 changes: 12 additions & 7 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,12 +1740,14 @@ def _retrieve_online_documents(
query,
top_k,
)
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[3] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])

# TODO Refactor to better way of populating result
# TODO populate entity in the response after returning entity in document_features is supported
# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[4] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={requested_feature: document_feature_vals},
Expand Down Expand Up @@ -1979,7 +1981,7 @@ def _retrieve_from_online_store(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]:
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]:
"""
Search and return document features from the online document store.
"""
Expand All @@ -1994,19 +1996,22 @@ def _retrieve_from_online_store(
read_row_protos = []
row_ts_proto = Timestamp()

for row_ts, feature_val, distance_val in documents:
for row_ts, feature_val, vector_value, distance_val in documents:
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)

if feature_val is None or distance_val is None:
if feature_val is None or vector_value is None or distance_val is None:
feature_val = Value()
vector_value = Value()
distance_val = Value()
status = FieldStatus.NOT_FOUND
else:
status = FieldStatus.PRESENT

read_row_protos.append((row_ts_proto, status, feature_val, distance_val))
read_row_protos.append(
(row_ts_proto, status, feature_val, vector_value, distance_val)
)
return read_row_protos

@staticmethod
Expand Down
55 changes: 37 additions & 18 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ def online_write_batch(

for feature_name, val in values.items():
vector_val = None
if (
"pgvector_enabled" in config.online_store
and config.online_store.pgvector_enabled
):
if config.online_store.pgvector_enabled:
vector_val = get_list_val_str(val)
insert_values.append(
(
Expand Down Expand Up @@ -226,10 +223,7 @@ def update(

for table in tables_to_keep:
table_name = _table_id(project, table)
if (
"pgvector_enabled" in config.online_store
and config.online_store.pgvector_enabled
):
if config.online_store.pgvector_enabled:
vector_value_type = f"vector({config.online_store.vector_len})"
else:
# keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
Expand Down Expand Up @@ -282,7 +276,14 @@ def retrieve_online_documents(
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
) -> List[
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
]:
"""
Args:
Expand All @@ -297,10 +298,7 @@ def retrieve_online_documents(
"""
project = config.project

if (
"pgvector_enabled" not in config.online_store
or not config.online_store.pgvector_enabled
):
if not config.online_store.pgvector_enabled:
raise ValueError(
"pgvector is not enabled in the online store configuration"
)
Expand All @@ -309,7 +307,12 @@ def retrieve_online_documents(
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

result: List[
Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
] = []
with self._get_conn(config) as conn, conn.cursor() as cur:
table_name = _table_id(project, table)
Expand All @@ -322,6 +325,7 @@ def retrieve_online_documents(
SELECT
entity_key,
feature_name,
value,
vector_value,
vector_value <-> %s as distance,
event_ts FROM {table_name}
Expand All @@ -338,16 +342,31 @@ def retrieve_online_documents(
)
rows = cur.fetchall()

for entity_key, feature_name, vector_value, distance, event_ts in rows:
for (
entity_key,
feature_name,
value,
vector_value,
distance,
event_ts,
) in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

# TODO Convert to List[float] for value type proto
feature_value_proto = ValueProto(string_val=vector_value)
feature_value_proto = ValueProto()
feature_value_proto.ParseFromString(bytes(value))

vector_value_proto = ValueProto(string_val=vector_value)
distance_value_proto = ValueProto(float_val=distance)
result.append((event_ts, feature_value_proto, distance_value_proto))
result.append(
(
event_ts,
feature_value_proto,
vector_value_proto,
distance_value_proto,
)
)

return result

Expand Down
9 changes: 8 additions & 1 deletion sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,14 @@ def retrieve_online_documents(
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
) -> List[
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
]:
"""
Retrieves online feature values for the specified embeddings.
Expand Down
9 changes: 8 additions & 1 deletion sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,14 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
) -> List[
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
]:
"""
Searches for the top-k nearest neighbors of the given document in the online document store.
Expand Down
9 changes: 8 additions & 1 deletion sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,12 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
) -> List[
Tuple[
Optional[datetime],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
]:
return []

0 comments on commit ad45bb4

Please sign in to comment.