Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experiment: try rank builder #530

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions benchmarks/benchmark_cb_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ const auto JSON_CFG = R"(
}
)";

u::configuration get_config(bool dedup, bool compression)
{
u::configuration config;
cfg::create_from_json(JSON_CFG, config);
config.set(r::name::PROTOCOL_VERSION, "2");
config.set(r::name::EH_TEST, "true");
config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA);
config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER);
config.set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER);
config.set(r::name::INTERACTION_FILE_NAME, "/dev/null");
config.set(r::name::OBSERVATION_FILE_NAME, "/dev/null");
config.set(r::name::MODEL_BACKGROUND_REFRESH, "false");
// config.set(r::name::MODEL_IMPLEMENTATION, r::value::PASSTHROUGH_PDF_MODEL);
config.set(r::name::VW_POOL_INIT_SIZE, "1");
config.set(r::name::INTERACTION_USE_COMPRESSION, compression ? "true" : "false");
config.set(r::name::INTERACTION_USE_DEDUP, dedup ? "true" : "false");
config.set("queue.mode", "BLOCK");
return config;
}

template <class... ExtraArgs>
static void bench_cb(benchmark::State& state, ExtraArgs&&... extra_args)
{
Expand All @@ -44,21 +64,7 @@ static void bench_cb(benchmark::State& state, ExtraArgs&&... extra_args)
std::vector<std::string> examples;
std::generate_n(std::back_inserter(examples), count, [&cb_gen] { return cb_gen.gen_example(); });

u::configuration config;
cfg::create_from_json(JSON_CFG, config);
config.set(r::name::PROTOCOL_VERSION, "2");
config.set(r::name::EH_TEST, "true");
config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA);
config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER);
config.set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER);
config.set(r::name::INTERACTION_FILE_NAME, "/dev/null");
config.set(r::name::OBSERVATION_FILE_NAME, "/dev/null");
config.set(r::name::MODEL_BACKGROUND_REFRESH, "false");
// config.set(r::name::MODEL_IMPLEMENTATION, r::value::PASSTHROUGH_PDF_MODEL);
config.set(r::name::VW_POOL_INIT_SIZE, "1");
config.set(r::name::INTERACTION_USE_COMPRESSION, compression ? "true" : "false");
config.set(r::name::INTERACTION_USE_DEDUP, dedup ? "true" : "false");
config.set("queue.mode", "BLOCK");
auto config = get_config(dedup, compression);

r::api_status status;
r::live_model model(config);
Expand All @@ -82,6 +88,49 @@ static void bench_cb(benchmark::State& state, ExtraArgs&&... extra_args)
}
}

template <class... ExtraArgs>
static void bench_cb_w_builder(benchmark::State& state, ExtraArgs&&... extra_args)
{
int res[sizeof...(extra_args)] = {extra_args...};
auto shared_features = res[0];
auto action_features = res[1];
auto actions_per_decision = res[2];
auto total_actions = res[3];
auto count = res[4];
bool compression = res[5];
bool dedup = res[6];

cb_decision_gen cb_gen(shared_features, action_features, actions_per_decision, total_actions, 0, false);

std::vector<std::string> examples;
std::generate_n(std::back_inserter(examples), count, [&cb_gen] { return cb_gen.gen_example(); });

auto config = get_config(dedup, compression);

r::api_status status;
r::live_model model(config);
model.init(&status);
const auto event_id = "event_id";

auto& rank_builder = model.get_rank_builder();

r::ranking_response response;

for (auto _ : state)
{
for (size_t i = 0; i < count; i++)
{
if (rank_builder.set_event_id(event_id).set_context(examples[i].c_str()).rank(response, &status) != err::success)
{
std::cout << "there was an error so something went wrong during "
"benchmarking: "
<< status.get_error_msg() << std::endl;
}
}
benchmark::ClobberMemory();
}
}

// characteristics of the benchmark examples that will be generated are:

