diff --git a/vowpalwabbit/conditional_contextual_bandit.cc b/vowpalwabbit/conditional_contextual_bandit.cc index f4a4957e949..165e2057450 100644 --- a/vowpalwabbit/conditional_contextual_bandit.cc +++ b/vowpalwabbit/conditional_contextual_bandit.cc @@ -44,8 +44,8 @@ void return_v_array(v_array&& array, VW::v_array_pool& pool) struct ccb { - vw* all; - example* shared; + vw* all = nullptr; + example* shared = nullptr; std::vector actions, slots; std::vector origin_index; CB::cb_class cb_label; @@ -53,17 +53,17 @@ struct ccb namespace_interactions generated_interactions; namespace_interactions* original_interactions; std::vector stored_labels; - size_t action_with_label; + size_t action_with_label = 0; multi_ex cb_ex; // All of these hashes are with a hasher seeded with the below namespace hash. std::vector slot_id_hashes; - uint64_t id_namespace_hash; + uint64_t id_namespace_hash = 0; std::string id_namespace_str; - size_t base_learner_stride_shift; - bool all_slots_loss_report; + size_t base_learner_stride_shift = 0; + bool all_slots_loss_report = false; VW::v_array_pool cb_label_pool; VW::v_array_pool action_score_pool; @@ -653,7 +653,7 @@ void save_load(ccb& sm, io_buf& io, bool read, bool text) base_learner* ccb_explore_adf_setup(options_i& options, vw& all) { - auto data = scoped_calloc_or_throw(); + auto data = VW::make_unique(); bool ccb_explore_adf_option = false; bool all_slots_loss_report = false; @@ -698,12 +698,15 @@ base_learner* ccb_explore_adf_setup(options_i& options, vw& all) data->id_namespace_str.append("_id"); data->id_namespace_hash = VW::hash_space(all, data->id_namespace_str); - learner& l = init_learner(data, base, learn_or_predict, learn_or_predict, 1, - prediction_type_t::decision_probs, all.get_setupfn_name(ccb_explore_adf_setup), true); - - l.set_finish_example(finish_multiline_example); - l.set_save_load(save_load); - return make_base(l); + auto* l = VW::LEARNER::make_reduction_learner(std::move(data), base, learn_or_predict, learn_or_predict, + all.get_setupfn_name(ccb_explore_adf_setup)) + .set_learn_returns_prediction(true) + .set_prediction_type(prediction_type_t::decision_probs) + .set_label_type(label_type_t::ccb) + .set_finish_example(finish_multiline_example) + .set_save_load(save_load) + .build(); + return make_base(*l); } bool ec_is_example_header(example const& ec) { return ec.l.conditional_contextual_bandit.type == example_type::shared; } diff --git a/vowpalwabbit/learner.h b/vowpalwabbit/learner.h index 4b7d0f4e00e..a7d92692b68 100644 --- a/vowpalwabbit/learner.h +++ b/vowpalwabbit/learner.h @@ -817,8 +817,22 @@ struct reduction_learner_builder set_params_per_weight(1); this->_learner->pred_type = base->pred_type; - // TODO add label type as something learner knows about itself. - // this->_learner.label_type = label_type; + // TODO add label type as something learner knows about itself, this will enable more type checking and better + // description of the learner. this->_learner.label_type = label_type; + } + + reduction_learner_builder& set_prediction_type(prediction_type_t pred_type) + { + this->_learner->pred_type = pred_type; + return *this; + } + + reduction_learner_builder& set_label_type(label_type_t label_type) + { + // TODO add label type as something learner knows about itself, this will enable more type checking and better + // description of the learner. this->_learner.label_type = label_type; + std::ignore = label_type; + return *this; } reduction_learner_builder& set_params_per_weight(size_t params_per_weight)