From 66e045a9c4622f2817e5a31f7a6a4d8d0186c183 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 8 Feb 2024 16:31:16 -0500 Subject: [PATCH] feat: static model loader (#603) * static model loader * implement model loading * add test * use pinned memory c# * standar memcpy * lint * fix includes * clang * memcpy in constuctor * lint * address comments * lint * comments * lint * CI * copy model outside of lambda * standard memcpy --- bindings/cs/rl.net.cli/BasicUsageCommand.cs | 36 ++++++++++++- bindings/cs/rl.net.cli/CommandBase.cs | 3 ++ bindings/cs/rl.net.cli/Helpers.cs | 50 ++++++++++++++++++ bindings/cs/rl.net.native/CMakeLists.txt | 2 + .../cs/rl.net.native/binding_static_model.cc | 11 ++++ .../cs/rl.net.native/binding_static_model.h | 27 ++++++++++ .../rl.net.native/rl.net.factory_context.cc | 35 ++++++++++++ .../cs/rl.net.native/rl.net.factory_context.h | 6 ++- bindings/cs/rl.net/FactoryContext.cs | 22 ++++++++ bindings/cs/rl.net/LiveModel.cs | 1 + examples/test_cpp/model.vw | Bin 0 -> 445 bytes include/errors_data.h | 1 + include/model_mgmt.h | 5 ++ rlclientlib/model_mgmt/model_mgmt.cc | 13 +++++ 14 files changed, 210 insertions(+), 2 deletions(-) create mode 100644 bindings/cs/rl.net.native/binding_static_model.cc create mode 100644 bindings/cs/rl.net.native/binding_static_model.h create mode 100644 examples/test_cpp/model.vw diff --git a/bindings/cs/rl.net.cli/BasicUsageCommand.cs b/bindings/cs/rl.net.cli/BasicUsageCommand.cs index 42bb8882c..8c3ff38d2 100644 --- a/bindings/cs/rl.net.cli/BasicUsageCommand.cs +++ b/bindings/cs/rl.net.cli/BasicUsageCommand.cs @@ -6,7 +6,7 @@ namespace Rl.Net.Cli [Verb("basicUsage", HelpText = "Basic usage of the API")] class BasicUsageCommand : CommandBase { - [Option(longName: "testType", HelpText = "select from (liveModel, caLoop, cbLoop, ccbLoop, pdfExample, slatesLoop) basic usage examples", Required = false, Default = "liveModel")] + [Option(longName: "testType", HelpText = "select from (liveModel, liveModelStatic, caLoop, cbLoop, ccbLoop, pdfExample, slatesLoop) basic usage examples", Required = false, Default = "liveModel")] public string testType { get; set; } public override void Run() @@ -16,6 +16,9 @@ public override void Run() case "liveModel": BasicUsage(this.ConfigPath); break; + case "liveModelStatic": + BasicUsageStaticModel(this.ConfigPath, this.ModelPath); + break; case "caLoop": BasicUsageCALoop(this.ConfigPath); break; @@ -68,6 +71,37 @@ public static void BasicUsage(string configPath) Console.WriteLine("Basic usage live model success"); } + public static void BasicUsageStaticModel(string configPath, string modelPath) + { + const float outcome = 1.0f; + const string eventId = "event_id"; + const string contextJson = "{\"GUser\":{\"id\":\"a\",\"major\":\"eng\",\"hobby\":\"hiking\"},\"_multi\":[ { \"TAction\":{\"a1\":\"f1\"} },{\"TAction\":{\"a2\":\"f2\"}}]}"; + + LiveModel liveModel = Helpers.CreateLiveModelWithStaticModelOrExit(configPath, modelPath); + + ApiStatus apiStatus = new ApiStatus(); + + RankingResponse rankingResponse = new RankingResponse(); + if (!liveModel.TryChooseRank(eventId, contextJson, rankingResponse, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + + long actionId; + if (!rankingResponse.TryGetChosenAction(out actionId, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + + Console.WriteLine($"Chosen action id: {actionId}"); + + if (!liveModel.TryQueueOutcomeEvent(eventId, outcome, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + Console.WriteLine("Basic usage live model success"); + } + public static void BasicUsageCALoop(string configPath) { const float outcome = 1.0f; diff --git a/bindings/cs/rl.net.cli/CommandBase.cs b/bindings/cs/rl.net.cli/CommandBase.cs index 2bf7228db..afee463f1 100644 --- a/bindings/cs/rl.net.cli/CommandBase.cs +++ b/bindings/cs/rl.net.cli/CommandBase.cs @@ -7,6 +7,9 @@ public abstract class CommandBase [Option(longName: "config", HelpText = "the path to client config", Required = true)] public string ConfigPath { get; set; } + [Option(longName: "model", HelpText = "the path to model file", Required = false)] + public string ModelPath { get; set; } + [Option(longName: "slates", HelpText = "Use slates for decisions", Required = false, Default = false)] public bool UseSlates { get; set; } diff --git a/bindings/cs/rl.net.cli/Helpers.cs b/bindings/cs/rl.net.cli/Helpers.cs index c8bf1dca8..6be23f9b6 100644 --- a/bindings/cs/rl.net.cli/Helpers.cs +++ b/bindings/cs/rl.net.cli/Helpers.cs @@ -1,5 +1,6 @@ using System; using System.IO; +using System.Collections.Generic; namespace Rl.Net.Cli { @@ -50,6 +51,55 @@ public static LiveModel CreateLiveModelOrExit(string clientJsonPath) return liveModel; } + public static LiveModel CreateLiveModelWithStaticModelOrExit(string clientJsonPath, string modelPath) + { + if (!File.Exists(clientJsonPath)) + { + WriteErrorAndExit($"Could not find file with path '{clientJsonPath}'."); + } + + string json = File.ReadAllText(clientJsonPath); + + ApiStatus apiStatus = new ApiStatus(); + + Configuration config; + if (!Configuration.TryLoadConfigurationFromJson(json, out config, apiStatus)) + { + WriteStatusAndExit(apiStatus); + } + + string trace_log = config["trace.logger.implementation"]; + + config["model.source"] = "BINDING_DATA_TRANSPORT"; + List modelData = null; + if (File.Exists(modelPath)) + { + byte[] fileBytes = File.ReadAllBytes(modelPath); + modelData = new List(fileBytes); + } + else + { + WriteErrorAndExit($"Could not find model file with path '{modelPath}'."); + } + + // Use static model FactoryContext + FactoryContext fc = new FactoryContext(modelData); + LiveModel liveModel = new LiveModel(config, fc); + + liveModel.BackgroundError += LiveModel_BackgroundError; + if (trace_log == "CONSOLE_TRACE_LOGGER") + { + liveModel.TraceLoggerEvent += LiveModel_TraceLogEvent; + } + + if (!liveModel.TryInit(apiStatus)) + { + WriteStatusAndExit(apiStatus); + } + + return liveModel; + } + public static LoopType CreateLoopOrExit(string clientJson, Func createLoop, bool jsonStr = false) where LoopType : ILoop { string json = clientJson; diff --git a/bindings/cs/rl.net.native/CMakeLists.txt b/bindings/cs/rl.net.native/CMakeLists.txt index 6f0d3327f..c7dba0dd5 100644 --- a/bindings/cs/rl.net.native/CMakeLists.txt +++ b/bindings/cs/rl.net.native/CMakeLists.txt @@ -1,5 +1,6 @@ set(rl_net_native_SOURCES binding_sender.cc + binding_static_model.cc binding_tracer.cc rl.net.api_status.cc rl.net.buffer.cc @@ -22,6 +23,7 @@ set(rl_net_native_SOURCES set(rl_net_native_HEADERS binding_sender.h + binding_static_model.h binding_tracer.h rl.net.api_status.h rl.net.loop_context.h diff --git a/bindings/cs/rl.net.native/binding_static_model.cc b/bindings/cs/rl.net.native/binding_static_model.cc new file mode 100644 index 000000000..9de13f730 --- /dev/null +++ b/bindings/cs/rl.net.native/binding_static_model.cc @@ -0,0 +1,11 @@ +#include "binding_static_model.h" + +using namespace rl_net_native; +using namespace reinforcement_learning; + +binding_static_model::binding_static_model(const char* vw_model, const size_t len) : vw_model(vw_model), len(len) {} + +int binding_static_model::get_data(model_transport::model_data& data, reinforcement_learning::api_status* status) +{ + return data.set_data(vw_model, len); +} diff --git a/bindings/cs/rl.net.native/binding_static_model.h b/bindings/cs/rl.net.native/binding_static_model.h new file mode 100644 index 000000000..a2a0074d9 --- /dev/null +++ b/bindings/cs/rl.net.native/binding_static_model.h @@ -0,0 +1,27 @@ +#pragma once + +#include "err_constants.h" +#include "model_mgmt.h" + +#include + +namespace model_transport = reinforcement_learning::model_management; + +namespace rl_net_native +{ + +namespace constants +{ +const char* const BINDING_DATA_TRANSPORT = "BINDING_DATA_TRANSPORT"; +} +class binding_static_model : public model_transport::i_data_transport +{ +public: + binding_static_model(const char* vw_model, const size_t len); + int get_data(model_transport::model_data& data, reinforcement_learning::api_status* status = nullptr) override; + +private: + const char* vw_model; + const size_t len; +}; +} // namespace rl_net_native \ No newline at end of file diff --git a/bindings/cs/rl.net.native/rl.net.factory_context.cc b/bindings/cs/rl.net.native/rl.net.factory_context.cc index c2cbe0027..909e0ff2f 100644 --- a/bindings/cs/rl.net.native/rl.net.factory_context.cc +++ b/bindings/cs/rl.net.native/rl.net.factory_context.cc @@ -17,6 +17,40 @@ API factory_context_t* CreateFactoryContext() return context; } +void cleanup_data_transport_factory(data_transport_factory_t* data_transport_factory) +{ + if (data_transport_factory != nullptr && data_transport_factory != &reinforcement_learning::data_transport_factory) + { + // We overrode the built-in data_transport_factory + delete data_transport_factory; + } +} + +API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len) +{ + using namespace reinforcement_learning::model_management; + auto context = CreateFactoryContext(); + char* vw_model_copy = new char[len]; + memcpy(vw_model_copy, vw_model, len); + + auto data_transport_factory_fn = [vw_model_copy, len](std::unique_ptr& retval, + const utility::configuration& configuration, i_trace* trace_logger, + api_status* status) -> int + { + retval.reset(new rl_net_native::binding_static_model(vw_model_copy, len)); + return error_code::success; + }; + + data_transport_factory_t* data_transport_factory = + new data_transport_factory_t(reinforcement_learning::data_transport_factory); + data_transport_factory->register_type(rl_net_native::constants::BINDING_DATA_TRANSPORT, data_transport_factory_fn); + + std::swap(context->data_transport_factory, data_transport_factory); + cleanup_data_transport_factory(data_transport_factory); + + return context; +} + void cleanup_trace_logger_factory(trace_logger_factory_t* trace_logger_factory) { if (trace_logger_factory != nullptr && trace_logger_factory != &reinforcement_learning::trace_logger_factory) @@ -39,6 +73,7 @@ API void DeleteFactoryContext(factory_context_t* context) { cleanup_trace_logger_factory(context->trace_logger_factory); cleanup_sender_factory(context->sender_factory); + cleanup_data_transport_factory(context->data_transport_factory); // TODO: Once we project the others, we will need to add them to cleanup. diff --git a/bindings/cs/rl.net.native/rl.net.factory_context.h b/bindings/cs/rl.net.native/rl.net.factory_context.h index e53564d47..15c97dd4a 100644 --- a/bindings/cs/rl.net.native/rl.net.factory_context.h +++ b/bindings/cs/rl.net.native/rl.net.factory_context.h @@ -1,9 +1,13 @@ #pragma once #include "binding_sender.h" +#include "binding_static_model.h" #include "factory_resolver.h" +#include "model_mgmt.h" #include "rl.net.native.h" +#include + typedef struct factory_context { reinforcement_learning::trace_logger_factory_t* trace_logger_factory; @@ -16,8 +20,8 @@ typedef struct factory_context extern "C" { API factory_context_t* CreateFactoryContext(); + API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len); API void DeleteFactoryContext(factory_context_t* context); - API void SetFactoryContextBindingSenderFactory( factory_context_t* context, rl_net_native::sender_create_fn create_fn, rl_net_native::sender_vtable_t vtable); } diff --git a/bindings/cs/rl.net/FactoryContext.cs b/bindings/cs/rl.net/FactoryContext.cs index 59fd6239d..16533f873 100644 --- a/bindings/cs/rl.net/FactoryContext.cs +++ b/bindings/cs/rl.net/FactoryContext.cs @@ -1,6 +1,8 @@ using System; using System.Threading; using System.Runtime.InteropServices; +using System.Collections.Generic; +using System.Linq; using Rl.Net.Native; @@ -10,6 +12,9 @@ public sealed class FactoryContext : NativeObject [DllImport("rlnetnative")] private static extern IntPtr CreateFactoryContext(); + [DllImport("rlnetnative")] + private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr vw_model, int len); + [DllImport("rlnetnative")] private static extern void DeleteFactoryContext(IntPtr context); @@ -20,6 +25,23 @@ public sealed class FactoryContext : NativeObject { } + public FactoryContext(IEnumerable vwModelEnumerable) : base( + new New(() => { + var vwModelArray = vwModelEnumerable.ToArray(); + GCHandle handle = GCHandle.Alloc(vwModelArray, GCHandleType.Pinned); + try { + IntPtr ptr = handle.AddrOfPinnedObject(); + return CreateFactoryContextWithStaticModel(ptr, vwModelArray.Length); + } + finally { + if (handle.IsAllocated) + handle.Free(); + } + }), + new Delete(DeleteFactoryContext)) + { + } + private GCHandleLifetime registeredSenderCreateHandle; public void SetSenderFactory(Func createSender) where TSender : ISender diff --git a/bindings/cs/rl.net/LiveModel.cs b/bindings/cs/rl.net/LiveModel.cs index 59f5b3b41..2e829d011 100644 --- a/bindings/cs/rl.net/LiveModel.cs +++ b/bindings/cs/rl.net/LiveModel.cs @@ -1,5 +1,6 @@ using System; using System.Runtime.InteropServices; +using System.Collections.Generic; using Rl.Net.Native; diff --git a/examples/test_cpp/model.vw b/examples/test_cpp/model.vw new file mode 100644 index 0000000000000000000000000000000000000000..6441a19b66f99122c3cd994643cff7d299b0b53d GIT binary patch literal 445 zcmZQ$U|_J+v(PhOU<9%lazO+GLxa5#gaIcrfGh=F-Q=YB#FR7$BekL+C%-5aAy86T zkgAYdQUsDN&QDB?&jBjUEzK#(%*o74g^Ly?=4AspsRhNEIr(`C26~1D5MXFxX<%q* z#=ru!OLeNPFB8;uRwxa2A&`CJ%o!O6h!_Wmb_fl9Y7gOq$OZ=py8$K! #include +#include #include #include #include @@ -34,6 +36,9 @@ class model_data char* alloc(size_t desired); void free(); + // Set data + int set_data(const char* vw_model, size_t len); + model_data() = default; ~model_data() = default; diff --git a/rlclientlib/model_mgmt/model_mgmt.cc b/rlclientlib/model_mgmt/model_mgmt.cc index a4d416694..f0493e4d0 100644 --- a/rlclientlib/model_mgmt/model_mgmt.cc +++ b/rlclientlib/model_mgmt/model_mgmt.cc @@ -25,5 +25,18 @@ char* model_data::alloc(const size_t desired) void model_data::free() { _data.clear(); } +int model_data::set_data(const char* vw_model, size_t len) +{ + if (vw_model == nullptr || len == 0) { return reinforcement_learning::error_code::static_model_load_error; } + + char* buffer = this->alloc(len); + if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; } + + memcpy(buffer, vw_model, len); + this->data_sz(len); + + return reinforcement_learning::error_code::success; +} + } // namespace model_management } // namespace reinforcement_learning