From 6f1debf651ca14c31b9968df4f91964c30e978a8 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Wed, 4 Sep 2024 16:26:58 +0800 Subject: [PATCH] Remove CFG_LIST and CFG_BYTES, and optimize error msg in config load (#811) Signed-off-by: Cai Yudong --- include/knowhere/config.h | 324 ++++++++----------------- scripts/run_codecov.sh | 2 +- tests/ut/test_config.cc | 488 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 586 insertions(+), 228 deletions(-) diff --git a/include/knowhere/config.h b/include/knowhere/config.h index 3bb3088da..146938387 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -46,20 +46,12 @@ typedef nlohmann::json Json; #define CFG_FLOAT std::optional #endif -#ifndef CFG_LIST -#define CFG_LIST std::optional> -#endif - #ifndef CFG_BOOL #define CFG_BOOL std::optional #endif #ifndef CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE -#define CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE std::optional -#endif - -#ifndef CFG_BYTES -#define CFG_BYTES std::optional> +#define CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE std::optional #endif template @@ -147,29 +139,6 @@ struct Entry { bool allow_empty_without_default = false; }; -template <> -struct Entry { - explicit Entry(CFG_LIST* v) { - val = v; - default_val = std::nullopt; - type = 0x0; - desc = std::nullopt; - } - - Entry() { - val = nullptr; - default_val = std::nullopt; - type = 0x0; - desc = std::nullopt; - } - - CFG_LIST* val; - std::optional default_val; - uint32_t type; - std::optional desc; - bool allow_empty_without_default = false; -}; - template <> struct Entry { explicit Entry(CFG_BOOL* v) { @@ -216,29 +185,6 @@ struct Entry { bool allow_empty_without_default = false; }; -template <> -struct Entry { - explicit Entry(CFG_BYTES* v) { - val = v; - default_val = std::nullopt; - type = 0x0; - desc = std::nullopt; - } - - Entry() { - val = nullptr; - default_val = std::nullopt; - type = 0x0; - desc = std::nullopt; - } - - CFG_BYTES* val; - std::optional default_val; - uint32_t type; - std::optional desc; - bool allow_empty_without_default = false; -}; - template class EntryAccess { public: @@ -336,6 +282,12 @@ class Config { static Status Load(Config& cfg, const Json& json, PARAM_TYPE type, std::string* const err_msg = nullptr) { + auto show_err_msg = [&](std::string& msg) { + LOG_KNOWHERE_ERROR_ << msg; + if (err_msg) { + *err_msg = msg; + } + }; for (const auto& it : cfg.__DICT__) { const auto& var = it.second; @@ -343,47 +295,43 @@ class Config { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { + if (json.find(it.first) == json.end()) { + if (!ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { + continue; + } + std::string msg = "param '" + it.first + "' not exist in json"; + show_err_msg(msg); + return Status::invalid_param_in_json; + } else { + *ptr->val = ptr->default_val; continue; } - LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; - if (err_msg) { - *err_msg = std::string("invalid param ") + it.first; - } - return Status::invalid_param_in_json; - } - if (json.find(it.first) == json.end()) { - *ptr->val = ptr->default_val; - continue; } if (!json[it.first].is_number_integer()) { - LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be integer."; - if (err_msg) { - *err_msg = std::string("param ") + it.first + " should be integer"; - } + std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + + ") should be integer"; + show_err_msg(msg); return Status::type_conflict_in_json; } if (ptr->range.has_value()) { if (json[it.first].get() > std::numeric_limits::max()) { - LOG_KNOWHERE_ERROR_ << "Arithmetic overflow: param [" << it.first << "] should be at most " - << std::numeric_limits::max(); - if (err_msg) { - *err_msg = std::string("param ") + it.first + " should be at most 2147483647"; - } + std::string msg = "Arithmetic overflow: param '" + it.first + "' (" + + to_string(json[it.first]) + ") should not bigger than " + + std::to_string(std::numeric_limits::max()); + show_err_msg(msg); return Status::arithmetic_overflow; } CFG_INT::value_type v = json[it.first]; - if (ptr->range.value().first <= v && v <= ptr->range.value().second) { + auto range_val = ptr->range.value(); + if (range_val.first <= v && v <= range_val.second) { *ptr->val = v; } else { - LOG_KNOWHERE_ERROR_ << "Out of range in json: param [" << it.first << "] should be in [" - << ptr->range.value().first << ", " << ptr->range.value().second << "]."; - if (err_msg) { - *err_msg = std::string("param ") + it.first + " out of range " + "[ " + - std::to_string(ptr->range.value().first) + "," + - std::to_string(ptr->range.value().second) + " ]"; - } + std::string msg = "Out of range in json: param '" + it.first + "' (" + + to_string(json[it.first]) + ") should be in range [" + + std::to_string(range_val.first) + ", " + std::to_string(range_val.second) + + "]"; + show_err_msg(msg); return Status::out_of_range_in_json; } } else { @@ -395,51 +343,43 @@ class Config { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { + if (json.find(it.first) == json.end()) { + if (!ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { + continue; + } + std::string msg = "param '" + it.first + "' not exist in json"; + show_err_msg(msg); + return Status::invalid_param_in_json; + } else { + *ptr->val = ptr->default_val; continue; } - LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; - if (err_msg) { - *err_msg = std::string("invalid param ") + it.first; - } - - return Status::invalid_param_in_json; - } - if (json.find(it.first) == json.end()) { - *ptr->val = ptr->default_val; - continue; } if (!json[it.first].is_number()) { - LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be a number."; - if (err_msg) { - *err_msg = std::string("param ") + it.first + " should be a number"; - } - + std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + + ") should be a number"; + show_err_msg(msg); return Status::type_conflict_in_json; } if (ptr->range.has_value()) { if (json[it.first].get() > std::numeric_limits::max()) { - LOG_KNOWHERE_ERROR_ << "Arithmetic overflow: param [" << it.first << "] should be at most " - << std::numeric_limits::max(); - if (err_msg) { - *err_msg = std::string("param ") + it.first + " should be at most 3.402823e+38"; - } - + std::string msg = "Arithmetic overflow: param '" + it.first + "' (" + + to_string(json[it.first]) + ") should not bigger than " + + std::to_string(std::numeric_limits::max()); + show_err_msg(msg); return Status::arithmetic_overflow; } CFG_FLOAT::value_type v = json[it.first]; - if (ptr->range.value().first <= v && v <= ptr->range.value().second) { + auto range_val = ptr->range.value(); + if (range_val.first <= v && v <= range_val.second) { *ptr->val = v; } else { - LOG_KNOWHERE_ERROR_ << "Out of range in json: param [" << it.first << "] should be in [" - << ptr->range.value().first << ", " << ptr->range.value().second << "]."; - if (err_msg) { - *err_msg = std::string("param ") + it.first + " out of range " + "[ " + - std::to_string(ptr->range.value().first) + "," + - std::to_string(ptr->range.value().second) + " ]"; - } - + std::string msg = "Out of range in json: param '" + it.first + "' (" + + to_string(json[it.first]) + ") should be in range [" + + std::to_string(range_val.first) + ", " + std::to_string(range_val.second) + + "]"; + show_err_msg(msg); return Status::out_of_range_in_json; } } else { @@ -451,88 +391,49 @@ class Config { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { + if (json.find(it.first) == json.end()) { + if (!ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { + continue; + } + std::string msg = "param [" + it.first + "] not exist in json"; + show_err_msg(msg); + return Status::invalid_param_in_json; + } else { + *ptr->val = ptr->default_val; continue; } - LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; - if (err_msg) { - *err_msg = std::string("invalid param ") + it.first; - } - return Status::invalid_param_in_json; - } - if (json.find(it.first) == json.end()) { - *ptr->val = ptr->default_val; - continue; } if (!json[it.first].is_string()) { - LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be a string."; - if (err_msg) { - *err_msg = std::string("param ") + it.first + " should be a string"; - } + std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + + ") should be a string"; + show_err_msg(msg); return Status::type_conflict_in_json; } *ptr->val = json[it.first]; } - if (const Entry* ptr = std::get_if>(&var)) { - if (!(type & ptr->type)) { - continue; - } - if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { - continue; - } - LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; - if (err_msg) { - *err_msg = std::string("invalid param ") + it.first; - } - - return Status::invalid_param_in_json; - } - if (json.find(it.first) == json.end()) { - *ptr->val = ptr->default_val; - continue; - } - if (!json[it.first].is_array()) { - LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be an array."; - if (err_msg) { - *err_msg = std::string("param ") + it.first + " should be an array"; - } - - return Status::type_conflict_in_json; - } - *ptr->val = CFG_LIST::value_type(); - for (auto&& i : json[it.first]) { - ptr->val->value().push_back(i); - } - } - if (const Entry* ptr = std::get_if>(&var)) { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { + if (json.find(it.first) == json.end()) { + if (!ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { + continue; + } + std::string msg = "param '" + it.first + "' not exist in json"; + show_err_msg(msg); + return Status::invalid_param_in_json; + } else { + *ptr->val = ptr->default_val; continue; } - LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; - if (err_msg) { - *err_msg = std::string("invalid param ") + it.first; - } - - return Status::invalid_param_in_json; - } - if (json.find(it.first) == json.end()) { - *ptr->val = ptr->default_val; - continue; } if (!json[it.first].is_boolean()) { - LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be a boolean."; - if (err_msg) { - *err_msg = std::string("param ") + it.first + " should be a boolean"; - } - + std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) + + ") should be a boolean"; + show_err_msg(msg); return Status::type_conflict_in_json; } *ptr->val = json[it.first]; @@ -543,52 +444,20 @@ class Config { if (!(type & ptr->type)) { continue; } - if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { - continue; - } - LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; - if (err_msg) { - *err_msg = std::string("invalid param ") + it.first; - } - return Status::invalid_param_in_json; - } if (json.find(it.first) == json.end()) { - *ptr->val = ptr->default_val; - continue; - } - *ptr->val = json[it.first]; - } - - if (const Entry* ptr = std::get_if>(&var)) { - if (!(type & ptr->type)) { - continue; - } - if (json.find(it.first) == json.end() && !ptr->default_val.has_value()) { - if (ptr->allow_empty_without_default) { + if (!ptr->default_val.has_value()) { + if (ptr->allow_empty_without_default) { + continue; + } + std::string msg = "param '" + it.first + "' not exist in json"; + show_err_msg(msg); + return Status::invalid_param_in_json; + } else { + *ptr->val = ptr->default_val; continue; } - LOG_KNOWHERE_ERROR_ << "Invalid param [" << it.first << "] in json."; - if (err_msg) { - *err_msg = std::string("invalid param ") + it.first; - } - return Status::invalid_param_in_json; - } - if (json.find(it.first) == json.end()) { - *ptr->val = ptr->default_val; - continue; - } - if (!json[it.first].is_array()) { - LOG_KNOWHERE_ERROR_ << "Type conflict in json: param [" << it.first << "] should be an array."; - if (err_msg) { - *err_msg = std::string("param ") + it.first + " should be an array"; - } - return Status::type_conflict_in_json; - } - *ptr->val = CFG_BYTES::value_type(); - for (auto&& i : json[it.first]) { - ptr->val->value().push_back(i); } + *ptr->val = json[it.first]; } } @@ -602,8 +471,8 @@ class Config { virtual ~Config() { } - using VarEntry = std::variant, Entry, Entry, Entry, Entry, - Entry, Entry>; + using VarEntry = std::variant, Entry, Entry, Entry, + Entry>; std::unordered_map __DICT__; protected: @@ -615,9 +484,10 @@ class Config { #define KNOHWERE_DECLARE_CONFIG(CONFIG) CONFIG() -#define KNOWHERE_CONFIG_DECLARE_FIELD(PARAM) \ - __DICT__[#PARAM] = knowhere::Config::VarEntry(std::in_place_type>, &PARAM); \ - EntryAccess PARAM##_access(std::get_if>(&__DICT__[#PARAM])); \ +#define KNOWHERE_CONFIG_DECLARE_FIELD(PARAM) \ + __DICT__[#PARAM] = knowhere::Config::VarEntry(std::in_place_type>, &PARAM); \ + knowhere::EntryAccess PARAM##_access( \ + std::get_if>(&__DICT__[#PARAM])); \ PARAM##_access const float defaultRangeFilter = 1.0f / 0.0; diff --git a/scripts/run_codecov.sh b/scripts/run_codecov.sh index 8eea419a5..3dce9dd0e 100755 --- a/scripts/run_codecov.sh +++ b/scripts/run_codecov.sh @@ -60,7 +60,7 @@ fi # run unittest for test in `ls ${KNOWHERE_UNITTEST_DIR}/*test*`; do - echo "Running unittest: ${KNOWHERE_UNITTEST_DIR}/$test" + echo "Running unittest: $test" # run unittest ${test} if [ $? -ne 0 ]; then diff --git a/tests/ut/test_config.cc b/tests/ut/test_config.cc index 29ed07dc6..0fd9e6d54 100644 --- a/tests/ut/test_config.cc +++ b/tests/ut/test_config.cc @@ -409,3 +409,491 @@ TEST_CASE("Test config json parse", "[config]") { #endif } + +TEST_CASE("Test config load", "[BOOL]") { + knowhere::Status s; + std::string err_msg; + + SECTION("check bool") { + class TestConfig : public knowhere::Config { + public: + CFG_BOOL bool_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(bool_val).description("bool field for test").for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::invalid_param_in_json); + + json = knowhere::Json::parse(R"({ + "bool_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "bool_val": true + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.bool_val.value() == true); + } + + SECTION("check bool allow empty") { + class TestConfig : public knowhere::Config { + public: + CFG_BOOL bool_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(bool_val) + .description("bool field for test") + .allow_empty_without_default() + .for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + + json = knowhere::Json::parse(R"({ + "bool_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "bool_val": true + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.bool_val.value() == true); + } + + SECTION("check bool with default") { + class TestConfig : public knowhere::Config { + public: + CFG_BOOL bool_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(bool_val) + .description("bool field for test") + .set_default(true) + .for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.bool_val.value() == true); + + json = knowhere::Json::parse(R"({ + "bool_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "bool_val": false + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.bool_val.value() == false); + } +} + +TEST_CASE("Test config load", "[INT]") { + knowhere::Status s; + std::string err_msg; + + SECTION("check int") { + class TestConfig : public knowhere::Config { + public: + CFG_INT int_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(int_val).description("int field for test").for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::invalid_param_in_json); + + json = knowhere::Json::parse(R"({ + "int_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "int_val": 10 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.int_val.value() == 10); + } + + SECTION("check int allow empty") { + class TestConfig : public knowhere::Config { + public: + CFG_INT int_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(int_val) + .description("int field for test") + .allow_empty_without_default() + .for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + + json = knowhere::Json::parse(R"({ + "int_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "int_val": 10 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.int_val.value() == 10); + } + + SECTION("check int in range") { + class TestConfig : public knowhere::Config { + public: + CFG_INT int_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(int_val) + .description("int field for test") + .set_default(2) + .for_train_and_search() + .set_range(1, 100); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.int_val.value() == 2); + + json = knowhere::Json::parse(R"({ + "int_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "int_val": 4294967296 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::arithmetic_overflow); + + json = knowhere::Json::parse(R"({ + "int_val": 123 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::out_of_range_in_json); + + json = knowhere::Json::parse(R"({ + "int_val": 10 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.int_val.value() == 10); + } +} + +TEST_CASE("Test config load", "[FLOAT]") { + knowhere::Status s; + std::string err_msg; + + SECTION("check float") { + class TestConfig : public knowhere::Config { + public: + CFG_FLOAT float_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(float_val).description("float field for test").for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::invalid_param_in_json); + + json = knowhere::Json::parse(R"({ + "float_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "float_val": 10 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.float_val.value() == 10.0); + } + + SECTION("check float allow empty") { + class TestConfig : public knowhere::Config { + public: + CFG_FLOAT float_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(float_val) + .description("float field for test") + .allow_empty_without_default() + .for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + + json = knowhere::Json::parse(R"({ + "float_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "float_val": 10 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.float_val.value() == 10.0); + } + + SECTION("check float in range") { + class TestConfig : public knowhere::Config { + public: + CFG_FLOAT float_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(float_val) + .description("float field for test") + .set_default(2.0) + .for_train_and_search() + .set_range(1.0, 100.0); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.float_val.value() == 2.0); + + json = knowhere::Json::parse(R"({ + "float_val": "a" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "float_val": 1e+40 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::arithmetic_overflow); + + json = knowhere::Json::parse(R"({ + "float_val": 123 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::out_of_range_in_json); + + json = knowhere::Json::parse(R"({ + "float_val": 10 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.float_val.value() == 10); + } +} + +TEST_CASE("Test config load", "[STRING]") { + knowhere::Status s; + std::string err_msg; + + SECTION("check string") { + class TestConfig : public knowhere::Config { + public: + CFG_STRING str_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(str_val).description("string field for test").for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::invalid_param_in_json); + + json = knowhere::Json::parse(R"({ + "str_val": 1 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "str_val": "abc" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.str_val.value() == "abc"); + } + + SECTION("check string allow empty") { + class TestConfig : public knowhere::Config { + public: + CFG_STRING str_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(str_val) + .description("string field for test") + .allow_empty_without_default() + .for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + + json = knowhere::Json::parse(R"({ + "str_val": 1 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "str_val": "abc" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.str_val.value() == "abc"); + } + + SECTION("check string with default") { + class TestConfig : public knowhere::Config { + public: + CFG_STRING str_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(str_val) + .description("string field for test") + .set_default("knowhere") + .for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.str_val.value() == "knowhere"); + + json = knowhere::Json::parse(R"({ + "str_val": 1 + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::type_conflict_in_json); + + json = knowhere::Json::parse(R"({ + "str_val": "abc" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + CHECK(test_cfg.str_val.value() == "abc"); + } +} + +TEST_CASE("Test config load", "[MATERIALIZED_VIEW_SEARCH_INFO]") { + knowhere::Status s; + std::string err_msg; + + SECTION("check string") { + class TestConfig : public knowhere::Config { + public: + CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE info_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(info_val).description("info field for test").for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::invalid_param_in_json); + + json = knowhere::Json::parse(R"({ + "info_val": "" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + } + + SECTION("check string allow empty") { + class TestConfig : public knowhere::Config { + public: + CFG_MATERIALIZED_VIEW_SEARCH_INFO_TYPE info_val; + KNOHWERE_DECLARE_CONFIG(TestConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(info_val) + .description("info field for test") + .allow_empty_without_default() + .for_train_and_search(); + } + }; + + TestConfig test_cfg; + knowhere::Json json; + + json = knowhere::Json::parse(R"({})"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + + json = knowhere::Json::parse(R"({ + "info_val": "" + })"); + s = knowhere::Config::Load(test_cfg, json, knowhere::TRAIN, &err_msg); + CHECK(s == knowhere::Status::success); + } +}