Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use lru_dedup_dict for rank call #569

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions external_parser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ set(binary_parser_headers
${CMAKE_CURRENT_LIST_DIR}/joiners/i_joiner.h
${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.h
${CMAKE_CURRENT_LIST_DIR}/log_converter.h
${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.h
${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.h
${CMAKE_CURRENT_LIST_DIR}/parse_example_external.h
Expand All @@ -146,7 +146,7 @@ set(binary_parser_sources
${CMAKE_CURRENT_LIST_DIR}/joiners/example_joiner.cc
${CMAKE_CURRENT_LIST_DIR}/joiners/multistep_example_joiner.cc
${CMAKE_CURRENT_LIST_DIR}/log_converter.cc
${CMAKE_CURRENT_LIST_DIR}/lru_dedup_cache.cc
${CMAKE_CURRENT_LIST_DIR}/../rlclientlib/lru_dedup_cache.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_binary.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_converter.cc
${CMAKE_CURRENT_LIST_DIR}/parse_example_external.cc
Expand Down
2 changes: 1 addition & 1 deletion external_parser/joiners/example_joiner.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include "../rlclientlib/lru_dedup_cache.h"
#include "event_processors/joined_event.h"
#include "event_processors/loop.h"
#include "joiners/i_joiner.h"
#include "lru_dedup_cache.h"
#include "metrics/metrics.h"
#include "parse_example_external.h"
#include "vw/core/error_constants.h"
Expand Down
2 changes: 1 addition & 1 deletion external_parser/joiners/i_joiner.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#pragma once

#include "../rlclientlib/lru_dedup_cache.h"
#include "event_processors/reward.h"
#include "generated/v2/CbEvent_generated.h"
#include "generated/v2/FileFormat_generated.h"
#include "generated/v2/Metadata_generated.h"
#include "lru_dedup_cache.h"
#include "metrics/metrics.h"
#include "parse_example_external.h"
#include "vw/core/error_constants.h"
Expand Down
2 changes: 1 addition & 1 deletion external_parser/unit_tests/test_lru_dedup_cache.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <boost/test/unit_test.hpp>

#include "lru_dedup_cache.h"
#include "../rlclientlib/lru_dedup_cache.h"
#include "parse_example_external.h"
#include "test_common.h"
#include "vw/config/options_cli.h"
Expand Down
10 changes: 10 additions & 0 deletions include/live_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ class live_model
*/
int init(api_status* status = nullptr);

/**
* @brief Load dedup cache.
* Load the dedup cache from the specified file. This cache is used to
* prevent duplicate actions from being sent to the online trainer.
* @param hash Hash of the dedup cache
* @param action_str Action string
* @return int Return error code. This will also be returned in the api_status object
*/
int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status);
bassmang marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Choose an action, given a list of actions, action features and context features. The
* inference library chooses an action by creating a probability distribution over the actions
Expand Down
1 change: 1 addition & 0 deletions include/model_mgmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class i_model
{
public:
virtual int update(const model_data& data, bool& model_ready, api_status* status = nullptr) = 0;
virtual int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) = 0;
virtual int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) = 0;
virtual int choose_continuous_action(string_view features, float& action, float& pdf_value,
Expand Down
2 changes: 2 additions & 0 deletions rlclientlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ set(PROJECT_SOURCES
logger/logger_facade.cc
logger/preamble.cc
logger/preamble_sender.cc
lru_dedup_cache.cc
model_mgmt/data_callback_fn.cc
model_mgmt/empty_data_transport.cc
model_mgmt/file_model_loader.cc
Expand Down Expand Up @@ -149,6 +150,7 @@ set(PROJECT_PRIVATE_HEADERS
logger/async_batcher.h
logger/event_logger.h
logger/logger_facade.h
lru_dedup_cache.h
model_mgmt/data_callback_fn.h
model_mgmt/empty_data_transport.h
model_mgmt/file_model_loader.h
Expand Down
6 changes: 6 additions & 0 deletions rlclientlib/extensions/onnx/src/onnx_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ int onnx_model::update(const model_management::model_data& data, bool& model_rea
return error_code::success;
}

// TODO: Implement LRU cache for ONNX models.
int onnx_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
return error_code::not_supported;
}

int onnx_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/extensions/onnx/src/onnx_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class onnx_model : public model_management::i_model
public:
onnx_model(i_trace* trace_logger, const char* app_id, const char* output_name, bool use_unstructured_input);
int update(const model_management::model_data& data, bool& model_ready, api_status* status = nullptr) override;
int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override;
int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) override;

Expand Down
6 changes: 6 additions & 0 deletions rlclientlib/live_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ std::vector<int> live_model::c_array_to_vector(const int* c_array, size_t array_
return std::vector<int>(c_array, c_array + array_size);
}

int live_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
INIT_CHECK();
return _pimpl->add_lru_dedup_cache(hash, std::move(action_str), status);
}

int live_model::choose_rank(
const char* event_id, string_view context_json, ranking_response& response, api_status* status)
{
Expand Down
5 changes: 5 additions & 0 deletions rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ int live_model_impl::init(api_status* status)
return error_code::success;
}

int live_model_impl::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
return _model->add_lru_dedup_cache(hash, std::move(action_str), status);
}

int live_model_impl::choose_rank(
const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/live_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class live_model_impl

int init(api_status* status);

int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status);
int choose_rank(
const char* event_id, string_view context, unsigned int flags, ranking_response& response, api_status* status);
// here the event_id is auto-generated
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct lru_dedup_cache
void* context = nullptr);
bool exists(uint64_t dedup_id);
void clear(release_example_f release_example = lru_dedup_cache::noop_release_example_f, void* context = nullptr);
std::unordered_map<uint64_t, VW::example*>* get_dict() { return &dedup_examples; }
bassmang marked this conversation as resolved.
Show resolved Hide resolved

lru_dedup_cache() = default;
~lru_dedup_cache() = default;
Expand Down
8 changes: 7 additions & 1 deletion rlclientlib/vw_model/pdf_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace model_management
// We construct a VW object here to use the example parser to parse joined dsjson-style examples
// to extract the PDF.
pdf_model::pdf_model(i_trace* trace_logger, const utility::configuration& /*unused*/)
: _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf"))
: _trace_logger(trace_logger), _vw(new safe_vw("--json --quiet --cb_adf", nullptr))
{
}

Expand All @@ -23,6 +23,12 @@ int pdf_model::update(const model_data& data, bool& model_ready, api_status* sta
return error_code::success;
}

// TODO: Implement LRU cache for PDF models.
int pdf_model::add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status)
{
return error_code::not_supported;
}

