Skip to content

Commit

Permalink
add a new cc index: ivf sq cc (zilliztech#475)
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
  • Loading branch information
cqy123456 authored Mar 22, 2024
1 parent d6c110b commit a374a71
Show file tree
Hide file tree
Showing 10 changed files with 493 additions and 28 deletions.
3 changes: 3 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ constexpr const char* INDEX_FAISS_IVFFLAT_CC = "IVF_FLAT_CC";
constexpr const char* INDEX_FAISS_IVFPQ = "IVF_PQ";
constexpr const char* INDEX_FAISS_SCANN = "SCANN";
constexpr const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8";
constexpr const char* INDEX_FAISS_IVFSQ_CC = "IVF_SQ_CC";

constexpr const char* INDEX_FAISS_GPU_IDMAP = "GPU_FAISS_FLAT";
constexpr const char* INDEX_FAISS_GPU_IVFFLAT = "GPU_FAISS_IVF_FLAT";
Expand Down Expand Up @@ -94,6 +95,8 @@ constexpr const char* SSIZE = "ssize";
constexpr const char* REORDER_K = "reorder_k";
constexpr const char* WITH_RAW_DATA = "with_raw_data";
constexpr const char* ENSURE_TOPK_FULL = "ensure_topk_full";
constexpr const char* CODE_SIZE = "code_size";
constexpr const char* RAW_DATA_STORE_PREFIX = "raw_data_store_prefix";
// RAFT Params
constexpr const char* REFINE_RATIO = "refine_ratio";
constexpr const char* CACHE_DATASET_ON_DEVICE = "cache_dataset_on_device";
Expand Down
76 changes: 72 additions & 4 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "faiss/IndexIVFFlat.h"
#include "faiss/IndexIVFPQ.h"
#include "faiss/IndexIVFPQFastScan.h"
#include "faiss/IndexIVFScalarQuantizerCC.h"
#include "faiss/IndexScaNN.h"
#include "faiss/IndexScalarQuantizer.h"
#include "faiss/index_io.h"
Expand Down Expand Up @@ -56,7 +57,8 @@ class IvfIndexNode : public IndexNode {
std::is_same<IndexType, faiss::IndexIVFPQ>::value ||
std::is_same<IndexType, faiss::IndexIVFScalarQuantizer>::value ||
std::is_same<IndexType, faiss::IndexBinaryIVF>::value ||
std::is_same<IndexType, faiss::IndexScaNN>::value,
std::is_same<IndexType, faiss::IndexScaNN>::value ||
std::is_same<IndexType, faiss::IndexIVFScalarQuantizerCC>::value,
"not support");
static_assert(std::is_same_v<DataType, fp32> || std::is_same_v<DataType, bin1>,
"IvfIndexNode only support float/bianry");
Expand Down Expand Up @@ -98,6 +100,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<faiss::IndexBinaryIVF, IndexType>::value) {
return true;
}
if constexpr (std::is_same<faiss::IndexIVFScalarQuantizerCC, IndexType>::value) {
return index_->with_raw_data();
}
}
expected<DataSetPtr>
GetIndexMeta(const Config& cfg) const override {
Expand Down Expand Up @@ -131,6 +136,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<faiss::IndexBinaryIVF, IndexType>::value) {
return std::make_unique<IvfBinConfig>();
}
if constexpr (std::is_same<faiss::IndexIVFScalarQuantizerCC, IndexType>::value) {
return std::make_unique<IvfSqCcConfig>();
}
};
int64_t
Dim() const override {
Expand Down Expand Up @@ -183,6 +191,12 @@ class IvfIndexNode : public IndexNode {
auto code_size = index_->code_size;
return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size);
}
if constexpr (std::is_same<IndexType, faiss::IndexIVFScalarQuantizerCC>::value) {
auto nb = index_->invlists->compute_ntotal();
auto code_size = index_->code_size;
auto nlist = index_->nlist;
return (nb * code_size + nb * sizeof(int64_t) + 2 * code_size + nlist * sizeof(float));
}
};
int64_t
Count() const override {
Expand Down Expand Up @@ -211,6 +225,9 @@ class IvfIndexNode : public IndexNode {
if constexpr (std::is_same<IndexType, faiss::IndexBinaryIVF>::value) {
return knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
}
if constexpr (std::is_same<IndexType, faiss::IndexIVFScalarQuantizerCC>::value) {
return knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC;
}
};

private:
Expand Down Expand Up @@ -345,6 +362,22 @@ to_index_flat(std::unique_ptr<faiss::IndexFlat>&& index) {
return std::make_unique<faiss::IndexFlat>(std::move(*index));
}

expected<faiss::ScalarQuantizer::QuantizerType>
get_ivf_sq_quantizer_type(int code_size) {
switch (code_size) {
case 4:
return faiss::ScalarQuantizer::QuantizerType::QT_4bit;
case 6:
return faiss::ScalarQuantizer::QuantizerType::QT_6bit;
case 8:
return faiss::ScalarQuantizer::QuantizerType::QT_8bit;
case 16:
return faiss::ScalarQuantizer::QuantizerType::QT_fp16;
default:
return expected<faiss::ScalarQuantizer::QuantizerType>::Err(
Status::invalid_args, fmt::format("current code size {} not in (4, 6, 8, 16)", code_size));
}
}
} // namespace

