diff --git a/faust/stores/rocksdb.py b/faust/stores/rocksdb.py index 9d62711c6..5182b9592 100644 --- a/faust/stores/rocksdb.py +++ b/faust/stores/rocksdb.py @@ -283,15 +283,24 @@ def _open_for_partition(self, partition: int) -> DB: return self.rocksdb_options.open(self.partition_path(partition)) def _get(self, key: bytes) -> Optional[bytes]: - dbvalue = self._get_bucket_for_key(key) - if dbvalue is None: - return None - db, value = dbvalue + event = current_event() + if event is not None: + partition = event.message.partition + db = self._db_for_partition(partition) + value = db.get(key) + if value is not None: + self._key_index[key] = partition + return value + else: + dbvalue = self._get_bucket_for_key(key) + if dbvalue is None: + return None + db, value = dbvalue - if value is None: - if db.key_may_exist(key)[0]: - return db.get(key) - return value + if value is None: + if db.key_may_exist(key)[0]: + return db.get(key) + return value def _get_bucket_for_key(self, key: bytes) -> Optional[_DBValueTuple]: dbs: Iterable[PartitionDB] @@ -374,6 +383,8 @@ async def _try_open_db_for_partition( return self._db_for_partition(partition) except rocksdb.errors.RocksIOError as exc: if i == max_retries - 1 or "lock" not in repr(exc): + # release all the locks and crash + await self.stop() raise self.log.info( "DB for partition %r is locked! Retry in 1s...", partition @@ -383,11 +394,21 @@ async def _try_open_db_for_partition( ... def _contains(self, key: bytes) -> bool: - for db in self._dbs_for_key(key): - # bloom filter: false positives possible, but not false negatives - if db.key_may_exist(key)[0] and db.get(key) is not None: + event = current_event() + if event is not None: + partition = event.message.partition + db = self._db_for_partition(partition) + value = db.get(key) + if value is not None: return True - return False + else: + return False + else: + for db in self._dbs_for_key(key): + # bloom filter: false positives possible, but not false negatives + if db.key_may_exist(key)[0] and db.get(key) is not None: + return True + return False def _dbs_for_key(self, key: bytes) -> Iterable[DB]: # Returns cached db if key is in index, otherwise all dbs