diff --git a/external_parser/CMakeLists.txt b/external_parser/CMakeLists.txt index 99a822c0d..9b316806a 100644 --- a/external_parser/CMakeLists.txt +++ b/external_parser/CMakeLists.txt @@ -135,7 +135,7 @@ set(binary_parser_headers ${CMAKE_CURRENT_LIST_DIR}/joiners/i_joiner.h ${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.h ${CMAKE_CURRENT_LIST_DIR}/log_converter.h - ${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.h + ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/lru_dedup_cache.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.h ${CMAKE_CURRENT_LIST_DIR}/parse_example_external.h @@ -146,7 +146,7 @@ set(binary_parser_sources ${CMAKE_CURRENT_LIST_DIR}/joiners/example_joiner.cc ${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.cc ${CMAKE_CURRENT_LIST_DIR}/log_converter.cc - ${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.cc + ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/lru_dedup_cache.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.cc ${CMAKE_CURRENT_LIST_DIR}/parse_example_external.cc @@ -161,6 +161,7 @@ target_include_directories(rl_binary_parser ${CMAKE_CURRENT_LIST_DIR}/../ext_libs/zstd/lib/ ${CMAKE_CURRENT_LIST_DIR}/../ext_libs/date/ ) +target_include_directories(rl_binary_parser PRIVATE ${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/example_cache/) # If flatbuffers found via CONFIG, add its target as a library dependency # Otherwise, the flatbuffers MODULE defines FLATBUFFERS_INCLUDE_DIR to add to the include path diff --git a/external_parser/joiners/example_joiner.h b/external_parser/joiners/example_joiner.h index efc11127c..c44346aa4 100644 --- a/external_parser/joiners/example_joiner.h +++ b/external_parser/joiners/example_joiner.h @@ -1,9 +1,9 @@ #pragma once +#include "../rlclientlib/example_cache/lru_dedup_cache.h" #include "event_processors/joined_event.h" #include "event_processors/loop.h" #include "joiners/i_joiner.h" -#include "lru_dedup_cache.h" #include "metrics/metrics.h" #include "parse_example_external.h" #include "vw/core/error_constants.h" diff --git a/external_parser/joiners/i_joiner.h b/external_parser/joiners/i_joiner.h index c7dd5c8f4..fad083684 100644 --- a/external_parser/joiners/i_joiner.h +++ b/external_parser/joiners/i_joiner.h @@ -1,10 +1,10 @@ #pragma once +#include "../rlclientlib/example_cache/lru_dedup_cache.h" #include "event_processors/reward.h" #include "generated/v2/CbEvent_generated.h" #include "generated/v2/FileFormat_generated.h" #include "generated/v2/Metadata_generated.h" -#include "lru_dedup_cache.h" #include "metrics/metrics.h" #include "parse_example_external.h" #include "vw/core/error_constants.h" diff --git a/external_parser/unit_tests/test_lru_dedup_cache.cc b/external_parser/unit_tests/test_lru_dedup_cache.cc index f6abdc267..11973cf4a 100644 --- a/external_parser/unit_tests/test_lru_dedup_cache.cc +++ b/external_parser/unit_tests/test_lru_dedup_cache.cc @@ -1,6 +1,6 @@ #include -#include "lru_dedup_cache.h" +#include "../../rlclientlib/example_cache/lru_dedup_cache.h" #include "parse_example_external.h" #include "test_common.h" #include "vw/config/options_cli.h" diff --git a/include/live_model.h b/include/live_model.h index 7ece185da..0ec565ee7 100644 --- a/include/live_model.h +++ b/include/live_model.h @@ -107,6 +107,16 @@ class live_model */ int init(api_status* status = nullptr); + /** + * @brief Load dedup cache. + * Load the dedup cache from the specified file. This cache is used to + * prevent duplicate actions from being sent to the online trainer. + * @param action_id Hash of the dedup cache + * @param action_str Action string + * @return int Return error code. This will also be returned in the api_status object + */ + int load_action(uint64_t action_id, std::string action_str, api_status* status); + /** * @brief Choose an action, given a list of actions, action features and context features. The * inference library chooses an action by creating a probability distribution over the actions diff --git a/include/model_mgmt.h b/include/model_mgmt.h index 4b381dba7..0b7528f4a 100644 --- a/include/model_mgmt.h +++ b/include/model_mgmt.h @@ -74,6 +74,7 @@ class i_model { public: virtual int update(const model_data& data, bool& model_ready, api_status* status = nullptr) = 0; + virtual int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) = 0; virtual int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) = 0; virtual int choose_continuous_action(string_view features, float& action, float& pdf_value, diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index f2a58afb7..3c53068cc 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -55,6 +55,7 @@ set(PROJECT_SOURCES decision_response.cc dedup.cc error_callback_fn.cc + example_cache/lru_dedup_cache.cc factory_resolver.cc generic_event.cc learning_mode.cc @@ -142,6 +143,7 @@ set(PROJECT_PUBLIC_HEADERS set(PROJECT_PRIVATE_HEADERS console_tracer.h dedup.h + example_cache/lru_dedup_cache.h federation/federated_client.h federation/joined_log_provider.h generic_event.h @@ -211,6 +213,7 @@ target_include_directories(rlclientlib ${CMAKE_CURRENT_SOURCE_DIR}/../include PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/example_cache ${CMAKE_CURRENT_SOURCE_DIR}/../ext_libs/date ) diff --git a/external_parser/lru_dedup_cache.cc b/rlclientlib/example_cache/lru_dedup_cache.cc similarity index 100% rename from external_parser/lru_dedup_cache.cc rename to rlclientlib/example_cache/lru_dedup_cache.cc diff --git a/external_parser/lru_dedup_cache.h b/rlclientlib/example_cache/lru_dedup_cache.h similarity index 100% rename from external_parser/lru_dedup_cache.h rename to rlclientlib/example_cache/lru_dedup_cache.h diff --git a/rlclientlib/extensions/onnx/src/onnx_model.cc b/rlclientlib/extensions/onnx/src/onnx_model.cc index 7d4241941..16f548a85 100644 --- a/rlclientlib/extensions/onnx/src/onnx_model.cc +++ b/rlclientlib/extensions/onnx/src/onnx_model.cc @@ -138,6 +138,9 @@ int onnx_model::update(const model_management::model_data& data, bool& model_rea return error_code::success; } +// TODO: Implement LRU cache for ONNX models. +int onnx_model::load_action(uint64_t, std::string, api_status*) { return error_code::not_supported; } + int onnx_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) { diff --git a/rlclientlib/extensions/onnx/src/onnx_model.h b/rlclientlib/extensions/onnx/src/onnx_model.h index d57aa943a..08bf1f805 100644 --- a/rlclientlib/extensions/onnx/src/onnx_model.h +++ b/rlclientlib/extensions/onnx/src/onnx_model.h @@ -20,6 +20,7 @@ class onnx_model : public model_management::i_model public: onnx_model(i_trace* trace_logger, const char* app_id, const char* output_name, bool use_unstructured_input); int update(const model_management::model_data& data, bool& model_ready, api_status* status = nullptr) override; + int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; diff --git a/rlclientlib/live_model.cc b/rlclientlib/live_model.cc index ab4776764..e15b39548 100644 --- a/rlclientlib/live_model.cc +++ b/rlclientlib/live_model.cc @@ -61,6 +61,12 @@ std::vector live_model::c_array_to_vector(const int* c_array, size_t array_ return std::vector(c_array, c_array + array_size); } +int live_model::load_action(uint64_t action_id, std::string action_str, api_status* status) +{ + INIT_CHECK(); + return _pimpl->load_action(action_id, std::move(action_str), status); +} + int live_model::choose_rank( const char* event_id, string_view context_json, ranking_response& response, api_status* status) { diff --git a/rlclientlib/live_model_impl.cc b/rlclientlib/live_model_impl.cc index ebd1c8c22..681edad3e 100644 --- a/rlclientlib/live_model_impl.cc +++ b/rlclientlib/live_model_impl.cc @@ -74,6 +74,11 @@ int live_model_impl::init(api_status* status) return error_code::success; } +int live_model_impl::load_action(uint64_t action_id, std::string action_str, api_status* status) +{ + return _model->load_action(action_id, std::move(action_str), status); +} + int live_model_impl::choose_rank( const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status) { diff --git a/rlclientlib/live_model_impl.h b/rlclientlib/live_model_impl.h index 8a5a0f237..6bce84b16 100644 --- a/rlclientlib/live_model_impl.h +++ b/rlclientlib/live_model_impl.h @@ -28,6 +28,7 @@ class live_model_impl int init(api_status* status); + int load_action(uint64_t action_id, std::string action_str, api_status* status); int choose_rank( const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status); // here the event_id is auto-generated diff --git a/rlclientlib/vw_model/pdf_model.cc b/rlclientlib/vw_model/pdf_model.cc index 8c08da0b5..f1ed8c44a 100644 --- a/rlclientlib/vw_model/pdf_model.cc +++ b/rlclientlib/vw_model/pdf_model.cc @@ -23,6 +23,9 @@ int pdf_model::update(const model_data& data, bool& model_ready, api_status* sta return error_code::success; } +// TODO: Implement LRU cache for PDF models. +int pdf_model::load_action(uint64_t, std::string, api_status*) { return error_code::not_supported; } + int pdf_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) { diff --git a/rlclientlib/vw_model/pdf_model.h b/rlclientlib/vw_model/pdf_model.h index 883e1fbee..77bdac3a5 100644 --- a/rlclientlib/vw_model/pdf_model.h +++ b/rlclientlib/vw_model/pdf_model.h @@ -20,6 +20,7 @@ class pdf_model : public i_model public: pdf_model(i_trace* trace_logger, const utility::configuration& config); int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override; + int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version, diff --git a/rlclientlib/vw_model/safe_vw.cc b/rlclientlib/vw_model/safe_vw.cc index 93db5e62b..92e9262c7 100644 --- a/rlclientlib/vw_model/safe_vw.cc +++ b/rlclientlib/vw_model/safe_vw.cc @@ -120,7 +120,28 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector& acti for (auto&& ex : examples) { _example_pool.emplace_back(ex); } } -void safe_vw::rank(string_view context, std::vector& actions, std::vector& scores) +void safe_vw::load_action(uint64_t action_id, std::string action_str, lru_dedup_cache* action_cache) +{ + VW::multi_ex examples; + examples.push_back(get_or_create_example()); + + if (_vw->audit) + { + _vw->audit_buffer->clear(); + VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); + } + else + { + VW::read_line_json_s(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this); + } + action_cache->add(action_id, examples[0]); + + // clean up examples and push examples back into pool for re-use + for (auto&& ex : examples) { _example_pool.emplace_back(ex); } +} + +void safe_vw::rank( + string_view context, std::vector& actions, std::vector& scores, lru_dedup_cache* action_cache) { VW::multi_ex examples; examples.push_back(get_or_create_example()); @@ -128,12 +149,19 @@ void safe_vw::rank(string_view context, std::vector& actions, std::vectordedup_examples; + // check for null if (_vw->audit) { _vw->audit_buffer->clear(); - VW::read_line_json_s(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this); + VW::read_line_json_s( + *_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, action_dict); + } + else + { + VW::read_line_json_s( + *_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, action_dict); } - else { VW::read_line_json_s(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this); } // finalize example VW::setup_examples(*_vw, examples); diff --git a/rlclientlib/vw_model/safe_vw.h b/rlclientlib/vw_model/safe_vw.h index c97b4b064..0405a8994 100644 --- a/rlclientlib/vw_model/safe_vw.h +++ b/rlclientlib/vw_model/safe_vw.h @@ -1,5 +1,6 @@ #pragma once +#include "lru_dedup_cache.h" #include "model_mgmt.h" #include "vw/core/vw.h" @@ -27,7 +28,9 @@ class safe_vw ~safe_vw(); void parse_context_with_pdf(string_view context, std::vector& actions, std::vector& scores); - void rank(string_view context, std::vector& actions, std::vector& scores); + void load_action(uint64_t action_id, std::string action_str, lru_dedup_cache* action_cache); + void rank(string_view context, std::vector& actions, std::vector& scores, + lru_dedup_cache* action_cache = nullptr); void choose_continuous_action(string_view context, float& action, float& pdf_value); // Used for CCB void rank_decisions(const std::vector& event_ids, string_view context, diff --git a/rlclientlib/vw_model/vw_model.cc b/rlclientlib/vw_model/vw_model.cc index 53d951a93..3a2e818c4 100644 --- a/rlclientlib/vw_model/vw_model.cc +++ b/rlclientlib/vw_model/vw_model.cc @@ -67,6 +67,14 @@ int vw_model::update(const model_data& data, bool& model_ready, api_status* stat return error_code::success; } +int vw_model::load_action(uint64_t action_id, std::string action_str, api_status* status) +{ + std::lock_guard lock(_mutex); + auto vw = _vw_pool.get_or_create(); + vw->load_action(action_id, action_str, &_action_cache); + return error_code::success; +} + int vw_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) { @@ -75,7 +83,8 @@ int vw_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view f auto vw = _vw_pool.get_or_create(); // Get a ranked list of action_ids and corresponding pdf - vw->rank(features, action_ids, action_pdf); + std::lock_guard lock(_mutex); + vw->rank(features, action_ids, action_pdf, &_action_cache); if (_audit) { write_audit_log(event_id, vw->get_audit_data()); } @@ -97,6 +106,7 @@ int vw_model::choose_rank_multistep(const char* event_id, uint64_t rnd_seed, str const episode_history& history, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status) { + std::lock_guard lock(_mutex); return choose_rank(event_id, rnd_seed, features, action_ids, action_pdf, model_version, status); } @@ -132,6 +142,7 @@ int vw_model::request_decision(const std::vector& event_ids, string auto vw = _vw_pool.get_or_create(); // Get a ranked list of action_ids and corresponding pdf + std::lock_guard lock(_mutex); vw->rank_decisions(event_ids, features, actions_ids, action_pdfs); model_version = vw->id(); diff --git a/rlclientlib/vw_model/vw_model.h b/rlclientlib/vw_model/vw_model.h index dd872b6e9..68f8772f1 100644 --- a/rlclientlib/vw_model/vw_model.h +++ b/rlclientlib/vw_model/vw_model.h @@ -1,10 +1,13 @@ #pragma once #include "../utility/versioned_object_pool.h" +#include "lru_dedup_cache.h" #include "model_mgmt.h" #include "multistep.h" #include "safe_vw.h" #include "trace_logger.h" +#include + namespace reinforcement_learning { namespace utility @@ -26,6 +29,7 @@ class vw_model : public i_model vw_model(i_trace* trace_logger, const utility::configuration& config); int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override; + int load_action(uint64_t action_id, std::string action_str, api_status* status = nullptr) override; int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector& action_ids, std::vector& action_pdf, std::string& model_version, api_status* status = nullptr) override; int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version, @@ -48,6 +52,8 @@ class vw_model : public i_model const std::string _quiet_commandline_options{"--json --quiet"}; const std::string _upgrade_to_CCB_vw_commandline_options{"--ccb_explore_adf --json --quiet"}; utility::versioned_object_pool _vw_pool; + lru_dedup_cache _action_cache; + std::mutex _mutex; i_trace* _trace_logger; }; } // namespace model_management diff --git a/unit_test/live_model_test.cc b/unit_test/live_model_test.cc index 83de9cd1a..fdcd05aa0 100644 --- a/unit_test/live_model_test.cc +++ b/unit_test/live_model_test.cc @@ -448,7 +448,7 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only) r::ranking_response response; const auto JSON_CB_CONTEXT_3ACTIONS = - R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"TAction":{"a1":"f1"} },{"TAction":{"a2":"f2"}},{"TAction":{"a3":"f3"}}]})"; + R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"TAction":{"a1":"f1"}},{"TAction":{"a2":"f2"}},{"TAction":{"a3":"f3"}}]})"; // request ranking BOOST_CHECK_EQUAL(model.choose_rank(event_id, JSON_CB_CONTEXT_3ACTIONS, response), err::success); @@ -468,6 +468,54 @@ BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only) } } +BOOST_AUTO_TEST_CASE(live_model_ranking_request_check_response_pdf_explore_only_dedup) +{ + // create a simple ds configuration + u::configuration config; + cfg::create_from_json(JSON_CFG, config); + config.set(r::name::EH_TEST, "true"); + config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA); + config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER); + config.set(r::name::INTERACTION_SENDER_IMPLEMENTATION, r::value::INTERACTION_FILE_SENDER); + config.set( + r::name::MODEL_VW_INITIAL_COMMAND_LINE, "--cb_explore_adf --json --quiet --epsilon 0.3 --first_only --id N/A"); + + r::api_status status; + + // create the ds live_model, and initialize it with the config + r::live_model model(config); + + BOOST_CHECK_EQUAL(model.init(&status), err::success); + const auto event_id = "event_id"; + + r::ranking_response response; + + const auto JSON_CB_CONTEXT_3ACTIONS_DEDUP = + R"({"GUser":{"id":"a","major":"eng","hobby":"hiking"},"_multi":[{"__aid":1},{"__aid":2},{"__aid":3}]})"; + + // add dedup + BOOST_CHECK_EQUAL(model.load_action(1, "{\"TAction\":{\"a1\":\"f1\"}}", &status), err::success); + BOOST_CHECK_EQUAL(model.load_action(2, "{\"TAction\":{\"a2\":\"f2\"}}", &status), err::success); + BOOST_CHECK_EQUAL(model.load_action(3, "{\"TAction\":{\"a3\":\"f3\"}}", &status), err::success); + + // request ranking + BOOST_CHECK_EQUAL(model.choose_rank(event_id, JSON_CB_CONTEXT_3ACTIONS_DEDUP, response), err::success); + + size_t num_actions = response.size(); + BOOST_CHECK_EQUAL(num_actions, 3); + + const float EXPECTED_PROBS[3] = {0.8f, 0.1f, 0.1f}; + + // check that our PDF is what we expected + r::ranking_response::iterator it = response.begin(); + + for (uint32_t i = 0; i < num_actions; i++) + { + auto action_probability = *(it + i); + BOOST_CHECK_CLOSE(action_probability.probability, EXPECTED_PROBS[action_probability.action_id], FLOAT_TOL); + } +} + BOOST_AUTO_TEST_CASE(live_model_ranking_w_las_request_check_response_pdf_explore_only) { // create a simple ds configuration