template <typename DataType, typename IndexType>
Expand Down Expand Up @@ -535,6 +568,35 @@ IvfIndexNode<DataType, IndexType>::TrainInternal(const DataSet& dataset, const C
qzr.release();
index->own_fields = true;
}
if constexpr (std::is_same<faiss::IndexIVFScalarQuantizerCC, IndexType>::value) {
const IvfSqCcConfig& ivf_sq_cc_cfg = static_cast<const IvfSqCcConfig&>(cfg);
auto nlist = MatchNlist(rows, ivf_sq_cc_cfg.nlist.value());
auto ssize = ivf_sq_cc_cfg.ssize.value();

const bool use_elkan = ivf_sq_cc_cfg.use_elkan.value_or(true);

// create quantizer for the training
std::unique_ptr<faiss::IndexFlat> qzr =
std::make_unique<faiss::IndexFlatElkan>(dim, metric.value(), false, use_elkan);
// create index. Index does not own qzr
auto qzr_type = get_ivf_sq_quantizer_type(ivf_sq_cc_cfg.code_size.value());
if (!qzr_type.has_value()) {
LOG_KNOWHERE_ERROR_ << "fail to get ivf sq quantizer type, " << qzr_type.what();
return qzr_type.error();
}
index = std::make_unique<faiss::IndexIVFScalarQuantizerCC>(qzr.get(), dim, nlist, ssize, qzr_type.value(),
metric.value(), is_cosine, false,
ivf_sq_cc_cfg.raw_data_store_prefix);
// train
index->train(rows, (const float*)data);
// replace quantizer with a regular IndexFlat
qzr = to_index_flat(std::move(qzr));
index->quantizer = qzr.get();
// transfer ownership of qzr to index
qzr.release();
index->own_fields = true;
index->make_direct_map(true, faiss::DirectMap::ConcurrentArray);
}
index_ = std::move(index);

