Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: migrate bfgs to new style learner creation #2926

Merged
merged 2 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 47 additions & 39 deletions vowpalwabbit/bfgs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,47 +61,49 @@ constexpr float max_precond_ratio = 10000.f;

struct bfgs
{
vw* all; // prediction, regressor
int m;
float rel_threshold; // termination threshold
vw* all = nullptr; // prediction, regressor
int m = 0;
float rel_threshold = 0.f; // termination threshold

double wolfe1_bound;
double wolfe1_bound = 0.0;

size_t final_pass;
size_t final_pass = 0;

std::chrono::time_point<std::chrono::system_clock> t_start_global;
std::chrono::time_point<std::chrono::system_clock> t_end_global;
double net_time;
double net_time = 0.0;

v_array<float> predictions;
size_t example_number;
size_t current_pass;
size_t no_win_counter;
size_t early_stop_thres;
size_t example_number = 0;
size_t current_pass = 0;
size_t no_win_counter = 0;
size_t early_stop_thres = 0;

// default transition behavior
bool first_hessian_on;
bool backstep_on;
bool first_hessian_on = false;
bool backstep_on = false;

// set by initializer
int mem_stride;
bool output_regularizer;
float* mem;
double* rho;
double* alpha;
int mem_stride = 0;
bool output_regularizer = false;
float* mem = nullptr;
double* rho = nullptr;
double* alpha = nullptr;

weight* regularizers;
weight* regularizers = nullptr;
// the below needs to be included when resetting, in addition to preconditioner and derivative
int lastj, origin;
double loss_sum, previous_loss_sum;
float step_size;
double importance_weight_sum;
double curvature;
int lastj = 0;
int origin = 0;
double loss_sum = 0.0;
double previous_loss_sum = 0.0;
float step_size = 0.f;
double importance_weight_sum = 0.0;
double curvature = 0.0;

// first pass specification
bool first_pass;
bool gradient_pass;
bool preconditioner_pass;
bool first_pass = false;
bool gradient_pass = false;
bool preconditioner_pass = false;

~bfgs()
{
Expand Down Expand Up @@ -1054,7 +1056,7 @@ void init_driver(bfgs& b) { b.backstep_on = true; }

base_learner* bfgs_setup(options_i& options, vw& all)
{
auto b = scoped_calloc_or_throw<bfgs>();
auto b = VW::make_unique<bfgs>();
bool conjugate_gradient = false;
bool bfgs_option = false;
option_group_definition bfgs_outer_options("Conjugate Gradient options");
Expand Down Expand Up @@ -1110,20 +1112,26 @@ base_learner* bfgs_setup(options_i& options, vw& all)
all.weights.stride_shift(2);

void (*learn_ptr)(bfgs&, base_learner&, example&) = nullptr;
if (all.audit)
void (*predict_ptr)(bfgs&, base_learner&, example&) = nullptr;
std::string learner_name;
if (all.audit || all.hash_inv)
{
learn_ptr = learn<true>;
predict_ptr = predict<true>;
learner_name = all.get_setupfn_name(bfgs_setup) + "-audit";
}
else
{
learn_ptr = learn<false>;
predict_ptr = predict<false>;
learner_name = all.get_setupfn_name(bfgs_setup);
}

learner<bfgs, example>* l;
if (all.audit || all.hash_inv)
l = &init_learner(b, learn_ptr, predict<true>, all.weights.stride(), all.get_setupfn_name(bfgs_setup) + "-audit");
else
l = &init_learner(b, learn_ptr, predict<false>, all.weights.stride(), all.get_setupfn_name(bfgs_setup));

l->set_save_load(save_load);
l->set_init_driver(init_driver);
l->set_end_pass(end_pass);

return make_base(*l);
return make_base(*make_base_learner(
std::move(b), learn_ptr, predict_ptr, learner_name, prediction_type_t::scalar, label_type_t::simple)
.set_params_per_weight(all.weights.stride())
.set_save_load(save_load)
.set_init_driver(init_driver)
.set_end_pass(end_pass)
.build());
}
8 changes: 4 additions & 4 deletions vowpalwabbit/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -764,25 +764,25 @@ struct common_learner_builder

FluentBuilderT& set_finish(void (*fn_ptr)(DataT&))
{
_learner->finisher_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (finish_fptr_type)(fn_ptr));
_learner->finisher_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (finish_fptr_type)(fn_ptr));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: tuple_dbf -> tuple_data_base_fn

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug I introduced in the original PR and was not caught because the template was not instantiated. The rename would be good but is out of scope for this change

return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& set_end_pass(void (*fn_ptr)(DataT&))
{
_learner->end_pass_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
_learner->end_pass_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& set_end_examples(void (*fn_ptr)(DataT&))
{
_learner->end_examples_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
_learner->end_examples_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& set_init_driver(void (*fn_ptr)(DataT&))
{
_learner->init_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
_learner->init_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
return *static_cast<FluentBuilderT*>(this);
}

Expand Down