int pdf_model::choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status)
{
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/vw_model/pdf_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class pdf_model : public i_model
public:
pdf_model(i_trace* trace_logger, const utility::configuration& config);
int update(const model_data& data, bool& model_ready, api_status* status = nullptr) override;
int add_lru_dedup_cache(uint64_t hash, std::string action_str, api_status* status = nullptr) override;
int choose_rank(const char* event_id, uint64_t rnd_seed, string_view features, std::vector<int>& action_ids,
std::vector<float>& action_pdf, std::string& model_version, api_status* status = nullptr) override;
int choose_continuous_action(string_view features, float& action, float& pdf_value, std::string& model_version,
Expand Down
68 changes: 52 additions & 16 deletions rlclientlib/vw_model/safe_vw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ namespace reinforcement_learning
{
static const std::string SEED_TAG = "seed=";

safe_vw::safe_vw(std::shared_ptr<safe_vw> master) : _master(std::move(master))
safe_vw::safe_vw(std::shared_ptr<safe_vw> master, lru_dedup_cache* dedup_cache)
bassmang marked this conversation as resolved.
Show resolved Hide resolved
: _master(std::move(master)), _dedup_cache(dedup_cache)
{
_vw = VW::seed_vw_model(_master->_vw, "", nullptr, nullptr);
init();
}

safe_vw::safe_vw(const char* model_data, size_t len)
safe_vw::safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache)
{
io_buf buf;
buf.add_file(VW::io::create_buffer_view(model_data, len));
Expand All @@ -34,7 +35,8 @@ safe_vw::safe_vw(const char* model_data, size_t len)
init();
}

safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline)
safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache)
: _dedup_cache(dedup_cache)
{
io_buf buf;
buf.add_file(VW::io::create_buffer_view(model_data, len));
Expand All @@ -43,7 +45,7 @@ safe_vw::safe_vw(const char* model_data, size_t len, const std::string& vw_comma
init();
}

safe_vw::safe_vw(const std::string& vw_commandline)
safe_vw::safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache) : _dedup_cache(dedup_cache)
{
_vw = VW::initialize(vw_commandline);
init();
Expand Down Expand Up @@ -120,6 +122,24 @@ void safe_vw::parse_context_with_pdf(string_view context, std::vector<int>& acti
for (auto&& ex : examples) { _example_pool.emplace_back(ex); }
}

void safe_vw::add_lru_dedup_cache(uint64_t hash, std::string action_str)
{
if (_dedup_cache == nullptr) { _dedup_cache = new lru_dedup_cache(); }
bassmang marked this conversation as resolved.
Show resolved Hide resolved
VW::multi_ex examples;
examples.push_back(get_or_create_example());

if (_vw->audit)
{
_vw->audit_buffer->clear();
VW::read_line_json_s<true>(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this);
}
else
{
VW::read_line_json_s<false>(*_vw, examples, &action_str[0], action_str.size(), get_or_create_example_f, this);
}
_dedup_cache->add(hash, examples[0]);
}

void safe_vw::rank(string_view context, std::vector<int>& actions, std::vector<float>& scores)
{
VW::multi_ex examples;
Expand All @@ -131,9 +151,14 @@ void safe_vw::rank(string_view context, std::vector<int>& actions, std::vector<f
if (_vw->audit)
{
_vw->audit_buffer->clear();
VW::read_line_json_s<true>(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this);
VW::read_line_json_s<true>(
*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict());
}
else
{
VW::read_line_json_s<false>(
*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this, _dedup_cache->get_dict());
}
else { VW::read_line_json_s<false>(*_vw, examples, &line_vec[0], line_vec.size(), get_or_create_example_f, this); }

// finalize example
VW::setup_examples(*_vw, examples);
Expand Down Expand Up @@ -372,19 +397,30 @@ void safe_vw::init()
}
}

safe_vw_factory::safe_vw_factory(std::string command_line) : _command_line(std::move(command_line)) {}
safe_vw_factory::safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache)
: _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data) : _master_data(master_data) {}
safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data) : _master_data(master_data) {}
safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data& master_data, std::string command_line)
: _master_data(master_data), _command_line(std::move(command_line))
safe_vw_factory::safe_vw_factory(
const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

safe_vw_factory::safe_vw_factory(const model_management::model_data&& master_data, std::string command_line)
: _master_data(master_data), _command_line(std::move(command_line))
safe_vw_factory::safe_vw_factory(
const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache)
: _master_data(master_data), _command_line(std::move(command_line)), _dedup_cache(dedup_cache)
{
}

Expand All @@ -393,13 +429,13 @@ safe_vw* safe_vw_factory::operator()()
if ((_master_data.data() != nullptr) && !_command_line.empty())
{
// Construct new vw object from raw model data and command line argument
return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line);
return new safe_vw(_master_data.data(), _master_data.data_sz(), _command_line, _dedup_cache);
}
if (_master_data.data() != nullptr)
{
// Construct new vw object from raw model data.
return new safe_vw(_master_data.data(), _master_data.data_sz());
return new safe_vw(_master_data.data(), _master_data.data_sz(), _dedup_cache);
}
return new safe_vw(_command_line);
return new safe_vw(_command_line, _dedup_cache);
}
} // namespace reinforcement_learning
24 changes: 15 additions & 9 deletions rlclientlib/vw_model/safe_vw.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "lru_dedup_cache.h"
#include "model_mgmt.h"
#include "vw/core/vw.h"

