Skip to content

Commit

Permalink
MIgrate bfgs to new style
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Apr 7, 2021
1 parent 488ffa8 commit 2a6f58a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 43 deletions.
86 changes: 47 additions & 39 deletions vowpalwabbit/bfgs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,47 +61,49 @@ constexpr float max_precond_ratio = 10000.f;

struct bfgs
{
vw* all; // prediction, regressor
int m;
float rel_threshold; // termination threshold
vw* all = nullptr; // prediction, regressor
int m = 0;
float rel_threshold = 0.f; // termination threshold

double wolfe1_bound;
double wolfe1_bound = 0.0;

size_t final_pass;
size_t final_pass = 0;

std::chrono::time_point<std::chrono::system_clock> t_start_global;
std::chrono::time_point<std::chrono::system_clock> t_end_global;
double net_time;
double net_time = 0.0;

v_array<float> predictions;
size_t example_number;
size_t current_pass;
size_t no_win_counter;
size_t early_stop_thres;
size_t example_number = 0;
size_t current_pass = 0;
size_t no_win_counter = 0;
size_t early_stop_thres = 0;

// default transition behavior
bool first_hessian_on;
bool backstep_on;
bool first_hessian_on = false;
bool backstep_on = false;

// set by initializer
int mem_stride;
bool output_regularizer;
float* mem;
double* rho;
double* alpha;
int mem_stride = 0;
bool output_regularizer = false;
float* mem = nullptr;
double* rho = nullptr;
double* alpha = nullptr;

weight* regularizers;
weight* regularizers = nullptr;
// the below needs to be included when resetting, in addition to preconditioner and derivative
int lastj, origin;
double loss_sum, previous_loss_sum;
float step_size;
double importance_weight_sum;
double curvature;
int lastj = 0;
int origin = 0;
double loss_sum = 0.0;
double previous_loss_sum = 0.0;
float step_size = 0.f;
double importance_weight_sum = 0.0;
double curvature = 0.0;

// first pass specification
bool first_pass;
bool gradient_pass;
bool preconditioner_pass;
bool first_pass = false;
bool gradient_pass = false;
bool preconditioner_pass = false;

~bfgs()
{
Expand Down Expand Up @@ -1053,7 +1055,7 @@ void init_driver(bfgs& b) { b.backstep_on = true; }

base_learner* bfgs_setup(options_i& options, vw& all)
{
auto b = scoped_calloc_or_throw<bfgs>();
auto b = VW::make_unique<bfgs>();
bool conjugate_gradient = false;
bool bfgs_option = false;
option_group_definition bfgs_outer_options("Conjugate Gradient options");
Expand Down Expand Up @@ -1109,20 +1111,26 @@ base_learner* bfgs_setup(options_i& options, vw& all)
all.weights.stride_shift(2);

void (*learn_ptr)(bfgs&, base_learner&, example&) = nullptr;
if (all.audit)
void (*predict_ptr)(bfgs&, base_learner&, example&) = nullptr;
std::string learner_name;
if (all.audit || all.hash_inv)
{
learn_ptr = learn<true>;
predict_ptr = predict<true>;
learner_name = all.get_setupfn_name(bfgs_setup) + "-audit";
}
else
{
learn_ptr = learn<false>;
predict_ptr = predict<false>;
learner_name = all.get_setupfn_name(bfgs_setup);
}

learner<bfgs, example>* l;
if (all.audit || all.hash_inv)
l = &init_learner(b, learn_ptr, predict<true>, all.weights.stride(), all.get_setupfn_name(bfgs_setup) + "-audit");
else
l = &init_learner(b, learn_ptr, predict<false>, all.weights.stride(), all.get_setupfn_name(bfgs_setup));

l->set_save_load(save_load);
l->set_init_driver(init_driver);
l->set_end_pass(end_pass);

return make_base(*l);
return make_base(*make_base_learner(
std::move(b), learn_ptr, predict_ptr, learner_name, prediction_type_t::scalar, label_type_t::simple)
.set_params_per_weight(all.weights.stride())
.set_save_load(save_load)
.set_init_driver(init_driver)
.set_end_pass(end_pass)
.build());
}
8 changes: 4 additions & 4 deletions vowpalwabbit/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -764,25 +764,25 @@ struct common_learner_builder

FluentBuilderT& set_finish(void (*fn_ptr)(DataT&))
{
_learner->finisher_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (finish_fptr_type)(fn_ptr));
_learner->finisher_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (finish_fptr_type)(fn_ptr));
return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& set_end_pass(void (*fn_ptr)(DataT&))
{
_learner->end_pass_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
_learner->end_pass_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& set_end_examples(void (*fn_ptr)(DataT&))
{
_learner->end_examples_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
_learner->end_examples_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& set_init_driver(void (*fn_ptr)(DataT&))
{
_learner->init_fd = func_data(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
_learner->init_fd = tuple_dbf(_learner->learn_fd.data, _learner->learn_fd.base, (func_data::fn)fn_ptr);
return *static_cast<FluentBuilderT*>(this);
}

Expand Down

0 comments on commit 2a6f58a

Please sign in to comment.