diff --git a/benchmarks/benchmark_cb_v2.cc b/benchmarks/benchmark_cb_v2.cc index 1f4248bdd..90a556b8f 100644 --- a/benchmarks/benchmark_cb_v2.cc +++ b/benchmarks/benchmark_cb_v2.cc @@ -27,6 +27,26 @@ const auto JSON_CFG = R"( } )"; +u::configuration get_config(bool dedup, bool compression) +{ + u::configuration config; + cfg::create_from_json(JSON_CFG, config); + config.set(r::name::PROTOCOL_VERSION, "2"); + config.set(r::name::EH_TEST, "true"); + config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA); + config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER); + config.set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER); + config.set(r::name::INTERACTION_FILE_NAME, "/dev/null"); + config.set(r::name::OBSERVATION_FILE_NAME, "/dev/null"); + config.set(r::name::MODEL_BACKGROUND_REFRESH, "false"); + // config.set(r::name::MODEL_IMPLEMENTATION, r::value::PASSTHROUGH_PDF_MODEL); + config.set(r::name::VW_POOL_INIT_SIZE, "1"); + config.set(r::name::INTERACTION_USE_COMPRESSION, compression ? "true" : "false"); + config.set(r::name::INTERACTION_USE_DEDUP, dedup ? "true" : "false"); + config.set("queue.mode", "BLOCK"); + return config; +} + template static void bench_cb(benchmark::State& state, ExtraArgs&&... extra_args) { @@ -44,21 +64,7 @@ static void bench_cb(benchmark::State& state, ExtraArgs&&... extra_args) std::vector examples; std::generate_n(std::back_inserter(examples), count, [&cb_gen] { return cb_gen.gen_example(); }); - u::configuration config; - cfg::create_from_json(JSON_CFG, config); - config.set(r::name::PROTOCOL_VERSION, "2"); - config.set(r::name::EH_TEST, "true"); - config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA); - config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER); - config.set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER); - config.set(r::name::INTERACTION_FILE_NAME, "/dev/null"); - config.set(r::name::OBSERVATION_FILE_NAME, "/dev/null"); - config.set(r::name::MODEL_BACKGROUND_REFRESH, "false"); - // config.set(r::name::MODEL_IMPLEMENTATION, r::value::PASSTHROUGH_PDF_MODEL); - config.set(r::name::VW_POOL_INIT_SIZE, "1"); - config.set(r::name::INTERACTION_USE_COMPRESSION, compression ? "true" : "false"); - config.set(r::name::INTERACTION_USE_DEDUP, dedup ? "true" : "false"); - config.set("queue.mode", "BLOCK"); + auto config = get_config(dedup, compression); r::api_status status; r::live_model model(config); @@ -82,6 +88,49 @@ static void bench_cb(benchmark::State& state, ExtraArgs&&... extra_args) } } +template +static void bench_cb_w_builder(benchmark::State& state, ExtraArgs&&... extra_args) +{ + int res[sizeof...(extra_args)] = {extra_args...}; + auto shared_features = res[0]; + auto action_features = res[1]; + auto actions_per_decision = res[2]; + auto total_actions = res[3]; + auto count = res[4]; + bool compression = res[5]; + bool dedup = res[6]; + + cb_decision_gen cb_gen(shared_features, action_features, actions_per_decision, total_actions, 0, false); + + std::vector examples; + std::generate_n(std::back_inserter(examples), count, [&cb_gen] { return cb_gen.gen_example(); }); + + auto config = get_config(dedup, compression); + + r::api_status status; + r::live_model model(config); + model.init(&status); + const auto event_id = "event_id"; + + auto& rank_builder = model.get_rank_builder(); + + r::ranking_response response; + + for (auto _ : state) + { + for (size_t i = 0; i < count; i++) + { + if (rank_builder.set_event_id(event_id).set_context(examples[i].c_str()).rank(response, &status) != err::success) + { + std::cout << "there was an error so something went wrong during " + "benchmarking: " + << status.get_error_msg() << std::endl; + } + } + benchmark::ClobberMemory(); + } +} + // characteristics of the benchmark examples that will be generated are: // x shared features @@ -92,6 +141,8 @@ static void bench_cb(benchmark::State& state, ExtraArgs&&... extra_args) // compression (on/off) // dedup (on/off) BENCHMARK_CAPTURE(bench_cb, non_dedupable_payload, 20, 10, 50, 2000, 500, false, false); +BENCHMARK_CAPTURE(bench_cb_w_builder, non_dedupable_payload_w_builder, 20, 10, 50, 2000, 500, false, false); + BENCHMARK_CAPTURE(bench_cb, non_dedupable_payload_compression, 20, 10, 50, 2000, 500, true, false); BENCHMARK_CAPTURE(bench_cb, non_dedupable_payload_dedup, 20, 10, 50, 2000, 500, false, true); BENCHMARK_CAPTURE(bench_cb, non_dedupable_payload_compression_dedup, 20, 10, 50, 2000, 500, true, true); \ No newline at end of file diff --git a/include/live_model.h b/include/live_model.h index 7ece185da..d9f28e633 100644 --- a/include/live_model.h +++ b/include/live_model.h @@ -6,17 +6,13 @@ * @date 2018-07-18 */ #pragma once -#include "action_flags.h" -#include "continuous_action_response.h" -#include "decision_response.h" #include "err_constants.h" #include "factory_resolver.h" #include "future_compat.h" -#include "multi_slot_response.h" -#include "multi_slot_response_detailed.h" #include "multistep.h" -#include "ranking_response.h" +#include "request_builders.h" #include "sender.h" +#include "vw/core/example.h" #include #include @@ -107,6 +103,8 @@ class live_model */ int init(api_status* status = nullptr); + rank_builder& get_rank_builder(); + /** * @brief Choose an action, given a list of actions, action features and context features. The * inference library chooses an action by creating a probability distribution over the actions @@ -487,6 +485,8 @@ class live_model const std::vector default_baseline_vector = std::vector(); static std::vector c_array_to_vector( const int* c_array, size_t array_size); //! Convert baseline_actions from c array to std vector. + + rank_builder _rank_builder; }; /** @@ -511,4 +511,4 @@ live_model::live_model(const utility::configuration& config, error_fn_t +class basic_builder +{ +protected: + const char* _event_id = nullptr; + unsigned int _flags = action_flags::DEFAULT; + string_view _context_json; + +public: + FluentBuilderT& set_event_id(const char* event_id) + { + _event_id = event_id; + return *static_cast(this); + } + + FluentBuilderT& set_flags(unsigned int flags) + { + _flags = flags; + return *static_cast(this); + } + + FluentBuilderT& set_context(string_view context_json) + { + _context_json = context_json; + return *static_cast(this); + } + + FluentBuilderT& clear() + { + _event_id = nullptr; + _flags = action_flags::DEFAULT; + _context_json = {}; + return *static_cast(this); + } +}; + +class rank_builder : public basic_builder +{ + live_model_impl* model; + +public: + rank_builder(live_model_impl* live_model); + int rank(ranking_response& resp, api_status* status = nullptr); + rank_builder& clear(); +}; +} // namespace reinforcement_learning \ No newline at end of file diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index 13d6e27bc..eeb940d3a 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -75,6 +75,7 @@ set(PROJECT_SOURCES multi_slot_response_detailed.cc ranking_event.cc ranking_response.cc + request_builders.cc sampling.cc serialization/payload_serializer.cc slot_ranking.cc @@ -129,6 +130,7 @@ set(PROJECT_PUBLIC_HEADERS ../include/object_factory.h ../include/personalization.h ../include/ranking_response.h + ../include/request_builders.h ../include/rl_string_view.h ../include/sender.h ../include/slot_ranking.h diff --git a/rlclientlib/live_model.cc b/rlclientlib/live_model.cc index f781d1c2e..991109aee 100644 --- a/rlclientlib/live_model.cc +++ b/rlclientlib/live_model.cc @@ -16,22 +16,23 @@ namespace reinforcement_learning live_model::live_model(const utility::configuration& config, error_fn fn, void* err_context, trace_logger_factory_t* trace_factory, data_transport_factory_t* t_factory, model_factory_t* m_factory, sender_factory_t* s_factory, time_provider_factory_t* time_prov_factory) + : _pimpl(new live_model_impl( + config, fn, err_context, trace_factory, t_factory, m_factory, s_factory, time_prov_factory)) + , _rank_builder(_pimpl.get()) { - _pimpl = std::unique_ptr( - new live_model_impl(config, fn, err_context, trace_factory, t_factory, m_factory, s_factory, time_prov_factory)); } live_model::live_model(const utility::configuration& config, std::function error_cb, trace_logger_factory_t* trace_factory, data_transport_factory_t* t_factory, model_factory_t* m_factory, sender_factory_t* s_factory, time_provider_factory_t* time_prov_factory) + : _pimpl(new live_model_impl( + config, std::move(error_cb), trace_factory, t_factory, m_factory, s_factory, time_prov_factory)) + , _rank_builder(_pimpl.get()) { - _pimpl = std::unique_ptr(new live_model_impl( - config, std::move(error_cb), trace_factory, t_factory, m_factory, s_factory, time_prov_factory)); } -live_model::live_model(live_model&& other) noexcept +live_model::live_model(live_model&& other) noexcept : _pimpl(other._pimpl.release()), _rank_builder(_pimpl.get()) { - std::swap(_pimpl, other._pimpl); _initialized = other._initialized; } @@ -44,6 +45,8 @@ live_model& live_model::operator=(live_model&& other) noexcept return *this; } +rank_builder& live_model::get_rank_builder() { return _rank_builder; } + int live_model::init(api_status* status) { if (_initialized) { return error_code::success; } @@ -64,13 +67,13 @@ int live_model::choose_rank( const char* event_id, string_view context_json, ranking_response& response, api_status* status) { INIT_CHECK(); - return choose_rank(event_id, context_json, action_flags::DEFAULT, response, status); + return _rank_builder.set_event_id(event_id).set_context(context_json).rank(response, status); } int live_model::choose_rank(string_view context_json, ranking_response& response, api_status* status) { INIT_CHECK(); - return choose_rank(context_json, action_flags::DEFAULT, response, status); + return _rank_builder.set_context(context_json).rank(response, status); } // not implemented yet @@ -78,7 +81,7 @@ int live_model::choose_rank( const char* event_id, string_view context_json, unsigned int flags, ranking_response& response, api_status* status) { INIT_CHECK(); - return _pimpl->choose_rank(event_id, context_json, flags, response, status); + return _rank_builder.set_event_id(event_id).set_flags(flags).set_context(context_json).rank(response, status); } // not implemented yet @@ -86,7 +89,7 @@ int live_model::choose_rank( string_view context_json, unsigned int flags, ranking_response& response, api_status* status) { INIT_CHECK(); - return _pimpl->choose_rank(context_json, flags, response, status); + return _rank_builder.set_context(context_json).set_flags(flags).rank(response, status); } int live_model::request_continuous_action(const char* event_id, string_view context_json, unsigned int flags, diff --git a/rlclientlib/request_builders.cc b/rlclientlib/request_builders.cc new file mode 100644 index 000000000..43522e5d5 --- /dev/null +++ b/rlclientlib/request_builders.cc @@ -0,0 +1,27 @@ +#include "request_builders.h" + +#include "live_model_impl.h" + +#include +#include +#include +#include + +namespace reinforcement_learning +{ + +rank_builder::rank_builder(live_model_impl* live_model) : model(live_model) {} + +int rank_builder::rank(ranking_response& resp, api_status* status) +{ + const auto uuid = boost::uuids::to_string(boost::uuids::random_generator()()); + if (_event_id == nullptr) { set_event_id(uuid.c_str()); } + + auto err_resp = model->choose_rank(_event_id, _context_json, _flags, resp, status); + clear(); + return err_resp; +} + +rank_builder& rank_builder::clear() { return basic_builder::clear(); } + +} // namespace reinforcement_learning \ No newline at end of file