Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reproducible parameter alias resolution for wrappers (fixes #5304) #5338

Merged
merged 8 commits into from
Jul 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ lgb.check.eval <- function(params, eval) {
# ways, the first item in this list is used:
#
# 1. the main (non-alias) parameter found in `params`
# 2. the first alias of that parameter found in `params`
# 2. the alias with the highest priority found in `params`
# 3. the keyword argument passed in
#
# For example, "num_iterations" can also be provided to lgb.train()
# via keyword "nrounds". lgb.train() will choose one value for this parameter
# based on the first match in this list:
#
# 1. params[["num_iterations]]
# 2. the first alias of "num_iterations" found in params
# 2. the highest priority alias of "num_iterations" found in params
# 3. the nrounds keyword argument
#
# If multiple aliases are found in `params` for the same parameter, they are
Expand All @@ -197,7 +197,7 @@ lgb.check.eval <- function(params, eval) {
lgb.check.wrapper_param <- function(main_param_name, params, alternative_kwarg_value) {

aliases <- .PARAMETER_ALIASES()[[main_param_name]]
aliases_provided <- names(params)[names(params) %in% aliases]
aliases_provided <- aliases[aliases %in% names(params)]
aliases_provided <- aliases_provided[aliases_provided != main_param_name]

# prefer the main parameter
Expand Down
1 change: 1 addition & 0 deletions R-package/tests/testthat/test_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ test_that(".PARAMETER_ALIASES() returns a named list of character vectors, where
expect_true(all(sapply(param_aliases, is.character)))
expect_true(length(unique(names(param_aliases))) == length(param_aliases))
expect_equal(sort(param_aliases[["task"]]), c("task", "task_type"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, looking at this again....why does this assertion need to be removed?

I understand why this PR adds the other one below that doesn't use sort(), but could we still keep the rest of this test intact? That would give me a bit more confidence that this change is working as expected, and having tests here referencing two different parameters increases the likelihood of catching bugs of the form ".PARAMETER_ALIASES() returns unexpected output for some but not all parameters".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored in e45fc48

expect_equal(param_aliases[["bagging_fraction"]], c("bagging_fraction", "bagging", "sub_row", "subsample"))
})

test_that(".PARAMETER_ALIASES() uses the internal session cache", {
Expand Down
6 changes: 3 additions & 3 deletions R-package/tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ test_that("lgb.check.wrapper_param() prefers alias to keyword arg", {
expect_equal(params[["num_iterations"]], num_tree)
expect_identical(params, list(num_iterations = num_tree))

# switching the order should switch which one is chosen
# switching the order shouldn't switch which one is chosen
params2 <- lgb.check.wrapper_param(
main_param_name = "num_iterations"
, params = list(
Expand All @@ -132,6 +132,6 @@ test_that("lgb.check.wrapper_param() prefers alias to keyword arg", {
)
, alternative_kwarg_value = kwarg_val
)
expect_equal(params2[["num_iterations"]], n_estimators)
expect_identical(params2, list(num_iterations = n_estimators))
expect_equal(params2[["num_iterations"]], num_tree)
expect_identical(params2, list(num_iterations = num_tree))
})
26 changes: 13 additions & 13 deletions helpers/parameter_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,20 +359,20 @@ def gen_parameter_code(
str_to_write += " return str_buf.str();\n"
str_to_write += "}\n\n"

str_to_write += "const std::string Config::DumpAliases() {\n"
str_to_write += " std::stringstream str_buf;\n"
str_to_write += ' str_buf << "{";\n'
for idx, name in enumerate(names):
if idx > 0:
str_to_write += ', ";\n'
aliases = '\\", \\"'.join([alias for alias in names_with_aliases[name]])
aliases = f'[\\"{aliases}\\"]' if aliases else '[]'
str_to_write += f' str_buf << "\\"{name}\\": {aliases}'
str_to_write += '";\n'
str_to_write += ' str_buf << "}";\n'
str_to_write += " return str_buf.str();\n"
str_to_write += "}\n\n"
str_to_write += """const std::unordered_map<std::string, std::vector<std::string>>& Config::parameter2aliases() {
static std::unordered_map<std::string, std::vector<std::string>> map({"""
for name in names:
str_to_write += '\n {"' + name + '", '
if names_with_aliases[name]:
str_to_write += '{"' + '", "'.join(names_with_aliases[name]) + '"}},'
else:
str_to_write += '{}},'
str_to_write += """
});
return map;
}

"""
str_to_write += "} // namespace LightGBM\n"
with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write)
Expand Down
17 changes: 14 additions & 3 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ struct Config {
const std::unordered_map<std::string, std::string>& params,
const std::string& name, bool* out);

/*!
* \brief Sort aliases by length and then alphabetically
* \param x Alias 1
* \param y Alias 2
* \return true if x has higher priority than y
*/
inline static bool SortAlias(const std::string& x, const std::string& y);

static void KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv);
static std::unordered_map<std::string, std::string> Str2Map(const char* parameters);

Expand Down Expand Up @@ -1063,6 +1071,7 @@ struct Config {
bool is_data_based_parallel = false;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params);
static const std::unordered_map<std::string, std::string>& alias_table();
static const std::unordered_map<std::string, std::vector<std::string>>& parameter2aliases();
static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector;
Expand Down Expand Up @@ -1131,6 +1140,10 @@ inline bool Config::GetBool(
return false;
}

inline bool Config::SortAlias(const std::string& x, const std::string& y) {
return x.size() < y.size() || (x.size() == y.size() && x < y);
}

struct ParameterAlias {
static void KeyAliasTransform(std::unordered_map<std::string, std::string>* params) {
std::unordered_map<std::string, std::string> tmp_map;
Expand All @@ -1139,9 +1152,7 @@ struct ParameterAlias {
if (alias != Config::alias_table().end()) { // found alias
auto alias_set = tmp_map.find(alias->second);
if (alias_set != tmp_map.end()) { // alias already set
// set priority by length & alphabetically to ensure reproducible behavior
if (alias_set->second.size() < pair.first.size() ||
(alias_set->second.size() == pair.first.size() && alias_set->second < pair.first)) {
if (Config::SortAlias(alias_set->second, pair.first)) {
Log::Warning("%s is set with %s=%s, %s=%s will be ignored. Current value: %s=%s",
alias->second.c_str(), alias_set->second.c_str(), params->at(alias_set->second).c_str(),
pair.first.c_str(), pair.second.c_str(), alias->second.c_str(), params->at(alias_set->second).c_str());
Expand Down
17 changes: 12 additions & 5 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ class _ConfigAliases:
aliases = None

@staticmethod
def _get_all_param_aliases() -> Dict[str, Set[str]]:
def _get_all_param_aliases() -> Dict[str, List[str]]:
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
Expand All @@ -361,7 +361,7 @@ def _get_all_param_aliases() -> Dict[str, Set[str]]:
ptr_string_buffer))
aliases = json.loads(
string_buffer.value.decode('utf-8'),
object_hook=lambda obj: {k: set(v) | {k} for k, v in obj.items()}
object_hook=lambda obj: {k: [k] + v for k, v in obj.items()}
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
)
return aliases

Expand All @@ -371,9 +371,15 @@ def get(cls, *args) -> Set[str]:
cls.aliases = cls._get_all_param_aliases()
ret = set()
for i in args:
ret |= cls.aliases.get(i, {i})
ret.update(cls.get_sorted(i))
return ret

@classmethod
def get_sorted(cls, name: str) -> List[str]:
if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases()
return cls.aliases.get(name, [name])

@classmethod
def get_by_alias(cls, *args) -> Set[str]:
if cls.aliases is None:
Expand All @@ -382,7 +388,7 @@ def get_by_alias(cls, *args) -> Set[str]:
for arg in args:
for aliases in cls.aliases.values():
if arg in aliases:
ret |= aliases
ret.update(aliases)
break
return ret

Expand All @@ -408,7 +414,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
# avoid side effects on passed-in parameters
params = deepcopy(params)

aliases = _ConfigAliases.get(main_param_name) - {main_param_name}
aliases = _ConfigAliases.get_sorted(main_param_name)
aliases = [a for a in aliases if a != main_param_name]

# if main_param_name was provided, keep that value and remove all aliases
if main_param_name in params.keys():
Expand Down
25 changes: 25 additions & 0 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,29 @@ std::string Config::ToString() const {
return str_buf.str();
}

const std::string Config::DumpAliases() {
auto map = Config::parameter2aliases();
for (auto& pair : map) {
std::sort(pair.second.begin(), pair.second.end(), SortAlias);
}
std::stringstream str_buf;
str_buf << "{\n";
bool first = true;
for (const auto& pair : map) {
if (first) {
str_buf << " \"";
first = false;
} else {
str_buf << " , \"";
}
str_buf << pair.first << "\": [";
if (pair.second.size() > 0) {
str_buf << "\"" << CommonC::Join(pair.second, "\", \"") << "\"";
}
str_buf << "]\n";
}
str_buf << "}\n";
return str_buf.str();
}

} // namespace LightGBM
Loading