From 2a6f58a5a4ec834b199b5ee16487afbe57444f48 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 7 Apr 2021 11:39:36 -0400 Subject: [PATCH] MIgrate bfgs to new style --- vowpalwabbit/bfgs.cc | 86 +++++++++++++++++++++++------------------- vowpalwabbit/learner.h | 8 ++-- 2 files changed, 51 insertions(+), 43 deletions(-) diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index 1efb3bbed7b..2fca2a1943a 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -61,47 +61,49 @@ constexpr float max_precond_ratio = 10000.f; struct bfgs { - vw* all; // prediction, regressor - int m; - float rel_threshold; // termination threshold + vw* all = nullptr; // prediction, regressor + int m = 0; + float rel_threshold = 0.f; // termination threshold - double wolfe1_bound; + double wolfe1_bound = 0.0; - size_t final_pass; + size_t final_pass = 0; std::chrono::time_point t_start_global; std::chrono::time_point t_end_global; - double net_time; + double net_time = 0.0; v_array predictions; - size_t example_number; - size_t current_pass; - size_t no_win_counter; - size_t early_stop_thres; + size_t example_number = 0; + size_t current_pass = 0; + size_t no_win_counter = 0; + size_t early_stop_thres = 0; // default transition behavior - bool first_hessian_on; - bool backstep_on; + bool first_hessian_on = false; + bool backstep_on = false; // set by initializer - int mem_stride; - bool output_regularizer; - float* mem; - double* rho; - double* alpha; + int mem_stride = 0; + bool output_regularizer = false; + float* mem = nullptr; + double* rho = nullptr; + double* alpha = nullptr; - weight* regularizers; + weight* regularizers = nullptr; // the below needs to be included when resetting, in addition to preconditioner and derivative - int lastj, origin; - double loss_sum, previous_loss_sum; - float step_size; - double importance_weight_sum; - double curvature; + int lastj = 0; + int origin = 0; + double loss_sum = 0.0; + double previous_loss_sum = 0.0; + float step_size = 0.f; + double importance_weight_sum = 0.0; + double curvature = 0.0; // first pass specification - bool first_pass; - bool gradient_pass; - bool preconditioner_pass; + bool first_pass = false; + bool gradient_pass = false; + bool preconditioner_pass = false; ~bfgs() { @@ -1053,7 +1055,7 @@ void init_driver(bfgs& b) { b.backstep_on = true; } base_learner* bfgs_setup(options_i& options, vw& all) { - auto b = scoped_calloc_or_throw(); + auto b = VW::make_unique(); bool conjugate_gradient = false; bool bfgs_option = false; option_group_definition bfgs_outer_options("Conjugate Gradient options"); @@ -1109,20 +1111,26 @@ base_learner* bfgs_setup(options_i& options, vw& all) all.weights.stride_shift(2); void (*learn_ptr)(bfgs&, base_learner&, example&) = nullptr; - if (all.audit) + void (*predict_ptr)(bfgs&, base_learner&, example&) = nullptr; + std::string learner_name; + if (all.audit || all.hash_inv) + { learn_ptr = learn; + predict_ptr = predict; + learner_name = all.get_setupfn_name(bfgs_setup) + "-audit"; + } else + { learn_ptr = learn; + predict_ptr = predict; + learner_name = all.get_setupfn_name(bfgs_setup); + } - learner* l; - if (all.audit || all.hash_inv) - l = &init_learner(b, learn_ptr, predict, all.weights.stride(), all.get_setupfn_name(bfgs_setup) + "-audit"); - else - l = &init_learner(b, learn_ptr, predict, all.weights.stride(), all.get_setupfn_name(bfgs_setup)); - - l->set_save_load(save_load); - l->set_init_driver(init_driver); - l->set_end_pass(end_pass); - - return make_base(*l); + return make_base(*make_base_learner( + std::move(b), learn_ptr, predict_ptr, learner_name, prediction_type_t::scalar, label_type_t::simple) + .set_params_per_weight(all.weights.stride()) + .set_save_load(save_load) + .set_init_driver(init_driver) + .set_end_pass(end_pass) + .build()); } diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index b2612fb38eb..4b7d0f4e00e 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -764,25 +764,25 @@ struct common_learner_builder FluentBuilderT& set_finish(void (*fn_ptr)(DataT&)) { - _learner->finisher_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (finish_fptr_type)(fn_ptr)); + _learner->finisher_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (finish_fptr_type)(fn_ptr)); return *static_cast(this); } FluentBuilderT& set_end_pass(void (*fn_ptr)(DataT&)) { - _learner->end_pass_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr); + _learner->end_pass_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr); return *static_cast(this); } FluentBuilderT& set_end_examples(void (*fn_ptr)(DataT&)) { - _learner->end_examples_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr); + _learner->end_examples_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr); return *static_cast(this); } FluentBuilderT& set_init_driver(void (*fn_ptr)(DataT&)) { - _learner->init_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr); + _learner->init_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr); return *static_cast(this); }