Expand All @@ -14,19 +15,21 @@ class safe_vw
std::shared_ptr<safe_vw> _master;
VW::workspace* _vw;
std::vector<VW::example*> _example_pool;
lru_dedup_cache* _dedup_cache;

VW::example* get_or_create_example();
static VW::example& get_or_create_example_f(void* vw);

public:
safe_vw(std::shared_ptr<safe_vw> master);
safe_vw(const char* model_data, size_t len, const std::string& vw_commandline);
safe_vw(const char* model_data, size_t len);
safe_vw(const std::string& vw_commandline);
safe_vw(std::shared_ptr<safe_vw> master, lru_dedup_cache* dedup_cache);
safe_vw(const char* model_data, size_t len, const std::string& vw_commandline, lru_dedup_cache* dedup_cache);
safe_vw(const char* model_data, size_t len, lru_dedup_cache* dedup_cache);
safe_vw(const std::string& vw_commandline, lru_dedup_cache* dedup_cache);

~safe_vw();

void parse_context_with_pdf(string_view context, std::vector<int>& actions, std::vector<float>& scores);
void add_lru_dedup_cache(uint64_t hash, std::string action_str);
void rank(string_view context, std::vector<int>& actions, std::vector<float>& scores);
void choose_continuous_action(string_view context, float& action, float& pdf_value);
// Used for CCB
Expand Down Expand Up @@ -57,14 +60,17 @@ class safe_vw_factory
{
model_management::model_data _master_data;
std::string _command_line;
lru_dedup_cache* _dedup_cache;

public:
// model_data is copied and stored in the factory object.
safe_vw_factory(std::string command_line);
safe_vw_factory(const model_management::model_data& master_data);
safe_vw_factory(const model_management::model_data&& master_data);
safe_vw_factory(const model_management::model_data& master_data, std::string command_line);
safe_vw_factory(const model_management::model_data&& master_data, std::string command_line);
safe_vw_factory(std::string command_line, lru_dedup_cache* dedup_cache);
safe_vw_factory(const model_management::model_data& master_data, lru_dedup_cache* dedup_cache);
safe_vw_factory(const model_management::model_data&& master_data, lru_dedup_cache* dedup_cache);
safe_vw_factory(
const model_management::model_data& master_data, std::string command_line, lru_dedup_cache* dedup_cache);
safe_vw_factory(
const model_management::model_data&& master_data, std::string command_line, lru_dedup_cache* dedup_cache);

safe_vw* operator()();
};
Expand Down
Loading