diff --git a/include/constants.h b/include/constants.h index 1b4962230..0bb94f579 100644 --- a/include/constants.h +++ b/include/constants.h @@ -175,9 +175,9 @@ const char* const REWARD_FUNCTION_MAX = "REWARD_FUNCTION_MAX"; // These are outside of #ifdef section so that we can recognize them as invalid // configuration options when rlclientlib is compiled without RL_BUILD_FEDERATION // -// Use local_loop_controller for model data +// Use federated_loop_controller for model data const char* const LOCAL_LOOP_MODEL_DATA = "LOCAL_LOOP_MODEL_DATA"; -// Send events to local_loop_controller +// Send events to federated_loop_controller const char* const LOCAL_LOOP_SENDER = "LOCAL_LOOP_SENDER"; const char* get_default_episode_sender(); diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index e8feb6ca0..92d74fc6d 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -119,8 +119,9 @@ endif() if(RL_BUILD_FEDERATION) list(APPEND PROJECT_SOURCES + federation/apim_federated_client.cc federation/local_client.cc - federation/local_loop_controller.cc + federation/federated_loop_controller.cc federation/sender_joined_log_provider.cc federation/vw_trainable_model.cc ) @@ -209,7 +210,8 @@ if(RL_BUILD_FEDERATION) federation/federated_client.h federation/joined_log_provider.h federation/local_client.h - federation/local_loop_controller.h + federation/apim_federated_client.h + federation/federated_loop_controller.h federation/sender_joined_log_provider.h federation/vw_trainable_model.h ) diff --git a/rlclientlib/factory_resolver.cc b/rlclientlib/factory_resolver.cc index c51763f03..850ce169e 100644 --- a/rlclientlib/factory_resolver.cc +++ b/rlclientlib/factory_resolver.cc @@ -10,7 +10,7 @@ #include "vw_model/vw_model.h" #ifdef RL_BUILD_FEDERATION -# include "federation/local_loop_controller.h" +# include "federation/federated_loop_controller.h" #endif #ifdef USE_AZURE_FACTORIES @@ -121,12 +121,13 @@ int file_model_loader_create(std::unique_ptr& retval, const } #ifdef RL_BUILD_FEDERATION -int local_loop_controller_create(std::unique_ptr& retval, const u::configuration& config, +int federated_loop_controller_create(std::unique_ptr& retval, const u::configuration& config, i_trace* trace_logger, api_status* status) { TRACE_INFO(trace_logger, "Local loop controller i_data_transport created."); - std::unique_ptr output; - RETURN_IF_FAIL(local_loop_controller::create(output, config, trace_logger, status)); + std::unique_ptr output; + std::unique_ptr transport; + RETURN_IF_FAIL(federated_loop_controller::create(output, config, std::move(transport), trace_logger, status)); retval = std::move(output); return error_code::success; } @@ -158,7 +159,7 @@ void factory_initializer::register_default_factories() data_transport_factory.register_type(value::FILE_MODEL_DATA, file_model_loader_create); #ifdef RL_BUILD_FEDERATION - data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, local_loop_controller_create); + data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, federated_loop_controller_create); #else data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, [](std::unique_ptr&, const u::configuration&, i_trace* trace_logger, api_status* status) diff --git a/rlclientlib/federation/apim_federated_client.cc b/rlclientlib/federation/apim_federated_client.cc new file mode 100644 index 000000000..e42aa222b --- /dev/null +++ b/rlclientlib/federation/apim_federated_client.cc @@ -0,0 +1,95 @@ +#include "federation/apim_federated_client.h" + +#include "api_status.h" +#include "constants.h" +#include "err_constants.h" +#include "trace_logger.h" +#include "utility/vw_logger_adapter.h" +#include "vw/config/options_cli.h" +#include "vw/core/global_data.h" +#include "vw/core/io_buf.h" +#include "vw/core/merge.h" +#include "vw/core/parse_primitives.h" +#include "vw/core/vw.h" +#include "vw/io/io_adapter.h" + +namespace reinforcement_learning +{ +apim_federated_client::apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger, + std::unique_ptr transport) + : _current_model(std::move(initial_model)) + , _state(state_t::model_available) + , _trace_logger(trace_logger) + , _transport(std::move(transport)) +{ +} + +apim_federated_client::~apim_federated_client() = default; + +int apim_federated_client::create(std::unique_ptr& output, const utility::configuration& config, + i_trace* trace_logger, std::unique_ptr transport, api_status* status) +{ + std::string cmd_line = "--cb_explore_adf --json --quiet --epsilon 0.0 --first_only --id "; + cmd_line += config.get("id", "default_id"); + // Create empty model based on ML args on first call + std::string initial_command_line(config.get(name::MODEL_VW_INITIAL_COMMAND_LINE, cmd_line.c_str())); + + // TODO try catch + auto args = VW::make_unique(VW::split_command_line(initial_command_line)); + auto logger = utility::make_vw_trace_logger(trace_logger); + auto workspace = VW::initialize_experimental(std::move(args), nullptr, nullptr, nullptr, &logger); + workspace->id += "/0"; // initialize iteration id to 0 + + output = std::unique_ptr( + new apim_federated_client(std::move(workspace), trace_logger, std::move(transport))); + return error_code::success; +} + +int apim_federated_client::try_get_model(const std::string& app_id, + /* inout */ model_management::model_data& data, /* out */ bool& model_received, api_status* status) +{ + _transport->get_data(data, status); + return error_code::success; +} + +int apim_federated_client::report_result(const uint8_t* payload, size_t size, api_status* status) +{ + switch (_state) + { + case state_t::model_available: + { + RETURN_ERROR_LS(_trace_logger, status, invalid_argument) + << "Cannot call report_result again until try_get_model has been called."; + } + break; + case state_t::model_retrieved: + { + // Payload must be a delta + // Apply delta to current model and move into model available state. + auto view = VW::io::create_buffer_view(reinterpret_cast(payload), size); + auto delta = VW::model_delta::deserialize(*view); + auto new_model = *_current_model + *delta; + + // Increment iteration id for new workspace + try + { + int iteration_id = std::stoi(_current_model->id.substr(_current_model->id.find('/') + 1, std::string::npos)); + iteration_id++; + new_model->id = _current_model->id.substr(0, _current_model->id.find('/')) + "/" + std::to_string(iteration_id); + } + catch (const std::exception& e) + { + RETURN_ERROR_ARG(_trace_logger, status, model_update_error, e.what()); + } + + // Update current model + _current_model.reset(new_model.release()); + _state = state_t::model_available; + } + break; + default: + RETURN_ERROR_LS(_trace_logger, status, invalid_argument) << "Invalid state."; + } + return error_code::success; +} +} // namespace reinforcement_learning diff --git a/rlclientlib/federation/apim_federated_client.h b/rlclientlib/federation/apim_federated_client.h new file mode 100644 index 000000000..36aa1fd32 --- /dev/null +++ b/rlclientlib/federation/apim_federated_client.h @@ -0,0 +1,46 @@ +#pragma once + +#include "configuration.h" +#include "federation/federated_client.h" +#include "model_mgmt.h" +#include "trace_logger.h" +#include "vw/core/vw_fwd.h" + +namespace reinforcement_learning +{ +class apim_federated_client : i_federated_client +{ +public: + RL_ATTR(nodiscard) + static int create(std::unique_ptr& output, const utility::configuration& config, + i_trace* trace_logger = nullptr, std::unique_ptr transport = nullptr, + api_status* status = nullptr); + + RL_ATTR(nodiscard) + int try_get_model(const std::string& app_id, + /* inout */ model_management::model_data& data, /* out */ bool& model_received, + api_status* status = nullptr) override; + + RL_ATTR(nodiscard) int report_result(const uint8_t* payload, size_t size, api_status* status = nullptr) override; + + ~apim_federated_client() override; + +private: + enum class state_t + { + model_available, + model_retrieved + }; + + apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger, + std::unique_ptr transport); + + state_t _state; + std::unique_ptr _current_model; + i_trace* _trace_logger; + std::unique_ptr _transport; +}; + +// Read MODEL_VW_INITIAL_COMMAND_LINE + +} // namespace reinforcement_learning \ No newline at end of file diff --git a/rlclientlib/federation/local_loop_controller.cc b/rlclientlib/federation/federated_loop_controller.cc similarity index 67% rename from rlclientlib/federation/local_loop_controller.cc rename to rlclientlib/federation/federated_loop_controller.cc index bf473a831..64aca8a1f 100644 --- a/rlclientlib/federation/local_loop_controller.cc +++ b/rlclientlib/federation/federated_loop_controller.cc @@ -1,7 +1,8 @@ -#include "federation/local_loop_controller.h" +#include "federation/federated_loop_controller.h" #include "constants.h" #include "err_constants.h" +#include "federation/apim_federated_client.h" #include "federation/local_client.h" #include "federation/sender_joined_log_provider.h" #include "model_mgmt.h" @@ -9,15 +10,23 @@ namespace reinforcement_learning { -int local_loop_controller::create(std::unique_ptr& output, - const reinforcement_learning::utility::configuration& config, i_trace* trace_logger, api_status* status) +int federated_loop_controller::create(std::unique_ptr& output, + const reinforcement_learning::utility::configuration& config, + std::unique_ptr transport, i_trace* trace_logger, api_status* status) { std::string app_id = config.get(name::APP_ID, ""); std::unique_ptr federated_client; std::unique_ptr trainable_model; std::unique_ptr sender_joiner; - RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status)); + if (config.get("federated_client", "local") == "local") + { + RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status)); + } + else + { + RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, std::move(transport), status)); + } RETURN_IF_FAIL(trainable_vw_model::create(trainable_model, config, trace_logger, status)); RETURN_IF_FAIL(sender_joined_log_provider::create(sender_joiner, config, trace_logger, status)); @@ -30,14 +39,14 @@ int local_loop_controller::create(std::unique_ptr& output joiner = std::static_pointer_cast(sender_joiner_shared); event_sink = std::static_pointer_cast(sender_joiner_shared); - output = std::unique_ptr(new local_loop_controller(std::move(app_id), + output = std::unique_ptr(new federated_loop_controller(std::move(app_id), std::move(federated_client), std::move(trainable_model), std::move(joiner), std::move(event_sink))); return error_code::success; } -local_loop_controller::local_loop_controller(std::string app_id, std::unique_ptr&& federated_client, - std::unique_ptr&& trainable_model, std::shared_ptr&& joiner, - std::shared_ptr&& event_sink) +federated_loop_controller::federated_loop_controller(std::string app_id, + std::unique_ptr&& federated_client, std::unique_ptr&& trainable_model, + std::shared_ptr&& joiner, std::shared_ptr&& event_sink) : _app_id(std::move(app_id)) , _federated_client(std::move(federated_client)) , _trainable_model(std::move(trainable_model)) @@ -46,7 +55,7 @@ local_loop_controller::local_loop_controller(std::string app_id, std::unique_ptr { } -int local_loop_controller::update_global(api_status* status) +int federated_loop_controller::update_global(api_status* status) { // ask for a new global model model_management::model_data data; @@ -70,7 +79,7 @@ int local_loop_controller::update_global(api_status* status) return error_code::success; } -int local_loop_controller::update_local(api_status* status) +int federated_loop_controller::update_local(api_status* status) { std::unique_ptr binary_log; RETURN_IF_FAIL(_joiner->invoke_join(binary_log, status)); @@ -78,13 +87,13 @@ int local_loop_controller::update_local(api_status* status) return error_code::success; } -int local_loop_controller::get_data(model_management::model_data& data, api_status* status) +int federated_loop_controller::get_data(model_management::model_data& data, api_status* status) { RETURN_IF_FAIL(update_global(status)); RETURN_IF_FAIL(_trainable_model->get_data(data, status)); return error_code::success; } -std::unique_ptr local_loop_controller::get_local_sender() { return _event_sink->get_sender_proxy(); } +std::unique_ptr federated_loop_controller::get_local_sender() { return _event_sink->get_sender_proxy(); } } // namespace reinforcement_learning diff --git a/rlclientlib/federation/local_loop_controller.h b/rlclientlib/federation/federated_loop_controller.h similarity index 80% rename from rlclientlib/federation/local_loop_controller.h rename to rlclientlib/federation/federated_loop_controller.h index 093291b1a..1ecf8f5ee 100644 --- a/rlclientlib/federation/local_loop_controller.h +++ b/rlclientlib/federation/federated_loop_controller.h @@ -16,16 +16,17 @@ namespace reinforcement_learning { -// The local_loop_controller will "plug in" to rlclientlib as an i_data_transport object. +// The federated_loop_controller will "plug in" to rlclientlib as an i_data_transport object. // It exposes a get_local_sender_factory() function that creates i_sender proxy objects. // These proxy objects will send events to its internal event sink. // The initialization code for live_model_impl must register this factory function correctly. -class local_loop_controller : public model_management::i_data_transport +class federated_loop_controller : public model_management::i_data_transport { public: RL_ATTR(nodiscard) - static int create(std::unique_ptr& output, const utility::configuration& config, - i_trace* trace_logger = nullptr, api_status* status = nullptr); + static int create(std::unique_ptr& output, const utility::configuration& config, + std::unique_ptr transport = nullptr, i_trace* trace_logger = nullptr, + api_status* status = nullptr); // Get model data in binary format // This will perform joining and training on any observed events, and then return the updated model @@ -35,11 +36,11 @@ class local_loop_controller : public model_management::i_data_transport // Returns a i_sender proxy object to be used for sending events to the internal event sink std::unique_ptr get_local_sender(); - virtual ~local_loop_controller() = default; + virtual ~federated_loop_controller() = default; protected: // Constructor is private because objects should be created using the factory function - local_loop_controller(std::string app_id, std::unique_ptr&& federated_client, + federated_loop_controller(std::string app_id, std::unique_ptr&& federated_client, std::unique_ptr&& trainable_model, std::shared_ptr&& joiner, std::shared_ptr&& event_sink); diff --git a/rlclientlib/live_model_impl.cc b/rlclientlib/live_model_impl.cc index f2f3c4907..1379b456b 100644 --- a/rlclientlib/live_model_impl.cc +++ b/rlclientlib/live_model_impl.cc @@ -18,7 +18,7 @@ #include "vw_model/safe_vw.h" #ifdef RL_BUILD_FEDERATION -# include "federation/local_loop_controller.h" +# include "federation/federated_loop_controller.h" #endif #include @@ -696,10 +696,10 @@ int live_model_impl::init_local_loop(api_status* status) // This function should only be called when the configuration is set to use LOCAL_LOOP_MODEL_DATA assert(model_src == value::LOCAL_LOOP_MODEL_DATA); - // Creating i_data_transport with type LOCAL_LOOP_MODEL_DATA results in local_loop_controller + // Creating i_data_transport with type LOCAL_LOOP_MODEL_DATA results in federated_loop_controller std::unique_ptr output; RETURN_IF_FAIL(_t_factory->create(output, model_src, _configuration, _trace_logger.get(), status)); - std::unique_ptr llc(reinterpret_cast(output.release())); + std::unique_ptr llc(reinterpret_cast(output.release())); // Create senders with default sender implementation set to LOCAL_LOOP_SENDER std::string interaction_sender_type = diff --git a/unit_test/CMakeLists.txt b/unit_test/CMakeLists.txt index 65be79147..3b12156f9 100644 --- a/unit_test/CMakeLists.txt +++ b/unit_test/CMakeLists.txt @@ -52,7 +52,7 @@ if(RL_BUILD_FEDERATION) list(APPEND TEST_SOURCES eud_test.cc local_client_test.cc - local_loop_controller_test.cc + federated_loop_controller_test.cc local_loop_end_to_end.cc sender_joined_log_provider_test.cc trainable_model_test.cc diff --git a/unit_test/local_loop_controller_test.cc b/unit_test/federated_loop_controller_test.cc similarity index 85% rename from unit_test/local_loop_controller_test.cc rename to unit_test/federated_loop_controller_test.cc index 47d62792b..60b2ff237 100644 --- a/unit_test/local_loop_controller_test.cc +++ b/unit_test/federated_loop_controller_test.cc @@ -6,8 +6,8 @@ #include "err_constants.h" #include "federation/event_sink.h" #include "federation/federated_client.h" +#include "federation/federated_loop_controller.h" #include "federation/local_client.h" -#include "federation/local_loop_controller.h" #include "federation/sender_joined_log_provider.h" #include "vw/core/shared_data.h" #include "vw/core/vw.h" @@ -16,19 +16,19 @@ using namespace reinforcement_learning; namespace { -// Wrapper around local_loop_controller to allow us to access member variables -class test_local_loop_controller : public local_loop_controller +// Wrapper around federated_loop_controller to allow us to access member variables +class test_federated_loop_controller : public federated_loop_controller { public: - test_local_loop_controller(std::string app_id, std::unique_ptr&& federated_client, + test_federated_loop_controller(std::string app_id, std::unique_ptr&& federated_client, std::unique_ptr&& trainable_model, std::shared_ptr&& joiner, std::shared_ptr&& event_sink) - : local_loop_controller(std::move(app_id), std::move(federated_client), std::move(trainable_model), + : federated_loop_controller(std::move(app_id), std::move(federated_client), std::move(trainable_model), std::move(joiner), std::move(event_sink)) { } - virtual ~test_local_loop_controller() = default; + virtual ~test_federated_loop_controller() = default; i_federated_client* get_client() { return _federated_client.get(); } trainable_vw_model* get_model() { return _trainable_model.get(); } @@ -121,7 +121,7 @@ utility::configuration get_test_config() return config; } -std::unique_ptr create_test_local_loop_controller(utility::configuration config) +std::unique_ptr create_test_federated_loop_controller(utility::configuration config) { std::unique_ptr trainable_model; std::unique_ptr sender_joiner; @@ -132,16 +132,16 @@ std::unique_ptr create_test_local_loop_controller(utility std::shared_ptr event_sink(new mock_event_sink()); std::unique_ptr federated_client(new mock_federated_client()); - return std::unique_ptr(new test_local_loop_controller("test_app_id", + return std::unique_ptr(new test_federated_loop_controller("test_app_id", std::move(federated_client), std::move(trainable_model), std::move(joiner), std::move(event_sink))); } } // namespace BOOST_AUTO_TEST_CASE(sender_factory_test) { - // create the local_loop_controller + // create the federated_loop_controller auto config = get_test_config(); - auto test_llc = create_test_local_loop_controller(config); + auto test_llc = create_test_federated_loop_controller(config); // create a sender and send some data std::unique_ptr sender = test_llc->get_local_sender(); @@ -153,7 +153,7 @@ BOOST_AUTO_TEST_CASE(sender_factory_test) sender->send(buffer_in); // get the data out of event sink - auto event_sink_out = dynamic_cast(test_llc.get())->get_event_sink(); + auto event_sink_out = dynamic_cast(test_llc.get())->get_event_sink(); BOOST_CHECK_NE(event_sink_out, nullptr); auto buffer_out = dynamic_cast(event_sink_out)->get_latest_event(); BOOST_CHECK_NE(buffer_out, nullptr); @@ -162,10 +162,10 @@ BOOST_AUTO_TEST_CASE(sender_factory_test) BOOST_AUTO_TEST_CASE(update_get_model_data) { - // create the local_loop_controller + // create the federated_loop_controller auto config = get_test_config(); - auto llc = create_test_local_loop_controller(config); - auto test_llc = dynamic_cast(llc.get()); + auto llc = create_test_federated_loop_controller(config); + auto test_llc = dynamic_cast(llc.get()); BOOST_CHECK_NE(test_llc, nullptr); auto mock_client = dynamic_cast(test_llc->get_client()); BOOST_CHECK_NE(mock_client, nullptr); diff --git a/unit_test/local_loop_end_to_end.cc b/unit_test/local_loop_end_to_end.cc index edf48afc1..9173e7bb1 100644 --- a/unit_test/local_loop_end_to_end.cc +++ b/unit_test/local_loop_end_to_end.cc @@ -2,7 +2,7 @@ #include "common_test_utils.h" #include "constants.h" -#include "federation/local_loop_controller.h" +#include "federation/federated_loop_controller.h" #include "live_model.h" #include "ranking_response.h" @@ -96,16 +96,17 @@ BOOST_AUTO_TEST_CASE(local_loop_end_to_end_test) auto config = get_test_config(); // create a custom data_transport_factory_t that saves a pointer - // to the local_loop_controller that was created - local_loop_controller* test_local_loop_controller = nullptr; + // to the federated_loop_controller that was created + federated_loop_controller* test_federated_loop_controller = nullptr; data_transport_factory_t test_data_transport_factory; test_data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, [&](std::unique_ptr& retval, const utility::configuration& cfg, i_trace* trace_logger, api_status* status) { - std::unique_ptr output; - RETURN_IF_FAIL(local_loop_controller::create(output, cfg, trace_logger, status)); - test_local_loop_controller = output.get(); + std::unique_ptr output; + std::unique_ptr transport; + RETURN_IF_FAIL(federated_loop_controller::create(output, cfg, std::move(transport), trace_logger, status)); + test_federated_loop_controller = output.get(); retval = std::move(output); return error_code::success; }); @@ -116,7 +117,7 @@ BOOST_AUTO_TEST_CASE(local_loop_end_to_end_test) config, nullptr, nullptr, &reinforcement_learning::trace_logger_factory, &test_data_transport_factory); model.init(&status); BOOST_TEST(status.get_error_code() == error_code::success, status.get_error_msg()); - BOOST_CHECK_NE(test_local_loop_controller, nullptr); + BOOST_CHECK_NE(test_federated_loop_controller, nullptr); // do some inference calls and report the outcome constexpr int iterations = 100; @@ -129,7 +130,7 @@ BOOST_AUTO_TEST_CASE(local_loop_end_to_end_test) // check that updated model has learned from previous outcomes model_management::model_data model_data; - test_local_loop_controller->get_data(model_data, &status); + test_federated_loop_controller->get_data(model_data, &status); BOOST_TEST(status.get_error_code() == error_code::success, status.get_error_msg()); auto vw = test_utils::create_vw(config.get(name::MODEL_VW_INITIAL_COMMAND_LINE, nullptr), model_data); BOOST_CHECK_EQUAL(vw->sd->weighted_labeled_examples, iterations);