Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: federated delta endpoint #583

Open
wants to merge 12 commits into
base: local_loop_prototype
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ const char* const REWARD_FUNCTION_MAX = "REWARD_FUNCTION_MAX";
// These are outside of #ifdef section so that we can recognize them as invalid
// configuration options when rlclientlib is compiled without RL_BUILD_FEDERATION
//
// Use local_loop_controller for model data
// Use federated_loop_controller for model data
const char* const LOCAL_LOOP_MODEL_DATA = "LOCAL_LOOP_MODEL_DATA";
// Send events to local_loop_controller
// Send events to federated_loop_controller
const char* const LOCAL_LOOP_SENDER = "LOCAL_LOOP_SENDER";

const char* get_default_episode_sender();
Expand Down
6 changes: 4 additions & 2 deletions rlclientlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ endif()

if(RL_BUILD_FEDERATION)
list(APPEND PROJECT_SOURCES
federation/apim_federated_client.cc
federation/local_client.cc
federation/local_loop_controller.cc
federation/federated_loop_controller.cc
federation/sender_joined_log_provider.cc
federation/vw_trainable_model.cc
)
Expand Down Expand Up @@ -209,7 +210,8 @@ if(RL_BUILD_FEDERATION)
federation/federated_client.h
federation/joined_log_provider.h
federation/local_client.h
federation/local_loop_controller.h
federation/apim_federated_client.h
federation/federated_loop_controller.h
federation/sender_joined_log_provider.h
federation/vw_trainable_model.h
)
Expand Down
11 changes: 6 additions & 5 deletions rlclientlib/factory_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "vw_model/vw_model.h"

#ifdef RL_BUILD_FEDERATION
# include "federation/local_loop_controller.h"
# include "federation/federated_loop_controller.h"
#endif

#ifdef USE_AZURE_FACTORIES
Expand Down Expand Up @@ -121,12 +121,13 @@ int file_model_loader_create(std::unique_ptr<m::i_data_transport>& retval, const
}