return Status::success;
Expand Down Expand Up @@ -624,7 +686,8 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSet& dataset, const Config&
distances[i + offset] = static_cast<float>(i_distances[i + offset]);
}
}
} else if constexpr (std::is_same<IndexType, faiss::IndexIVFFlatCC>::value) {
} else if constexpr (std::is_same<IndexType, faiss::IndexIVFFlatCC>::value ||
std::is_same<IndexType, faiss::IndexIVFScalarQuantizerCC>::value) {
auto cur_query = (const float*)data + index * dim;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
Expand Down Expand Up @@ -924,7 +987,8 @@ IvfIndexNode<DataType, IndexType>::GetVectorByIds(const DataSet& dataset) const
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<DataSetPtr>::Err(Status::faiss_inner_error, e.what());
}
} else if constexpr (std::is_same<IndexType, faiss::IndexScaNN>::value) {
} else if constexpr (std::is_same<IndexType, faiss::IndexScaNN>::value ||
std::is_same<IndexType, faiss::IndexIVFScalarQuantizerCC>::value) {
// we should never go here since we should call HasRawData() first
if (!index_->with_raw_data()) {
return expected<DataSetPtr>::Err(Status::not_implemented, "GetVectorByIds not implemented");
Expand Down Expand Up @@ -1083,7 +1147,8 @@ IvfIndexNode<DataType, IndexType>::Deserialize(const BinarySet& binset, const Co
} else {
index_.reset(static_cast<IndexType*>(faiss::read_index(&reader)));
}
if constexpr (!std::is_same_v<IndexType, faiss::IndexScaNN>) {
if constexpr (!std::is_same_v<IndexType, faiss::IndexScaNN> &&
!std::is_same_v<IndexType, faiss::IndexIVFScalarQuantizerCC>) {
const BaseConfig& base_cfg = static_cast<const BaseConfig&>(config);
if (HasRawData(base_cfg.metric_type.value())) {
index_->make_direct_map(true);
Expand Down Expand Up @@ -1136,6 +1201,7 @@ KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, fp32, faiss::IndexIVFPQ);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, fp32, faiss::IndexIVFPQ);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_SQ_CC, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizerCC);
// fp16
KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, fp16, faiss::IndexIVFFlat);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, fp16, faiss::IndexIVFFlat);
Expand All @@ -1146,6 +1212,7 @@ KNOWHERE_MOCK_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, fp16, faiss::IndexIVFPQ);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, fp16, faiss::IndexIVFPQ);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, fp16, faiss::IndexIVFScalarQuantizer);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, fp16, faiss::IndexIVFScalarQuantizer);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ_CC, IvfIndexNode, fp16, faiss::IndexIVFScalarQuantizerCC);
// bf16
KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, bf16, faiss::IndexIVFFlat);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, bf16, faiss::IndexIVFFlat);
Expand All @@ -1156,4 +1223,5 @@ KNOWHERE_MOCK_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, bf16, faiss::IndexIVFPQ);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, bf16, faiss::IndexIVFPQ);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, bf16, faiss::IndexIVFScalarQuantizer);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, bf16, faiss::IndexIVFScalarQuantizer);
KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ_CC, IvfIndexNode, bf16, faiss::IndexIVFScalarQuantizerCC);
} // namespace knowhere
35 changes: 35 additions & 0 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,41 @@ class IvfSqConfig : public IvfConfig {};

class IvfBinConfig : public IvfConfig {};

class IvfSqCcConfig : public IvfFlatCcConfig {
public:
// user can use code size to control ivf_sq_cc quntizer type
CFG_INT code_size;
// IVF_SQ_CC holds all vectors in file when raw_data_store_prefix has value;
// cc index is a just-in-time index, raw data is avaliable after training if raw_data_store_prefix has value.
// ivf sq cc index will not keep raw data after using binaryset to create a new ivf sq cc index.
CFG_STRING raw_data_store_prefix;
KNOHWERE_DECLARE_CONFIG(IvfSqCcConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(code_size)
.set_default(8)
.description("code size, range in [4, 6, 8 and 16]")
.for_train();
KNOWHERE_CONFIG_DECLARE_FIELD(raw_data_store_prefix)
.description("Raw data will be set in this prefix path")
.for_train()
.allow_empty_without_default();
};
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
auto code_size_v = code_size.value();
auto legal_code_size_list = std::vector<int>{4, 6, 8, 16};
if (std::find(legal_code_size_list.begin(), legal_code_size_list.end(), code_size_v) ==
legal_code_size_list.end()) {
*err_msg =
"compress a vector into (code_size * dim)/8 bytes, code size value should be in 4, 6, 8 and 16";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_value_in_json;
}
}
return Status::success;
}
};

} // namespace knowhere

#endif /* IVF_CONFIG_H */
47 changes: 29 additions & 18 deletions tests/ut/test_ivfflat_cc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#include <filesystem>
#include <future>

#include "catch2/catch_approx.hpp"
Expand Down Expand Up @@ -44,21 +45,27 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") {
return json;
};

