@@ -20,7 +20,23 @@ def _serialize_val(
20
20
return struct .pack ("<l" , v .int64_val ), ValueType .INT64
21
21
return struct .pack ("<q" , v .int64_val ), ValueType .INT64
22
22
else :
23
- raise ValueError (f"Value type not supported for Firestore: { v } " )
23
+ raise ValueError (f"Value type not supported for feast feature store: { v } " )
24
+
25
+
26
+ def _deserialize_value (value_type , value_bytes ) -> ValueProto :
27
+ if value_type == ValueType .INT64 :
28
+ value = struct .unpack ("<q" , value_bytes )[0 ]
29
+ return ValueProto (int64_val = value )
30
+ if value_type == ValueType .INT32 :
31
+ value = struct .unpack ("<i" , value_bytes )[0 ]
32
+ return ValueProto (int32_val = value )
33
+ elif value_type == ValueType .STRING :
34
+ value = value_bytes .decode ("utf-8" )
35
+ return ValueProto (string_val = value )
36
+ elif value_type == ValueType .BYTES :
37
+ return ValueProto (bytes_val = value_bytes )
38
+ else :
39
+ raise ValueError (f"Unsupported value type: { value_type } " )
24
40
25
41
26
42
def serialize_entity_key_prefix (entity_keys : List [str ]) -> bytes :
@@ -50,6 +66,15 @@ def serialize_entity_key(
50
66
serialize to the same byte string[1].
51
67
52
68
[1] https://developers.google.com/protocol-buffers/docs/encoding
69
+
70
+ Args:
71
+ entity_key_serialization_version: version of the entity key serialization
72
+ version 1: int64 values are serialized as 4 bytes
73
+ version 2: int64 values are serialized as 8 bytes
74
+ version 3: entity_key size is added to the serialization for deserialization purposes
75
+ entity_key: EntityKeyProto
76
+
77
+ Returns: bytes of the serialized entity key
53
78
"""
54
79
sorted_keys , sorted_values = zip (
55
80
* sorted (zip (entity_key .join_keys , entity_key .entity_values ))
@@ -58,6 +83,8 @@ def serialize_entity_key(
58
83
output : List [bytes ] = []
59
84
for k in sorted_keys :
60
85
output .append (struct .pack ("<I" , ValueType .STRING ))
86
+ if entity_key_serialization_version > 2 :
87
+ output .append (struct .pack ("<I" , len (k )))
61
88
output .append (k .encode ("utf8" ))
62
89
for v in sorted_values :
63
90
val_bytes , value_type = _serialize_val (
@@ -74,6 +101,57 @@ def serialize_entity_key(
74
101
return b"" .join (output )
75
102
76
103
104
+ def deserialize_entity_key (
105
+ serialized_entity_key : bytes , entity_key_serialization_version = 3
106
+ ) -> EntityKeyProto :
107
+ """
108
+ Deserialize entity key from a bytestring. This function can only be used with entity_key_serialization_version > 2.
109
+ Args:
110
+ entity_key_serialization_version: version of the entity key serialization
111
+ serialized_entity_key: serialized entity key bytes
112
+
113
+ Returns: EntityKeyProto
114
+
115
+ """
116
+ if entity_key_serialization_version <= 2 :
117
+ raise ValueError (
118
+ "Deserialization of entity key with version <= 2 is not supported. Please use version > 2 by setting entity_key_serialization_version=3"
119
+ )
120
+ offset = 0
121
+ keys = []
122
+ values = []
123
+ while offset < len (serialized_entity_key ):
124
+ key_type = struct .unpack_from ("<I" , serialized_entity_key , offset )[0 ]
125
+ offset += 4
126
+
127
+ # Read the length of the key
128
+ key_length = struct .unpack_from ("<I" , serialized_entity_key , offset )[0 ]
129
+ offset += 4
130
+
131
+ if key_type == ValueType .STRING :
132
+ key = struct .unpack_from (f"<{ key_length } s" , serialized_entity_key , offset )[
133
+ 0
134
+ ]
135
+ keys .append (key .decode ("utf-8" ).rstrip ("\x00 " ))
136
+ offset += key_length
137
+ else :
138
+ raise ValueError (f"Unsupported key type: { key_type } " )
139
+
140
+ (value_type ,) = struct .unpack_from ("<I" , serialized_entity_key , offset )
141
+ offset += 4
142
+
143
+ (value_length ,) = struct .unpack_from ("<I" , serialized_entity_key , offset )
144
+ offset += 4
145
+
146
+ # Read the value based on its type and length
147
+ value_bytes = serialized_entity_key [offset : offset + value_length ]
148
+ value = _deserialize_value (value_type , value_bytes )
149
+ values .append (value )
150
+ offset += value_length
151
+
152
+ return EntityKeyProto (join_keys = keys , entity_values = values )
153
+
154
+
77
155
def get_list_val_str (val ):
78
156
accept_value_types = [
79
157
"float_list_val" ,
0 commit comments