Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/azure oauth #604

Merged
merged 26 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a04d24f
Initial commit for Azure Oauth work
peterychang Jan 29, 2024
a93981f
Fixing blob storage requests
peterychang Jan 30, 2024
ebba307
Merge branch 'master' into dev/azure_oauth
peterychang Feb 2, 2024
0a9f167
remove azure dependencies from mainline cmake files
peterychang Feb 2, 2024
e5ccf4f
Add AzureVcpkg to modules
peterychang Feb 2, 2024
b713eda
Fix keytype conversions
peterychang Feb 2, 2024
729160b
fix bad string conversion
peterychang Feb 2, 2024
9bc06eb
Update callback signature. Conditionally compile azure code
peterychang Feb 2, 2024
9d3ddcd
Add cmake option for azure libs, fix compile issues
peterychang Feb 2, 2024
7d19e78
separate azure dependencies from default
peterychang Feb 2, 2024
246697e
fix conditional compile for azure libs
peterychang Feb 5, 2024
1a9212b
rl_sim tenant id as parameter
peterychang Feb 6, 2024
980f523
testing workflow fixes
peterychang Feb 6, 2024
bbf8cc6
remove rapidjson required version for now
peterychang Feb 6, 2024
9235700
update workflows to fetch full vcpkg submodule
peterychang Feb 6, 2024
b59c0fb
run lint
peterychang Feb 6, 2024
5b5de4d
fixing more workflows
peterychang Feb 6, 2024
7c59f60
fixing clang tidy issues
peterychang Feb 6, 2024
81ff98c
more clang tidy fixes
peterychang Feb 6, 2024
f19c3e8
lint
peterychang Feb 6, 2024
b71e34d
fixing compile issues
peterychang Feb 6, 2024
b45d6a9
clang tidy
peterychang Feb 6, 2024
c70dac0
fix github workflow
peterychang Feb 6, 2024
cb7fb61
Remove unnecessary cmake files/commands
peterychang Feb 7, 2024
ec2f546
review comments
peterychang Feb 8, 2024
44260dc
Merge branch 'master' into dev/azure_oauth
peterychang Feb 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/asan.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
- uses: actions/checkout@v2
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- name: Setup MSVC Developer Command Prompt
if: ${{ startsWith(matrix.os, 'windows') }}
uses: ilammy/msvc-dev-cmd@v1
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/build_rlclientlib.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- uses: lukka/get-cmake@latest
- name: Install ONNX
run: |
Expand Down Expand Up @@ -83,6 +84,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- uses: lukka/get-cmake@latest
- run: echo "VCPKG_COMMIT=$(git rev-parse :ext_libs/vcpkg)" >> $GITHUB_ENV
shell: bash
Expand Down Expand Up @@ -124,6 +126,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- name: Setup MSVC Developer Command Prompt
uses: ilammy/msvc-dev-cmd@v1
- uses: lukka/get-cmake@latest
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build_vw_bp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- name: Setup MSVC Developer Command Prompt
if: ${{ startsWith(matrix.config.os, 'windows') }}
uses: ilammy/msvc-dev-cmd@v1
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- uses: lukka/get-cmake@latest
- run: echo "VCPKG_COMMIT=$(git rev-parse :ext_libs/vcpkg)" >> $GITHUB_ENV
shell: bash
Expand Down Expand Up @@ -101,6 +102,7 @@ jobs:
uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- name: Setup MSVC Developer Command Prompt
uses: ilammy/msvc-dev-cmd@v1
- uses: lukka/get-cmake@latest
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/daily_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- name: Update VW to latest
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/run_benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- uses: lukka/get-cmake@latest
- run: echo "VCPKG_COMMIT=$(git rev-parse :ext_libs/vcpkg)" >> $GITHUB_ENV
shell: bash
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/vcpkg_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: recursive
- run: git -C ${{ github.workspace }}/ext_libs/vcpkg fetch --unshallow
- uses: lukka/get-cmake@latest
- run: echo "VCPKG_COMMIT=$(git rev-parse :ext_libs/vcpkg)" >> $GITHUB_ENV
shell: bash
Expand Down
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ if(RL_BUILD_BENCHMARKS)
list(APPEND VCPKG_MANIFEST_FEATURES "benchmarks")
endif()

