diff --git a/vowpalwabbit/core/include/vw/core/reductions/active.h b/vowpalwabbit/core/include/vw/core/reductions/active.h index dfca91718c3..0bff31db2d0 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/active.h +++ b/vowpalwabbit/core/include/vw/core/reductions/active.h @@ -16,10 +16,18 @@ namespace reductions class active { public: - active(float active_c0, VW::workspace* all) : active_c0(active_c0), _all(all) {} + active(float active_c0, std::shared_ptr shared_data, std::shared_ptr random_state, + VW::version_struct model_version) + : active_c0(active_c0) + , _shared_data(shared_data) + , _random_state(std::move(random_state)) + , _model_version{std::move(model_version)} + { + } float active_c0; - VW::workspace* _all = nullptr; + std::shared_ptr _shared_data; // statistics, loss + std::shared_ptr _random_state; float _min_seen_label = 0.f; float _max_seen_label = 1.f; diff --git a/vowpalwabbit/core/include/vw/core/vw_versions.h b/vowpalwabbit/core/include/vw/core/vw_versions.h index 05b9a2dc9aa..77b2429970e 100644 --- a/vowpalwabbit/core/include/vw/core/vw_versions.h +++ b/vowpalwabbit/core/include/vw/core/vw_versions.h @@ -59,6 +59,9 @@ constexpr VW::version_struct VERSION_PASS_UINT64{8, 3, 3}; /// Added serialized seen min and max labels in the --active reduction constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS{9, 0, 0}; +/// Active seen labels was accidentally reverted out in 9.4.0 +constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED{9, 4, 0}; +constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED{9, 10, 0}; /// Moved option values from command line to model data constexpr VW::version_struct VERSION_FILE_WITH_L1_AND_L2_STATE_IN_MODEL_DATA{9, 0, 0}; diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index 1f55a34ef87..ea8b66c40e4 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -50,13 +50,13 @@ float query_decision(const active& a, float ec_revert_weight, float k) if (k <= 1.f) { bias = 1.f; } else { - const auto weighted_queries = static_cast(a._all->sd->weighted_labeled_examples); - const float avg_loss = (static_cast(a._all->sd->sum_loss) / k) + + const auto weighted_queries = static_cast(a._shared_data->weighted_labeled_examples); + const float avg_loss = (static_cast(a._shared_data->sum_loss) / k) + std::sqrt((1.f + 0.5f * std::log(k)) / (weighted_queries + 0.0001f)); bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0); } - return (a._all->get_random_state()->get_and_update_random() < bias) ? 1.f / bias : -1.f; + return (a._random_state->get_and_update_random() < bias) ? 1.f / bias : -1.f; } template @@ -66,7 +66,7 @@ void predict_or_learn_simulation(active& a, learner& base, VW::example& ec) if (is_learn) { - const auto k = static_cast(a._all->sd->t); + const auto k = static_cast(a._shared_data->t); constexpr float threshold = 0.f; ec.confidence = fabsf(ec.pred.scalar - threshold) / base.sensitivity(ec); @@ -74,7 +74,7 @@ void predict_or_learn_simulation(active& a, learner& base, VW::example& ec) if (importance > 0.f) { - a._all->sd->queries += 1; + a._shared_data->queries += 1; ec.weight *= importance; base.learn(ec); } @@ -94,7 +94,7 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec) if (ec.l.simple.label == FLT_MAX) { - const float threshold = (a._all->sd->max_label + a._all->sd->min_label) * 0.5f; + const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f; // We want to understand the change in prediction if the label were to be // the opposite of what was predicted. 0 and 1 are used for the expected min // and max labels to be coming in from the active interactor. @@ -129,8 +129,14 @@ void active_print_result( void save_load(active& a, VW::io_buf& io, bool read, bool text) { + using namespace VW::version_definitions; if (io.num_files() == 0) { return; } - if (a._model_version >= VW::version_definitions::VERSION_FILE_WITH_ACTIVE_SEEN_LABELS) + // This code is valid if version is within + // [VERSION_FILE_WITH_ACTIVE_SEEN_LABELS, VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) + // or >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED + if ((a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS && + a._model_version < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) || + a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED) { if (read) { @@ -195,7 +201,7 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } if (options.was_supplied("lda")) { THROW("lda cannot be combined with active learning") } - auto data = VW::make_unique(active_c0, &all); + auto data = VW::make_unique(active_c0, all.sd, all.get_random_state(), all.runtime_state.model_file_ver); auto base = require_singleline(stack_builder.setup_base_learner()); using learn_pred_func_t = void (*)(active&, VW::LEARNER::learner&, VW::example&);