From cdf1ea706e30f542a0fac890e9b09aa7b968f4e5 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Tue, 25 Apr 2023 15:09:24 -0400 Subject: [PATCH 01/11] test --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ae04a9a02..6ea5dbc82 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![MacOS Build Status](https://img.shields.io/azure-devops/build/vowpalwabbit/3934113c-9e2b-4dbc-8972-72ab9b9b4342/22/master?label=MacOS%20build&logo=Azure%20Devops)](https://dev.azure.com/vowpalwabbit/Vowpal%20Wabbit/_build?definitionId=32) [![Windows Build status](https://ci.appveyor.com/api/projects/status/57p7o5v34onsqma2/branch/master?svg=true)](https://ci.appveyor.com/project/JohnLangford/reinforcement-learning/branch/master) [![Integration with latest VW](https://github.com/VowpalWabbit/reinforcement_learning/actions/workflows/daily_integration.yml/badge.svg?event=schedule)](https://github.com/VowpalWabbit/reinforcement_learning/actions/workflows/daily_integration.yml) - + # RL Client Library Interaction-side integration library for Reinforcement Learning loops: Predict, Log, [Learn,] Update From 2a0ac3a7590a64a56bacc6012fbd689766f2ca0a Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Tue, 25 Apr 2023 15:11:04 -0400 Subject: [PATCH 02/11] rm --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6ea5dbc82..ae04a9a02 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![MacOS Build Status](https://img.shields.io/azure-devops/build/vowpalwabbit/3934113c-9e2b-4dbc-8972-72ab9b9b4342/22/master?label=MacOS%20build&logo=Azure%20Devops)](https://dev.azure.com/vowpalwabbit/Vowpal%20Wabbit/_build?definitionId=32) [![Windows Build status](https://ci.appveyor.com/api/projects/status/57p7o5v34onsqma2/branch/master?svg=true)](https://ci.appveyor.com/project/JohnLangford/reinforcement-learning/branch/master) [![Integration with latest VW](https://github.com/VowpalWabbit/reinforcement_learning/actions/workflows/daily_integration.yml/badge.svg?event=schedule)](https://github.com/VowpalWabbit/reinforcement_learning/actions/workflows/daily_integration.yml) - + # RL Client Library Interaction-side integration library for Reinforcement Learning loops: Predict, Log, [Learn,] Update From e98853f0a915c8137e65a0e702e0870805a7ed15 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Wed, 26 Apr 2023 11:37:06 -0400 Subject: [PATCH 03/11] delta apim sender --- include/constants.h | 6 ++++++ rlclientlib/azure_factories.cc | 15 +++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/include/constants.h b/include/constants.h index 1b4962230..eda75438c 100644 --- a/include/constants.h +++ b/include/constants.h @@ -84,6 +84,12 @@ const char* const OBSERVATION_APIM_MAX_HTTP_RETRIES = "observation.apim.max_http const char* const OBSERVATION_APIM_MAX_HTTP_RETRY_DURATION_MS = "observation.apim.max_http_retry_duration_ms"; const char* const OBSERVATION_SUBSAMPLE_RATE = "observation.subsample.rate"; +// Delta +const char* const DELTA_APIM_TASKS_LIMIT = "delta.apim.tasks_limit"; +const char* const DELTA_HTTP_API_HOST = "delta.http.api.host"; +const char* const DELTA_APIM_MAX_HTTP_RETRIES = "delta.apim.max_http_retries"; +const char* const DELTA_APIM_MAX_HTTP_RETRY_DURATION_MS = "delta.apim.max_http_retry_duration_ms"; + // global sender properties const char* const SEND_HIGH_WATER_MARK = "send.highwatermark"; const char* const SEND_QUEUE_MAX_CAPACITY_KB = "send.queue.maxcapacity.kb"; diff --git a/rlclientlib/azure_factories.cc b/rlclientlib/azure_factories.cc index 687fcd53b..7819802a9 100644 --- a/rlclientlib/azure_factories.cc +++ b/rlclientlib/azure_factories.cc @@ -32,6 +32,8 @@ int observation_api_sender_create(std::unique_ptr& retval, const u::co error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); int interaction_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); void register_azure_factories() { @@ -42,6 +44,7 @@ void register_azure_factories() sender_factory.register_type(value::EPISODE_EH_SENDER, episode_sender_create); sender_factory.register_type(value::OBSERVATION_HTTP_API_SENDER, observation_api_sender_create); sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER, interaction_api_sender_create); + sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER, delta_api_sender_create); sender_factory.register_type(value::EPISODE_HTTP_API_SENDER, episode_api_sender_create); } @@ -165,4 +168,16 @@ int interaction_sender_create(std::unique_ptr& retval, const u::config error_cb)); return error_code::success; } + +// Creates i_sender object for sending delta data to the apim endpoint. +int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +{ + const auto* const api_host = cfg.get(name::DELTA_HTTP_API_HOST, "localhost:8080"); + return create_apim_http_api_sender(retval, cfg, api_host, cfg.get_int(name::DELTA_APIM_TASKS_LIMIT, 16), + cfg.get_int(name::DELTA_APIM_MAX_HTTP_RETRIES, 4), + std::chrono::milliseconds(cfg.get_int(name::DELTA_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, + trace_logger, status); +} + } // namespace reinforcement_learning From 7a9332df8d435593aa4b3a8f888e4a9c1a288861 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Wed, 26 Apr 2023 14:28:28 -0400 Subject: [PATCH 04/11] report result --- include/constants.h | 1 + rlclientlib/azure_factories.cc | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/include/constants.h b/include/constants.h index eda75438c..e6ddaa663 100644 --- a/include/constants.h +++ b/include/constants.h @@ -143,6 +143,7 @@ const char* const INTERACTION_FILE_SENDER = "INTERACTION_FILE_SENDER"; const char* const EPISODE_HTTP_API_SENDER = "EPISODE_HTTP_API_SENDER"; const char* const OBSERVATION_HTTP_API_SENDER = "OBSERVATION_HTTP_API_SENDER"; const char* const INTERACTION_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER"; +const char* const DELTA_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER"; const char* const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER"; const char* const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER"; const char* const NULL_TIME_PROVIDER = "NULL_TIME_PROVIDER"; diff --git a/rlclientlib/azure_factories.cc b/rlclientlib/azure_factories.cc index 7819802a9..8e9ac9aac 100644 --- a/rlclientlib/azure_factories.cc +++ b/rlclientlib/azure_factories.cc @@ -2,6 +2,8 @@ #include "constants.h" #include "factory_resolver.h" +#include "federation/federated_client.h" +#include "federation/local_client.h" #include "logger/event_logger.h" #include "logger/http_transport_client.h" #include "model_mgmt/restapi_data_transport.h" @@ -44,7 +46,7 @@ void register_azure_factories() sender_factory.register_type(value::EPISODE_EH_SENDER, episode_sender_create); sender_factory.register_type(value::OBSERVATION_HTTP_API_SENDER, observation_api_sender_create); sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER, interaction_api_sender_create); - sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER, delta_api_sender_create); + sender_factory.register_type(value::DELTA_HTTP_API_SENDER, delta_api_sender_create); sender_factory.register_type(value::EPISODE_HTTP_API_SENDER, episode_api_sender_create); } @@ -169,12 +171,25 @@ int interaction_sender_create(std::unique_ptr& retval, const u::config return error_code::success; } +int create_apim_delta_http_api_sender(std::unique_ptr& retval, const u::configuration& cfg, + const char* api_host, int tasks_limit, int max_http_retries, std::chrono::milliseconds max_http_retry_duration, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +{ + std::unique_ptr client; + RETURN_IF_FAIL(local_client::create(client, cfg, trace_logger, status)); + model_management::model_data delta; + bool model_received = false; + RETURN_IF_FAIL(client->try_get_model(api_host, delta, model_received)); + client->report_result(reinterpret_cast(delta.data()), delta.data_sz(), status); + return error_code::success; +} + // Creates i_sender object for sending delta data to the apim endpoint. int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { const auto* const api_host = cfg.get(name::DELTA_HTTP_API_HOST, "localhost:8080"); - return create_apim_http_api_sender(retval, cfg, api_host, cfg.get_int(name::DELTA_APIM_TASKS_LIMIT, 16), + return create_apim_delta_http_api_sender(retval, cfg, api_host, cfg.get_int(name::DELTA_APIM_TASKS_LIMIT, 16), cfg.get_int(name::DELTA_APIM_MAX_HTTP_RETRIES, 4), std::chrono::milliseconds(cfg.get_int(name::DELTA_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, trace_logger, status); From 5de2eba09b89ec3593ab2fb443fd0b13c950f28a Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Wed, 26 Apr 2023 14:28:41 -0400 Subject: [PATCH 05/11] clang --- rlclientlib/azure_factories.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rlclientlib/azure_factories.cc b/rlclientlib/azure_factories.cc index 8e9ac9aac..5667eb38b 100644 --- a/rlclientlib/azure_factories.cc +++ b/rlclientlib/azure_factories.cc @@ -34,8 +34,8 @@ int observation_api_sender_create(std::unique_ptr& retval, const u::co error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); int interaction_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); -int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, - error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, + i_trace* trace_logger, api_status* status); void register_azure_factories() { @@ -185,8 +185,8 @@ int create_apim_delta_http_api_sender(std::unique_ptr& retval, const u } // Creates i_sender object for sending delta data to the apim endpoint. -int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, - error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, + i_trace* trace_logger, api_status* status) { const auto* const api_host = cfg.get(name::DELTA_HTTP_API_HOST, "localhost:8080"); return create_apim_delta_http_api_sender(retval, cfg, api_host, cfg.get_int(name::DELTA_APIM_TASKS_LIMIT, 16), From 41d0663144ceaeabb640a47664ba920de9808036 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Tue, 2 May 2023 16:42:37 -0400 Subject: [PATCH 06/11] apim client --- include/constants.h | 7 -- rlclientlib/CMakeLists.txt | 2 + rlclientlib/azure_factories.cc | 30 ------- .../federation/apim_federated_client.cc | 90 +++++++++++++++++++ .../federation/apim_federated_client.h | 42 +++++++++ .../federation/local_loop_controller.cc | 10 ++- 6 files changed, 143 insertions(+), 38 deletions(-) create mode 100644 rlclientlib/federation/apim_federated_client.cc create mode 100644 rlclientlib/federation/apim_federated_client.h diff --git a/include/constants.h b/include/constants.h index e6ddaa663..1b4962230 100644 --- a/include/constants.h +++ b/include/constants.h @@ -84,12 +84,6 @@ const char* const OBSERVATION_APIM_MAX_HTTP_RETRIES = "observation.apim.max_http const char* const OBSERVATION_APIM_MAX_HTTP_RETRY_DURATION_MS = "observation.apim.max_http_retry_duration_ms"; const char* const OBSERVATION_SUBSAMPLE_RATE = "observation.subsample.rate"; -// Delta -const char* const DELTA_APIM_TASKS_LIMIT = "delta.apim.tasks_limit"; -const char* const DELTA_HTTP_API_HOST = "delta.http.api.host"; -const char* const DELTA_APIM_MAX_HTTP_RETRIES = "delta.apim.max_http_retries"; -const char* const DELTA_APIM_MAX_HTTP_RETRY_DURATION_MS = "delta.apim.max_http_retry_duration_ms"; - // global sender properties const char* const SEND_HIGH_WATER_MARK = "send.highwatermark"; const char* const SEND_QUEUE_MAX_CAPACITY_KB = "send.queue.maxcapacity.kb"; @@ -143,7 +137,6 @@ const char* const INTERACTION_FILE_SENDER = "INTERACTION_FILE_SENDER"; const char* const EPISODE_HTTP_API_SENDER = "EPISODE_HTTP_API_SENDER"; const char* const OBSERVATION_HTTP_API_SENDER = "OBSERVATION_HTTP_API_SENDER"; const char* const INTERACTION_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER"; -const char* const DELTA_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER"; const char* const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER"; const char* const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER"; const char* const NULL_TIME_PROVIDER = "NULL_TIME_PROVIDER"; diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index e8feb6ca0..73ac9973f 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -119,6 +119,7 @@ endif() if(RL_BUILD_FEDERATION) list(APPEND PROJECT_SOURCES + federation/apim_federated_client.cc federation/local_client.cc federation/local_loop_controller.cc federation/sender_joined_log_provider.cc @@ -209,6 +210,7 @@ if(RL_BUILD_FEDERATION) federation/federated_client.h federation/joined_log_provider.h federation/local_client.h + federation/apim_federated_client.h federation/local_loop_controller.h federation/sender_joined_log_provider.h federation/vw_trainable_model.h diff --git a/rlclientlib/azure_factories.cc b/rlclientlib/azure_factories.cc index 5667eb38b..687fcd53b 100644 --- a/rlclientlib/azure_factories.cc +++ b/rlclientlib/azure_factories.cc @@ -2,8 +2,6 @@ #include "constants.h" #include "factory_resolver.h" -#include "federation/federated_client.h" -#include "federation/local_client.h" #include "logger/event_logger.h" #include "logger/http_transport_client.h" #include "model_mgmt/restapi_data_transport.h" @@ -34,8 +32,6 @@ int observation_api_sender_create(std::unique_ptr& retval, const u::co error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); int interaction_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); -int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, - i_trace* trace_logger, api_status* status); void register_azure_factories() { @@ -46,7 +42,6 @@ void register_azure_factories() sender_factory.register_type(value::EPISODE_EH_SENDER, episode_sender_create); sender_factory.register_type(value::OBSERVATION_HTTP_API_SENDER, observation_api_sender_create); sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER, interaction_api_sender_create); - sender_factory.register_type(value::DELTA_HTTP_API_SENDER, delta_api_sender_create); sender_factory.register_type(value::EPISODE_HTTP_API_SENDER, episode_api_sender_create); } @@ -170,29 +165,4 @@ int interaction_sender_create(std::unique_ptr& retval, const u::config error_cb)); return error_code::success; } - -int create_apim_delta_http_api_sender(std::unique_ptr& retval, const u::configuration& cfg, - const char* api_host, int tasks_limit, int max_http_retries, std::chrono::milliseconds max_http_retry_duration, - error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) -{ - std::unique_ptr client; - RETURN_IF_FAIL(local_client::create(client, cfg, trace_logger, status)); - model_management::model_data delta; - bool model_received = false; - RETURN_IF_FAIL(client->try_get_model(api_host, delta, model_received)); - client->report_result(reinterpret_cast(delta.data()), delta.data_sz(), status); - return error_code::success; -} - -// Creates i_sender object for sending delta data to the apim endpoint. -int delta_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, - i_trace* trace_logger, api_status* status) -{ - const auto* const api_host = cfg.get(name::DELTA_HTTP_API_HOST, "localhost:8080"); - return create_apim_delta_http_api_sender(retval, cfg, api_host, cfg.get_int(name::DELTA_APIM_TASKS_LIMIT, 16), - cfg.get_int(name::DELTA_APIM_MAX_HTTP_RETRIES, 4), - std::chrono::milliseconds(cfg.get_int(name::DELTA_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, - trace_logger, status); -} - } // namespace reinforcement_learning diff --git a/rlclientlib/federation/apim_federated_client.cc b/rlclientlib/federation/apim_federated_client.cc new file mode 100644 index 000000000..b8a8a0536 --- /dev/null +++ b/rlclientlib/federation/apim_federated_client.cc @@ -0,0 +1,90 @@ +#include "federation/apim_federated_client.h" + +#include "api_status.h" +#include "constants.h" +#include "err_constants.h" +#include "trace_logger.h" +#include "utility/vw_logger_adapter.h" +#include "vw/config/options_cli.h" +#include "vw/core/global_data.h" +#include "vw/core/io_buf.h" +#include "vw/core/merge.h" +#include "vw/core/parse_primitives.h" +#include "vw/core/vw.h" +#include "vw/io/io_adapter.h" + +namespace reinforcement_learning +{ +apim_federated_client::apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger) + : _current_model(std::move(initial_model)), _state(state_t::model_available), _trace_logger(trace_logger) +{ +} + +apim_federated_client::~apim_federated_client() = default; + +int apim_federated_client::create(std::unique_ptr& output, const utility::configuration& config, + i_trace* trace_logger, api_status* status) +{ + std::string cmd_line = "--cb_explore_adf --json --quiet --epsilon 0.0 --first_only --id "; + cmd_line += config.get("id", "default_id"); + // Create empty model based on ML args on first call + std::string initial_command_line(config.get(name::MODEL_VW_INITIAL_COMMAND_LINE, cmd_line.c_str())); + + // TODO try catch + auto args = VW::make_unique(VW::split_command_line(initial_command_line)); + auto logger = utility::make_vw_trace_logger(trace_logger); + auto workspace = VW::initialize_experimental(std::move(args), nullptr, nullptr, nullptr, &logger); + workspace->id += "/0"; // initialize iteration id to 0 + + output = std::unique_ptr(new apim_federated_client(std::move(workspace), trace_logger)); + return error_code::success; +} + +int apim_federated_client::try_get_model(const std::string& app_id, + /* inout */ model_management::model_data& data, /* out */ bool& model_received, api_status* status) +{ + + return error_code::success; +} + +int apim_federated_client::report_result(const uint8_t* payload, size_t size, api_status* status) +{ + switch (_state) + { + case state_t::model_available: + { + RETURN_ERROR_LS(_trace_logger, status, invalid_argument) + << "Cannot call report_result again until try_get_model has been called."; + } + break; + case state_t::model_retrieved: + { + // Payload must be a delta + // Apply delta to current model and move into model available state. + auto view = VW::io::create_buffer_view(reinterpret_cast(payload), size); + auto delta = VW::model_delta::deserialize(*view); + auto new_model = *_current_model + *delta; + + // Increment iteration id for new workspace + try + { + int iteration_id = std::stoi(_current_model->id.substr(_current_model->id.find('/') + 1, std::string::npos)); + iteration_id++; + new_model->id = _current_model->id.substr(0, _current_model->id.find('/')) + "/" + std::to_string(iteration_id); + } + catch (const std::exception& e) + { + RETURN_ERROR_ARG(_trace_logger, status, model_update_error, e.what()); + } + + // Update current model + _current_model.reset(new_model.release()); + _state = state_t::model_available; + } + break; + default: + RETURN_ERROR_LS(_trace_logger, status, invalid_argument) << "Invalid state."; + } + return error_code::success; +} +} // namespace reinforcement_learning diff --git a/rlclientlib/federation/apim_federated_client.h b/rlclientlib/federation/apim_federated_client.h new file mode 100644 index 000000000..936adc503 --- /dev/null +++ b/rlclientlib/federation/apim_federated_client.h @@ -0,0 +1,42 @@ +#pragma once + +#include "configuration.h" +#include "federation/federated_client.h" +#include "trace_logger.h" +#include "vw/core/vw_fwd.h" + +namespace reinforcement_learning +{ +class apim_federated_client : i_federated_client +{ +public: + RL_ATTR(nodiscard) + static int create(std::unique_ptr& output, const utility::configuration& config, + i_trace* trace_logger = nullptr, api_status* status = nullptr); + + RL_ATTR(nodiscard) + int try_get_model(const std::string& app_id, + /* inout */ model_management::model_data& data, /* out */ bool& model_received, + api_status* status = nullptr) override; + + RL_ATTR(nodiscard) int report_result(const uint8_t* payload, size_t size, api_status* status = nullptr) override; + + ~apim_federated_client() override; + +private: + enum class state_t + { + model_available, + model_retrieved + }; + + apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger); + + state_t _state; + std::unique_ptr _current_model; + i_trace* _trace_logger; +}; + +// Read MODEL_VW_INITIAL_COMMAND_LINE + +} // namespace reinforcement_learning \ No newline at end of file diff --git a/rlclientlib/federation/local_loop_controller.cc b/rlclientlib/federation/local_loop_controller.cc index bf473a831..100463403 100644 --- a/rlclientlib/federation/local_loop_controller.cc +++ b/rlclientlib/federation/local_loop_controller.cc @@ -2,6 +2,7 @@ #include "constants.h" #include "err_constants.h" +#include "federation/apim_federated_client.h" #include "federation/local_client.h" #include "federation/sender_joined_log_provider.h" #include "model_mgmt.h" @@ -17,7 +18,14 @@ int local_loop_controller::create(std::unique_ptr& output std::unique_ptr federated_client; std::unique_ptr trainable_model; std::unique_ptr sender_joiner; - RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status)); + if (config.get("federated_client", "local") == "local") + { + RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status)); + } + else + { + RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, status)); + } RETURN_IF_FAIL(trainable_vw_model::create(trainable_model, config, trace_logger, status)); RETURN_IF_FAIL(sender_joined_log_provider::create(sender_joiner, config, trace_logger, status)); From cb4c88e5f648ab5930dab60d4f6577faa7ed53c8 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Tue, 2 May 2023 16:43:00 -0400 Subject: [PATCH 07/11] clang --- rlclientlib/federation/apim_federated_client.cc | 1 - rlclientlib/federation/local_loop_controller.cc | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/rlclientlib/federation/apim_federated_client.cc b/rlclientlib/federation/apim_federated_client.cc index b8a8a0536..7506023c8 100644 --- a/rlclientlib/federation/apim_federated_client.cc +++ b/rlclientlib/federation/apim_federated_client.cc @@ -43,7 +43,6 @@ int apim_federated_client::create(std::unique_ptr& output, c int apim_federated_client::try_get_model(const std::string& app_id, /* inout */ model_management::model_data& data, /* out */ bool& model_received, api_status* status) { - return error_code::success; } diff --git a/rlclientlib/federation/local_loop_controller.cc b/rlclientlib/federation/local_loop_controller.cc index 100463403..a2bc9adce 100644 --- a/rlclientlib/federation/local_loop_controller.cc +++ b/rlclientlib/federation/local_loop_controller.cc @@ -22,10 +22,7 @@ int local_loop_controller::create(std::unique_ptr& output { RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status)); } - else - { - RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, status)); - } + else { RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, status)); } RETURN_IF_FAIL(trainable_vw_model::create(trainable_model, config, trace_logger, status)); RETURN_IF_FAIL(sender_joined_log_provider::create(sender_joiner, config, trace_logger, status)); From cf65ab6b7ec96173b68e385746ecb8b35f417c7d Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Tue, 16 May 2023 09:51:57 -0400 Subject: [PATCH 08/11] add data transport --- rlclientlib/factory_resolver.cc | 3 ++- rlclientlib/federation/apim_federated_client.cc | 9 +++++---- rlclientlib/federation/apim_federated_client.h | 6 ++++-- rlclientlib/federation/local_loop_controller.cc | 4 ++-- rlclientlib/federation/local_loop_controller.h | 2 +- unit_test/local_loop_end_to_end.cc | 3 ++- 6 files changed, 16 insertions(+), 11 deletions(-) diff --git a/rlclientlib/factory_resolver.cc b/rlclientlib/factory_resolver.cc index c51763f03..0a9dbc3c0 100644 --- a/rlclientlib/factory_resolver.cc +++ b/rlclientlib/factory_resolver.cc @@ -126,7 +126,8 @@ int local_loop_controller_create(std::unique_ptr& retval, c { TRACE_INFO(trace_logger, "Local loop controller i_data_transport created."); std::unique_ptr output; - RETURN_IF_FAIL(local_loop_controller::create(output, config, trace_logger, status)); + std::unique_ptr transport; + RETURN_IF_FAIL(local_loop_controller::create(output, config, std::move(transport), trace_logger, status)); retval = std::move(output); return error_code::success; } diff --git a/rlclientlib/federation/apim_federated_client.cc b/rlclientlib/federation/apim_federated_client.cc index 7506023c8..ed2dad4ae 100644 --- a/rlclientlib/federation/apim_federated_client.cc +++ b/rlclientlib/federation/apim_federated_client.cc @@ -15,15 +15,15 @@ namespace reinforcement_learning { -apim_federated_client::apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger) - : _current_model(std::move(initial_model)), _state(state_t::model_available), _trace_logger(trace_logger) +apim_federated_client::apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger, std::unique_ptr transport) + : _current_model(std::move(initial_model)), _state(state_t::model_available), _trace_logger(trace_logger), _transport(std::move(transport)) { } apim_federated_client::~apim_federated_client() = default; int apim_federated_client::create(std::unique_ptr& output, const utility::configuration& config, - i_trace* trace_logger, api_status* status) + i_trace* trace_logger, std::unique_ptr transport, api_status* status) { std::string cmd_line = "--cb_explore_adf --json --quiet --epsilon 0.0 --first_only --id "; cmd_line += config.get("id", "default_id"); @@ -36,13 +36,14 @@ int apim_federated_client::create(std::unique_ptr& output, c auto workspace = VW::initialize_experimental(std::move(args), nullptr, nullptr, nullptr, &logger); workspace->id += "/0"; // initialize iteration id to 0 - output = std::unique_ptr(new apim_federated_client(std::move(workspace), trace_logger)); + output = std::unique_ptr(new apim_federated_client(std::move(workspace), trace_logger, std::move(transport))); return error_code::success; } int apim_federated_client::try_get_model(const std::string& app_id, /* inout */ model_management::model_data& data, /* out */ bool& model_received, api_status* status) { + _transport->get_data(data, status); return error_code::success; } diff --git a/rlclientlib/federation/apim_federated_client.h b/rlclientlib/federation/apim_federated_client.h index 936adc503..538f2b9c5 100644 --- a/rlclientlib/federation/apim_federated_client.h +++ b/rlclientlib/federation/apim_federated_client.h @@ -1,6 +1,7 @@ #pragma once #include "configuration.h" +#include "model_mgmt.h" #include "federation/federated_client.h" #include "trace_logger.h" #include "vw/core/vw_fwd.h" @@ -12,7 +13,7 @@ class apim_federated_client : i_federated_client public: RL_ATTR(nodiscard) static int create(std::unique_ptr& output, const utility::configuration& config, - i_trace* trace_logger = nullptr, api_status* status = nullptr); + i_trace* trace_logger = nullptr, std::unique_ptr transport = nullptr, api_status* status = nullptr); RL_ATTR(nodiscard) int try_get_model(const std::string& app_id, @@ -30,11 +31,12 @@ class apim_federated_client : i_federated_client model_retrieved }; - apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger); + apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger, std::unique_ptr transport); state_t _state; std::unique_ptr _current_model; i_trace* _trace_logger; + std::unique_ptr _transport; }; // Read MODEL_VW_INITIAL_COMMAND_LINE diff --git a/rlclientlib/federation/local_loop_controller.cc b/rlclientlib/federation/local_loop_controller.cc index a2bc9adce..794a96f9a 100644 --- a/rlclientlib/federation/local_loop_controller.cc +++ b/rlclientlib/federation/local_loop_controller.cc @@ -11,7 +11,7 @@ namespace reinforcement_learning { int local_loop_controller::create(std::unique_ptr& output, - const reinforcement_learning::utility::configuration& config, i_trace* trace_logger, api_status* status) + const reinforcement_learning::utility::configuration& config, std::unique_ptr transport, i_trace* trace_logger, api_status* status) { std::string app_id = config.get(name::APP_ID, ""); @@ -22,7 +22,7 @@ int local_loop_controller::create(std::unique_ptr& output { RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status)); } - else { RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, status)); } + else { RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, std::move(transport), status)); } RETURN_IF_FAIL(trainable_vw_model::create(trainable_model, config, trace_logger, status)); RETURN_IF_FAIL(sender_joined_log_provider::create(sender_joiner, config, trace_logger, status)); diff --git a/rlclientlib/federation/local_loop_controller.h b/rlclientlib/federation/local_loop_controller.h index 093291b1a..7946f1570 100644 --- a/rlclientlib/federation/local_loop_controller.h +++ b/rlclientlib/federation/local_loop_controller.h @@ -25,7 +25,7 @@ class local_loop_controller : public model_management::i_data_transport public: RL_ATTR(nodiscard) static int create(std::unique_ptr& output, const utility::configuration& config, - i_trace* trace_logger = nullptr, api_status* status = nullptr); + std::unique_ptr transport = nullptr, i_trace* trace_logger = nullptr, api_status* status = nullptr); // Get model data in binary format // This will perform joining and training on any observed events, and then return the updated model diff --git a/unit_test/local_loop_end_to_end.cc b/unit_test/local_loop_end_to_end.cc index edf48afc1..362c21038 100644 --- a/unit_test/local_loop_end_to_end.cc +++ b/unit_test/local_loop_end_to_end.cc @@ -104,7 +104,8 @@ BOOST_AUTO_TEST_CASE(local_loop_end_to_end_test) i_trace* trace_logger, api_status* status) { std::unique_ptr output; - RETURN_IF_FAIL(local_loop_controller::create(output, cfg, trace_logger, status)); + std::unique_ptr transport; + RETURN_IF_FAIL(local_loop_controller::create(output, cfg, std::move(transport), trace_logger, status)); test_local_loop_controller = output.get(); retval = std::move(output); return error_code::success; From cf8d158b5b860c42bb2bd0af70173256c11e4881 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Tue, 16 May 2023 09:52:10 -0400 Subject: [PATCH 09/11] clanbg --- rlclientlib/federation/apim_federated_client.cc | 11 ++++++++--- rlclientlib/federation/apim_federated_client.h | 8 +++++--- rlclientlib/federation/local_loop_controller.cc | 8 ++++++-- rlclientlib/federation/local_loop_controller.h | 3 ++- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/rlclientlib/federation/apim_federated_client.cc b/rlclientlib/federation/apim_federated_client.cc index ed2dad4ae..e42aa222b 100644 --- a/rlclientlib/federation/apim_federated_client.cc +++ b/rlclientlib/federation/apim_federated_client.cc @@ -15,8 +15,12 @@ namespace reinforcement_learning { -apim_federated_client::apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger, std::unique_ptr transport) - : _current_model(std::move(initial_model)), _state(state_t::model_available), _trace_logger(trace_logger), _transport(std::move(transport)) +apim_federated_client::apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger, + std::unique_ptr transport) + : _current_model(std::move(initial_model)) + , _state(state_t::model_available) + , _trace_logger(trace_logger) + , _transport(std::move(transport)) { } @@ -36,7 +40,8 @@ int apim_federated_client::create(std::unique_ptr& output, c auto workspace = VW::initialize_experimental(std::move(args), nullptr, nullptr, nullptr, &logger); workspace->id += "/0"; // initialize iteration id to 0 - output = std::unique_ptr(new apim_federated_client(std::move(workspace), trace_logger, std::move(transport))); + output = std::unique_ptr( + new apim_federated_client(std::move(workspace), trace_logger, std::move(transport))); return error_code::success; } diff --git a/rlclientlib/federation/apim_federated_client.h b/rlclientlib/federation/apim_federated_client.h index 538f2b9c5..36aa1fd32 100644 --- a/rlclientlib/federation/apim_federated_client.h +++ b/rlclientlib/federation/apim_federated_client.h @@ -1,8 +1,8 @@ #pragma once #include "configuration.h" -#include "model_mgmt.h" #include "federation/federated_client.h" +#include "model_mgmt.h" #include "trace_logger.h" #include "vw/core/vw_fwd.h" @@ -13,7 +13,8 @@ class apim_federated_client : i_federated_client public: RL_ATTR(nodiscard) static int create(std::unique_ptr& output, const utility::configuration& config, - i_trace* trace_logger = nullptr, std::unique_ptr transport = nullptr, api_status* status = nullptr); + i_trace* trace_logger = nullptr, std::unique_ptr transport = nullptr, + api_status* status = nullptr); RL_ATTR(nodiscard) int try_get_model(const std::string& app_id, @@ -31,7 +32,8 @@ class apim_federated_client : i_federated_client model_retrieved }; - apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger, std::unique_ptr transport); + apim_federated_client(std::unique_ptr initial_model, i_trace* trace_logger, + std::unique_ptr transport); state_t _state; std::unique_ptr _current_model; diff --git a/rlclientlib/federation/local_loop_controller.cc b/rlclientlib/federation/local_loop_controller.cc index 794a96f9a..15aa28c2e 100644 --- a/rlclientlib/federation/local_loop_controller.cc +++ b/rlclientlib/federation/local_loop_controller.cc @@ -11,7 +11,8 @@ namespace reinforcement_learning { int local_loop_controller::create(std::unique_ptr& output, - const reinforcement_learning::utility::configuration& config, std::unique_ptr transport, i_trace* trace_logger, api_status* status) + const reinforcement_learning::utility::configuration& config, + std::unique_ptr transport, i_trace* trace_logger, api_status* status) { std::string app_id = config.get(name::APP_ID, ""); @@ -22,7 +23,10 @@ int local_loop_controller::create(std::unique_ptr& output { RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status)); } - else { RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, std::move(transport), status)); } + else + { + RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, std::move(transport), status)); + } RETURN_IF_FAIL(trainable_vw_model::create(trainable_model, config, trace_logger, status)); RETURN_IF_FAIL(sender_joined_log_provider::create(sender_joiner, config, trace_logger, status)); diff --git a/rlclientlib/federation/local_loop_controller.h b/rlclientlib/federation/local_loop_controller.h index 7946f1570..f065acc5d 100644 --- a/rlclientlib/federation/local_loop_controller.h +++ b/rlclientlib/federation/local_loop_controller.h @@ -25,7 +25,8 @@ class local_loop_controller : public model_management::i_data_transport public: RL_ATTR(nodiscard) static int create(std::unique_ptr& output, const utility::configuration& config, - std::unique_ptr transport = nullptr, i_trace* trace_logger = nullptr, api_status* status = nullptr); + std::unique_ptr transport = nullptr, i_trace* trace_logger = nullptr, + api_status* status = nullptr); // Get model data in binary format // This will perform joining and training on any observed events, and then return the updated model From a7a9aae036fb268a890b76647aea09cce279b0ef Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Tue, 16 May 2023 14:35:47 -0400 Subject: [PATCH 10/11] federated controller naming --- include/constants.h | 4 +-- rlclientlib/CMakeLists.txt | 4 +-- rlclientlib/factory_resolver.cc | 10 +++---- ...roller.cc => federated_loop_controller.cc} | 16 +++++------ ...ntroller.h => federated_loop_controller.h} | 10 +++---- rlclientlib/live_model_impl.cc | 6 ++-- unit_test/CMakeLists.txt | 2 +- ...t.cc => federated_loop_controller_test.cc} | 28 +++++++++---------- unit_test/local_loop_end_to_end.cc | 16 +++++------ 9 files changed, 48 insertions(+), 48 deletions(-) rename rlclientlib/federation/{local_loop_controller.cc => federated_loop_controller.cc} (82%) rename rlclientlib/federation/{local_loop_controller.h => federated_loop_controller.h} (85%) rename unit_test/{local_loop_controller_test.cc => federated_loop_controller_test.cc} (85%) diff --git a/include/constants.h b/include/constants.h index 1b4962230..0bb94f579 100644 --- a/include/constants.h +++ b/include/constants.h @@ -175,9 +175,9 @@ const char* const REWARD_FUNCTION_MAX = "REWARD_FUNCTION_MAX"; // These are outside of #ifdef section so that we can recognize them as invalid // configuration options when rlclientlib is compiled without RL_BUILD_FEDERATION // -// Use local_loop_controller for model data +// Use federated_loop_controller for model data const char* const LOCAL_LOOP_MODEL_DATA = "LOCAL_LOOP_MODEL_DATA"; -// Send events to local_loop_controller +// Send events to federated_loop_controller const char* const LOCAL_LOOP_SENDER = "LOCAL_LOOP_SENDER"; const char* get_default_episode_sender(); diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index 73ac9973f..92d74fc6d 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -121,7 +121,7 @@ if(RL_BUILD_FEDERATION) list(APPEND PROJECT_SOURCES federation/apim_federated_client.cc federation/local_client.cc - federation/local_loop_controller.cc + federation/federated_loop_controller.cc federation/sender_joined_log_provider.cc federation/vw_trainable_model.cc ) @@ -211,7 +211,7 @@ if(RL_BUILD_FEDERATION) federation/joined_log_provider.h federation/local_client.h federation/apim_federated_client.h - federation/local_loop_controller.h + federation/federated_loop_controller.h federation/sender_joined_log_provider.h federation/vw_trainable_model.h ) diff --git a/rlclientlib/factory_resolver.cc b/rlclientlib/factory_resolver.cc index 0a9dbc3c0..850ce169e 100644 --- a/rlclientlib/factory_resolver.cc +++ b/rlclientlib/factory_resolver.cc @@ -10,7 +10,7 @@ #include "vw_model/vw_model.h" #ifdef RL_BUILD_FEDERATION -# include "federation/local_loop_controller.h" +# include "federation/federated_loop_controller.h" #endif #ifdef USE_AZURE_FACTORIES @@ -121,13 +121,13 @@ int file_model_loader_create(std::unique_ptr& retval, const } #ifdef RL_BUILD_FEDERATION -int local_loop_controller_create(std::unique_ptr& retval, const u::configuration& config, +int federated_loop_controller_create(std::unique_ptr& retval, const u::configuration& config, i_trace* trace_logger, api_status* status) { TRACE_INFO(trace_logger, "Local loop controller i_data_transport created."); - std::unique_ptr output; + std::unique_ptr output; std::unique_ptr transport; - RETURN_IF_FAIL(local_loop_controller::create(output, config, std::move(transport), trace_logger, status)); + RETURN_IF_FAIL(federated_loop_controller::create(output, config, std::move(transport), trace_logger, status)); retval = std::move(output); return error_code::success; } @@ -159,7 +159,7 @@ void factory_initializer::register_default_factories() data_transport_factory.register_type(value::FILE_MODEL_DATA, file_model_loader_create); #ifdef RL_BUILD_FEDERATION - data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, local_loop_controller_create); + data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, federated_loop_controller_create); #else data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, [](std::unique_ptr&, const u::configuration&, i_trace* trace_logger, api_status* status) diff --git a/rlclientlib/federation/local_loop_controller.cc b/rlclientlib/federation/federated_loop_controller.cc similarity index 82% rename from rlclientlib/federation/local_loop_controller.cc rename to rlclientlib/federation/federated_loop_controller.cc index 15aa28c2e..16a893a99 100644 --- a/rlclientlib/federation/local_loop_controller.cc +++ b/rlclientlib/federation/federated_loop_controller.cc @@ -1,4 +1,4 @@ -#include "federation/local_loop_controller.h" +#include "federation/federated_loop_controller.h" #include "constants.h" #include "err_constants.h" @@ -10,7 +10,7 @@ namespace reinforcement_learning { -int local_loop_controller::create(std::unique_ptr& output, +int federated_loop_controller::create(std::unique_ptr& output, const reinforcement_learning::utility::configuration& config, std::unique_ptr transport, i_trace* trace_logger, api_status* status) { @@ -39,12 +39,12 @@ int local_loop_controller::create(std::unique_ptr& output joiner = std::static_pointer_cast(sender_joiner_shared); event_sink = std::static_pointer_cast(sender_joiner_shared); - output = std::unique_ptr(new local_loop_controller(std::move(app_id), + output = std::unique_ptr(new federated_loop_controller(std::move(app_id), std::move(federated_client), std::move(trainable_model), std::move(joiner), std::move(event_sink))); return error_code::success; } -local_loop_controller::local_loop_controller(std::string app_id, std::unique_ptr&& federated_client, +federated_loop_controller::federated_loop_controller(std::string app_id, std::unique_ptr&& federated_client, std::unique_ptr&& trainable_model, std::shared_ptr&& joiner, std::shared_ptr&& event_sink) : _app_id(std::move(app_id)) @@ -55,7 +55,7 @@ local_loop_controller::local_loop_controller(std::string app_id, std::unique_ptr { } -int local_loop_controller::update_global(api_status* status) +int federated_loop_controller::update_global(api_status* status) { // ask for a new global model model_management::model_data data; @@ -79,7 +79,7 @@ int local_loop_controller::update_global(api_status* status) return error_code::success; } -int local_loop_controller::update_local(api_status* status) +int federated_loop_controller::update_local(api_status* status) { std::unique_ptr binary_log; RETURN_IF_FAIL(_joiner->invoke_join(binary_log, status)); @@ -87,13 +87,13 @@ int local_loop_controller::update_local(api_status* status) return error_code::success; } -int local_loop_controller::get_data(model_management::model_data& data, api_status* status) +int federated_loop_controller::get_data(model_management::model_data& data, api_status* status) { RETURN_IF_FAIL(update_global(status)); RETURN_IF_FAIL(_trainable_model->get_data(data, status)); return error_code::success; } -std::unique_ptr local_loop_controller::get_local_sender() { return _event_sink->get_sender_proxy(); } +std::unique_ptr federated_loop_controller::get_local_sender() { return _event_sink->get_sender_proxy(); } } // namespace reinforcement_learning diff --git a/rlclientlib/federation/local_loop_controller.h b/rlclientlib/federation/federated_loop_controller.h similarity index 85% rename from rlclientlib/federation/local_loop_controller.h rename to rlclientlib/federation/federated_loop_controller.h index f065acc5d..1ecf8f5ee 100644 --- a/rlclientlib/federation/local_loop_controller.h +++ b/rlclientlib/federation/federated_loop_controller.h @@ -16,15 +16,15 @@ namespace reinforcement_learning { -// The local_loop_controller will "plug in" to rlclientlib as an i_data_transport object. +// The federated_loop_controller will "plug in" to rlclientlib as an i_data_transport object. // It exposes a get_local_sender_factory() function that creates i_sender proxy objects. // These proxy objects will send events to its internal event sink. // The initialization code for live_model_impl must register this factory function correctly. -class local_loop_controller : public model_management::i_data_transport +class federated_loop_controller : public model_management::i_data_transport { public: RL_ATTR(nodiscard) - static int create(std::unique_ptr& output, const utility::configuration& config, + static int create(std::unique_ptr& output, const utility::configuration& config, std::unique_ptr transport = nullptr, i_trace* trace_logger = nullptr, api_status* status = nullptr); @@ -36,11 +36,11 @@ class local_loop_controller : public model_management::i_data_transport // Returns a i_sender proxy object to be used for sending events to the internal event sink std::unique_ptr get_local_sender(); - virtual ~local_loop_controller() = default; + virtual ~federated_loop_controller() = default; protected: // Constructor is private because objects should be created using the factory function - local_loop_controller(std::string app_id, std::unique_ptr&& federated_client, + federated_loop_controller(std::string app_id, std::unique_ptr&& federated_client, std::unique_ptr&& trainable_model, std::shared_ptr&& joiner, std::shared_ptr&& event_sink); diff --git a/rlclientlib/live_model_impl.cc b/rlclientlib/live_model_impl.cc index f2f3c4907..1379b456b 100644 --- a/rlclientlib/live_model_impl.cc +++ b/rlclientlib/live_model_impl.cc @@ -18,7 +18,7 @@ #include "vw_model/safe_vw.h" #ifdef RL_BUILD_FEDERATION -# include "federation/local_loop_controller.h" +# include "federation/federated_loop_controller.h" #endif #include @@ -696,10 +696,10 @@ int live_model_impl::init_local_loop(api_status* status) // This function should only be called when the configuration is set to use LOCAL_LOOP_MODEL_DATA assert(model_src == value::LOCAL_LOOP_MODEL_DATA); - // Creating i_data_transport with type LOCAL_LOOP_MODEL_DATA results in local_loop_controller + // Creating i_data_transport with type LOCAL_LOOP_MODEL_DATA results in federated_loop_controller std::unique_ptr output; RETURN_IF_FAIL(_t_factory->create(output, model_src, _configuration, _trace_logger.get(), status)); - std::unique_ptr llc(reinterpret_cast(output.release())); + std::unique_ptr llc(reinterpret_cast(output.release())); // Create senders with default sender implementation set to LOCAL_LOOP_SENDER std::string interaction_sender_type = diff --git a/unit_test/CMakeLists.txt b/unit_test/CMakeLists.txt index 65be79147..3b12156f9 100644 --- a/unit_test/CMakeLists.txt +++ b/unit_test/CMakeLists.txt @@ -52,7 +52,7 @@ if(RL_BUILD_FEDERATION) list(APPEND TEST_SOURCES eud_test.cc local_client_test.cc - local_loop_controller_test.cc + federated_loop_controller_test.cc local_loop_end_to_end.cc sender_joined_log_provider_test.cc trainable_model_test.cc diff --git a/unit_test/local_loop_controller_test.cc b/unit_test/federated_loop_controller_test.cc similarity index 85% rename from unit_test/local_loop_controller_test.cc rename to unit_test/federated_loop_controller_test.cc index 47d62792b..4043bd6df 100644 --- a/unit_test/local_loop_controller_test.cc +++ b/unit_test/federated_loop_controller_test.cc @@ -7,7 +7,7 @@ #include "federation/event_sink.h" #include "federation/federated_client.h" #include "federation/local_client.h" -#include "federation/local_loop_controller.h" +#include "federation/federated_loop_controller.h" #include "federation/sender_joined_log_provider.h" #include "vw/core/shared_data.h" #include "vw/core/vw.h" @@ -16,19 +16,19 @@ using namespace reinforcement_learning; namespace { -// Wrapper around local_loop_controller to allow us to access member variables -class test_local_loop_controller : public local_loop_controller +// Wrapper around federated_loop_controller to allow us to access member variables +class test_federated_loop_controller : public federated_loop_controller { public: - test_local_loop_controller(std::string app_id, std::unique_ptr&& federated_client, + test_federated_loop_controller(std::string app_id, std::unique_ptr&& federated_client, std::unique_ptr&& trainable_model, std::shared_ptr&& joiner, std::shared_ptr&& event_sink) - : local_loop_controller(std::move(app_id), std::move(federated_client), std::move(trainable_model), + : federated_loop_controller(std::move(app_id), std::move(federated_client), std::move(trainable_model), std::move(joiner), std::move(event_sink)) { } - virtual ~test_local_loop_controller() = default; + virtual ~test_federated_loop_controller() = default; i_federated_client* get_client() { return _federated_client.get(); } trainable_vw_model* get_model() { return _trainable_model.get(); } @@ -121,7 +121,7 @@ utility::configuration get_test_config() return config; } -std::unique_ptr create_test_local_loop_controller(utility::configuration config) +std::unique_ptr create_test_federated_loop_controller(utility::configuration config) { std::unique_ptr trainable_model; std::unique_ptr sender_joiner; @@ -132,16 +132,16 @@ std::unique_ptr create_test_local_loop_controller(utility std::shared_ptr event_sink(new mock_event_sink()); std::unique_ptr federated_client(new mock_federated_client()); - return std::unique_ptr(new test_local_loop_controller("test_app_id", + return std::unique_ptr(new test_federated_loop_controller("test_app_id", std::move(federated_client), std::move(trainable_model), std::move(joiner), std::move(event_sink))); } } // namespace BOOST_AUTO_TEST_CASE(sender_factory_test) { - // create the local_loop_controller + // create the federated_loop_controller auto config = get_test_config(); - auto test_llc = create_test_local_loop_controller(config); + auto test_llc = create_test_federated_loop_controller(config); // create a sender and send some data std::unique_ptr sender = test_llc->get_local_sender(); @@ -153,7 +153,7 @@ BOOST_AUTO_TEST_CASE(sender_factory_test) sender->send(buffer_in); // get the data out of event sink - auto event_sink_out = dynamic_cast(test_llc.get())->get_event_sink(); + auto event_sink_out = dynamic_cast(test_llc.get())->get_event_sink(); BOOST_CHECK_NE(event_sink_out, nullptr); auto buffer_out = dynamic_cast(event_sink_out)->get_latest_event(); BOOST_CHECK_NE(buffer_out, nullptr); @@ -162,10 +162,10 @@ BOOST_AUTO_TEST_CASE(sender_factory_test) BOOST_AUTO_TEST_CASE(update_get_model_data) { - // create the local_loop_controller + // create the federated_loop_controller auto config = get_test_config(); - auto llc = create_test_local_loop_controller(config); - auto test_llc = dynamic_cast(llc.get()); + auto llc = create_test_federated_loop_controller(config); + auto test_llc = dynamic_cast(llc.get()); BOOST_CHECK_NE(test_llc, nullptr); auto mock_client = dynamic_cast(test_llc->get_client()); BOOST_CHECK_NE(mock_client, nullptr); diff --git a/unit_test/local_loop_end_to_end.cc b/unit_test/local_loop_end_to_end.cc index 362c21038..9173e7bb1 100644 --- a/unit_test/local_loop_end_to_end.cc +++ b/unit_test/local_loop_end_to_end.cc @@ -2,7 +2,7 @@ #include "common_test_utils.h" #include "constants.h" -#include "federation/local_loop_controller.h" +#include "federation/federated_loop_controller.h" #include "live_model.h" #include "ranking_response.h" @@ -96,17 +96,17 @@ BOOST_AUTO_TEST_CASE(local_loop_end_to_end_test) auto config = get_test_config(); // create a custom data_transport_factory_t that saves a pointer - // to the local_loop_controller that was created - local_loop_controller* test_local_loop_controller = nullptr; + // to the federated_loop_controller that was created + federated_loop_controller* test_federated_loop_controller = nullptr; data_transport_factory_t test_data_transport_factory; test_data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, [&](std::unique_ptr& retval, const utility::configuration& cfg, i_trace* trace_logger, api_status* status) { - std::unique_ptr output; + std::unique_ptr output; std::unique_ptr transport; - RETURN_IF_FAIL(local_loop_controller::create(output, cfg, std::move(transport), trace_logger, status)); - test_local_loop_controller = output.get(); + RETURN_IF_FAIL(federated_loop_controller::create(output, cfg, std::move(transport), trace_logger, status)); + test_federated_loop_controller = output.get(); retval = std::move(output); return error_code::success; }); @@ -117,7 +117,7 @@ BOOST_AUTO_TEST_CASE(local_loop_end_to_end_test) config, nullptr, nullptr, &reinforcement_learning::trace_logger_factory, &test_data_transport_factory); model.init(&status); BOOST_TEST(status.get_error_code() == error_code::success, status.get_error_msg()); - BOOST_CHECK_NE(test_local_loop_controller, nullptr); + BOOST_CHECK_NE(test_federated_loop_controller, nullptr); // do some inference calls and report the outcome constexpr int iterations = 100; @@ -130,7 +130,7 @@ BOOST_AUTO_TEST_CASE(local_loop_end_to_end_test) // check that updated model has learned from previous outcomes model_management::model_data model_data; - test_local_loop_controller->get_data(model_data, &status); + test_federated_loop_controller->get_data(model_data, &status); BOOST_TEST(status.get_error_code() == error_code::success, status.get_error_msg()); auto vw = test_utils::create_vw(config.get(name::MODEL_VW_INITIAL_COMMAND_LINE, nullptr), model_data); BOOST_CHECK_EQUAL(vw->sd->weighted_labeled_examples, iterations); From 3f59665b4cea0e732a2b5566e386b09b3efb87e4 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Tue, 16 May 2023 14:37:01 -0400 Subject: [PATCH 11/11] clang --- rlclientlib/federation/federated_loop_controller.cc | 6 +++--- unit_test/federated_loop_controller_test.cc | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rlclientlib/federation/federated_loop_controller.cc b/rlclientlib/federation/federated_loop_controller.cc index 16a893a99..64aca8a1f 100644 --- a/rlclientlib/federation/federated_loop_controller.cc +++ b/rlclientlib/federation/federated_loop_controller.cc @@ -44,9 +44,9 @@ int federated_loop_controller::create(std::unique_ptr return error_code::success; } -federated_loop_controller::federated_loop_controller(std::string app_id, std::unique_ptr&& federated_client, - std::unique_ptr&& trainable_model, std::shared_ptr&& joiner, - std::shared_ptr&& event_sink) +federated_loop_controller::federated_loop_controller(std::string app_id, + std::unique_ptr&& federated_client, std::unique_ptr&& trainable_model, + std::shared_ptr&& joiner, std::shared_ptr&& event_sink) : _app_id(std::move(app_id)) , _federated_client(std::move(federated_client)) , _trainable_model(std::move(trainable_model)) diff --git a/unit_test/federated_loop_controller_test.cc b/unit_test/federated_loop_controller_test.cc index 4043bd6df..60b2ff237 100644 --- a/unit_test/federated_loop_controller_test.cc +++ b/unit_test/federated_loop_controller_test.cc @@ -6,8 +6,8 @@ #include "err_constants.h" #include "federation/event_sink.h" #include "federation/federated_client.h" -#include "federation/local_client.h" #include "federation/federated_loop_controller.h" +#include "federation/local_client.h" #include "federation/sender_joined_log_provider.h" #include "vw/core/shared_data.h" #include "vw/core/vw.h"