option(RL_LINK_AZURE_LIBS "Whether to build components requiring the use of Azure libraries. Requires C++14 or greater" OFF)
peterychang marked this conversation as resolved.
Show resolved Hide resolved
if(RL_LINK_AZURE_LIBS)
list(APPEND VCPKG_MANIFEST_FEATURES "azurelibs")
endif()
project(reinforcement_learning)

# Add support for building library with latest version of C++ supported by your compiler
Expand Down
19 changes: 18 additions & 1 deletion examples/rl_sim_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
add_executable(rl_sim_cpp.out
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/../../cmake/Modules/")

set(RL_SIM_SOURCES
main.cc
person.cc
robot_joint.cc
rl_sim.cc
)
if(RL_LINK_AZURE_LIBS)
list(APPEND RL_SIM_SOURCES
azure_credentials.cc
)
endif()

add_executable(rl_sim_cpp.out
${RL_SIM_SOURCES}
)

target_link_libraries(rl_sim_cpp.out PRIVATE Boost::program_options rlclientlib)

if(RL_LINK_AZURE_LIBS)
target_compile_definitions(rl_sim_cpp.out PRIVATE LINK_AZURE_LIBS)
find_package(azure-identity-cpp CONFIG REQUIRED)
target_link_libraries(rl_sim_cpp.out PRIVATE Azure::azure-identity)
endif()
67 changes: 67 additions & 0 deletions examples/rl_sim_cpp/azure_credentials.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#ifdef LINK_AZURE_LIBS
# include "azure_credentials.h"

# include "err_constants.h"
# include "future_compat.h"

# include <azure/core/datetime.hpp>
# include <chrono>
// These are needed because azure does a bad time conversion
# include <exception>
# include <iomanip>
# include <iostream>
# include <sstream>

using namespace reinforcement_learning;

AzureCredentials::AzureCredentials(const std::string& tenant_id) : _tenant_id(tenant_id), _creds(create_options()) {}

Azure::Identity::AzureCliCredentialOptions AzureCredentials::create_options()
{
Azure::Identity::AzureCliCredentialOptions options;
options.TenantId = _tenant_id;
options.AdditionallyAllowedTenants.push_back("*");
return options;
}

int AzureCredentials::get_credentials(
const std::vector<std::string>& scopes, std::string& token_out, std::chrono::system_clock::time_point& expiry_out)
{
# ifdef HAS_STD14
peterychang marked this conversation as resolved.
Show resolved Hide resolved
Azure::Core::Credentials::TokenRequestContext request_context;
request_context.Scopes = scopes;
// TODO: needed?
request_context.TenantId = _tenant_id;
Azure::Core::Context context;
try
{
auto auth = _creds.GetToken(request_context, context);
token_out = auth.Token;

// Casting from an azure DateTime object to a time_point does the calculation
// incorrectly. The expiration is returned as a local time, but the library
// assumes that it is GMT, and converts the value incorrectly.
// See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075
// expiry_out = static_cast<std::chrono::system_clock::time_point>(auth.ExpiresOn);
std::string dt_string = auth.ExpiresOn.ToString();
std::tm tm = {};
std::istringstream ss(dt_string);
ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ");
expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm));
}
catch (std::exception& e)
{
std::cout << "Error getting auth token: " << e.what();
return error_code::external_error;
}
catch (...)
{
std::cout << "Unknown error while getting auth token";
return error_code::external_error;
}
# else
# error Requires C++14 or greater
# endif
return error_code::success;
}
#endif
30 changes: 30 additions & 0 deletions examples/rl_sim_cpp/azure_credentials.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#ifdef LINK_AZURE_LIBS
# include "api_status.h"
# include "configuration.h"
# include "future_compat.h"