auto ivfflat_gen = [&base_gen]() {
auto ivf_gen = [&base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 128;
json[knowhere::indexparam::NPROBE] = 16;
json[knowhere::indexparam::ENSURE_TOPK_FULL] = false;
return json;
};

auto ivfflatcc_gen = [&ivfflat_gen]() {
knowhere::Json json = ivfflat_gen();
auto ivf_cc_gen = [&ivf_gen]() {
knowhere::Json json = ivf_gen();
json[knowhere::meta::NUM_BUILD_THREAD] = 1;
json[knowhere::indexparam::SSIZE] = 48;
return json;
};

auto ivf_sq_8_cc_gen = [&ivf_cc_gen]() {
knowhere::Json json = ivf_cc_gen();
json[knowhere::indexparam::CODE_SIZE] = 8;
return json;
};

SECTION("Test Concurrent Invlists ") {
size_t nlist = 128;
size_t code_size = 512;
Expand Down Expand Up @@ -140,7 +147,8 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") {
SECTION("Test Add & Search & RangeSearch Serialized ") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivf_cc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_8_cc_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -183,29 +191,29 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") {

SECTION("Test Build & Search Correctness") {
using std::make_tuple;
auto [index_name, cc_index_name] = GENERATE_REF(table<std::string, std::string>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC),
}));
auto ivf = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(index_name, version);
auto ivf_cc = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(cc_index_name, version);

auto ivf_flat = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, version);
auto ivf_flat_cc = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, version);

knowhere::Json ivf_flat_json = knowhere::Json::parse(ivfflat_gen().dump());
knowhere::Json ivf_flat_cc_json = knowhere::Json::parse(ivfflatcc_gen().dump());
knowhere::Json ivf_json = knowhere::Json::parse(ivf_gen().dump());
knowhere::Json ivf_cc_json = knowhere::Json::parse(ivf_cc_gen().dump());

auto train_ds = GenDataSet(nb, dim, seed);
auto query_ds = GenDataSet(nq, dim, seed);

auto flat_res = ivf_flat.Build(*train_ds, ivf_flat_json);
auto flat_res = ivf.Build(*train_ds, ivf_json);
REQUIRE(flat_res == knowhere::Status::success);
auto cc_res = ivf_flat_cc.Build(*train_ds, ivf_flat_json);
auto cc_res = ivf_cc.Build(*train_ds, ivf_json);
REQUIRE(cc_res == knowhere::Status::success);

// test search
{
auto flat_results = ivf_flat.Search(*query_ds, ivf_flat_json, nullptr);
auto flat_results = ivf.Search(*query_ds, ivf_json, nullptr);
REQUIRE(flat_results.has_value());

auto cc_results = ivf_flat_cc.Search(*query_ds, ivf_flat_json, nullptr);
auto cc_results = ivf_cc.Search(*query_ds, ivf_json, nullptr);
REQUIRE(cc_results.has_value());

auto flat_ids = flat_results.value()->GetIds();
Expand All @@ -219,10 +227,10 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") {
}
// test range_search
{
auto flat_results = ivf_flat.RangeSearch(*query_ds, ivf_flat_json, nullptr);
auto flat_results = ivf.RangeSearch(*query_ds, ivf_json, nullptr);
REQUIRE(flat_results.has_value());

auto cc_results = ivf_flat_cc.RangeSearch(*query_ds, ivf_flat_json, nullptr);
auto cc_results = ivf_cc.RangeSearch(*query_ds, ivf_json, nullptr);
REQUIRE(cc_results.has_value());

auto flat_ids = flat_results.value()->GetIds();
Expand All @@ -242,12 +250,14 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") {
SECTION("Test Add & Search & RangeSearch ConCurrent") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivf_cc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_8_cc_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
json[knowhere::indexparam::RAW_DATA_STORE_PREFIX] = std::filesystem::current_path().string() + "/";
auto train_ds = GenDataSet(nb, dim, seed);
auto res = idx.Build(*train_ds, json);
REQUIRE(res == knowhere::Status::success);
Expand Down Expand Up @@ -290,6 +300,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") {
retrieve_task_list.push_back(std::async(
std::launch::async, [&idx, &retrieve_ids_set] { return idx.GetVectorByIds(*retrieve_ids_set); }));
}

for (auto& task : add_task_list) {
REQUIRE(task.get() == knowhere::Status::success);
}
Expand Down
Loading

0 comments on commit a374a71

Please sign in to comment.