Skip to content

Commit

Permalink
ML-4171: Redis target: move to one redis-key per entity record
Browse files Browse the repository at this point in the history
  • Loading branch information
alxtkr77 committed Jul 17, 2023
1 parent 5f4b4ad commit 55ef0cd
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 54 deletions.
2 changes: 1 addition & 1 deletion integration/test_flow_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _get_redis_kv_all_attrs(setup_teardown_test: ContextForTests, key: str):
hash_key = RedisDriver.make_key("storey-test:", table_name, key)
redis_key = RedisDriver._static_data_key(hash_key)
redis_fake_server = setup_teardown_test.redis_fake_server
values = get_redis_client(redis_fake_server=redis_fake_server).hgetall(redis_key)
_, values = get_redis_client(redis_fake_server=redis_fake_server).hscan(redis_key, 0, match=f"[^{chr(0x1)}]*")
return {
RedisDriver.convert_to_str(key): RedisDriver.convert_redis_value_to_python_obj(val)
for key, val in values.items()
Expand Down
3 changes: 1 addition & 2 deletions integration/test_redis_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def test_redis_driver_write(redis):
table_name = f"{table_name}/"
hash_key = RedisDriver.make_key("storey:", table_name, "key")
redis_key = RedisDriver._static_data_key(hash_key)

data = driver.redis.hgetall(redis_key)
_, data = driver.redis.hscan(redis_key, 0, match=f"[^{driver.INTERFNAL_FIELD_PREFIX}]*")
data_strings = {}
for key, val in data.items():
if isinstance(key, bytes):
Expand Down
133 changes: 82 additions & 51 deletions storey/redis_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def __init__(self, redis_url=None):


class RedisDriver(NeedsRedisAccess, Driver):
INTERFNAL_FIELD_PREFIX = chr(0x1)
REDIS_TIMEOUT = 5 # Seconds
REDIS_WATCH_INTERVAL = 1 # Seconds
DATETIME_FIELD_PREFIX = "_dt:"
TIMEDELTA_FIELD_PREFIX = "_td:"
DEFAULT_KEY_PREFIX = "storey:"
AGGREGATION_ATTRIBUTE_PREFIX = "moving_windows"
AGGREGATION_TIME_ATTRIBUTE_PREFIX = "_"
AGGREGATION_PREFIXES = AGGREGATION_TIME_ATTRIBUTE_PREFIX
AGGREGATION_ATTRIBUTE_PREFIX = INTERFNAL_FIELD_PREFIX + "aggr_"
AGGREGATION_TIME_ATTRIBUTE_PREFIX = INTERFNAL_FIELD_PREFIX + "mtaggr_"
OBJECT_MTIME_ATTRIBUTE_PREFIX = INTERFNAL_FIELD_PREFIX + "_mtime_"

def __init__(
self,
Expand All @@ -71,7 +72,7 @@ def __init__(
self._redis = None

self._key_prefix = key_prefix if key_prefix is not None else self.DEFAULT_KEY_PREFIX
self._mtime_name = "$_mtime_"
self._mtime_name = self.OBJECT_MTIME_ATTRIBUTE_PREFIX

@staticmethod
def asyncify(fn):
Expand Down Expand Up @@ -237,8 +238,7 @@ def _build_feature_store_lua_update_script(

# Static attributes, like "name," "age," -- everything that isn't an agg.
if additional_data:
redis_keys_involved.append(self._static_data_key(redis_key_prefix))
additional_data_lua_script = f'local additional_data_key="{self._static_data_key(redis_key_prefix)}";\n'
additional_data_lua_script = f'local redis_hash="{self._static_data_key(redis_key_prefix)}";\n'
for name, value in additional_data.items():
expression_value = self._convert_python_obj_to_lua_value(value)
# NOTE: This logic assumes that static attributes we're supposed
Expand All @@ -247,15 +247,14 @@ def _build_feature_store_lua_update_script(
if expression_value:
additional_data_lua_script = (
f"{additional_data_lua_script}redis.call("
f'"HSET",additional_data_key, "{name}", {expression_value});\n'
f'"HSET",redis_hash, "{name}", {expression_value});\n'
)
else:
additional_data_lua_script = (
f'{additional_data_lua_script}redis.call("HDEL",additional_data_key, "{name}");\n'
f'{additional_data_lua_script}redis.call("HDEL",redis_hash, "{name}");\n'
)

lua_script = additional_data_lua_script
list_attribute_key_aggr = self._list_key(redis_key_prefix, RedisDriver.AGGREGATION_ATTRIBUTE_PREFIX)

if aggregation_element:
times_updates = {}
Expand All @@ -277,9 +276,7 @@ def _build_feature_store_lua_update_script(
" return result\n"
"end\n"
)
lua_script = (
f"{lua_script}local old_value;local attr_name;\n" f'local aggr_key="{list_attribute_key_aggr}";\n'
)
lua_script = f"{lua_script}local old_value;local aggr_key;\n"
lua_script = f"{lua_script}{lua_tonum_function}{lua_strToArr_funct}\n"
for name, bucket in aggregation_element.aggregation_buckets.items():
# Only save raw aggregates, not virtual
Expand All @@ -296,6 +293,9 @@ def _build_feature_store_lua_update_script(
array_time_attribute_key = self._aggregation_time_key(
redis_key_prefix, aggr_time_attribute_name
)
aggr_mtime_attr_name = (
f"{RedisDriver.AGGREGATION_TIME_ATTRIBUTE_PREFIX}{aggr_time_attribute_name}"
)

cached_time = bucket.storage_specific_cache.get(array_time_attribute_key, -1)

Expand All @@ -310,29 +310,28 @@ def _build_feature_store_lua_update_script(
aggregation_value,
) in aggregation_values.items():
list_attribute_name = f"{name}_{aggregation}_{feature_attr}"
if list_attribute_key_aggr not in redis_keys_involved:
redis_keys_involved.append(list_attribute_key_aggr)
lua_script = f'{lua_script}attr_name="{list_attribute_name}";\n'
lua_script = f'{lua_script}\
aggr_key="{RedisDriver.AGGREGATION_ATTRIBUTE_PREFIX}{list_attribute_name}";\n'

if cached_time < expected_time:
list_attribute_key = self._list_key(redis_key_prefix, list_attribute_name)
if not initialized_attributes.get(list_attribute_key, -1) == expected_time:
initialized_attributes[list_attribute_key] = expected_time
lua_script = (
f'{lua_script}local t=redis.call("GET","{array_time_attribute_key}");\n'
f'{lua_script}local t=redis.call("HGET",redis_hash,"{aggr_mtime_attr_name}");\n'
f'if (type(t)~="boolean" and (tonumber(t) < {expected_time})) then '
f'redis.call("HDEL",aggr_key, attr_name); end;\n'
f'redis.call("HDEL",redis_hash, aggr_key); end;\n'
)
default_value = self._convert_python_obj_to_redis_value(
aggregation_value.default_value
)
lua_script = (
f'{lua_script}local curr_agg=redis.call("HGET",aggr_key, attr_name)\n'
f'{lua_script}local curr_agg=redis.call("HGET",redis_hash, aggr_key)\n'
"local arr=strToArr(curr_agg);\n"
f"local org_arr_len=#arr\n"
f"for i=1,({bucket.total_number_of_buckets}-org_arr_len) \
do arr[#arr+1]={default_value};end;\n"
f'if org_arr_len ~= #arr then redis.call("HSET", aggr_key, attr_name,\
f'if org_arr_len ~= #arr then redis.call("HSET", redis_hash, aggr_key,\
table.concat(arr, ",")) end;\n'
)
if array_time_attribute_key not in times_updates:
Expand All @@ -348,18 +347,18 @@ def _build_feature_store_lua_update_script(
"old_value", aggregation_value.value
)
lua_script = (
f'{lua_script}arr=strToArr(redis.call("HGET",aggr_key, attr_name))\n'
f'{lua_script}arr=strToArr(redis.call("HGET",redis_hash, aggr_key))\n'
f"old_value=tonum(arr[{lua_index_to_update}]);\n"
f'arr[{lua_index_to_update}]=string.format("%.17f",{new_value_expression});\n'
'redis.call("HSET", aggr_key, attr_name, table.concat(arr, ","))\n'
'redis.call("HSET", redis_hash, aggr_key, table.concat(arr, ","))\n'
)

redis_keys_involved.append(array_time_attribute_key)
lua_script = f'{lua_script}redis.call("SET","{array_time_attribute_key}",{expected_time}); \n'
lua_script = f'{lua_script}\
redis.call("HSET",redis_hash,"{aggr_mtime_attr_name}",{expected_time});'
return lua_script, condition_expression, pending_updates, redis_keys_involved

async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_key, additional_data):
redis_key_prefix = self._make_key(container, table_path, key)
static_redis_key_prefix = self._static_data_key(redis_key_prefix)
(
update_expression,
mtime_condition,
Expand All @@ -373,17 +372,18 @@ async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_
current_time = int(time.time_ns() / 1000)
if mtime_condition is not None:
update_expression = (
f'if redis.call("HGET", "{redis_key_prefix}","{self._mtime_name}") == "{mtime_condition}" then\n'
f'{update_expression}redis.call("HSET","{redis_key_prefix}","{self._mtime_name}",{current_time});\n'
f'if redis.call("HGET", "{static_redis_key_prefix}","{self._mtime_name}") == "{mtime_condition}" then\n'
f'{update_expression}\
redis.call("HSET","{static_redis_key_prefix}","{self._mtime_name}",{current_time});\n'
"return 1;else return 0;end;"
)
else:
update_expression = (
f"{update_expression}redis.call("
f'"HSET","{redis_key_prefix}","{self._mtime_name}",{current_time});return 1;'
f'"HSET","{static_redis_key_prefix}","{self._mtime_name}",{current_time});return 1;'
)

redis_keys_involved.append(redis_key_prefix)
redis_keys_involved.append(static_redis_key_prefix)
update_ok = await self.asyncify(self.redis.eval)(
update_expression, len(redis_keys_involved), *redis_keys_involved
)
Expand All @@ -402,9 +402,8 @@ async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_
) = self._build_feature_store_lua_update_script(redis_key_prefix, aggr_item, False, additional_data)
update_expression = (
f"{update_expression}redis.call("
f'"HSET","{redis_key_prefix}","{self._mtime_name}",{current_time});return 1;'
f'"HSET","{static_redis_key_prefix}","{self._mtime_name}",{current_time});return 1;'
)
redis_keys_involved.append(redis_key_prefix)

update_ok = await RedisDriver.asyncify(self.redis.eval)(
update_expression, len(redis_keys_involved), *redis_keys_involved
Expand All @@ -414,20 +413,26 @@ async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_

async def _get_all_fields(self, redis_key: str):
try:
# TODO: This should be HSCAN, not HGETALL, to avoid blocking Redis
# with very large hashes.
values = await RedisDriver.asyncify(self.redis.hgetall)(redis_key)
cursor = 0
values = {}
# Get all the fields except the internal ones
while True:
cursor, v = await RedisDriver.asyncify(self.redis.hscan)(
redis_key, cursor, match=f"[^{self.INTERFNAL_FIELD_PREFIX}]*"
)
values.update(v)
if cursor == 0:
break
except redis.ResponseError as e:
raise RedisError(f"Failed to get key {redis_key}. Response error was: {e}")
res = {
RedisDriver.convert_to_str(key): RedisDriver.convert_redis_value_to_python_obj(val)
for key, val in values.items()
if not str(key).startswith(RedisDriver.AGGREGATION_PREFIXES)
}
return res

async def _get_specific_fields(self, redis_key: str, attributes: List[str]):
non_aggregation_attrs = [name for name in attributes if not name.startswith(RedisDriver.AGGREGATION_PREFIXES)]
non_aggregation_attrs = [name for name in attributes if not name.startswith(RedisDriver.INTERFNAL_FIELD_PREFIX)]
try:
values = await RedisDriver.asyncify(self.redis.hmget)(redis_key, non_aggregation_attrs)
except redis.ResponseError as e:
Expand Down Expand Up @@ -462,8 +467,12 @@ async def _get_associated_time_attr(self, redis_key_prefix, aggr_name):
aggr_without_relevant_attr = aggr_name[:-2]
feature_name_only = aggr_without_relevant_attr[: aggr_without_relevant_attr.rindex("_")]
feature_with_relevant_attr = f"{feature_name_only}{aggr_name[-2:]}"
aggr_mtime_attribute_name = f"{RedisDriver.AGGREGATION_TIME_ATTRIBUTE_PREFIX}{feature_with_relevant_attr}"

associated_time_key = self._aggregation_time_key(redis_key_prefix, feature_with_relevant_attr)
time_val = await RedisDriver.asyncify(self.redis.get)(associated_time_key)
time_val = await RedisDriver.asyncify(self.redis.hget)(
self._static_data_key(redis_key_prefix), aggr_mtime_attribute_name
)
time_val = RedisDriver.convert_to_str(time_val)

try:
Expand Down Expand Up @@ -492,18 +501,30 @@ async def _load_aggregates_by_key(self, container, table_path, key):
# Aggregation Redis keys start with the Redis key prefix for this Storey container, table
# path, and "key," followed by ":aggr_"
redis_key_prefix = self._make_key(container, table_path, key)
aggr_key_prefix = f"{redis_key_prefix}:{RedisDriver.AGGREGATION_ATTRIBUTE_PREFIX}"
all_aggr_keys = self.redis.hkeys(aggr_key_prefix)
for aggr_key in all_aggr_keys:
aggr_key = RedisDriver.convert_to_str(aggr_key)
value = await RedisDriver.asyncify(self.redis.hget)(aggr_key_prefix, aggr_key)
redis_key = self._static_data_key(redis_key_prefix)
try:
cursor = 0
values = {}
while True:
cursor, v = await RedisDriver.asyncify(self.redis.hscan)(
redis_key, cursor, match=f"{self.AGGREGATION_ATTRIBUTE_PREFIX}*"
)
values.update(v)
if cursor == 0:
break
except redis.ResponseError as e:
raise RedisError(f"Failed to get key {redis_key}. Response error was: {e}")

for aggr_key, value in values.items():
# Build an attribute for this aggregation in the format that Storey
# expects to receive from this method. The feature and aggregation
# name are embedded in the Redis key. Also, drop the "_a" or "_b"
# portion of the key, which is "the relevant attribute out of the 2
# feature attributes," according to comments in the V3IO driver.
value = RedisDriver.convert_to_str(value)
value = value.split(",")
aggr_key = aggr_key[len(self.AGGREGATION_ATTRIBUTE_PREFIX) :]
aggr_key = RedisDriver.convert_to_str(aggr_key)
feature_and_aggr_name = aggr_key[:-2]

# To get the associated time, we need the aggregation name and the relevant
Expand All @@ -526,25 +547,35 @@ async def _load_aggregates_by_key(self, container, table_path, key):
return aggregations_to_return, additional_data_to_return

async def _fetch_state_by_key(self, aggr_item, container, table_path, key):
key = str(key)
redis_key_prefix = self._make_key(container, table_path, key)
aggregations = {}
# Aggregation Redis keys start with the Redis key prefix for this Storey container, table
# path, and "key," followed by ":aggr_"
aggregations = {}

redis_key_prefix = self._make_key(container, table_path, key)
aggr_key_prefix = f"{redis_key_prefix}:{RedisDriver.AGGREGATION_ATTRIBUTE_PREFIX}"
all_aggr_keys = self.redis.hkeys(aggr_key_prefix)
for aggr_key in all_aggr_keys:
aggr_key = RedisDriver.convert_to_str(aggr_key)
value = await RedisDriver.asyncify(self.redis.hget)(aggr_key_prefix, aggr_key)
redis_key = self._static_data_key(redis_key_prefix)
try:
cursor = 0
values = {}
while True:
cursor, v = await RedisDriver.asyncify(self.redis.hscan)(
redis_key, cursor, match=f"{self.AGGREGATION_ATTRIBUTE_PREFIX}*"
)
values.update(v)
if cursor == 0:
break
except redis.ResponseError as e:
raise RedisError(f"Failed to get key {redis_key}. Response error was: {e}")

for aggr_key, value in values.items():
# Build an attribute for this aggregation in the format that Storey
# expects to receive from this method. The feature and aggregation
# name are embedded in the Redis key. Also, drop the "_a" or "_b"
# portion of the key, which is "the relevant attribute out of the 2
# feature attributes," according to comments in the V3IO driver.
# feature attributes," according to comments in the V3IO driver.
value = RedisDriver.convert_to_str(value)
value = value.split(",")

aggr_key = aggr_key[len(self.AGGREGATION_ATTRIBUTE_PREFIX) :]
aggr_key = RedisDriver.convert_to_str(aggr_key)
feature_and_aggr_name = aggr_key[:-2]

# To get the associated time, we need the aggregation name and the relevant
Expand Down

0 comments on commit 55ef0cd

Please sign in to comment.