Skip to content

Commit

Permalink
copy predictions for compat
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Aug 15, 2023
1 parent cb27769 commit 8d4d435
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions vowpalwabbit/core/include/vw/core/example.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class polyprediction

std::string to_string(const v_array<float>& scalars, int decimal_precision = details::DEFAULT_FLOAT_PRECISION);
void swap_prediction(polyprediction& a, polyprediction& b, prediction_type_t prediction_type);
void copy_prediction(const VW::polyprediction& src, VW::polyprediction& dest, VW::prediction_type_t prediction_type);

class example : public example_predict // core example datatype.
{
Expand Down
47 changes: 47 additions & 0 deletions vowpalwabbit/core/src/example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,53 @@ void VW::swap_prediction(VW::polyprediction& a, VW::polyprediction& b, VW::predi
}
}

void VW::copy_prediction(const VW::polyprediction& src, VW::polyprediction& dest, VW::prediction_type_t prediction_type)
{
switch (prediction_type)
{
case VW::prediction_type_t::SCALAR:
dest.scalar = src.scalar;
break;
case VW::prediction_type_t::SCALARS:
dest.scalars = src.scalars;
break;
case VW::prediction_type_t::ACTION_SCORES:
dest.a_s = src.a_s;
break;
case VW::prediction_type_t::PDF:
dest.pdf = src.pdf;
break;
case VW::prediction_type_t::ACTION_PROBS:
dest.a_s = src.a_s;
break;
case VW::prediction_type_t::MULTICLASS:
dest.multiclass = src.multiclass;
break;
case VW::prediction_type_t::MULTILABELS:
dest.multilabels = src.multilabels;
break;
case VW::prediction_type_t::PROB:
dest.prob = src.prob;
break;
case VW::prediction_type_t::MULTICLASS_PROBS:
dest.scalars = src.scalars;
break;
case VW::prediction_type_t::DECISION_PROBS:
dest.decision_scores = src.decision_scores;
break;
case VW::prediction_type_t::ACTION_PDF_VALUE:
dest.pdf_value = src.pdf_value;
break;
case VW::prediction_type_t::ACTIVE_MULTICLASS:
dest.active_multiclass = src.active_multiclass;
break;
case VW::prediction_type_t::NOPRED:
// Noop
break;
}
}


float calculate_total_sum_features_squared(bool permutations, VW::example& ec)
{
float sum_features_squared = 0.f;
Expand Down
4 changes: 4 additions & 0 deletions vowpalwabbit/core/src/global_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ void workspace::learn(example& ec)
VW::LEARNER::require_singleline(l)->predict(ec);
VW::polyprediction saved_prediction;
VW::swap_prediction(ec.pred, saved_prediction, l->get_output_prediction_type());
// Some reductions break without this line, TODO fix and remove
VW::copy_prediction(saved_prediction, ec.pred, l->get_output_prediction_type());
VW::LEARNER::require_singleline(l)->learn(ec);
VW::swap_prediction(saved_prediction, ec.pred, l->get_output_prediction_type());
}
Expand All @@ -113,6 +115,8 @@ void workspace::learn(multi_ex& ec)
VW::LEARNER::require_multiline(l)->predict(ec);
VW::polyprediction saved_prediction;
VW::swap_prediction(ec[0]->pred, saved_prediction, l->get_output_prediction_type());
// Some reductions break without this line, TODO fix and remove
VW::copy_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type());
VW::LEARNER::require_multiline(l)->learn(ec);
VW::swap_prediction(saved_prediction, ec[0]->pred, l->get_output_prediction_type());
}
Expand Down

0 comments on commit 8d4d435

Please sign in to comment.