// x shared features
Expand All @@ -92,6 +141,8 @@ static void bench_cb(benchmark::State& state, ExtraArgs&&... extra_args)
// compression (on/off)
// dedup (on/off)
BENCHMARK_CAPTURE(bench_cb, non_dedupable_payload, 20, 10, 50, 2000, 500, false, false);
BENCHMARK_CAPTURE(bench_cb_w_builder, non_dedupable_payload_w_builder, 20, 10, 50, 2000, 500, false, false);

BENCHMARK_CAPTURE(bench_cb, non_dedupable_payload_compression, 20, 10, 50, 2000, 500, true, false);
BENCHMARK_CAPTURE(bench_cb, non_dedupable_payload_dedup, 20, 10, 50, 2000, 500, false, true);
BENCHMARK_CAPTURE(bench_cb, non_dedupable_payload_compression_dedup, 20, 10, 50, 2000, 500, true, true);
14 changes: 7 additions & 7 deletions include/live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
* @date 2018-07-18
*/
#pragma once
#include "action_flags.h"
#include "continuous_action_response.h"
#include "decision_response.h"
#include "err_constants.h"
#include "factory_resolver.h"
#include "future_compat.h"
#include "multi_slot_response.h"
#include "multi_slot_response_detailed.h"
#include "multistep.h"
#include "ranking_response.h"
#include "request_builders.h"
#include "sender.h"
#include "vw/core/example.h"

#include <functional>
#include <memory>
Expand Down Expand Up @@ -107,6 +103,8 @@ class live_model
*/
int init(api_status* status = nullptr);

rank_builder& get_rank_builder();

