From 7f0737d4e5e4569cbff6c4d59961fbeac604dfa8 Mon Sep 17 00:00:00 2001 From: alxtkr77 <3098237+alxtkr77@users.noreply.github.com> Date: Sun, 23 Jul 2023 12:38:35 +0300 Subject: [PATCH] ML-4171: Redis target: move to one redis-key per entity record (#449) --- integration/test_flow_integration.py | 11 ++- integration/test_redis_specific.py | 8 +- storey/redis_driver.py | 115 +++++++++++++++------------ 3 files changed, 79 insertions(+), 55 deletions(-) diff --git a/integration/test_flow_integration.py b/integration/test_flow_integration.py index dfbf1f09..283ef31d 100644 --- a/integration/test_flow_integration.py +++ b/integration/test_flow_integration.py @@ -111,7 +111,16 @@ def _get_redis_kv_all_attrs(setup_teardown_test: ContextForTests, key: str): hash_key = RedisDriver.make_key("storey-test:", table_name, key) redis_key = RedisDriver._static_data_key(hash_key) redis_fake_server = setup_teardown_test.redis_fake_server - values = get_redis_client(redis_fake_server=redis_fake_server).hgetall(redis_key) + + cursor = 0 + values = {} + while True: + cursor, v = get_redis_client(redis_fake_server=redis_fake_server).hscan( + redis_key, cursor, match=f"[^{RedisDriver.INTERFNAL_FIELD_PREFIX}]*" + ) + values.update(v) + if cursor == 0: + break return { RedisDriver.convert_to_str(key): RedisDriver.convert_redis_value_to_python_obj(val) for key, val in values.items() diff --git a/integration/test_redis_specific.py b/integration/test_redis_specific.py index fe2054c4..cff6a965 100644 --- a/integration/test_redis_specific.py +++ b/integration/test_redis_specific.py @@ -39,7 +39,13 @@ def test_redis_driver_write(redis): hash_key = RedisDriver.make_key("storey:", table_name, "key") redis_key = RedisDriver._static_data_key(hash_key) - data = driver.redis.hgetall(redis_key) + cursor = 0 + data = {} + while True: + cursor, v = driver.redis.hscan(redis_key, cursor, match=f"[^{driver.INTERFNAL_FIELD_PREFIX}]*") + data.update(v) + if cursor == 0: + break data_strings = {} for key, val in data.items(): if isinstance(key, bytes): diff --git a/storey/redis_driver.py b/storey/redis_driver.py index e02ad766..fab29c4e 100644 --- a/storey/redis_driver.py +++ b/storey/redis_driver.py @@ -46,14 +46,15 @@ def __init__(self, redis_url=None): class RedisDriver(NeedsRedisAccess, Driver): + INTERFNAL_FIELD_PREFIX = chr(0x1) REDIS_TIMEOUT = 5 # Seconds REDIS_WATCH_INTERVAL = 1 # Seconds DATETIME_FIELD_PREFIX = "_dt:" TIMEDELTA_FIELD_PREFIX = "_td:" DEFAULT_KEY_PREFIX = "storey:" - AGGREGATION_ATTRIBUTE_PREFIX = "moving_windows" - AGGREGATION_TIME_ATTRIBUTE_PREFIX = "_" - AGGREGATION_PREFIXES = AGGREGATION_TIME_ATTRIBUTE_PREFIX + AGGREGATION_ATTRIBUTE_PREFIX = INTERFNAL_FIELD_PREFIX + "aggr_" + AGGREGATION_TIME_ATTRIBUTE_PREFIX = INTERFNAL_FIELD_PREFIX + "mtaggr_" + OBJECT_MTIME_ATTRIBUTE_PREFIX = INTERFNAL_FIELD_PREFIX + "_mtime_" def __init__( self, @@ -71,7 +72,7 @@ def __init__( self._redis = None self._key_prefix = key_prefix if key_prefix is not None else self.DEFAULT_KEY_PREFIX - self._mtime_name = "$_mtime_" + self._mtime_name = self.OBJECT_MTIME_ATTRIBUTE_PREFIX @staticmethod def asyncify(fn): @@ -237,8 +238,7 @@ def _build_feature_store_lua_update_script( # Static attributes, like "name," "age," -- everything that isn't an agg. if additional_data: - redis_keys_involved.append(self._static_data_key(redis_key_prefix)) - additional_data_lua_script = f'local additional_data_key="{self._static_data_key(redis_key_prefix)}";\n' + additional_data_lua_script = f'local redis_hash="{self._static_data_key(redis_key_prefix)}";\n' for name, value in additional_data.items(): expression_value = self._convert_python_obj_to_lua_value(value) # NOTE: This logic assumes that static attributes we're supposed @@ -247,15 +247,14 @@ def _build_feature_store_lua_update_script( if expression_value: additional_data_lua_script = ( f"{additional_data_lua_script}redis.call(" - f'"HSET",additional_data_key, "{name}", {expression_value});\n' + f'"HSET",redis_hash, "{name}", {expression_value});\n' ) else: additional_data_lua_script = ( - f'{additional_data_lua_script}redis.call("HDEL",additional_data_key, "{name}");\n' + f'{additional_data_lua_script}redis.call("HDEL",redis_hash, "{name}");\n' ) lua_script = additional_data_lua_script - list_attribute_key_aggr = self._list_key(redis_key_prefix, RedisDriver.AGGREGATION_ATTRIBUTE_PREFIX) if aggregation_element: times_updates = {} @@ -277,9 +276,7 @@ def _build_feature_store_lua_update_script( " return result\n" "end\n" ) - lua_script = ( - f"{lua_script}local old_value;local attr_name;\n" f'local aggr_key="{list_attribute_key_aggr}";\n' - ) + lua_script = f"{lua_script}local old_value;local aggr_key;\n" lua_script = f"{lua_script}{lua_tonum_function}{lua_strToArr_funct}\n" for name, bucket in aggregation_element.aggregation_buckets.items(): # Only save raw aggregates, not virtual @@ -296,6 +293,9 @@ def _build_feature_store_lua_update_script( array_time_attribute_key = self._aggregation_time_key( redis_key_prefix, aggr_time_attribute_name ) + aggr_mtime_attr_name = ( + f"{RedisDriver.AGGREGATION_TIME_ATTRIBUTE_PREFIX}{aggr_time_attribute_name}" + ) cached_time = bucket.storage_specific_cache.get(array_time_attribute_key, -1) @@ -310,29 +310,28 @@ def _build_feature_store_lua_update_script( aggregation_value, ) in aggregation_values.items(): list_attribute_name = f"{name}_{aggregation}_{feature_attr}" - if list_attribute_key_aggr not in redis_keys_involved: - redis_keys_involved.append(list_attribute_key_aggr) - lua_script = f'{lua_script}attr_name="{list_attribute_name}";\n' + lua_script = f'{lua_script}\ + aggr_key="{RedisDriver.AGGREGATION_ATTRIBUTE_PREFIX}{list_attribute_name}";\n' if cached_time < expected_time: list_attribute_key = self._list_key(redis_key_prefix, list_attribute_name) if not initialized_attributes.get(list_attribute_key, -1) == expected_time: initialized_attributes[list_attribute_key] = expected_time lua_script = ( - f'{lua_script}local t=redis.call("GET","{array_time_attribute_key}");\n' + f'{lua_script}local t=redis.call("HGET",redis_hash,"{aggr_mtime_attr_name}");\n' f'if (type(t)~="boolean" and (tonumber(t) < {expected_time})) then ' - f'redis.call("HDEL",aggr_key, attr_name); end;\n' + f'redis.call("HDEL",redis_hash, aggr_key); end;\n' ) default_value = self._convert_python_obj_to_redis_value( aggregation_value.default_value ) lua_script = ( - f'{lua_script}local curr_agg=redis.call("HGET",aggr_key, attr_name)\n' + f'{lua_script}local curr_agg=redis.call("HGET",redis_hash, aggr_key)\n' "local arr=strToArr(curr_agg);\n" f"local org_arr_len=#arr\n" f"for i=1,({bucket.total_number_of_buckets}-org_arr_len) \ do arr[#arr+1]={default_value};end;\n" - f'if org_arr_len ~= #arr then redis.call("HSET", aggr_key, attr_name,\ + f'if org_arr_len ~= #arr then redis.call("HSET", redis_hash, aggr_key,\ table.concat(arr, ",")) end;\n' ) if array_time_attribute_key not in times_updates: @@ -348,18 +347,18 @@ def _build_feature_store_lua_update_script( "old_value", aggregation_value.value ) lua_script = ( - f'{lua_script}arr=strToArr(redis.call("HGET",aggr_key, attr_name))\n' + f'{lua_script}arr=strToArr(redis.call("HGET",redis_hash, aggr_key))\n' f"old_value=tonum(arr[{lua_index_to_update}]);\n" f'arr[{lua_index_to_update}]=string.format("%.17f",{new_value_expression});\n' - 'redis.call("HSET", aggr_key, attr_name, table.concat(arr, ","))\n' + 'redis.call("HSET", redis_hash, aggr_key, table.concat(arr, ","))\n' ) - - redis_keys_involved.append(array_time_attribute_key) - lua_script = f'{lua_script}redis.call("SET","{array_time_attribute_key}",{expected_time}); \n' + lua_script = f'{lua_script}\ + redis.call("HSET",redis_hash,"{aggr_mtime_attr_name}",{expected_time});' return lua_script, condition_expression, pending_updates, redis_keys_involved async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_key, additional_data): redis_key_prefix = self._make_key(container, table_path, key) + static_redis_key_prefix = self._static_data_key(redis_key_prefix) ( update_expression, mtime_condition, @@ -373,17 +372,18 @@ async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_ current_time = int(time.time_ns() / 1000) if mtime_condition is not None: update_expression = ( - f'if redis.call("HGET", "{redis_key_prefix}","{self._mtime_name}") == "{mtime_condition}" then\n' - f'{update_expression}redis.call("HSET","{redis_key_prefix}","{self._mtime_name}",{current_time});\n' + f'if redis.call("HGET", "{static_redis_key_prefix}","{self._mtime_name}") == "{mtime_condition}" then\n' + f'{update_expression}\ + redis.call("HSET","{static_redis_key_prefix}","{self._mtime_name}",{current_time});\n' "return 1;else return 0;end;" ) else: update_expression = ( f"{update_expression}redis.call(" - f'"HSET","{redis_key_prefix}","{self._mtime_name}",{current_time});return 1;' + f'"HSET","{static_redis_key_prefix}","{self._mtime_name}",{current_time});return 1;' ) - redis_keys_involved.append(redis_key_prefix) + redis_keys_involved.append(static_redis_key_prefix) update_ok = await self.asyncify(self.redis.eval)( update_expression, len(redis_keys_involved), *redis_keys_involved ) @@ -402,9 +402,8 @@ async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_ ) = self._build_feature_store_lua_update_script(redis_key_prefix, aggr_item, False, additional_data) update_expression = ( f"{update_expression}redis.call(" - f'"HSET","{redis_key_prefix}","{self._mtime_name}",{current_time});return 1;' + f'"HSET","{static_redis_key_prefix}","{self._mtime_name}",{current_time});return 1;' ) - redis_keys_involved.append(redis_key_prefix) update_ok = await RedisDriver.asyncify(self.redis.eval)( update_expression, len(redis_keys_involved), *redis_keys_involved @@ -413,21 +412,15 @@ async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_ await self._fetch_state_by_key(aggr_item, container, table_path, key) async def _get_all_fields(self, redis_key: str): - try: - # TODO: This should be HSCAN, not HGETALL, to avoid blocking Redis - # with very large hashes. - values = await RedisDriver.asyncify(self.redis.hgetall)(redis_key) - except redis.ResponseError as e: - raise RedisError(f"Failed to get key {redis_key}. Response error was: {e}") + values = await self.redis_hscan(redis_key, f"[^{self.INTERFNAL_FIELD_PREFIX}]*") res = { RedisDriver.convert_to_str(key): RedisDriver.convert_redis_value_to_python_obj(val) for key, val in values.items() - if not str(key).startswith(RedisDriver.AGGREGATION_PREFIXES) } return res async def _get_specific_fields(self, redis_key: str, attributes: List[str]): - non_aggregation_attrs = [name for name in attributes if not name.startswith(RedisDriver.AGGREGATION_PREFIXES)] + non_aggregation_attrs = [name for name in attributes if not name.startswith(RedisDriver.INTERFNAL_FIELD_PREFIX)] try: values = await RedisDriver.asyncify(self.redis.hmget)(redis_key, non_aggregation_attrs) except redis.ResponseError as e: @@ -462,8 +455,12 @@ async def _get_associated_time_attr(self, redis_key_prefix, aggr_name): aggr_without_relevant_attr = aggr_name[:-2] feature_name_only = aggr_without_relevant_attr[: aggr_without_relevant_attr.rindex("_")] feature_with_relevant_attr = f"{feature_name_only}{aggr_name[-2:]}" + aggr_mtime_attribute_name = f"{RedisDriver.AGGREGATION_TIME_ATTRIBUTE_PREFIX}{feature_with_relevant_attr}" + associated_time_key = self._aggregation_time_key(redis_key_prefix, feature_with_relevant_attr) - time_val = await RedisDriver.asyncify(self.redis.get)(associated_time_key) + time_val = await RedisDriver.asyncify(self.redis.hget)( + self._static_data_key(redis_key_prefix), aggr_mtime_attribute_name + ) time_val = RedisDriver.convert_to_str(time_val) try: @@ -492,11 +489,10 @@ async def _load_aggregates_by_key(self, container, table_path, key): # Aggregation Redis keys start with the Redis key prefix for this Storey container, table # path, and "key," followed by ":aggr_" redis_key_prefix = self._make_key(container, table_path, key) - aggr_key_prefix = f"{redis_key_prefix}:{RedisDriver.AGGREGATION_ATTRIBUTE_PREFIX}" - all_aggr_keys = self.redis.hkeys(aggr_key_prefix) - for aggr_key in all_aggr_keys: - aggr_key = RedisDriver.convert_to_str(aggr_key) - value = await RedisDriver.asyncify(self.redis.hget)(aggr_key_prefix, aggr_key) + redis_key = self._static_data_key(redis_key_prefix) + values = await self.redis_hscan(redis_key, f"{self.AGGREGATION_ATTRIBUTE_PREFIX}*") + + for aggr_key, value in values.items(): # Build an attribute for this aggregation in the format that Storey # expects to receive from this method. The feature and aggregation # name are embedded in the Redis key. Also, drop the "_a" or "_b" @@ -504,6 +500,8 @@ async def _load_aggregates_by_key(self, container, table_path, key): # feature attributes," according to comments in the V3IO driver. value = RedisDriver.convert_to_str(value) value = value.split(",") + aggr_key = aggr_key[len(self.AGGREGATION_ATTRIBUTE_PREFIX) :] + aggr_key = RedisDriver.convert_to_str(aggr_key) feature_and_aggr_name = aggr_key[:-2] # To get the associated time, we need the aggregation name and the relevant @@ -525,18 +523,28 @@ async def _load_aggregates_by_key(self, container, table_path, key): additional_data_to_return = additional_data if additional_data else None return aggregations_to_return, additional_data_to_return + async def redis_hscan(self, redis_key, match): + try: + cursor = 0 + values = {} + while True: + cursor, v = await RedisDriver.asyncify(self.redis.hscan)(redis_key, cursor, match=match) + values.update(v) + if cursor == 0: + break + except redis.ResponseError as e: + raise RedisError(f"Failed to get key {redis_key}. Response error was: {e}") + return values + async def _fetch_state_by_key(self, aggr_item, container, table_path, key): - key = str(key) + redis_key_prefix = self._make_key(container, table_path, key) + aggregations = {} # Aggregation Redis keys start with the Redis key prefix for this Storey container, table # path, and "key," followed by ":aggr_" - aggregations = {} - redis_key_prefix = self._make_key(container, table_path, key) - aggr_key_prefix = f"{redis_key_prefix}:{RedisDriver.AGGREGATION_ATTRIBUTE_PREFIX}" - all_aggr_keys = self.redis.hkeys(aggr_key_prefix) - for aggr_key in all_aggr_keys: - aggr_key = RedisDriver.convert_to_str(aggr_key) - value = await RedisDriver.asyncify(self.redis.hget)(aggr_key_prefix, aggr_key) + redis_key = self._static_data_key(redis_key_prefix) + values = await self.redis_hscan(redis_key, f"{self.AGGREGATION_ATTRIBUTE_PREFIX}*") + for aggr_key, value in values.items(): # Build an attribute for this aggregation in the format that Storey # expects to receive from this method. The feature and aggregation # name are embedded in the Redis key. Also, drop the "_a" or "_b" @@ -544,7 +552,8 @@ async def _fetch_state_by_key(self, aggr_item, container, table_path, key): # feature attributes," according to comments in the V3IO driver. value = RedisDriver.convert_to_str(value) value = value.split(",") - + aggr_key = aggr_key[len(self.AGGREGATION_ATTRIBUTE_PREFIX) :] + aggr_key = RedisDriver.convert_to_str(aggr_key) feature_and_aggr_name = aggr_key[:-2] # To get the associated time, we need the aggregation name and the relevant