|
14 | 14 | from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse
|
15 | 15 | from feast.protos.feast.types import Value_pb2
|
16 | 16 | from feast.repo_config import RepoConfig
|
| 17 | +from feast.types import from_value_type |
17 | 18 | from feast.value_type import ValueType
|
18 | 19 |
|
19 | 20 | from .lib.embedded import DataTable, NewOnlineFeatureService, OnlineFeatureServiceConfig
|
20 | 21 | from .lib.go import Slice_string
|
| 22 | +from .type_map import FEAST_TYPE_TO_ARROW_TYPE, arrow_array_to_array_of_proto |
21 | 23 |
|
22 | 24 | if TYPE_CHECKING:
|
23 | 25 | from feast.feature_store import FeatureStore
|
24 | 26 |
|
25 | 27 |
|
26 |
| -ARROW_TYPE_TO_PROTO_FIELD = { |
27 |
| - pa.int32(): "int32_val", |
28 |
| - pa.int64(): "int64_val", |
29 |
| - pa.float32(): "float_val", |
30 |
| - pa.float64(): "double_val", |
31 |
| - pa.bool_(): "bool_val", |
32 |
| - pa.string(): "string_val", |
33 |
| - pa.binary(): "bytes_val", |
34 |
| - pa.time32("s"): "unix_timestamp_val", |
35 |
| -} |
36 |
| - |
37 |
| -ARROW_LIST_TYPE_TO_PROTO_FIELD = { |
38 |
| - pa.int32(): "int32_list_val", |
39 |
| - pa.int64(): "int64_list_val", |
40 |
| - pa.float32(): "float_list_val", |
41 |
| - pa.float64(): "double_list_val", |
42 |
| - pa.bool_(): "bool_list_val", |
43 |
| - pa.string(): "string_list_val", |
44 |
| - pa.binary(): "bytes_list_val", |
45 |
| - pa.time32("s"): "unix_timestamp_list_val", |
46 |
| -} |
47 |
| - |
48 |
| -ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS = { |
49 |
| - pa.int32(): Value_pb2.Int32List, |
50 |
| - pa.int64(): Value_pb2.Int64List, |
51 |
| - pa.float32(): Value_pb2.FloatList, |
52 |
| - pa.float64(): Value_pb2.DoubleList, |
53 |
| - pa.bool_(): Value_pb2.BoolList, |
54 |
| - pa.string(): Value_pb2.StringList, |
55 |
| - pa.binary(): Value_pb2.BytesList, |
56 |
| - pa.time32("s"): Value_pb2.Int64List, |
57 |
| -} |
58 |
| - |
59 |
| -# used for entity types only |
60 |
| -PROTO_TYPE_TO_ARROW_TYPE = { |
61 |
| - ValueType.INT32: pa.int32(), |
62 |
| - ValueType.INT64: pa.int64(), |
63 |
| - ValueType.FLOAT: pa.float32(), |
64 |
| - ValueType.DOUBLE: pa.float64(), |
65 |
| - ValueType.STRING: pa.string(), |
66 |
| - ValueType.BYTES: pa.binary(), |
67 |
| -} |
68 |
| - |
69 |
| - |
70 | 28 | class EmbeddedOnlineFeatureServer:
|
71 | 29 | def __init__(
|
72 | 30 | self, repo_path: str, repo_config: RepoConfig, feature_store: "FeatureStore"
|
@@ -179,8 +137,10 @@ def _to_arrow(value, type_hint: Optional[ValueType]) -> pa.Array:
|
179 | 137 | if isinstance(value, Value_pb2.RepeatedValue):
|
180 | 138 | _proto_to_arrow(value)
|
181 | 139 |
|
182 |
| - if type_hint in PROTO_TYPE_TO_ARROW_TYPE: |
183 |
| - return pa.array(value, PROTO_TYPE_TO_ARROW_TYPE[type_hint]) |
| 140 | + if type_hint: |
| 141 | + feast_type = from_value_type(type_hint) |
| 142 | + if feast_type in FEAST_TYPE_TO_ARROW_TYPE: |
| 143 | + return pa.array(value, FEAST_TYPE_TO_ARROW_TYPE[feast_type]) |
184 | 144 |
|
185 | 145 | return pa.array(value)
|
186 | 146 |
|
@@ -263,31 +223,9 @@ def record_batch_to_online_response(record_batch):
|
263 | 223 | [Value_pb2.Value()] * len(record_batch.columns[idx])
|
264 | 224 | )
|
265 | 225 | else:
|
266 |
| - if isinstance(field.type, pa.ListType): |
267 |
| - proto_list_class = ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS[ |
268 |
| - field.type.value_type |
269 |
| - ] |
270 |
| - proto_field_name = ARROW_LIST_TYPE_TO_PROTO_FIELD[field.type.value_type] |
271 |
| - |
272 |
| - column = record_batch.columns[idx] |
273 |
| - if field.type.value_type == pa.time32("s"): |
274 |
| - column = column.cast(pa.list_(pa.int32())) |
275 |
| - |
276 |
| - for v in column.tolist(): |
277 |
| - feature_vector.values.append( |
278 |
| - Value_pb2.Value(**{proto_field_name: proto_list_class(val=v)}) |
279 |
| - ) |
280 |
| - else: |
281 |
| - proto_field_name = ARROW_TYPE_TO_PROTO_FIELD[field.type] |
282 |
| - |
283 |
| - column = record_batch.columns[idx] |
284 |
| - if field.type == pa.time32("s"): |
285 |
| - column = column.cast(pa.int32()) |
286 |
| - |
287 |
| - for v in column.tolist(): |
288 |
| - feature_vector.values.append( |
289 |
| - Value_pb2.Value(**{proto_field_name: v}) |
290 |
| - ) |
| 226 | + feature_vector.values.extend( |
| 227 | + arrow_array_to_array_of_proto(field.type, record_batch.columns[idx]) |
| 228 | + ) |
291 | 229 |
|
292 | 230 | resp.results.append(feature_vector)
|
293 | 231 | resp.metadata.feature_names.val.append(field.name)
|
|
0 commit comments