Skip to content

Commit

Permalink
Merge branch 'master' into gd_get
Browse files Browse the repository at this point in the history
  • Loading branch information
bassmang authored Oct 26, 2023
2 parents 24a0b63 + 4734d3f commit 384fb17
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
12 changes: 10 additions & 2 deletions vowpalwabbit/core/include/vw/core/reductions/active.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@ namespace reductions
class active
{
public:
active(float active_c0, VW::workspace* all) : active_c0(active_c0), _all(all) {}
active(float active_c0, std::shared_ptr<shared_data> shared_data, std::shared_ptr<rand_state> random_state,
VW::version_struct model_version)
: active_c0(active_c0)
, _shared_data(shared_data)
, _random_state(std::move(random_state))
, _model_version{std::move(model_version)}
{
}

float active_c0;
VW::workspace* _all = nullptr;
std::shared_ptr<shared_data> _shared_data; // statistics, loss
std::shared_ptr<rand_state> _random_state;

float _min_seen_label = 0.f;
float _max_seen_label = 1.f;
Expand Down
3 changes: 3 additions & 0 deletions vowpalwabbit/core/include/vw/core/vw_versions.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ constexpr VW::version_struct VERSION_PASS_UINT64{8, 3, 3};

/// Added serialized seen min and max labels in the --active reduction
constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS{9, 0, 0};
/// Active seen labels was accidentally reverted out in 9.4.0
constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED{9, 4, 0};
constexpr VW::version_struct VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED{9, 10, 0};

/// Moved option values from command line to model data
constexpr VW::version_struct VERSION_FILE_WITH_L1_AND_L2_STATE_IN_MODEL_DATA{9, 0, 0};
Expand Down
22 changes: 14 additions & 8 deletions vowpalwabbit/core/src/reductions/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ float query_decision(const active& a, float ec_revert_weight, float k)
if (k <= 1.f) { bias = 1.f; }
else
{
const auto weighted_queries = static_cast<float>(a._all->sd->weighted_labeled_examples);
const float avg_loss = (static_cast<float>(a._all->sd->sum_loss) / k) +
const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / k) +
std::sqrt((1.f + 0.5f * std::log(k)) / (weighted_queries + 0.0001f));
bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0);
}

return (a._all->get_random_state()->get_and_update_random() < bias) ? 1.f / bias : -1.f;
return (a._random_state->get_and_update_random() < bias) ? 1.f / bias : -1.f;
}

template <bool is_learn>
Expand All @@ -66,15 +66,15 @@ void predict_or_learn_simulation(active& a, learner& base, VW::example& ec)

if (is_learn)
{
const auto k = static_cast<float>(a._all->sd->t);
const auto k = static_cast<float>(a._shared_data->t);
constexpr float threshold = 0.f;

ec.confidence = fabsf(ec.pred.scalar - threshold) / base.sensitivity(ec);
const float importance = query_decision(a, ec.confidence, k);

if (importance > 0.f)
{
a._all->sd->queries += 1;
a._shared_data->queries += 1;
ec.weight *= importance;
base.learn(ec);
}
Expand All @@ -94,7 +94,7 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec)

if (ec.l.simple.label == FLT_MAX)
{
const float threshold = (a._all->sd->max_label + a._all->sd->min_label) * 0.5f;
const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f;
// We want to understand the change in prediction if the label were to be
// the opposite of what was predicted. 0 and 1 are used for the expected min
// and max labels to be coming in from the active interactor.
Expand Down Expand Up @@ -129,8 +129,14 @@ void active_print_result(

void save_load(active& a, VW::io_buf& io, bool read, bool text)
{
using namespace VW::version_definitions;
if (io.num_files() == 0) { return; }
if (a._model_version >= VW::version_definitions::VERSION_FILE_WITH_ACTIVE_SEEN_LABELS)
// This code is valid if version is within
// [VERSION_FILE_WITH_ACTIVE_SEEN_LABELS, VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED)
// or >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED
if ((a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS &&
a._model_version < VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_REVERTED) ||
a._model_version >= VERSION_FILE_WITH_ACTIVE_SEEN_LABELS_FIXED)
{
if (read)
{
Expand Down Expand Up @@ -195,7 +201,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }

if (options.was_supplied("lda")) { THROW("lda cannot be combined with active learning") }
auto data = VW::make_unique<active>(active_c0, &all);
auto data = VW::make_unique<active>(active_c0, all.sd, all.get_random_state(), all.runtime_state.model_file_ver);
auto base = require_singleline(stack_builder.setup_base_learner());

using learn_pred_func_t = void (*)(active&, VW::LEARNER::learner&, VW::example&);
Expand Down

0 comments on commit 384fb17

Please sign in to comment.