Skip to content

Commit

Permalink
Merge branch 'zilliztech:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
alexanderguzhva authored Nov 22, 2023
2 parents e17f790 + 2191316 commit ee05caa
Show file tree
Hide file tree
Showing 38 changed files with 436 additions and 1,404 deletions.
32 changes: 11 additions & 21 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ class Config {
}
}

return Status::success;
if (!err_msg) {
std::string tem_msg;
return cfg.CheckAndAdjust(type, &tem_msg);
}
return cfg.CheckAndAdjust(type, err_msg);
}

virtual ~Config() {
Expand All @@ -485,6 +489,12 @@ class Config {
using VarEntry =
std::variant<Entry<CFG_STRING>, Entry<CFG_FLOAT>, Entry<CFG_INT>, Entry<CFG_LIST>, Entry<CFG_BOOL>>;
std::unordered_map<std::string, VarEntry> __DICT__;

protected:
inline virtual Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* const err_msg) {
return Status::success;
}
};

#define KNOHWERE_DECLARE_CONFIG(CONFIG) CONFIG()
Expand Down Expand Up @@ -557,26 +567,6 @@ class BaseConfig : public Config {
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search();
}

virtual Status
CheckAndAdjustForSearch(std::string* err_msg) {
return Status::success;
}

virtual Status
CheckAndAdjustForRangeSearch(std::string* err_msg) {
return Status::success;
}

virtual Status
CheckAndAdjustForIterator() {
return Status::success;
}

virtual inline Status
CheckAndAdjustForBuild() {
return Status::success;
}
};
} // namespace knowhere

