From 9bfc2c1bca9b97bdf91b0b72741a93cfebe090bf Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 15 Aug 2023 16:07:49 -0400 Subject: [PATCH] cats learn returns prediction --- vowpalwabbit/core/src/reductions/cats.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vowpalwabbit/core/src/reductions/cats.cc b/vowpalwabbit/core/src/reductions/cats.cc index 5c2ffd91da8..d6abbfa9a4e 100644 --- a/vowpalwabbit/core/src/reductions/cats.cc +++ b/vowpalwabbit/core/src/reductions/cats.cc @@ -51,8 +51,10 @@ int cats::learn(example& ec, experimental::api_status* status = nullptr) { assert(!ec.test_only); predict(ec, status); + auto pred = ec.pred.pdf_value; VW_DBG(ec) << "cats::learn(), " << to_string(ec.l.cb_cont) << VW::debug::features_to_string(ec) << endl; _base->learn(ec); + ec.pred.pdf_value = pred; return VW::experimental::error_code::success; } @@ -199,6 +201,7 @@ std::shared_ptr VW::reductions::cats_setup(setup_base_i& s auto l = make_reduction_learner(std::move(p_reduction), p_base, predict_or_learn, predict_or_learn, stack_builder.get_setupfn_name(cats_setup)) + .set_learn_returns_prediction(true) .set_input_label_type(VW::label_type_t::CONTINUOUS) .set_output_label_type(VW::label_type_t::CONTINUOUS) .set_input_prediction_type(VW::prediction_type_t::ACTION_PDF_VALUE)