# include <azure/identity/azure_cli_credential.hpp>
# include <azure/identity/default_azure_credential.hpp>
# include <chrono>
# include <memory>
# include <string>

class AzureCredentials
peterychang marked this conversation as resolved.
Show resolved Hide resolved
{
public:
AzureCredentials(const std::string& tenant_id);
int get_credentials(const std::vector<std::string>& scopes, std::string& token_out,
std::chrono::system_clock::time_point& expiry_out);

private:
std::string _tenant_id;
# ifdef HAS_STD14
Azure::Identity::AzureCliCredentialOptions create_options();

// Azure::Identity::DefaultAzureCredential _creds;
Azure::Identity::AzureCliCredential _creds;
# endif
};
#endif
4 changes: 3 additions & 1 deletion examples/rl_sim_cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ po::variables_map process_cmd_line(const int argc, char** argv)
"random_seed", po::value<uint64_t>()->default_value(rand()), "Random seed. Default is random")(
"delay", po::value<int64_t>()->default_value(2000), "Delay between events in ms")(
"quiet", po::bool_switch(), "Suppress logs")("random_ids", po::value<bool>()->default_value(true),
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats");
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats")(
"azure_oauth_factories", po::value<bool>()->default_value(false), "Use oauth for azure factores. Default false")(
"azure_tenant_id", po::value<std::string>()->default_value(""), "Tenant ID for use with azure oauth factories.");

po::variables_map vm;
store(parse_command_line(argc, argv, desc), vm);
Expand Down
22 changes: 20 additions & 2 deletions examples/rl_sim_cpp/rl_sim.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "api_status.h"
#include "constants.h"
#include "factory_resolver.h"
#include "future_compat.h"
#include "live_model.h"
#include "multistep.h"
#include "person.h"
Expand All @@ -13,6 +14,7 @@
#include <boost/uuid/uuid_io.hpp>
#include <chrono>
#include <cmath>
#include <functional>
#include <thread>

using namespace std;
Expand Down Expand Up @@ -487,6 +489,17 @@ int rl_sim::init_rl()
wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::EPISODE_HTTP_API_SENDER));
sender_factory = &factory;
}
// probably incompatible with the throughput option?
else if (_options["azure_oauth_factories"].as<bool>())
{
#ifdef LINK_AZURE_LIBS
// Note: This requires C++14 or better
using namespace std::placeholders;
reinforcement_learning::oauth_callback_t callback =
std::bind(&AzureCredentials::get_credentials, &_creds, _1, _2, _3);
reinforcement_learning::register_default_factories_callback(callback);
#endif
}

// Initialize the API
_rl = std::unique_ptr<r::live_model>(new r::live_model(config, _on_error, this,
Expand Down Expand Up @@ -637,7 +650,12 @@ std::string rl_sim::create_event_id()
return oss.str();
}

rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm)), _loop_kind(CB)
rl_sim::rl_sim(boost::program_options::variables_map vm)
: _options(std::move(vm))
, _loop_kind(CB)
#ifdef LINK_AZURE_LIBS
, _creds(_options["azure_tenant_id"].as<std::string>())
#endif
{
if (_options["ccb"].as<bool>()) { _loop_kind = CCB; }
else if (_options["slates"].as<bool>()) { _loop_kind = Slates; }
Expand Down Expand Up @@ -699,4 +717,4 @@ std::string get_dist_str(const reinforcement_learning::decision_response& respon
}
ret += ")";
return ret;
}
}
4 changes: 4 additions & 0 deletions examples/rl_sim_cpp/rl_sim.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* @date 2018-07-18
*/
#pragma once
#include "azure_credentials.h"
#include "live_model.h"
#include "person.h"
#include "robot_joint.h"
Expand Down Expand Up @@ -177,4 +178,7 @@ class rl_sim
int64_t _delay = 2000;
bool _quiet = false;
bool _random_ids = true;
#ifdef LINK_AZURE_LIBS
AzureCredentials _creds;
#endif
};
6 changes: 6 additions & 0 deletions include/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const char* const LEARNING_MODE = "rank.learning.mode";
const char* const PROTOCOL_VERSION = "protocol.version";
const char* const HTTP_API_KEY = "http.api.key";
const char* const HTTP_API_HEADER_KEY_NAME = "http.api.header.key.name";
const char* const HTTP_API_OAUTH_TOKEN_TYPE = "http.api.oauth.token.type";
const char* const AUDIT_ENABLED = "audit.enabled";
const char* const AUDIT_OUTPUT_PATH = "audit.output.path";

