Skip to content

Commit

Permalink
feat: static model loader (#603)
Browse files Browse the repository at this point in the history
* static model loader

* implement model loading

* add test

* use pinned memory c#

* standar memcpy

* lint

* fix includes

* clang

* memcpy in constuctor

* lint

* address comments

* lint

* comments

* lint

* CI

* copy model outside of lambda

* standard memcpy
  • Loading branch information
bassmang authored Feb 8, 2024
1 parent 4ed7401 commit 66e045a
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 2 deletions.
36 changes: 35 additions & 1 deletion bindings/cs/rl.net.cli/BasicUsageCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace Rl.Net.Cli
[Verb("basicUsage", HelpText = "Basic usage of the API")]
class BasicUsageCommand : CommandBase
{
[Option(longName: "testType", HelpText = "select from (liveModel, caLoop, cbLoop, ccbLoop, pdfExample, slatesLoop) basic usage examples", Required = false, Default = "liveModel")]
[Option(longName: "testType", HelpText = "select from (liveModel, liveModelStatic, caLoop, cbLoop, ccbLoop, pdfExample, slatesLoop) basic usage examples", Required = false, Default = "liveModel")]
public string testType { get; set; }

public override void Run()
Expand All @@ -16,6 +16,9 @@ public override void Run()
case "liveModel":
BasicUsage(this.ConfigPath);
break;
case "liveModelStatic":
BasicUsageStaticModel(this.ConfigPath, this.ModelPath);
break;
case "caLoop":
BasicUsageCALoop(this.ConfigPath);
break;
Expand Down Expand Up @@ -68,6 +71,37 @@ public static void BasicUsage(string configPath)
Console.WriteLine("Basic usage live model success");
}

public static void BasicUsageStaticModel(string configPath, string modelPath)
{
const float outcome = 1.0f;
const string eventId = "event_id";
const string contextJson = "{\"GUser\":{\"id\":\"a\",\"major\":\"eng\",\"hobby\":\"hiking\"},\"_multi\":[ { \"TAction\":{\"a1\":\"f1\"} },{\"TAction\":{\"a2\":\"f2\"}}]}";

LiveModel liveModel = Helpers.CreateLiveModelWithStaticModelOrExit(configPath, modelPath);

ApiStatus apiStatus = new ApiStatus();

RankingResponse rankingResponse = new RankingResponse();
if (!liveModel.TryChooseRank(eventId, contextJson, rankingResponse, apiStatus))
{
Helpers.WriteStatusAndExit(apiStatus);
}

long actionId;
if (!rankingResponse.TryGetChosenAction(out actionId, apiStatus))
{
Helpers.WriteStatusAndExit(apiStatus);
}

Console.WriteLine($"Chosen action id: {actionId}");

if (!liveModel.TryQueueOutcomeEvent(eventId, outcome, apiStatus))
{
Helpers.WriteStatusAndExit(apiStatus);
}
Console.WriteLine("Basic usage live model success");
}

public static void BasicUsageCALoop(string configPath)
{
const float outcome = 1.0f;
Expand Down
3 changes: 3 additions & 0 deletions bindings/cs/rl.net.cli/CommandBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ public abstract class CommandBase
[Option(longName: "config", HelpText = "the path to client config", Required = true)]
public string ConfigPath { get; set; }

[Option(longName: "model", HelpText = "the path to model file", Required = false)]
public string ModelPath { get; set; }

[Option(longName: "slates", HelpText = "Use slates for decisions", Required = false, Default = false)]
public bool UseSlates { get; set; }

Expand Down
50 changes: 50 additions & 0 deletions bindings/cs/rl.net.cli/Helpers.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.IO;
using System.Collections.Generic;

namespace Rl.Net.Cli
{
Expand Down Expand Up @@ -50,6 +51,55 @@ public static LiveModel CreateLiveModelOrExit(string clientJsonPath)
return liveModel;
}

public static LiveModel CreateLiveModelWithStaticModelOrExit(string clientJsonPath, string modelPath)
{
if (!File.Exists(clientJsonPath))
{
WriteErrorAndExit($"Could not find file with path '{clientJsonPath}'.");
}

string json = File.ReadAllText(clientJsonPath);

ApiStatus apiStatus = new ApiStatus();

Configuration config;
if (!Configuration.TryLoadConfigurationFromJson(json, out config, apiStatus))
{
WriteStatusAndExit(apiStatus);
}

string trace_log = config["trace.logger.implementation"];

config["model.source"] = "BINDING_DATA_TRANSPORT";
List<byte> modelData = null;
if (File.Exists(modelPath))
{
byte[] fileBytes = File.ReadAllBytes(modelPath);
modelData = new List<byte>(fileBytes);
}
else
{
WriteErrorAndExit($"Could not find model file with path '{modelPath}'.");
}

// Use static model FactoryContext
FactoryContext fc = new FactoryContext(modelData);
LiveModel liveModel = new LiveModel(config, fc);

liveModel.BackgroundError += LiveModel_BackgroundError;
if (trace_log == "CONSOLE_TRACE_LOGGER")
{
liveModel.TraceLoggerEvent += LiveModel_TraceLogEvent;
}

if (!liveModel.TryInit(apiStatus))
{
WriteStatusAndExit(apiStatus);
}

return liveModel;
}

public static LoopType CreateLoopOrExit<LoopType>(string clientJson, Func<Configuration, LoopType> createLoop, bool jsonStr = false) where LoopType : ILoop
{
string json = clientJson;
Expand Down
2 changes: 2 additions & 0 deletions bindings/cs/rl.net.native/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(rl_net_native_SOURCES
binding_sender.cc
binding_static_model.cc
binding_tracer.cc
rl.net.api_status.cc
rl.net.buffer.cc
Expand All @@ -22,6 +23,7 @@ set(rl_net_native_SOURCES

set(rl_net_native_HEADERS
binding_sender.h
binding_static_model.h
binding_tracer.h
rl.net.api_status.h
rl.net.loop_context.h
Expand Down
11 changes: 11 additions & 0 deletions bindings/cs/rl.net.native/binding_static_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "binding_static_model.h"

using namespace rl_net_native;
using namespace reinforcement_learning;

binding_static_model::binding_static_model(const char* vw_model, const size_t len) : vw_model(vw_model), len(len) {}

int binding_static_model::get_data(model_transport::model_data& data, reinforcement_learning::api_status* status)
{
return data.set_data(vw_model, len);
}
27 changes: 27 additions & 0 deletions bindings/cs/rl.net.native/binding_static_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "err_constants.h"
#include "model_mgmt.h"

#include <cstring>

namespace model_transport = reinforcement_learning::model_management;

namespace rl_net_native
{

namespace constants
{
const char* const BINDING_DATA_TRANSPORT = "BINDING_DATA_TRANSPORT";
}
class binding_static_model : public model_transport::i_data_transport
{
public:
binding_static_model(const char* vw_model, const size_t len);
int get_data(model_transport::model_data& data, reinforcement_learning::api_status* status = nullptr) override;

private:
const char* vw_model;
const size_t len;
};
} // namespace rl_net_native
35 changes: 35 additions & 0 deletions bindings/cs/rl.net.native/rl.net.factory_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,40 @@ API factory_context_t* CreateFactoryContext()
return context;
}

void cleanup_data_transport_factory(data_transport_factory_t* data_transport_factory)
{
if (data_transport_factory != nullptr && data_transport_factory != &reinforcement_learning::data_transport_factory)
{
// We overrode the built-in data_transport_factory
delete data_transport_factory;
}
}

API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len)
{
using namespace reinforcement_learning::model_management;
auto context = CreateFactoryContext();
char* vw_model_copy = new char[len];
memcpy(vw_model_copy, vw_model, len);

auto data_transport_factory_fn = [vw_model_copy, len](std::unique_ptr<model_management::i_data_transport>& retval,
const utility::configuration& configuration, i_trace* trace_logger,
api_status* status) -> int
{
retval.reset(new rl_net_native::binding_static_model(vw_model_copy, len));
return error_code::success;
};

data_transport_factory_t* data_transport_factory =
new data_transport_factory_t(reinforcement_learning::data_transport_factory);
data_transport_factory->register_type(rl_net_native::constants::BINDING_DATA_TRANSPORT, data_transport_factory_fn);

std::swap(context->data_transport_factory, data_transport_factory);
cleanup_data_transport_factory(data_transport_factory);

return context;
}

void cleanup_trace_logger_factory(trace_logger_factory_t* trace_logger_factory)
{
if (trace_logger_factory != nullptr && trace_logger_factory != &reinforcement_learning::trace_logger_factory)
Expand All @@ -39,6 +73,7 @@ API void DeleteFactoryContext(factory_context_t* context)
{
cleanup_trace_logger_factory(context->trace_logger_factory);
cleanup_sender_factory(context->sender_factory);
cleanup_data_transport_factory(context->data_transport_factory);

// TODO: Once we project the others, we will need to add them to cleanup.

Expand Down
6 changes: 5 additions & 1 deletion bindings/cs/rl.net.native/rl.net.factory_context.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#pragma once

#include "binding_sender.h"
#include "binding_static_model.h"
#include "factory_resolver.h"
#include "model_mgmt.h"
#include "rl.net.native.h"

#include <cstddef>

typedef struct factory_context
{
reinforcement_learning::trace_logger_factory_t* trace_logger_factory;
Expand All @@ -16,8 +20,8 @@ typedef struct factory_context
extern "C"
{
API factory_context_t* CreateFactoryContext();
API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len);
API void DeleteFactoryContext(factory_context_t* context);

API void SetFactoryContextBindingSenderFactory(
factory_context_t* context, rl_net_native::sender_create_fn create_fn, rl_net_native::sender_vtable_t vtable);
}
22 changes: 22 additions & 0 deletions bindings/cs/rl.net/FactoryContext.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Threading;
using System.Runtime.InteropServices;
using System.Collections.Generic;
using System.Linq;

using Rl.Net.Native;

Expand All @@ -10,6 +12,9 @@ public sealed class FactoryContext : NativeObject<FactoryContext>
[DllImport("rlnetnative")]
private static extern IntPtr CreateFactoryContext();

[DllImport("rlnetnative")]
private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr vw_model, int len);

[DllImport("rlnetnative")]
private static extern void DeleteFactoryContext(IntPtr context);

Expand All @@ -20,6 +25,23 @@ public sealed class FactoryContext : NativeObject<FactoryContext>
{
}

public FactoryContext(IEnumerable<byte> vwModelEnumerable) : base(
new New<FactoryContext>(() => {
var vwModelArray = vwModelEnumerable.ToArray();
GCHandle handle = GCHandle.Alloc(vwModelArray, GCHandleType.Pinned);
try {
IntPtr ptr = handle.AddrOfPinnedObject();
return CreateFactoryContextWithStaticModel(ptr, vwModelArray.Length);
}
finally {
if (handle.IsAllocated)
handle.Free();
}
}),
new Delete<FactoryContext>(DeleteFactoryContext))
{
}

private GCHandleLifetime registeredSenderCreateHandle;

public void SetSenderFactory<TSender>(Func<IReadOnlyConfiguration, ErrorCallback, TSender> createSender) where TSender : ISender
Expand Down
1 change: 1 addition & 0 deletions bindings/cs/rl.net/LiveModel.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Runtime.InteropServices;
using System.Collections.Generic;

using Rl.Net.Native;

Expand Down
Binary file added examples/test_cpp/model.vw
Binary file not shown.
1 change: 1 addition & 0 deletions include/errors_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ ERROR_CODE_DEFINITION(48, extension_error, "Error from extension: ")
ERROR_CODE_DEFINITION(49, baseline_actions_not_defined, "Baseline Actions must be defined in apprentice mode")
ERROR_CODE_DEFINITION(50, http_api_key_not_provided, "Http api key must be provided")
ERROR_CODE_DEFINITION(51, http_model_uri_not_provided, "Model Blob URI parameter was not passed in via configuration")
ERROR_CODE_DEFINITION(52, static_model_load_error, "Static model passed in C# layer is not loading properly")
//! [Error Definitions]
5 changes: 5 additions & 0 deletions include/model_mgmt.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#pragma once
#include "err_constants.h"
#include "multistep.h"

#include <cstddef>
#include <cstdint>
#include <cstring>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -34,6 +36,9 @@ class model_data
char* alloc(size_t desired);
void free();

// Set data
int set_data(const char* vw_model, size_t len);

model_data() = default;
~model_data() = default;

Expand Down
13 changes: 13 additions & 0 deletions rlclientlib/model_mgmt/model_mgmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,18 @@ char* model_data::alloc(const size_t desired)

void model_data::free() { _data.clear(); }

int model_data::set_data(const char* vw_model, size_t len)
{
if (vw_model == nullptr || len == 0) { return reinforcement_learning::error_code::static_model_load_error; }

char* buffer = this->alloc(len);
if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; }

memcpy(buffer, vw_model, len);
this->data_sz(len);

return reinforcement_learning::error_code::success;
}

} // namespace model_management
} // namespace reinforcement_learning

0 comments on commit 66e045a

Please sign in to comment.