diff --git a/src/storage/src/redis_sets.cc b/src/storage/src/redis_sets.cc index 5f33d9574..fecdfd50b 100644 --- a/src/storage/src/redis_sets.cc +++ b/src/storage/src/redis_sets.cc @@ -913,7 +913,11 @@ rocksdb::Status Redis::SMove(const Slice& source, const Slice& destination, cons } rocksdb::Status Redis::SPop(const Slice& key, std::vector* members, int64_t cnt) { - std::default_random_engine engine; + if (cnt <= 0) { + return Status::InvalidArgument("Count must be greater than zero"); + } + + std::default_random_engine engine(std::random_device{}()); std::string meta_value; rocksdb::WriteBatch batch; @@ -925,7 +929,7 @@ rocksdb::Status Redis::SPop(const Slice& key, std::vector* members, if (ExpectedStale(meta_value)) { s = Status::NotFound(); } else { - return Status::InvalidArgument( + return Status::InvalidArgument( "WRONGTYPE, key: " + key.ToString() + ", expect type: " + DataTypeStrings[static_cast(DataType::kSets)] + ", get type: " + DataTypeStrings[static_cast(GetMetaValueType(meta_value))]); @@ -940,63 +944,51 @@ rocksdb::Status Redis::SPop(const Slice& key, std::vector* members, } else { int32_t length = parsed_sets_meta_value.Count(); if (length < cnt) { - int32_t size = parsed_sets_meta_value.Count(); - int32_t cur_index = 0; + // 删除全部成员 uint64_t version = parsed_sets_meta_value.Version(); SetsMemberKey sets_member_key(key, version, Slice()); auto iter = db_->NewIterator(default_read_options_, handles_[kSetsDataCF]); for (iter->Seek(sets_member_key.EncodeSeekKey()); - iter->Valid() && cur_index < size; - iter->Next(), cur_index++) { - - batch.Delete(handles_[kSetsDataCF], iter->key()); - ParsedSetsMemberKey parsed_sets_member_key(iter->key()); - members->push_back(parsed_sets_member_key.member().ToString()); - + iter->Valid(); + iter->Next()) { + batch.Delete(handles_[kSetsDataCF], iter->key()); + ParsedSetsMemberKey parsed_sets_member_key(iter->key()); + members->push_back(parsed_sets_member_key.member().ToString()); } - - //parsed_sets_meta_value.ModifyCount(-cnt); - //batch.Put(handles_[kMetaCF], key, meta_value); batch.Delete(handles_[kMetaCF], base_meta_key.Encode()); delete iter; } else { + // 随机删除部分成员 engine.seed(time(nullptr)); - int32_t cur_index = 0; - int32_t size = parsed_sets_meta_value.Count(); - int32_t target_index = -1; uint64_t version = parsed_sets_meta_value.Version(); - std::unordered_set sets_index; - int32_t modnum = size; - - for (int64_t cur_round = 0; - cur_round < cnt; - cur_round++) { - do { - target_index = static_cast( engine() % modnum); - } while (sets_index.find(target_index) != sets_index.end()); - sets_index.insert(target_index); - } - SetsMemberKey sets_member_key(key, version, Slice()); - int64_t del_count = 0; - KeyStatisticsDurationGuard guard(this, DataType::kSets, key.ToString()); + int32_t size = parsed_sets_meta_value.Count(); + std::vector indices(size); + std::iota(indices.begin(), indices.end(), 0); // 生成 0 到 size-1 的索引 + std::shuffle(indices.begin(), indices.end(), engine); + indices.resize(cnt); // 保留 cnt 个随机索引 + std::sort(indices.begin(), indices.end()); // 排序以优化迭代器遍历 + auto iter = db_->NewIterator(default_read_options_, handles_[kSetsDataCF]); + int32_t cur_index = 0; + int64_t del_count = 0; + auto target_iter = indices.begin(); + for (iter->Seek(sets_member_key.EncodeSeekKey()); - iter->Valid() && cur_index < size; - iter->Next(), cur_index++) { - if (del_count == cnt) { - break; - } - if (sets_index.find(cur_index) != sets_index.end()) { - del_count++; + iter->Valid() && target_iter != indices.end(); + iter->Next(), ++cur_index) { + if (cur_index == *target_iter) { batch.Delete(handles_[kSetsDataCF], iter->key()); ParsedSetsMemberKey parsed_sets_member_key(iter->key()); members->push_back(parsed_sets_member_key.member().ToString()); + ++del_count; + ++target_iter; } } if (!parsed_sets_meta_value.CheckModifyCount(static_cast(-cnt))) { + delete iter; return Status::InvalidArgument("set size overflow"); } parsed_sets_meta_value.ModifyCount(static_cast(-cnt));