diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index 0e4033c19..17142d9f9 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -33,6 +33,7 @@ jobs: - uses: actions/checkout@v2 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - name: Setup MSVC Developer Command Prompt if: ${{ startsWith(matrix.os, 'windows') }} uses: ilammy/msvc-dev-cmd@v1 diff --git a/.github/workflows/build_rlclientlib.yml b/.github/workflows/build_rlclientlib.yml index 878c99489..3300a5ff2 100644 --- a/.github/workflows/build_rlclientlib.yml +++ b/.github/workflows/build_rlclientlib.yml @@ -29,6 +29,7 @@ jobs: - uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - uses: lukka/get-cmake@latest - name: Install ONNX run: | @@ -83,6 +84,7 @@ jobs: - uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - uses: lukka/get-cmake@latest - run: echo "VCPKG_COMMIT=$(git rev-parse :ext_libs/vcpkg)" >> $GITHUB_ENV shell: bash @@ -124,6 +126,7 @@ jobs: - uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - name: Setup MSVC Developer Command Prompt uses: ilammy/msvc-dev-cmd@v1 - uses: lukka/get-cmake@latest diff --git a/.github/workflows/build_vw_bp.yml b/.github/workflows/build_vw_bp.yml index 044abb11e..21c3e37e7 100644 --- a/.github/workflows/build_vw_bp.yml +++ b/.github/workflows/build_vw_bp.yml @@ -38,6 +38,7 @@ jobs: - uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - name: Setup MSVC Developer Command Prompt if: ${{ startsWith(matrix.config.os, 'windows') }} uses: ilammy/msvc-dev-cmd@v1 diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 0dfaae16a..36726e8ab 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -30,6 +30,7 @@ jobs: uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - uses: lukka/get-cmake@latest - run: echo "VCPKG_COMMIT=$(git rev-parse :ext_libs/vcpkg)" >> $GITHUB_ENV shell: bash @@ -101,6 +102,7 @@ jobs: uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - name: Setup MSVC Developer Command Prompt uses: ilammy/msvc-dev-cmd@v1 - uses: lukka/get-cmake@latest diff --git a/.github/workflows/daily_integration.yml b/.github/workflows/daily_integration.yml index c4f187144..986e10923 100644 --- a/.github/workflows/daily_integration.yml +++ b/.github/workflows/daily_integration.yml @@ -18,6 +18,7 @@ jobs: - uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - name: Update VW to latest shell: bash run: | diff --git a/.github/workflows/run_benchmarks.yml b/.github/workflows/run_benchmarks.yml index 030d8587b..075e2d1fb 100644 --- a/.github/workflows/run_benchmarks.yml +++ b/.github/workflows/run_benchmarks.yml @@ -23,6 +23,7 @@ jobs: - uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - uses: lukka/get-cmake@latest - run: echo "VCPKG_COMMIT=$(git rev-parse :ext_libs/vcpkg)" >> $GITHUB_ENV shell: bash diff --git a/.github/workflows/vcpkg_build.yml b/.github/workflows/vcpkg_build.yml index fbc510a20..4fb2fc27a 100644 --- a/.github/workflows/vcpkg_build.yml +++ b/.github/workflows/vcpkg_build.yml @@ -29,6 +29,7 @@ jobs: - uses: actions/checkout@v3 with: submodules: recursive + - run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow - uses: lukka/get-cmake@latest - run: echo "VCPKG_COMMIT=$(git rev-parse :ext_libs/vcpkg)" >> $GITHUB_ENV shell: bash diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ee51d28b..a0aa41108 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,10 @@ if(RL_BUILD_BENCHMARKS) list(APPEND VCPKG_MANIFEST_FEATURES "benchmarks") endif() +option(RL_LINK_AZURE_LIBS "Whether to build components requiring the use of Azure libraries. Requires C++14 or greater" OFF) +if(RL_LINK_AZURE_LIBS) + list(APPEND VCPKG_MANIFEST_FEATURES "azurelibs") +endif() project(reinforcement_learning) # Add support for building library with latest version of C++ supported by your compiler diff --git a/examples/rl_sim_cpp/CMakeLists.txt b/examples/rl_sim_cpp/CMakeLists.txt index 5bf816e2b..39a3ce35e 100644 --- a/examples/rl_sim_cpp/CMakeLists.txt +++ b/examples/rl_sim_cpp/CMakeLists.txt @@ -1,8 +1,25 @@ -add_executable(rl_sim_cpp.out +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/../../cmake/Modules/") + +set(RL_SIM_SOURCES main.cc person.cc robot_joint.cc rl_sim.cc ) +if(RL_LINK_AZURE_LIBS) + list(APPEND RL_SIM_SOURCES + azure_credentials.cc + ) +endif() + +add_executable(rl_sim_cpp.out + ${RL_SIM_SOURCES} +) target_link_libraries(rl_sim_cpp.out PRIVATE Boost::program_options rlclientlib) + +if(RL_LINK_AZURE_LIBS) + target_compile_definitions(rl_sim_cpp.out PRIVATE LINK_AZURE_LIBS) + find_package(azure-identity-cpp CONFIG REQUIRED) + target_link_libraries(rl_sim_cpp.out PRIVATE Azure::azure-identity) +endif() diff --git a/examples/rl_sim_cpp/azure_credentials.cc b/examples/rl_sim_cpp/azure_credentials.cc new file mode 100644 index 000000000..a2affffd1 --- /dev/null +++ b/examples/rl_sim_cpp/azure_credentials.cc @@ -0,0 +1,67 @@ +#ifdef LINK_AZURE_LIBS +# include "azure_credentials.h" + +# include "err_constants.h" +# include "future_compat.h" + +# include +# include +// These are needed because azure does a bad time conversion +# include +# include +# include +# include + +using namespace reinforcement_learning; + +AzureCredentials::AzureCredentials(const std::string& tenant_id) : _tenant_id(tenant_id), _creds(create_options()) {} + +Azure::Identity::AzureCliCredentialOptions AzureCredentials::create_options() +{ + Azure::Identity::AzureCliCredentialOptions options; + options.TenantId = _tenant_id; + options.AdditionallyAllowedTenants.push_back("*"); + return options; +} + +int AzureCredentials::get_credentials( + const std::vector& scopes, std::string& token_out, std::chrono::system_clock::time_point& expiry_out) +{ +# ifdef HAS_STD14 + Azure::Core::Credentials::TokenRequestContext request_context; + request_context.Scopes = scopes; + // TODO: needed? + request_context.TenantId = _tenant_id; + Azure::Core::Context context; + try + { + auto auth = _creds.GetToken(request_context, context); + token_out = auth.Token; + + // Casting from an azure DateTime object to a time_point does the calculation + // incorrectly. The expiration is returned as a local time, but the library + // assumes that it is GMT, and converts the value incorrectly. + // See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075 + // expiry_out = static_cast(auth.ExpiresOn); + std::string dt_string = auth.ExpiresOn.ToString(); + std::tm tm = {}; + std::istringstream ss(dt_string); + ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); + expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm)); + } + catch (std::exception& e) + { + std::cout << "Error getting auth token: " << e.what(); + return error_code::external_error; + } + catch (...) + { + std::cout << "Unknown error while getting auth token"; + return error_code::external_error; + } +# else +# error Requires C++14 or greater +# endif + return error_code::success; +} +#endif \ No newline at end of file diff --git a/examples/rl_sim_cpp/azure_credentials.h b/examples/rl_sim_cpp/azure_credentials.h new file mode 100644 index 000000000..30c2256e0 --- /dev/null +++ b/examples/rl_sim_cpp/azure_credentials.h @@ -0,0 +1,30 @@ +#pragma once + +#ifdef LINK_AZURE_LIBS +# include "api_status.h" +# include "configuration.h" +# include "future_compat.h" + +# include +# include +# include +# include +# include + +class AzureCredentials +{ +public: + AzureCredentials(const std::string& tenant_id); + int get_credentials(const std::vector& scopes, std::string& token_out, + std::chrono::system_clock::time_point& expiry_out); + +private: + std::string _tenant_id; +# ifdef HAS_STD14 + Azure::Identity::AzureCliCredentialOptions create_options(); + + // Azure::Identity::DefaultAzureCredential _creds; + Azure::Identity::AzureCliCredential _creds; +# endif +}; +#endif \ No newline at end of file diff --git a/examples/rl_sim_cpp/main.cc b/examples/rl_sim_cpp/main.cc index 5852fbaef..47eb72e1c 100644 --- a/examples/rl_sim_cpp/main.cc +++ b/examples/rl_sim_cpp/main.cc @@ -38,7 +38,9 @@ po::variables_map process_cmd_line(const int argc, char** argv) "random_seed", po::value()->default_value(rand()), "Random seed. Default is random")( "delay", po::value()->default_value(2000), "Delay between events in ms")( "quiet", po::bool_switch(), "Suppress logs")("random_ids", po::value()->default_value(true), - "Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats"); + "Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats")( + "azure_oauth_factories", po::value()->default_value(false), "Use oauth for azure factores. Default false")( + "azure_tenant_id", po::value()->default_value(""), "Tenant ID for use with azure oauth factories."); po::variables_map vm; store(parse_command_line(argc, argv, desc), vm); diff --git a/examples/rl_sim_cpp/rl_sim.cc b/examples/rl_sim_cpp/rl_sim.cc index 724de55a5..0cb4be798 100644 --- a/examples/rl_sim_cpp/rl_sim.cc +++ b/examples/rl_sim_cpp/rl_sim.cc @@ -1,6 +1,7 @@ #include "api_status.h" #include "constants.h" #include "factory_resolver.h" +#include "future_compat.h" #include "live_model.h" #include "multistep.h" #include "person.h" @@ -13,6 +14,7 @@ #include #include #include +#include #include using namespace std; @@ -487,6 +489,17 @@ int rl_sim::init_rl() wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::EPISODE_HTTP_API_SENDER)); sender_factory = &factory; } + // probably incompatible with the throughput option? + else if (_options["azure_oauth_factories"].as()) + { +#ifdef LINK_AZURE_LIBS + // Note: This requires C++14 or better + using namespace std::placeholders; + reinforcement_learning::oauth_callback_t callback = + std::bind(&AzureCredentials::get_credentials, &_creds, _1, _2, _3); + reinforcement_learning::register_default_factories_callback(callback); +#endif + } // Initialize the API _rl = std::unique_ptr(new r::live_model(config, _on_error, this, @@ -637,7 +650,12 @@ std::string rl_sim::create_event_id() return oss.str(); } -rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm)), _loop_kind(CB) +rl_sim::rl_sim(boost::program_options::variables_map vm) + : _options(std::move(vm)) + , _loop_kind(CB) +#ifdef LINK_AZURE_LIBS + , _creds(_options["azure_tenant_id"].as()) +#endif { if (_options["ccb"].as()) { _loop_kind = CCB; } else if (_options["slates"].as()) { _loop_kind = Slates; } @@ -699,4 +717,4 @@ std::string get_dist_str(const reinforcement_learning::decision_response& respon } ret += ")"; return ret; -} +} \ No newline at end of file diff --git a/examples/rl_sim_cpp/rl_sim.h b/examples/rl_sim_cpp/rl_sim.h index 8e583e62d..5cd68bd92 100644 --- a/examples/rl_sim_cpp/rl_sim.h +++ b/examples/rl_sim_cpp/rl_sim.h @@ -6,6 +6,7 @@ * @date 2018-07-18 */ #pragma once +#include "azure_credentials.h" #include "live_model.h" #include "person.h" #include "robot_joint.h" @@ -177,4 +178,7 @@ class rl_sim int64_t _delay = 2000; bool _quiet = false; bool _random_ids = true; +#ifdef LINK_AZURE_LIBS + AzureCredentials _creds; +#endif }; diff --git a/include/constants.h b/include/constants.h index 29c770e8b..c4f3c632e 100644 --- a/include/constants.h +++ b/include/constants.h @@ -26,6 +26,7 @@ const char* const LEARNING_MODE = "rank.learning.mode"; const char* const PROTOCOL_VERSION = "protocol.version"; const char* const HTTP_API_KEY = "http.api.key"; const char* const HTTP_API_HEADER_KEY_NAME = "http.api.header.key.name"; +const char* const HTTP_API_OAUTH_TOKEN_TYPE = "http.api.oauth.token.type"; const char* const AUDIT_ENABLED = "audit.enabled"; const char* const AUDIT_OUTPUT_PATH = "audit.output.path"; @@ -118,6 +119,7 @@ const char* const AZURE_STORAGE_BLOB = "AZURE_STORAGE_BLOB"; const char* const NO_MODEL_DATA = "NO_MODEL_DATA"; const char* const HTTP_MODEL_DATA = "HTTP_MODEL_DATA"; const char* const FILE_MODEL_DATA = "FILE_MODEL_DATA"; +const char* const HTTP_MODEL_DATA_OAUTH = "HTTP_MODEL_DATA_OAUTH"; const char* const VW = "VW"; const char* const PASSTHROUGH_PDF_MODEL = "PASSTHROUGH_PDF"; const char* const EPISODE_EH_SENDER = "EPISODE_EH_SENDER"; @@ -129,6 +131,9 @@ const char* const INTERACTION_FILE_SENDER = "INTERACTION_FILE_SENDER"; const char* const EPISODE_HTTP_API_SENDER = "EPISODE_HTTP_API_SENDER"; const char* const OBSERVATION_HTTP_API_SENDER = "OBSERVATION_HTTP_API_SENDER"; const char* const INTERACTION_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER"; +const char* const EPISODE_HTTP_API_SENDER_OAUTH = "EPISODE_HTTP_API_SENDER_OAUTH"; +const char* const OBSERVATION_HTTP_API_SENDER_OAUTH = "OBSERVATION_HTTP_API_SENDER_OAUTH"; +const char* const INTERACTION_HTTP_API_SENDER_OAUTH = "INTERACTION_HTTP_API_SENDER_OAUTH"; const char* const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER"; const char* const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER"; const char* const NULL_TIME_PROVIDER = "NULL_TIME_PROVIDER"; @@ -139,6 +144,7 @@ const char* const LEARNING_MODE_LOGGINGONLY = "LOGGINGONLY"; const char* const CONTENT_ENCODING_IDENTITY = "IDENTITY"; const char* const CONTENT_ENCODING_DEDUP = "DEDUP"; const char* const HTTP_API_DEFAULT_HEADER_KEY_NAME = "Ocp-Apim-Subscription-Key"; +const char* const HTTP_API_DEFAULT_OAUTH_TOKEN_TYPE = "Bearer"; const char* const TRACE_LOG_LEVEL_DEFAULT = "info"; const char* const QUEUE_MODE_DROP = "DROP"; diff --git a/include/factory_resolver.h b/include/factory_resolver.h index bd8650896..a4cae3316 100644 --- a/include/factory_resolver.h +++ b/include/factory_resolver.h @@ -1,5 +1,9 @@ #pragma once +#include "oauth_callback_fn.h" #include "object_factory.h" + +#include +#include namespace reinforcement_learning { namespace utility @@ -70,4 +74,11 @@ struct factory_initializer // Every translation unit gets a factory_initializer // only one translation unit will initialize it static factory_initializer _init; + +// no-op if USE_AZURE_FACTORIES is not defined +/** + * @brief Register default factories with an authentication callback + */ +void register_default_factories_callback(oauth_callback_t& callback); + } // namespace reinforcement_learning diff --git a/include/oauth_callback_fn.h b/include/oauth_callback_fn.h new file mode 100644 index 000000000..80fdaddaa --- /dev/null +++ b/include/oauth_callback_fn.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include +#include +#include + +namespace reinforcement_learning +{ +using oauth_callback_t = + std::function&, std::string&, std::chrono::system_clock::time_point&)>; +} \ No newline at end of file diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index 29f405cdf..c7b3236b5 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -107,6 +107,7 @@ if(vw_USE_AZURE_FACTORIES) list(APPEND PROJECT_SOURCES azure_factories.cc model_mgmt/restapi_data_transport.cc + model_mgmt/restapi_data_transport_oauth.cc utility/eventhub_http_authorization.cc utility/header_authorization.cc utility/http_client.cc @@ -141,6 +142,7 @@ set(PROJECT_PUBLIC_HEADERS ../include/multi_slot_response_detailed.h ../include/multistep.h ../include/object_factory.h + ../include/oauth_callback_fn.h ../include/personalization.h ../include/ranking_response.h ../include/rl_string_view.h @@ -186,10 +188,12 @@ if(vw_USE_AZURE_FACTORIES) azure_factories.h logger/http_transport_client.h model_mgmt/restapi_data_transport.h + model_mgmt/restapi_data_transport_oauth.h utility/eventhub_http_authorization.h utility/header_authorization.h utility/http_client.h utility/http_helper.h + utility/api_header_token.h ) endif() @@ -209,7 +213,6 @@ if(vw_USE_AZURE_FACTORIES) target_compile_definitions(rlclientlib PRIVATE USE_AZURE_FACTORIES) endif() - if(RL_USE_ZSTD) target_compile_definitions(rlclientlib PRIVATE USE_ZSTD) target_link_libraries(rlclientlib PRIVATE libzstd_static) diff --git a/rlclientlib/azure_factories.cc b/rlclientlib/azure_factories.cc index 687fcd53b..86a981b09 100644 --- a/rlclientlib/azure_factories.cc +++ b/rlclientlib/azure_factories.cc @@ -5,10 +5,14 @@ #include "logger/event_logger.h" #include "logger/http_transport_client.h" #include "model_mgmt/restapi_data_transport.h" +#include "model_mgmt/restapi_data_transport_oauth.h" +#include "utility/api_header_token.h" #include "utility/eventhub_http_authorization.h" #include "utility/header_authorization.h" #include "utility/http_helper.h" +#include + namespace reinforcement_learning { namespace m = model_management; @@ -33,18 +37,46 @@ int observation_api_sender_create(std::unique_ptr& retval, const u::co int interaction_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int oauth_restapi_data_transport_create(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& config, i_trace* trace_logger, api_status* status); +int episode_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int observation_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int interaction_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); + void register_azure_factories() { data_transport_factory.register_type(value::AZURE_STORAGE_BLOB, restapi_data_transport_create); - data_transport_factory.register_type(value::HTTP_MODEL_DATA, authenticated_restapi_data_transport_create); sender_factory.register_type(value::OBSERVATION_EH_SENDER, observation_sender_create); sender_factory.register_type(value::INTERACTION_EH_SENDER, interaction_sender_create); sender_factory.register_type(value::EPISODE_EH_SENDER, episode_sender_create); + + // These functions need to have 2 nearly identical versions. One will use the standard + // header_authorization, which uses a hard coded key in the config + // The other will use the new OAUTH callback interface to generate its keys. + // The latter will require manual registration by the user to get the necessary callback + data_transport_factory.register_type(value::HTTP_MODEL_DATA, authenticated_restapi_data_transport_create); sender_factory.register_type(value::OBSERVATION_HTTP_API_SENDER, observation_api_sender_create); sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER, interaction_api_sender_create); sender_factory.register_type(value::EPISODE_HTTP_API_SENDER, episode_api_sender_create); } +void register_azure_oauth_factories(oauth_callback_t& callback) +{ + // TODO: bind functions? + using namespace std::placeholders; + data_transport_factory.register_type( + value::HTTP_MODEL_DATA_OAUTH, std::bind(oauth_restapi_data_transport_create, callback, _1, _2, _3, _4)); + sender_factory.register_type(value::OBSERVATION_HTTP_API_SENDER_OAUTH, + std::bind(observation_api_sender_oauth_create, callback, _1, _2, _3, _4, _5)); + sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER_OAUTH, + std::bind(interaction_api_sender_oauth_create, callback, _1, _2, _3, _4, _5)); + sender_factory.register_type( + value::EPISODE_HTTP_API_SENDER_OAUTH, std::bind(episode_api_sender_oauth_create, callback, _1, _2, _3, _4, _5)); +} + int restapi_data_transport_create(std::unique_ptr& retval, const u::configuration& config, i_trace* trace_logger, api_status* status) { @@ -101,6 +133,19 @@ int create_apim_http_api_sender(std::unique_ptr& retval, const u::conf return error_code::success; } +int create_apim_http_api_oauth_sender(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& cfg, const char* api_host, int tasks_limit, int max_http_retries, + std::chrono::milliseconds max_http_retry_duration, error_callback_fn* error_cb, i_trace* trace_logger, + api_status* status) +{ + i_http_client* client = nullptr; + RETURN_IF_FAIL(create_http_client(api_host, cfg, &client, status)); + retval.reset( + new http_transport_client>(client, tasks_limit, max_http_retries, + max_http_retry_duration, trace_logger, error_cb, callback, "https://eventhubs.azure.net//.default")); + return error_code::success; +} + // Creates i_sender object for sending episode data to the apim endpoint. int episode_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) @@ -165,4 +210,47 @@ int interaction_sender_create(std::unique_ptr& retval, const u::config error_cb)); return error_code::success; } + +int oauth_restapi_data_transport_create(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& config, i_trace* trace_logger, api_status* status) +{ + const auto* model_uri = config.get(name::MODEL_BLOB_URI, nullptr); + if (model_uri == nullptr) { RETURN_ERROR(trace_logger, status, http_model_uri_not_provided); } + i_http_client* client = nullptr; + RETURN_IF_FAIL(create_http_client(model_uri, config, &client, status)); + retval.reset(new m::restapi_data_transport_oauth(std::unique_ptr(client), config, + m::model_source::HTTP_API, trace_logger, callback, "https://storage.azure.com//.default")); + return error_code::success; +} + +int episode_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +{ + const auto* const api_host = cfg.get(name::EPISODE_HTTP_API_HOST, "localhost:8080"); + return create_apim_http_api_oauth_sender(callback, retval, cfg, api_host, + cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRIES, 4), cfg.get_int(name::EPISODE_APIM_TASKS_LIMIT, 4), + std::chrono::milliseconds(cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, + trace_logger, status); +} + +int observation_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +{ + const auto* const api_host = cfg.get(name::OBSERVATION_HTTP_API_HOST, "localhost:8080"); + return create_apim_http_api_oauth_sender(callback, retval, cfg, api_host, + cfg.get_int(name::OBSERVATION_APIM_TASKS_LIMIT, 16), cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRIES, 4), + std::chrono::milliseconds(cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, + trace_logger, status); +} + +int interaction_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +{ + const auto* const api_host = cfg.get(name::INTERACTION_HTTP_API_HOST, "localhost:8080"); + return create_apim_http_api_oauth_sender(callback, retval, cfg, api_host, + cfg.get_int(name::INTERACTION_APIM_TASKS_LIMIT, 16), cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRIES, 4), + std::chrono::milliseconds(cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, + trace_logger, status); +} + } // namespace reinforcement_learning diff --git a/rlclientlib/azure_factories.h b/rlclientlib/azure_factories.h index 1d500b07d..b8d61f5a8 100644 --- a/rlclientlib/azure_factories.h +++ b/rlclientlib/azure_factories.h @@ -1,6 +1,10 @@ #pragma once +#include "oauth_callback_fn.h" + namespace reinforcement_learning { void register_azure_factories(); + +void register_azure_oauth_factories(oauth_callback_t& callback); } diff --git a/rlclientlib/factory_resolver.cc b/rlclientlib/factory_resolver.cc index cb6880a24..e66d45013 100644 --- a/rlclientlib/factory_resolver.cc +++ b/rlclientlib/factory_resolver.cc @@ -72,6 +72,13 @@ factory_initializer::~factory_initializer() } } +void register_default_factories_callback(oauth_callback_t& callback) +{ +#ifdef USE_AZURE_FACTORIES + register_azure_oauth_factories(callback); +#endif +} + template int model_create( std::unique_ptr& retval, const u::configuration& c, i_trace* trace_logger, api_status* status) diff --git a/rlclientlib/logger/http_transport_client.h b/rlclientlib/logger/http_transport_client.h index d0baab533..58fa13a34 100644 --- a/rlclientlib/logger/http_transport_client.h +++ b/rlclientlib/logger/http_transport_client.h @@ -40,8 +40,9 @@ class http_transport_client : public i_sender virtual int init(const utility::configuration& config, api_status* status) override; // Takes the ownership of the i_http_client and delete it at the end of lifetime + template http_transport_client(i_http_client* client, size_t tasks_count, size_t MAX_RETRIES, - std::chrono::milliseconds max_retry_duration, i_trace* trace, error_callback_fn* _error_cb); + std::chrono::milliseconds max_retry_duration, i_trace* trace, error_callback_fn* _error_cb, Args&&... args); ~http_transport_client(); protected: @@ -162,6 +163,7 @@ pplx::task http_transport_client::http_request_ta utility::stl_container_adapter container(_post_data.get()); const size_t container_size = container.size(); + const auto stream = concurrency::streams::bytestream::open_istream(container); request.set_body(stream, container_size); @@ -278,14 +280,17 @@ int http_transport_client::v_send(const buffer& post_data, api_s } template +template http_transport_client::http_transport_client(i_http_client* client, size_t max_tasks_count, - size_t max_retries, std::chrono::milliseconds max_retry_duration, i_trace* trace, error_callback_fn* error_callback) + size_t max_retries, std::chrono::milliseconds max_retry_duration, i_trace* trace, error_callback_fn* error_callback, + Args&&... args) : _client(client) , _max_tasks_count(max_tasks_count) , _max_retry_count(max_retries) , _max_retry_duration(max_retry_duration) , _trace(trace) , _error_callback(error_callback) + , _authorization(std::forward(args)...) { } diff --git a/rlclientlib/model_mgmt/restapi_data_transport_oauth.cc b/rlclientlib/model_mgmt/restapi_data_transport_oauth.cc new file mode 100644 index 000000000..35ffb74d7 --- /dev/null +++ b/rlclientlib/model_mgmt/restapi_data_transport_oauth.cc @@ -0,0 +1,204 @@ +#include "restapi_data_transport_oauth.h" + +#include "api_status.h" +#include "factory_resolver.h" +#include "trace_logger.h" +#include "utility/api_header_token.h" + +#include +#include +#include + +#include + +using namespace web; // Common features like URIs. +using namespace web::http; // Common HTTP functionality +using namespace std::chrono; + +namespace u = reinforcement_learning::utility; +namespace e = reinforcement_learning::error_code; + +namespace reinforcement_learning +{ +namespace model_management +{ +restapi_data_transport_oauth::restapi_data_transport_oauth( + i_http_client* httpcli, i_trace* trace, oauth_callback_t& callback, std::string scope) + : _httpcli(httpcli), _datasz{0}, _trace{trace}, _headerimpl(callback, std::move(scope)) +{ +} +restapi_data_transport_oauth::restapi_data_transport_oauth(std::unique_ptr&& httpcli, + utility::configuration cfg, model_source model_source, i_trace* trace, oauth_callback_t& callback, + std::string scope) + : _httpcli(std::move(httpcli)) + , _cfg(std::move(cfg)) + , _model_source(model_source) + , _datasz{0} + , _trace{trace} + , _headerimpl(callback, std::move(scope)) +{ +} + +/* + * Example successful response + * + * Received response status code:200 + * Accept-Ranges = bytes + * Content-Length = 7666 + * Content-MD5 = VuJg8VgcBQwevGhJR2Yehw== + * Content-Type = application/octet-stream + * Date = Mon, 28 May 2018 14:41:02 GMT + * ETag = "0x8D5C03A2AEC2189" + * Last-Modified = Tue, 22 May 2018 23:17:20 GMT + * Server = Windows-Azure-Blob/1.0 Microsoft-HTTPAPI/2.0 + * x-ms-blob-type = BlockBlob + * x-ms-lease-state = available + * x-ms-lease-status = unlocked + * x-ms-request-id = 241f3513-801e-0041-0991-f6893e000000 + * x-ms-server-encrypted = true + * x-ms-version = 2017-04-17 + */ + +int restapi_data_transport_oauth::get_data_info( + ::utility::datetime& last_modified, ::utility::size64_t& sz, api_status* status) +{ + // Get request URI and start the request. + http_request request(_method_type); + RETURN_IF_FAIL(add_authentication_header(request.headers(), status)); + // Build request URI and start the request. + auto request_task = _httpcli->request(request).then( + [&](http_response response) + { + if (response.status_code() != 200) + { + // if the call using HEAD fails, try with GET only once and return the results of GET request call + if (_retry_get_data) + { + _retry_get_data = false; + _method_type = methods::GET; + RETURN_IF_FAIL(get_data_info(last_modified, sz, status)); + return error_code::success; + } + + RETURN_ERROR_ARG( + _trace, status, http_bad_status_code, "Found: ", response.status_code(), _httpcli->get_url()); + } + const auto iter = response.headers().find(U("Last-Modified")); + if (iter == response.headers().end()) + { + RETURN_ERROR_ARG(_trace, status, last_modified_not_found, _httpcli->get_url()); + } + + last_modified = ::utility::datetime::from_string(iter->second); + if (last_modified.to_interval() == 0) + { + RETURN_ERROR_ARG(_trace, status, last_modified_invalid, _httpcli->get_url()); + } + + sz = response.headers().content_length(); + + return error_code::success; + }); + + // Wait for all the outstanding I/O to complete and handle any exceptions + try + { + return request_task.get(); + } + catch (const std::exception& e) + { + RETURN_ERROR_LS(_trace, status, exception_during_http_req) << e.what() << "\n URL: " << _httpcli->get_url(); + } +} + +int restapi_data_transport_oauth::add_authentication_header(http_headers& header, api_status* status) +{ + if (_model_source != model_source::AZURE) + { + RETURN_IF_FAIL(_headerimpl.init(_cfg, status, _trace)); + RETURN_IF_FAIL(_headerimpl.insert_authorization_header(header, status, _trace)); + } + return error_code::success; +} + +int restapi_data_transport_oauth::get_data(model_data& ret, api_status* status) +{ + ::utility::datetime curr_last_modified; + ::utility::size64_t curr_datasz = 0; + _method_type = methods::HEAD; + _retry_get_data = true; + RETURN_IF_FAIL(get_data_info(curr_last_modified, curr_datasz, status)); + + if (curr_last_modified == _last_modified && curr_datasz == _datasz) { return error_code::success; } + _method_type = methods::GET; + http_request request(_method_type); + RETURN_IF_FAIL(add_authentication_header(request.headers(), status)); + // Build request URI and start the request. + auto request_task = + _httpcli + ->request(request) + // Handle response headers arriving. + .then( + [&](const pplx::task& resp_task) + { + auto response = resp_task.get(); + if (response.status_code() != 200) + { + RETURN_ERROR_ARG( + _trace, status, http_bad_status_code, "Found: ", response.status_code(), _httpcli->get_url()); + } + + const auto iter = response.headers().find(U("Last-Modified")); + if (iter == response.headers().end()) + { + RETURN_ERROR_ARG(_trace, status, last_modified_not_found, _httpcli->get_url()); + } + + curr_last_modified = ::utility::datetime::from_string(iter->second); + if (curr_last_modified.to_interval() == 0) + { + RETURN_ERROR_ARG(_trace, status, last_modified_invalid, + "Found: ", ::utility::conversions::to_utf8string(curr_last_modified.to_string()), + _httpcli->get_url()); + } + + curr_datasz = response.headers().content_length(); + if (curr_datasz > 0) + { + auto* const buff = ret.alloc(curr_datasz); + const Concurrency::streams::rawptr_buffer rb(buff, curr_datasz, std::ios::out); + + // Write response body into the file. + const auto readval = + response.body().read_to_end(rb).get(); // need to use task.get to throw exceptions properly + + ret.data_sz(readval); + ret.increment_refresh_count(); + _datasz = readval; + } + else { ret.data_sz(0); } + + _last_modified = curr_last_modified; + return error_code::success; + }); + + // Wait for all the outstanding I/O to complete and handle any exceptions + try + { + request_task.wait(); + } + catch (const std::exception& e) + { + ret.free(); + RETURN_ERROR_LS(_trace, status, exception_during_http_req) << e.what(); + } + catch (...) + { + ret.free(); + RETURN_ERROR_LS(_trace, status, exception_during_http_req) << error_code::unknown_s; + } + + return request_task.get(); +} +} // namespace model_management +} // namespace reinforcement_learning diff --git a/rlclientlib/model_mgmt/restapi_data_transport_oauth.h b/rlclientlib/model_mgmt/restapi_data_transport_oauth.h new file mode 100644 index 000000000..02c05caf3 --- /dev/null +++ b/rlclientlib/model_mgmt/restapi_data_transport_oauth.h @@ -0,0 +1,45 @@ +#pragma once +#include "model_mgmt.h" +#include "oauth_callback_fn.h" +#include "utility/api_header_token.h" +#include "utility/http_client.h" + +#include + +#include +#include + +// TODO: This is basically just a copy/paste of restapi_data_transport +// We could templatize that object similar to how http_transport_client works, +// but there's a lot of code that would need to be shifted to the header +namespace reinforcement_learning +{ +class i_trace; +namespace model_management +{ +class restapi_data_transport_oauth : public i_data_transport +{ +public: + // Takes the ownership of the i_http_client and delete it at the end of lifetime + restapi_data_transport_oauth(i_http_client* httpcli, i_trace* trace, oauth_callback_t& callback, std::string scope); + restapi_data_transport_oauth(std::unique_ptr&& httpcli, utility::configuration cfg, + model_source model_source, i_trace* trace, oauth_callback_t& callback, std::string scope); + + int get_data(model_data& ret, api_status* status) override; + +private: + using time_t = std::chrono::time_point; + int get_data_info(::utility::datetime& last_modified, ::utility::size64_t& sz, api_status* status); + int add_authentication_header(http_headers& header, api_status* status); + std::unique_ptr _httpcli; + ::utility::datetime _last_modified; + uint64_t _datasz; + i_trace* _trace; + const utility::configuration _cfg; + model_source _model_source = model_source::AZURE; + method _method_type = methods::HEAD; + bool _retry_get_data = true; + api_header_token_callback _headerimpl; +}; +} // namespace model_management +} // namespace reinforcement_learning diff --git a/rlclientlib/utility/api_header_token.h b/rlclientlib/utility/api_header_token.h new file mode 100644 index 000000000..82ff74491 --- /dev/null +++ b/rlclientlib/utility/api_header_token.h @@ -0,0 +1,134 @@ +#pragma once + +#include "api_status.h" +#include "configuration.h" +#include "constants.h" +#include "oauth_callback_fn.h" +#include "trace_logger.h" + +#include + +#include +#include +#include +#include + +using namespace web::http; + +namespace reinforcement_learning +{ +class eventhub_headers +{ +public: + void insert_additional_headers(http_headers& headers) + { + http_headers::key_type content_type; +#ifdef _WIN32 + content_type = ::utility::conversions::utf8_to_utf16("Content-Type"); +#else + content_type = "Content-Type"; +#endif + + headers.add(content_type, "application/atom+xml;type=entry;charset=utf-8"); + } +}; +class blob_storage_headers +{ +public: + void insert_additional_headers(http_headers& headers) + { + http_headers::key_type version; +#ifdef _WIN32 + version = ::utility::conversions::utf8_to_utf16("x-ms-version"); +#else + version = "x-ms-version"; +#endif + + // For version, see https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-azure-active-directory + headers.add(version, "2017-11-09"); + } +}; + +template +class api_header_token_callback +{ +public: + api_header_token_callback(oauth_callback_t& token_cb, std::string scope) + : _token_callback(token_cb), _scopes{std::move(scope)} + { + } + ~api_header_token_callback() = default; + + int init(const utility::configuration& config, api_status* status, i_trace* trace) + { + // The transport client calls init and insert_header on every message + // Just bail out if we've already been initialized + if (_initialized) { return error_code::success; } + +#ifdef _WIN32 + _http_api_header_key_name = ::utility::conversions::utf8_to_utf16( + config.get(name::HTTP_API_HEADER_KEY_NAME, value::HTTP_API_DEFAULT_HEADER_KEY_NAME)); +#else + _http_api_header_key_name = config.get(name::HTTP_API_HEADER_KEY_NAME, value::HTTP_API_DEFAULT_HEADER_KEY_NAME); +#endif + _token_type = config.get(name::HTTP_API_OAUTH_TOKEN_TYPE, value::HTTP_API_DEFAULT_OAUTH_TOKEN_TYPE); + RETURN_IF_FAIL(refresh_auth_token(status, trace)); + _initialized = true; + return error_code::success; + } + + int insert_authorization_header(http_headers& headers, api_status* status, i_trace* trace) + { + if (!_initialized) + { + int result = error_code::not_initialized; + api_status::try_update(status, result, error_code::not_initialized_s); + return result; + } + using namespace std::chrono; + system_clock::time_point now = system_clock::now(); + // TODO: make this configurable? + system_clock::time_point refresh_time = _token_expiry - std::chrono::seconds(10); + if (now >= refresh_time) { RETURN_IF_FAIL(refresh_auth_token(status, trace)); } + std::string header_value = _token_type + " " + _bearer_token; + headers.add(_http_api_header_key_name, header_value.c_str()); + _additional_headers.insert_additional_headers(headers); + + return error_code::success; + } + + api_header_token_callback(const api_header_token_callback&) = delete; + api_header_token_callback(api_header_token_callback&&) = delete; + api_header_token_callback& operator=(const api_header_token_callback&) = delete; + api_header_token_callback& operator=(api_header_token_callback&&) = delete; + +private: + int refresh_auth_token(api_status* status, i_trace* trace) + { + using namespace std::chrono; + system_clock::time_point tp; + RETURN_IF_FAIL(_token_callback(_scopes, _bearer_token, _token_expiry)); + + if (_bearer_token.empty()) + { + int result = error_code::external_error; + api_status::try_update(status, result, error_code::external_error_s); + return result; + } + + if (_bearer_token.empty()) { RETURN_ERROR(trace, status, http_api_key_not_provided); } + return error_code::success; + } + +private: + http_headers::key_type _http_api_header_key_name; + std::string _token_type; + oauth_callback_t _token_callback; + std::vector _scopes; + + std::string _bearer_token; + std::chrono::system_clock::time_point _token_expiry; + bool _initialized = false; + Resource _additional_headers; +}; +} // namespace reinforcement_learning diff --git a/vcpkg.json b/vcpkg.json index 51b42b277..126428b73 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -1,7 +1,8 @@ { "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg/master/scripts/vcpkg.schema.json", "name": "reinforcement-learning", - "version": "1.0.0", + "version": "1.0.2", + "builtin-baseline": "f30434939d5516ce764c549ab04e3d23d312180a", "dependencies": [ "boost-align", "boost-asio", @@ -23,10 +24,21 @@ "spdlog", "zlib" ], + "overrides": [ + {"name": "cpprestsdk", "version": "2.10.18"}, + {"name": "flatbuffers", "version": "23.1.21"}, + {"name": "fmt", "version": "9.1.0"}, + {"name": "spdlog", "version": "1.11.0"}, + {"name": "zlib", "version": "1.2.13"} + ], "features": { "benchmarks": { "description": "Build Benchmarks", "dependencies": [{"name":"benchmark", "version>=":"1.7.1"}] + }, + "azurelibs": { + "description": "Build Azure-specific code", + "dependencies": [{"name":"azure-identity-cpp"}] } } }