Skip to content

Commit

Permalink
cats learn returns prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Aug 15, 2023
1 parent b77e9ab commit 9bfc2c1
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions vowpalwabbit/core/src/reductions/cats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ int cats::learn(example& ec, experimental::api_status* status = nullptr)
{
assert(!ec.test_only);
predict(ec, status);
auto pred = ec.pred.pdf_value;
VW_DBG(ec) << "cats::learn(), " << to_string(ec.l.cb_cont) << VW::debug::features_to_string(ec) << endl;
_base->learn(ec);
ec.pred.pdf_value = pred;
return VW::experimental::error_code::success;
}

Expand Down Expand Up @@ -199,6 +201,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::cats_setup(setup_base_i& s

auto l = make_reduction_learner(std::move(p_reduction), p_base, predict_or_learn<true>, predict_or_learn<false>,
stack_builder.get_setupfn_name(cats_setup))
.set_learn_returns_prediction(true)
.set_input_label_type(VW::label_type_t::CONTINUOUS)
.set_output_label_type(VW::label_type_t::CONTINUOUS)
.set_input_prediction_type(VW::prediction_type_t::ACTION_PDF_VALUE)
Expand Down

0 comments on commit 9bfc2c1

Please sign in to comment.