#ifdef RL_BUILD_FEDERATION
int local_loop_controller_create(std::unique_ptr<m::i_data_transport>& retval, const u::configuration& config,
int federated_loop_controller_create(std::unique_ptr<m::i_data_transport>& retval, const u::configuration& config,
i_trace* trace_logger, api_status* status)
{
TRACE_INFO(trace_logger, "Local loop controller i_data_transport created.");
std::unique_ptr<local_loop_controller> output;
RETURN_IF_FAIL(local_loop_controller::create(output, config, trace_logger, status));
std::unique_ptr<federated_loop_controller> output;
std::unique_ptr<model_management::i_data_transport> transport;
RETURN_IF_FAIL(federated_loop_controller::create(output, config, std::move(transport), trace_logger, status));
retval = std::move(output);
return error_code::success;
}
Expand Down Expand Up @@ -158,7 +159,7 @@ void factory_initializer::register_default_factories()
data_transport_factory.register_type(value::FILE_MODEL_DATA, file_model_loader_create);

#ifdef RL_BUILD_FEDERATION
data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, local_loop_controller_create);
data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA, federated_loop_controller_create);
#else
data_transport_factory.register_type(value::LOCAL_LOOP_MODEL_DATA,
[](std::unique_ptr<m::i_data_transport>&, const u::configuration&, i_trace* trace_logger, api_status* status)
Expand Down
95 changes: 95 additions & 0 deletions rlclientlib/federation/apim_federated_client.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "federation/apim_federated_client.h"

#include "api_status.h"
#include "constants.h"
#include "err_constants.h"
#include "trace_logger.h"
#include "utility/vw_logger_adapter.h"
#include "vw/config/options_cli.h"
#include "vw/core/global_data.h"
#include "vw/core/io_buf.h"
#include "vw/core/merge.h"
#include "vw/core/parse_primitives.h"
#include "vw/core/vw.h"
#include "vw/io/io_adapter.h"

namespace reinforcement_learning
{
apim_federated_client::apim_federated_client(std::unique_ptr<VW::workspace> initial_model, i_trace* trace_logger,
std::unique_ptr<model_management::i_data_transport> transport)
: _current_model(std::move(initial_model))
, _state(state_t::model_available)
, _trace_logger(trace_logger)
, _transport(std::move(transport))
{
}

apim_federated_client::~apim_federated_client() = default;

int apim_federated_client::create(std::unique_ptr<i_federated_client>& output, const utility::configuration& config,
i_trace* trace_logger, std::unique_ptr<model_management::i_data_transport> transport, api_status* status)
{
std::string cmd_line = "--cb_explore_adf --json --quiet --epsilon 0.0 --first_only --id ";
cmd_line += config.get("id", "default_id");
// Create empty model based on ML args on first call
std::string initial_command_line(config.get(name::MODEL_VW_INITIAL_COMMAND_LINE, cmd_line.c_str()));

// TODO try catch
auto args = VW::make_unique<VW::config::options_cli>(VW::split_command_line(initial_command_line));
auto logger = utility::make_vw_trace_logger(trace_logger);
auto workspace = VW::initialize_experimental(std::move(args), nullptr, nullptr, nullptr, &logger);
workspace->id += "/0"; // initialize iteration id to 0

output = std::unique_ptr<i_federated_client>(
new apim_federated_client(std::move(workspace), trace_logger, std::move(transport)));
return error_code::success;
}

int apim_federated_client::try_get_model(const std::string& app_id,
/* inout */ model_management::model_data& data, /* out */ bool& model_received, api_status* status)
{
_transport->get_data(data, status);
return error_code::success;
}

int apim_federated_client::report_result(const uint8_t* payload, size_t size, api_status* status)
{
switch (_state)
{
case state_t::model_available:
{
RETURN_ERROR_LS(_trace_logger, status, invalid_argument)
<< "Cannot call report_result again until try_get_model has been called.";
}
break;
case state_t::model_retrieved:
{
// Payload must be a delta
// Apply delta to current model and move into model available state.
auto view = VW::io::create_buffer_view(reinterpret_cast<const char*>(payload), size);
auto delta = VW::model_delta::deserialize(*view);
auto new_model = *_current_model + *delta;

// Increment iteration id for new workspace
try
{
int iteration_id = std::stoi(_current_model->id.substr(_current_model->id.find('/') + 1, std::string::npos));
iteration_id++;
new_model->id = _current_model->id.substr(0, _current_model->id.find('/')) + "/" + std::to_string(iteration_id);
}
catch (const std::exception& e)
{
RETURN_ERROR_ARG(_trace_logger, status, model_update_error, e.what());
}

// Update current model
_current_model.reset(new_model.release());
_state = state_t::model_available;
}
break;
default:
RETURN_ERROR_LS(_trace_logger, status, invalid_argument) << "Invalid state.";
}
return error_code::success;
}
} // namespace reinforcement_learning
46 changes: 46 additions & 0 deletions rlclientlib/federation/apim_federated_client.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "configuration.h"
#include "federation/federated_client.h"
#include "model_mgmt.h"
#include "trace_logger.h"
#include "vw/core/vw_fwd.h"

namespace reinforcement_learning
{
class apim_federated_client : i_federated_client
{
public:
RL_ATTR(nodiscard)
static int create(std::unique_ptr<i_federated_client>& output, const utility::configuration& config,
i_trace* trace_logger = nullptr, std::unique_ptr<model_management::i_data_transport> transport = nullptr,
api_status* status = nullptr);

RL_ATTR(nodiscard)
int try_get_model(const std::string& app_id,
/* inout */ model_management::model_data& data, /* out */ bool& model_received,
api_status* status = nullptr) override;

RL_ATTR(nodiscard) int report_result(const uint8_t* payload, size_t size, api_status* status = nullptr) override;

~apim_federated_client() override;

private:
enum class state_t
{
model_available,
model_retrieved
};

apim_federated_client(std::unique_ptr<VW::workspace> initial_model, i_trace* trace_logger,
std::unique_ptr<model_management::i_data_transport> transport);

state_t _state;
std::unique_ptr<VW::workspace> _current_model;
i_trace* _trace_logger;
std::unique_ptr<model_management::i_data_transport> _transport;
};

// Read MODEL_VW_INITIAL_COMMAND_LINE

} // namespace reinforcement_learning
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
#include "federation/local_loop_controller.h"
#include "federation/federated_loop_controller.h"

#include "constants.h"
#include "err_constants.h"
#include "federation/apim_federated_client.h"
#include "federation/local_client.h"
#include "federation/sender_joined_log_provider.h"
#include "model_mgmt.h"
#include "vw/io/io_adapter.h"

namespace reinforcement_learning
{
int local_loop_controller::create(std::unique_ptr<local_loop_controller>& output,
const reinforcement_learning::utility::configuration& config, i_trace* trace_logger, api_status* status)
int federated_loop_controller::create(std::unique_ptr<federated_loop_controller>& output,
const reinforcement_learning::utility::configuration& config,
std::unique_ptr<model_management::i_data_transport> transport, i_trace* trace_logger, api_status* status)
{
std::string app_id = config.get(name::APP_ID, "");

std::unique_ptr<i_federated_client> federated_client;
std::unique_ptr<trainable_vw_model> trainable_model;
std::unique_ptr<sender_joined_log_provider> sender_joiner;
RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status));
if (config.get("federated_client", "local") == "local")
{
RETURN_IF_FAIL(local_client::create(federated_client, config, trace_logger, status));
}
else
{
RETURN_IF_FAIL(apim_federated_client::create(federated_client, config, trace_logger, std::move(transport), status));
}
RETURN_IF_FAIL(trainable_vw_model::create(trainable_model, config, trace_logger, status));
RETURN_IF_FAIL(sender_joined_log_provider::create(sender_joiner, config, trace_logger, status));

Expand All @@ -30,14 +39,14 @@ int local_loop_controller::create(std::unique_ptr<local_loop_controller>& output
joiner = std::static_pointer_cast<i_joined_log_provider>(sender_joiner_shared);
event_sink = std::static_pointer_cast<i_event_sink>(sender_joiner_shared);

output = std::unique_ptr<local_loop_controller>(new local_loop_controller(std::move(app_id),
output = std::unique_ptr<federated_loop_controller>(new federated_loop_controller(std::move(app_id),
std::move(federated_client), std::move(trainable_model), std::move(joiner), std::move(event_sink)));
return error_code::success;
}

local_loop_controller::local_loop_controller(std::string app_id, std::unique_ptr<i_federated_client>&& federated_client,
std::unique_ptr<trainable_vw_model>&& trainable_model, std::shared_ptr<i_joined_log_provider>&& joiner,
std::shared_ptr<i_event_sink>&& event_sink)
federated_loop_controller::federated_loop_controller(std::string app_id,
std::unique_ptr<i_federated_client>&& federated_client, std::unique_ptr<trainable_vw_model>&& trainable_model,
std::shared_ptr<i_joined_log_provider>&& joiner, std::shared_ptr<i_event_sink>&& event_sink)
: _app_id(std::move(app_id))
, _federated_client(std::move(federated_client))
, _trainable_model(std::move(trainable_model))
Expand All @@ -46,7 +55,7 @@ local_loop_controller::local_loop_controller(std::string app_id, std::unique_ptr
{
}

int local_loop_controller::update_global(api_status* status)
int federated_loop_controller::update_global(api_status* status)
{
// ask for a new global model
model_management::model_data data;
Expand All @@ -70,21 +79,21 @@ int local_loop_controller::update_global(api_status* status)
return error_code::success;
}

int local_loop_controller::update_local(api_status* status)
int federated_loop_controller::update_local(api_status* status)
{
std::unique_ptr<VW::io::reader> binary_log;
RETURN_IF_FAIL(_joiner->invoke_join(binary_log, status));
RETURN_IF_FAIL(_trainable_model->learn(std::move(binary_log), status));
return error_code::success;
}

int local_loop_controller::get_data(model_management::model_data& data, api_status* status)
int federated_loop_controller::get_data(model_management::model_data& data, api_status* status)
{
RETURN_IF_FAIL(update_global(status));
RETURN_IF_FAIL(_trainable_model->get_data(data, status));
return error_code::success;
}

std::unique_ptr<i_sender> local_loop_controller::get_local_sender() { return _event_sink->get_sender_proxy(); }
std::unique_ptr<i_sender> federated_loop_controller::get_local_sender() { return _event_sink->get_sender_proxy(); }

} // namespace reinforcement_learning
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@

namespace reinforcement_learning
{
// The local_loop_controller will "plug in" to rlclientlib as an i_data_transport object.
// The federated_loop_controller will "plug in" to rlclientlib as an i_data_transport object.
// It exposes a get_local_sender_factory() function that creates i_sender proxy objects.
// These proxy objects will send events to its internal event sink.
// The initialization code for live_model_impl must register this factory function correctly.
class local_loop_controller : public model_management::i_data_transport
class federated_loop_controller : public model_management::i_data_transport
{
public:
RL_ATTR(nodiscard)
static int create(std::unique_ptr<local_loop_controller>& output, const utility::configuration& config,
i_trace* trace_logger = nullptr, api_status* status = nullptr);
static int create(std::unique_ptr<federated_loop_controller>& output, const utility::configuration& config,
std::unique_ptr<model_management::i_data_transport> transport = nullptr, i_trace* trace_logger = nullptr,
api_status* status = nullptr);

// Get model data in binary format
// This will perform joining and training on any observed events, and then return the updated model
Expand All @@ -35,11 +36,11 @@ class local_loop_controller : public model_management::i_data_transport
// Returns a i_sender proxy object to be used for sending events to the internal event sink
std::unique_ptr<i_sender> get_local_sender();

virtual ~local_loop_controller() = default;
virtual ~federated_loop_controller() = default;

protected:
// Constructor is private because objects should be created using the factory function
local_loop_controller(std::string app_id, std::unique_ptr<i_federated_client>&& federated_client,
federated_loop_controller(std::string app_id, std::unique_ptr<i_federated_client>&& federated_client,
std::unique_ptr<trainable_vw_model>&& trainable_model, std::shared_ptr<i_joined_log_provider>&& joiner,
std::shared_ptr<i_event_sink>&& event_sink);

Expand Down
6 changes: 3 additions & 3 deletions rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "vw_model/safe_vw.h"

#ifdef RL_BUILD_FEDERATION
# include "federation/local_loop_controller.h"
# include "federation/federated_loop_controller.h"
#endif

#include <boost/uuid/random_generator.hpp>
Expand Down Expand Up @@ -696,10 +696,10 @@ int live_model_impl::init_local_loop(api_status* status)
// This function should only be called when the configuration is set to use LOCAL_LOOP_MODEL_DATA
assert(model_src == value::LOCAL_LOOP_MODEL_DATA);

// Creating i_data_transport with type LOCAL_LOOP_MODEL_DATA results in local_loop_controller
// Creating i_data_transport with type LOCAL_LOOP_MODEL_DATA results in federated_loop_controller
std::unique_ptr<m::i_data_transport> output;
RETURN_IF_FAIL(_t_factory->create(output, model_src, _configuration, _trace_logger.get(), status));
std::unique_ptr<local_loop_controller> llc(reinterpret_cast<local_loop_controller*>(output.release()));
std::unique_ptr<federated_loop_controller> llc(reinterpret_cast<federated_loop_controller*>(output.release()));

// Create senders with default sender implementation set to LOCAL_LOOP_SENDER
std::string interaction_sender_type =
Expand Down
2 changes: 1 addition & 1 deletion unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ if(RL_BUILD_FEDERATION)
list(APPEND TEST_SOURCES
eud_test.cc
local_client_test.cc
local_loop_controller_test.cc
federated_loop_controller_test.cc
local_loop_end_to_end.cc
sender_joined_log_provider_test.cc
trainable_model_test.cc
Expand Down
Loading