Skip to content

Commit

Permalink
refactor: Migrate ccb to new make learner structure (#2923)
Browse files Browse the repository at this point in the history
* refactor: Migrate ccb to new make learner structure

* Fixes

* Update conditional_contextual_bandit.cc

* Update conditional_contextual_bandit.cc

* Update learner.h

* formatting
  • Loading branch information
jackgerrits authored Apr 8, 2021
1 parent 367a563 commit 9c24a77
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
29 changes: 16 additions & 13 deletions vowpalwabbit/conditional_contextual_bandit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,26 @@ void return_v_array(v_array<T>&& array, VW::v_array_pool<T>& pool)

struct ccb
{
vw* all;
example* shared;
vw* all = nullptr;
example* shared = nullptr;
std::vector<example*> actions, slots;
std::vector<uint32_t> origin_index;
CB::cb_class cb_label;
std::vector<bool> exclude_list, include_list;
namespace_interactions generated_interactions;
namespace_interactions* original_interactions;
std::vector<CCB::label> stored_labels;
size_t action_with_label;
size_t action_with_label = 0;

multi_ex cb_ex;

// All of these hashes are with a hasher seeded with the below namespace hash.
std::vector<uint64_t> slot_id_hashes;
uint64_t id_namespace_hash;
uint64_t id_namespace_hash = 0;
std::string id_namespace_str;

size_t base_learner_stride_shift;
bool all_slots_loss_report;
size_t base_learner_stride_shift = 0;
bool all_slots_loss_report = false;

VW::v_array_pool<CB::cb_class> cb_label_pool;
VW::v_array_pool<ACTION_SCORE::action_score> action_score_pool;
Expand Down Expand Up @@ -653,7 +653,7 @@ void save_load(ccb& sm, io_buf& io, bool read, bool text)

base_learner* ccb_explore_adf_setup(options_i& options, vw& all)
{
auto data = scoped_calloc_or_throw<ccb>();
auto data = VW::make_unique<ccb>();
bool ccb_explore_adf_option = false;
bool all_slots_loss_report = false;

Expand Down Expand Up @@ -698,12 +698,15 @@ base_learner* ccb_explore_adf_setup(options_i& options, vw& all)
data->id_namespace_str.append("_id");
data->id_namespace_hash = VW::hash_space(all, data->id_namespace_str);

learner<ccb, multi_ex>& l = init_learner(data, base, learn_or_predict<true>, learn_or_predict<false>, 1,
prediction_type_t::decision_probs, all.get_setupfn_name(ccb_explore_adf_setup), true);

l.set_finish_example(finish_multiline_example);
l.set_save_load(save_load);
return make_base(l);
auto* l = VW::LEARNER::make_reduction_learner(std::move(data), base, learn_or_predict<true>, learn_or_predict<false>,
all.get_setupfn_name(ccb_explore_adf_setup))
.set_learn_returns_prediction(true)
.set_prediction_type(prediction_type_t::decision_probs)
.set_label_type(label_type_t::ccb)
.set_finish_example(finish_multiline_example)
.set_save_load(save_load)
.build();
return make_base(*l);
}

bool ec_is_example_header(example const& ec) { return ec.l.conditional_contextual_bandit.type == example_type::shared; }
Expand Down
18 changes: 16 additions & 2 deletions vowpalwabbit/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,22 @@ struct reduction_learner_builder
set_params_per_weight(1);

this->_learner->pred_type = base->pred_type;
// TODO add label type as something learner knows about itself.
// this->_learner.label_type = label_type;
// TODO add label type as something learner knows about itself, this will enable more type checking and better
// description of the learner. this->_learner.label_type = label_type;
}

reduction_learner_builder<DataT, ExampleT, BaseLearnerT>& set_prediction_type(prediction_type_t pred_type)
{
this->_learner->pred_type = pred_type;
return *this;
}

reduction_learner_builder<DataT, ExampleT, BaseLearnerT>& set_label_type(label_type_t label_type)
{
// TODO add label type as something learner knows about itself, this will enable more type checking and better
// description of the learner. this->_learner.label_type = label_type;
std::ignore = label_type;
return *this;
}

reduction_learner_builder<DataT, ExampleT, BaseLearnerT>& set_params_per_weight(size_t params_per_weight)
Expand Down

0 comments on commit 9c24a77

Please sign in to comment.