Expand Down Expand Up @@ -118,6 +119,7 @@ const char* const AZURE_STORAGE_BLOB = "AZURE_STORAGE_BLOB";
const char* const NO_MODEL_DATA = "NO_MODEL_DATA";
const char* const HTTP_MODEL_DATA = "HTTP_MODEL_DATA";
const char* const FILE_MODEL_DATA = "FILE_MODEL_DATA";
const char* const HTTP_MODEL_DATA_OAUTH = "HTTP_MODEL_DATA_OAUTH";
const char* const VW = "VW";
const char* const PASSTHROUGH_PDF_MODEL = "PASSTHROUGH_PDF";
const char* const EPISODE_EH_SENDER = "EPISODE_EH_SENDER";
Expand All @@ -129,6 +131,9 @@ const char* const INTERACTION_FILE_SENDER = "INTERACTION_FILE_SENDER";
const char* const EPISODE_HTTP_API_SENDER = "EPISODE_HTTP_API_SENDER";
const char* const OBSERVATION_HTTP_API_SENDER = "OBSERVATION_HTTP_API_SENDER";
const char* const INTERACTION_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER";
const char* const EPISODE_HTTP_API_SENDER_OAUTH = "EPISODE_HTTP_API_SENDER_OAUTH";
const char* const OBSERVATION_HTTP_API_SENDER_OAUTH = "OBSERVATION_HTTP_API_SENDER_OAUTH";
const char* const INTERACTION_HTTP_API_SENDER_OAUTH = "INTERACTION_HTTP_API_SENDER_OAUTH";
const char* const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER";
const char* const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER";
const char* const NULL_TIME_PROVIDER = "NULL_TIME_PROVIDER";
Expand All @@ -139,6 +144,7 @@ const char* const LEARNING_MODE_LOGGINGONLY = "LOGGINGONLY";
const char* const CONTENT_ENCODING_IDENTITY = "IDENTITY";
const char* const CONTENT_ENCODING_DEDUP = "DEDUP";
const char* const HTTP_API_DEFAULT_HEADER_KEY_NAME = "Ocp-Apim-Subscription-Key";
const char* const HTTP_API_DEFAULT_OAUTH_TOKEN_TYPE = "Bearer";
const char* const TRACE_LOG_LEVEL_DEFAULT = "info";

const char* const QUEUE_MODE_DROP = "DROP";
Expand Down
11 changes: 11 additions & 0 deletions include/factory_resolver.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#pragma once
#include "oauth_callback_fn.h"
#include "object_factory.h"

#include <chrono>
#include <vector>
namespace reinforcement_learning
{
namespace utility
Expand Down Expand Up @@ -70,4 +74,11 @@ struct factory_initializer
// Every translation unit gets a factory_initializer
// only one translation unit will initialize it
static factory_initializer _init;

// no-op if USE_AZURE_FACTORIES is not defined
/**
* @brief Register default factories with an authentication callback
*/
void register_default_factories_callback(oauth_callback_t& callback);

} // namespace reinforcement_learning
12 changes: 12 additions & 0 deletions include/oauth_callback_fn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <chrono>
#include <functional>
#include <string>
#include <vector>

namespace reinforcement_learning
{
using oauth_callback_t =
std::function<int(const std::vector<std::string>&, std::string&, std::chrono::system_clock::time_point&)>;
}
Loading
Loading