From 8d4d435854bd035a1a9f2b5cde88459c3a1d5f9b Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 15 Aug 2023 15:37:51 -0400 Subject: [PATCH] copy predictions for compat --- vowpalwabbit/core/include/vw/core/example.h | 1 + vowpalwabbit/core/src/example.cc | 47 +++++++++++++++++++++ vowpalwabbit/core/src/global_data.cc | 4 ++ 3 files changed, 52 insertions(+) diff --git a/vowpalwabbit/core/include/vw/core/example.h b/vowpalwabbit/core/include/vw/core/example.h index cd2185a8b08..df8e9f925e1 100644 --- a/vowpalwabbit/core/include/vw/core/example.h +++ b/vowpalwabbit/core/include/vw/core/example.h @@ -84,6 +84,7 @@ class polyprediction std::string to_string(const v_array& scalars, int decimal_precision = details::DEFAULT_FLOAT_PRECISION); void swap_prediction(polyprediction& a, polyprediction& b, prediction_type_t prediction_type); +void copy_prediction(const VW::polyprediction& src, VW::polyprediction& dest, VW::prediction_type_t prediction_type); class example : public example_predict // core example datatype. { diff --git a/vowpalwabbit/core/src/example.cc b/vowpalwabbit/core/src/example.cc index cb2c88be038..5afff571620 100644 --- a/vowpalwabbit/core/src/example.cc +++ b/vowpalwabbit/core/src/example.cc @@ -62,6 +62,53 @@ void VW::swap_prediction(VW::polyprediction& a, VW::polyprediction& b, VW::predi } } +void VW::copy_prediction(const VW::polyprediction& src, VW::polyprediction& dest, VW::prediction_type_t prediction_type) +{ + switch (prediction_type) + { + case VW::prediction_type_t::SCALAR: + dest.scalar = src.scalar; + break; + case VW::prediction_type_t::SCALARS: + dest.scalars = src.scalars; + break; + case VW::prediction_type_t::ACTION_SCORES: + dest.a_s = src.a_s; + break; + case VW::prediction_type_t::PDF: + dest.pdf = src.pdf; + break; + case VW::prediction_type_t::ACTION_PROBS: + dest.a_s = src.a_s; + break; + case VW::prediction_type_t::MULTICLASS: + dest.multiclass = src.multiclass; + break; + case VW::prediction_type_t::MULTILABELS: + dest.multilabels = src.multilabels; + break; + case VW::prediction_type_t::PROB: + dest.prob = src.prob; + break; + case VW::prediction_type_t::MULTICLASS_PROBS: + dest.scalars = src.scalars; + break; + case VW::prediction_type_t::DECISION_PROBS: + dest.decision_scores = src.decision_scores; + break; + case VW::prediction_type_t::ACTION_PDF_VALUE: + dest.pdf_value = src.pdf_value; + break; + case VW::prediction_type_t::ACTIVE_MULTICLASS: + dest.active_multiclass = src.active_multiclass; + break; + case VW::prediction_type_t::NOPRED: + // Noop + break; + } +} + + float calculate_total_sum_features_squared(bool permutations, VW::example& ec) { float sum_features_squared = 0.f; diff --git a/vowpalwabbit/core/src/global_data.cc b/vowpalwabbit/core/src/global_data.cc index ff3833d0682..4a759a7588a 100644 --- a/vowpalwabbit/core/src/global_data.cc +++ b/vowpalwabbit/core/src/global_data.cc @@ -94,6 +94,8 @@ void workspace::learn(example& ec) VW::LEARNER::require_singleline(l)->predict(ec); VW::polyprediction saved_prediction; VW::swap_prediction(ec.pred, saved_prediction, l->get_output_prediction_type()); + // Some reductions break without this line, TODO fix and remove + VW::copy_prediction(saved_prediction, ec.pred, l->get_output_prediction_type()); VW::LEARNER::require_singleline(l)->learn(ec); VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type()); } @@ -113,6 +115,8 @@ void workspace::learn(multi_ex& ec) VW::LEARNER::require_multiline(l)->predict(ec); VW::polyprediction saved_prediction; VW::swap_prediction(ec[0]->pred, saved_prediction, l->get_output_prediction_type()); + // Some reductions break without this line, TODO fix and remove + VW::copy_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type()); VW::LEARNER::require_multiline(l)->learn(ec); VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type()); }