Expand Down
44 changes: 43 additions & 1 deletion python/knowhere/knowhere.i
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ import_array();
%include <std_pair.i>
%include <std_map.i>
%include <std_shared_ptr.i>
%include <std_vector.i>
%include <exception.i>
%shared_ptr(knowhere::DataSet)
%shared_ptr(knowhere::BinarySet)
%template(DataSetPtr) std::shared_ptr<knowhere::DataSet>;
%template(BinarySetPtr) std::shared_ptr<knowhere::BinarySet>;
%template(int64_float_pair) std::pair<long long int, float>;
%include <knowhere/expected.h>
%include <knowhere/dataset.h>
%include <knowhere/binaryset.h>
Expand Down Expand Up @@ -104,7 +106,7 @@ del Enum
%inline %{

class GILReleaser {
public:
public:
GILReleaser() : save(PyEval_SaveThread()) {
}
~GILReleaser() {
Expand All @@ -113,6 +115,28 @@ public:
PyThreadState* save;
};

class AnnIteratorWrap {
public:
AnnIteratorWrap(std::shared_ptr<IndexNode::iterator> it = nullptr) : it_(it) {
if (it_ == nullptr) {
throw std::runtime_error("ann iterator must not be nullptr.");
}
}
~AnnIteratorWrap() {
}

bool HasNext() {
return it_->HasNext();
}

std::pair<int64_t, float> Next() {
return it_->Next();
}

private:
std::shared_ptr<IndexNode::iterator> it_;
};

class IndexWrap {
public:
IndexWrap(const std::string& name, const int32_t& version) {
Expand Down Expand Up @@ -157,6 +181,22 @@ class IndexWrap {
}
}

std::vector<AnnIteratorWrap>
GetAnnIterator(knowhere::DataSetPtr dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status) {
GILReleaser rel;
auto res = idx.AnnIterator(*dataset, knowhere::Json::parse(json), bitset);
std::vector<AnnIteratorWrap> result;
if (!res.has_value()) {
status = res.error();
return result;
}
status = knowhere::Status::success;
for (auto it : res.value()) {
result.emplace_back(it);
}
return result;
}

knowhere::DataSetPtr
RangeSearch(knowhere::DataSetPtr dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status){
GILReleaser rel;
Expand Down Expand Up @@ -468,3 +508,5 @@ SetSimdType(const std::string type) {
}

%}

%template(AnnIteratorWrapVector) std::vector<AnnIteratorWrap>;
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_readme():
description=(
"A library for efficient similarity search and clustering of vectors."
),
url="https://github.com/milvus-io/knowhere",
url="https://github.com/zilliztech/knowhere",
author="Milvus Team",
author_email="milvus-team@zilliz.com",
license='Apache License 2.0',
Expand Down
1 change: 1 addition & 0 deletions src/common/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ static const std::unordered_set<std::string> ext_legal_json_keys = {"metric_type
"M", // HNSW param
"efConstruction", // HNSW param
"ef", // HNSW param
"seed_ef", // HNSW iterator param
"level",
"index_type",
"index_mode",
Expand Down
13 changes: 0 additions & 13 deletions src/common/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ inline Status
Index<T>::Build(const DataSet& dataset, const Json& json) {
auto cfg = this->node->CreateConfig();
RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Build"));
RETURN_IF_ERROR(cfg->CheckAndAdjustForBuild());

#ifdef NOT_COMPILE_FOR_SWIG
TimeRecorder rc("Build index", 2);
Expand Down Expand Up @@ -77,10 +76,6 @@ Index<T>::Search(const DataSet& dataset, const Json& json, const BitsetView& bit
if (load_status != Status::success) {
return expected<DataSetPtr>::Err(load_status, msg);
}
const Status search_status = cfg->CheckAndAdjustForSearch(&msg);
if (search_status != Status::success) {
return expected<DataSetPtr>::Err(search_status, msg);
}

#ifdef NOT_COMPILE_FOR_SWIG
TimeRecorder rc("Search");
Expand All @@ -105,10 +100,6 @@ Index<T>::AnnIterator(const DataSet& dataset, const Json& json, const BitsetView
if (status != Status::success) {
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(status, msg);
}
status = cfg->CheckAndAdjustForIterator();
if (status != Status::success) {
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(status, "invalid params for iterator");
}

#ifdef NOT_COMPILE_FOR_SWIG
// note that this time includes only the initial search phase of iterator.
Expand All @@ -133,10 +124,6 @@ Index<T>::RangeSearch(const DataSet& dataset, const Json& json, const BitsetView
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, std::move(msg));
}
status = cfg->CheckAndAdjustForRangeSearch(&msg);
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, std::move(msg));
}

#ifdef NOT_COMPILE_FOR_SWIG
TimeRecorder rc("Range Search");
Expand Down
9 changes: 3 additions & 6 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,23 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#include "knowhere/feder/DiskANN.h"

#include <omp.h>

#include <cstdint>

#include "common/range_util.h"
#include "diskann/aux_utils.h"
#include "diskann/linux_aligned_file_reader.h"
#include "diskann/pq_flash_index.h"
#include "fmt/core.h"
#include "index/diskann/diskann_config.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/dataset.h"
#include "knowhere/expected.h"
#ifndef _WINDOWS
#include "diskann/linux_aligned_file_reader.h"
#else
#include "diskann/windows_aligned_file_reader.h"
#endif
#include "knowhere/factory.h"
#include "knowhere/feder/DiskANN.h"
#include "knowhere/file_manager.h"
#include "knowhere/log.h"
#include "knowhere/utils.h"
Expand Down
40 changes: 22 additions & 18 deletions src/index/diskann/diskann_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,28 @@ class DiskANNConfig : public BaseConfig {
.for_search();
}

inline Status
CheckAndAdjustForSearch(std::string* err_msg) override {
if (!search_list_size.has_value()) {
search_list_size = std::max(k.value(), kSearchListSizeMinValue);
} else if (k.value() > search_list_size.value()) {
*err_msg = "search_list_size(" + std::to_string(search_list_size.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}

return Status::success;
}

inline Status
CheckAndAdjustForBuild() override {
if (!search_list_size.has_value()) {
search_list_size = kDefaultSearchListSizeForBuild;
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
switch (param_type) {
case PARAM_TYPE::TRAIN: {
if (!search_list_size.has_value()) {
search_list_size = kDefaultSearchListSizeForBuild;
}
break;
}
case PARAM_TYPE::SEARCH: {
if (!search_list_size.has_value()) {
search_list_size = std::max(k.value(), kSearchListSizeMinValue);
} else if (k.value() > search_list_size.value()) {
*err_msg = "search_list_size(" + std::to_string(search_list_size.value()) +
") should be larger than k(" + std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}
break;
}
default:
break;
}
return Status::success;
}
Expand Down
42 changes: 23 additions & 19 deletions src/index/hnsw/hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,29 @@ class HnswConfig : public BaseConfig {
.for_feder();
}

inline Status
CheckAndAdjustForSearch(std::string* err_msg) override {
if (!ef.has_value()) {
ef = std::max(k.value(), kEfMinValue);
} else if (k.value() > ef.value()) {
*err_msg =
"ef(" + std::to_string(ef.value()) + ") should be larger than k(" + std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}

return Status::success;
}

inline Status
CheckAndAdjustForRangeSearch(std::string* err_msg) override {
if (!ef.has_value()) {
// if ef is not set by user, set it to default
ef = kDefaultRangeSearchEf;
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
switch (param_type) {
case PARAM_TYPE::SEARCH: {
if (!ef.has_value()) {
ef = std::max(k.value(), kEfMinValue);
} else if (k.value() > ef.value()) {
*err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}
break;
}
case PARAM_TYPE::RANGE_SEARCH: {
if (!ef.has_value()) {
// if ef is not set by user, set it to default
ef = kDefaultRangeSearchEf;
}
break;
}
default:
break;
}
return Status::success;
}
Expand Down
75 changes: 40 additions & 35 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,41 +79,46 @@ class ScannConfig : public IvfFlatConfig {
.for_train();
}

inline Status
CheckAndAdjustForSearch(std::string* err_msg) override {
if (!faiss::support_pq_fast_scan) {
*err_msg = "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch.";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_instruction_set;
}
if (!reorder_k.has_value()) {
reorder_k = k.value();
} else if (reorder_k.value() < k.value()) {
*err_msg = "reorder_k(" + std::to_string(reorder_k.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}

return Status::success;
}

inline Status
CheckAndAdjustForRangeSearch(std::string* err_msg) override {
if (!faiss::support_pq_fast_scan) {
*err_msg = "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch.";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_instruction_set;
}
return Status::success;
}

inline Status
CheckAndAdjustForBuild() override {
if (!faiss::support_pq_fast_scan) {
LOG_KNOWHERE_ERROR_
<< "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch.";
return Status::invalid_instruction_set;
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
switch (param_type) {
case PARAM_TYPE::TRAIN: {
if (!faiss::support_pq_fast_scan) {
LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is "
"needed for x86 arch.";
return Status::invalid_instruction_set;
}
break;
}
case PARAM_TYPE::SEARCH: {
if (!faiss::support_pq_fast_scan) {
LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is "
"needed for x86 arch.";
return Status::invalid_instruction_set;
}
if (!reorder_k.has_value()) {
reorder_k = k.value();
} else if (reorder_k.value() < k.value()) {
if (!err_msg) {
err_msg = new std::string();
}
*err_msg = "reorder_k(" + std::to_string(reorder_k.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}
break;
}
case PARAM_TYPE::RANGE_SEARCH: {
if (!faiss::support_pq_fast_scan) {
LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is "
"needed for x86 arch.";
return Status::invalid_instruction_set;
}
break;
}
default:
break;
}
return Status::success;
}
Expand Down
Loading

0 comments on commit ee05caa

Please sign in to comment.