/**
* @brief Choose an action, given a list of actions, action features and context features. The
* inference library chooses an action by creating a probability distribution over the actions
Expand Down Expand Up @@ -487,6 +485,8 @@ class live_model
const std::vector<int> default_baseline_vector = std::vector<int>();
static std::vector<int> c_array_to_vector(
const int* c_array, size_t array_size); //! Convert baseline_actions from c array to std vector.

rank_builder _rank_builder;
};

/**
Expand All @@ -511,4 +511,4 @@ live_model::live_model(const utility::configuration& config, error_fn_t<ErrCntxt
s_factory, time_prov_factory)
{
}
} // namespace reinforcement_learning
} // namespace reinforcement_learning
58 changes: 58 additions & 0 deletions include/request_builders.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "action_flags.h"
#include "continuous_action_response.h"
#include "decision_response.h"
#include "multi_slot_response.h"
#include "multi_slot_response_detailed.h"
#include "ranking_response.h"
#include "rl_string_view.h"

namespace reinforcement_learning
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following vw's learner builder style

{
class live_model_impl;

template <class FluentBuilderT>
class basic_builder
{
protected:
const char* _event_id = nullptr;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be string_view after I finally merge the other string_view event id PR

unsigned int _flags = action_flags::DEFAULT;
string_view _context_json;

public:
FluentBuilderT& set_event_id(const char* event_id)
{
_event_id = event_id;
return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& set_flags(unsigned int flags)
{
_flags = flags;
return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& set_context(string_view context_json)
{
_context_json = context_json;
return *static_cast<FluentBuilderT*>(this);
}

FluentBuilderT& clear()
{
_event_id = nullptr;
_flags = action_flags::DEFAULT;
_context_json = {};
return *static_cast<FluentBuilderT*>(this);
}
};

class rank_builder : public basic_builder<rank_builder>
{
live_model_impl* model;

public:
rank_builder(live_model_impl* live_model);
int rank(ranking_response& resp, api_status* status = nullptr);
rank_builder& clear();
};
} // namespace reinforcement_learning
2 changes: 2 additions & 0 deletions rlclientlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ set(PROJECT_SOURCES
multi_slot_response_detailed.cc
ranking_event.cc
ranking_response.cc
request_builders.cc
sampling.cc
serialization/payload_serializer.cc
slot_ranking.cc
Expand Down Expand Up @@ -129,6 +130,7 @@ set(PROJECT_PUBLIC_HEADERS
../include/object_factory.h
../include/personalization.h
../include/ranking_response.h
../include/request_builders.h
../include/rl_string_view.h
../include/sender.h
../include/slot_ranking.h
Expand Down
23 changes: 13 additions & 10 deletions rlclientlib/live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@ namespace reinforcement_learning
live_model::live_model(const utility::configuration& config, error_fn fn, void* err_context,
trace_logger_factory_t* trace_factory, data_transport_factory_t* t_factory, model_factory_t* m_factory,
sender_factory_t* s_factory, time_provider_factory_t* time_prov_factory)
: _pimpl(new live_model_impl(
config, fn, err_context, trace_factory, t_factory, m_factory, s_factory, time_prov_factory))
, _rank_builder(_pimpl.get())
{
_pimpl = std::unique_ptr<live_model_impl>(
new live_model_impl(config, fn, err_context, trace_factory, t_factory, m_factory, s_factory, time_prov_factory));
}

live_model::live_model(const utility::configuration& config, std::function<void(const api_status&)> error_cb,
trace_logger_factory_t* trace_factory, data_transport_factory_t* t_factory, model_factory_t* m_factory,
sender_factory_t* s_factory, time_provider_factory_t* time_prov_factory)
: _pimpl(new live_model_impl(
config, std::move(error_cb), trace_factory, t_factory, m_factory, s_factory, time_prov_factory))
, _rank_builder(_pimpl.get())
{
_pimpl = std::unique_ptr<live_model_impl>(new live_model_impl(
config, std::move(error_cb), trace_factory, t_factory, m_factory, s_factory, time_prov_factory));
}

live_model::live_model(live_model&& other) noexcept
live_model::live_model(live_model&& other) noexcept : _pimpl(other._pimpl.release()), _rank_builder(_pimpl.get())
{
std::swap(_pimpl, other._pimpl);
_initialized = other._initialized;
}

Expand All @@ -44,6 +45,8 @@ live_model& live_model::operator=(live_model&& other) noexcept
return *this;
}

rank_builder& live_model::get_rank_builder() { return _rank_builder; }

int live_model::init(api_status* status)
{
if (_initialized) { return error_code::success; }
Expand All @@ -64,29 +67,29 @@ int live_model::choose_rank(
const char* event_id, string_view context_json, ranking_response& response, api_status* status)
{
INIT_CHECK();
return choose_rank(event_id, context_json, action_flags::DEFAULT, response, status);
return _rank_builder.set_event_id(event_id).set_context(context_json).rank(response, status);
}

int live_model::choose_rank(string_view context_json, ranking_response& response, api_status* status)
{
INIT_CHECK();
return choose_rank(context_json, action_flags::DEFAULT, response, status);
return _rank_builder.set_context(context_json).rank(response, status);
}

// not implemented yet
int live_model::choose_rank(
const char* event_id, string_view context_json, unsigned int flags, ranking_response& response, api_status* status)
{
INIT_CHECK();
return _pimpl->choose_rank(event_id, context_json, flags, response, status);
return _rank_builder.set_event_id(event_id).set_flags(flags).set_context(context_json).rank(response, status);
}

// not implemented yet
int live_model::choose_rank(
string_view context_json, unsigned int flags, ranking_response& response, api_status* status)
{
INIT_CHECK();
return _pimpl->choose_rank(context_json, flags, response, status);
return _rank_builder.set_context(context_json).set_flags(flags).rank(response, status);
}

int live_model::request_continuous_action(const char* event_id, string_view context_json, unsigned int flags,
Expand Down
27 changes: 27 additions & 0 deletions rlclientlib/request_builders.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "request_builders.h"

#include "live_model_impl.h"

#include <boost/uuid/random_generator.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <string>
#include <vector>

namespace reinforcement_learning
{

rank_builder::rank_builder(live_model_impl* live_model) : model(live_model) {}

int rank_builder::rank(ranking_response& resp, api_status* status)
{
const auto uuid = boost::uuids::to_string(boost::uuids::random_generator()());
if (_event_id == nullptr) { set_event_id(uuid.c_str()); }

auto err_resp = model->choose_rank(_event_id, _context_json, _flags, resp, status);
clear();
return err_resp;
}

rank_builder& rank_builder::clear() { return basic_builder<rank_builder>::clear(); }

} // namespace reinforcement_learning