Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: static model loader #603

Merged
merged 18 commits into from
Feb 8, 2024
36 changes: 35 additions & 1 deletion bindings/cs/rl.net.cli/BasicUsageCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace Rl.Net.Cli
[Verb("basicUsage", HelpText = "Basic usage of the API")]
class BasicUsageCommand : CommandBase
{
[Option(longName: "testType", HelpText = "select from (liveModel, caLoop, cbLoop, ccbLoop, pdfExample, slatesLoop) basic usage examples", Required = false, Default = "liveModel")]
[Option(longName: "testType", HelpText = "select from (liveModel, liveModelStatic, caLoop, cbLoop, ccbLoop, pdfExample, slatesLoop) basic usage examples", Required = false, Default = "liveModel")]
public string testType { get; set; }

public override void Run()
Expand All @@ -16,6 +16,9 @@ public override void Run()
case "liveModel":
BasicUsage(this.ConfigPath);
break;
case "liveModelStatic":
BasicUsageStaticModel(this.ConfigPath, this.ModelPath);
break;
case "caLoop":
BasicUsageCALoop(this.ConfigPath);
break;
Expand Down Expand Up @@ -68,6 +71,37 @@ public static void BasicUsage(string configPath)
Console.WriteLine("Basic usage live model success");
}

public static void BasicUsageStaticModel(string configPath, string modelPath)
{
const float outcome = 1.0f;
const string eventId = "event_id";
const string contextJson = "{\"GUser\":{\"id\":\"a\",\"major\":\"eng\",\"hobby\":\"hiking\"},\"_multi\":[ { \"TAction\":{\"a1\":\"f1\"} },{\"TAction\":{\"a2\":\"f2\"}}]}";

LiveModel liveModel = Helpers.CreateLiveModelWithStaticModelOrExit(configPath, modelPath);

ApiStatus apiStatus = new ApiStatus();

RankingResponse rankingResponse = new RankingResponse();
if (!liveModel.TryChooseRank(eventId, contextJson, rankingResponse, apiStatus))
{
Helpers.WriteStatusAndExit(apiStatus);
}

long actionId;
if (!rankingResponse.TryGetChosenAction(out actionId, apiStatus))
{
Helpers.WriteStatusAndExit(apiStatus);
}

Console.WriteLine($"Chosen action id: {actionId}");

if (!liveModel.TryQueueOutcomeEvent(eventId, outcome, apiStatus))
{
Helpers.WriteStatusAndExit(apiStatus);
}
Console.WriteLine("Basic usage live model success");
}

public static void BasicUsageCALoop(string configPath)
{
const float outcome = 1.0f;
Expand Down
3 changes: 3 additions & 0 deletions bindings/cs/rl.net.cli/CommandBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ public abstract class CommandBase
[Option(longName: "config", HelpText = "the path to client config", Required = true)]
public string ConfigPath { get; set; }

[Option(longName: "model", HelpText = "the path to model file", Required = false)]
public string ModelPath { get; set; }

[Option(longName: "slates", HelpText = "Use slates for decisions", Required = false, Default = false)]
public bool UseSlates { get; set; }

Expand Down
50 changes: 50 additions & 0 deletions bindings/cs/rl.net.cli/Helpers.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.IO;
using System.Collections.Generic;

namespace Rl.Net.Cli
{
Expand Down Expand Up @@ -50,6 +51,55 @@ public static LiveModel CreateLiveModelOrExit(string clientJsonPath)
return liveModel;
}

public static LiveModel CreateLiveModelWithStaticModelOrExit(string clientJsonPath, string modelPath)
{
if (!File.Exists(clientJsonPath))
{
WriteErrorAndExit($"Could not find file with path '{clientJsonPath}'.");
}

string json = File.ReadAllText(clientJsonPath);

ApiStatus apiStatus = new ApiStatus();

Configuration config;
if (!Configuration.TryLoadConfigurationFromJson(json, out config, apiStatus))
{
WriteStatusAndExit(apiStatus);
}

string trace_log = config["trace.logger.implementation"];

config["model.source"] = "BINDING_DATA_TRANSPORT";
List<byte> modelData = null;
if (File.Exists(modelPath))
{
byte[] fileBytes = File.ReadAllBytes(modelPath);
modelData = new List<byte>(fileBytes);
}
else
{
WriteErrorAndExit($"Could not find model file with path '{modelPath}'.");
}

// Use static model FactoryContext
FactoryContext fc = new FactoryContext(modelData);
LiveModel liveModel = new LiveModel(config, fc);

liveModel.BackgroundError += LiveModel_BackgroundError;
if (trace_log == "CONSOLE_TRACE_LOGGER")
{
liveModel.TraceLoggerEvent += LiveModel_TraceLogEvent;
}

if (!liveModel.TryInit(apiStatus))
{
WriteStatusAndExit(apiStatus);
}

return liveModel;
}

public static LoopType CreateLoopOrExit<LoopType>(string clientJson, Func<Configuration, LoopType> createLoop, bool jsonStr = false) where LoopType : ILoop
{
string json = clientJson;
Expand Down
2 changes: 2 additions & 0 deletions bindings/cs/rl.net.native/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(rl_net_native_SOURCES
binding_sender.cc
binding_static_model.cc
binding_tracer.cc
rl.net.api_status.cc
rl.net.buffer.cc
Expand All @@ -22,6 +23,7 @@ set(rl_net_native_SOURCES

set(rl_net_native_HEADERS
binding_sender.h
binding_static_model.h
binding_tracer.h
rl.net.api_status.h
rl.net.loop_context.h
Expand Down
11 changes: 11 additions & 0 deletions bindings/cs/rl.net.native/binding_static_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "binding_static_model.h"

using namespace rl_net_native;
using namespace reinforcement_learning;

binding_static_model::binding_static_model(const char* vw_model, const size_t len) : vw_model(vw_model), len(len) {}

int binding_static_model::get_data(model_transport::model_data& data, reinforcement_learning::api_status* status)
{
return data.set_data(vw_model, len);
}
27 changes: 27 additions & 0 deletions bindings/cs/rl.net.native/binding_static_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "err_constants.h"
#include "model_mgmt.h"

bassmang marked this conversation as resolved.
Show resolved Hide resolved
#include <cstring>

namespace model_transport = reinforcement_learning::model_management;

namespace rl_net_native
{

namespace constants
{
const char* const BINDING_DATA_TRANSPORT = "BINDING_DATA_TRANSPORT";
}
class binding_static_model : public model_transport::i_data_transport
{
public:
binding_static_model(const char* vw_model, const size_t len);
int get_data(model_transport::model_data& data, reinforcement_learning::api_status* status = nullptr) override;

private:
const char* vw_model;
const size_t len;
};
} // namespace rl_net_native
35 changes: 35 additions & 0 deletions bindings/cs/rl.net.native/rl.net.factory_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,40 @@ API factory_context_t* CreateFactoryContext()
return context;
}

void cleanup_data_transport_factory(data_transport_factory_t* data_transport_factory)
{
if (data_transport_factory != nullptr && data_transport_factory != &reinforcement_learning::data_transport_factory)
{
// We overrode the built-in data_transport_factory
delete data_transport_factory;
}
}

API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len)
{
using namespace reinforcement_learning::model_management;
auto context = CreateFactoryContext();
char* vw_model_copy = new char[len];
memcpy(vw_model_copy, vw_model, len);

auto data_transport_factory_fn = [vw_model_copy, len](std::unique_ptr<model_management::i_data_transport>& retval,
const utility::configuration& configuration, i_trace* trace_logger,
api_status* status) -> int
{
retval.reset(new rl_net_native::binding_static_model(vw_model_copy, len));
return error_code::success;
};

data_transport_factory_t* data_transport_factory =
new data_transport_factory_t(reinforcement_learning::data_transport_factory);
data_transport_factory->register_type(rl_net_native::constants::BINDING_DATA_TRANSPORT, data_transport_factory_fn);

std::swap(context->data_transport_factory, data_transport_factory);
cleanup_data_transport_factory(data_transport_factory);

return context;
}

void cleanup_trace_logger_factory(trace_logger_factory_t* trace_logger_factory)
{
if (trace_logger_factory != nullptr && trace_logger_factory != &reinforcement_learning::trace_logger_factory)
Expand All @@ -39,6 +73,7 @@ API void DeleteFactoryContext(factory_context_t* context)
{
cleanup_trace_logger_factory(context->trace_logger_factory);
cleanup_sender_factory(context->sender_factory);
cleanup_data_transport_factory(context->data_transport_factory);

// TODO: Once we project the others, we will need to add them to cleanup.

Expand Down
6 changes: 5 additions & 1 deletion bindings/cs/rl.net.native/rl.net.factory_context.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#pragma once

#include "binding_sender.h"
#include "binding_static_model.h"
#include "factory_resolver.h"
#include "model_mgmt.h"
#include "rl.net.native.h"

#include <cstddef>

typedef struct factory_context
{
reinforcement_learning::trace_logger_factory_t* trace_logger_factory;
Expand All @@ -16,8 +20,8 @@ typedef struct factory_context
extern "C"
{
API factory_context_t* CreateFactoryContext();
API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len);
API void DeleteFactoryContext(factory_context_t* context);

API void SetFactoryContextBindingSenderFactory(
factory_context_t* context, rl_net_native::sender_create_fn create_fn, rl_net_native::sender_vtable_t vtable);
}
22 changes: 22 additions & 0 deletions bindings/cs/rl.net/FactoryContext.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Threading;
using System.Runtime.InteropServices;
using System.Collections.Generic;
using System.Linq;

using Rl.Net.Native;

Expand All @@ -10,6 +12,9 @@ public sealed class FactoryContext : NativeObject<FactoryContext>
[DllImport("rlnetnative")]
private static extern IntPtr CreateFactoryContext();

[DllImport("rlnetnative")]
private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr vw_model, int len);

[DllImport("rlnetnative")]
private static extern void DeleteFactoryContext(IntPtr context);

Expand All @@ -20,6 +25,23 @@ public sealed class FactoryContext : NativeObject<FactoryContext>
{
}

public FactoryContext(IEnumerable<byte> vwModelEnumerable) : base(
new New<FactoryContext>(() => {
var vwModelArray = vwModelEnumerable.ToArray();
GCHandle handle = GCHandle.Alloc(vwModelArray, GCHandleType.Pinned);
try {
Copy link
Member

@rajan-chari rajan-chari Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to check if this makes a copy here? Unclear to me...
CreateFactoryContextWithStaticModel does make a copy. Wondering if there are 2 copies happening.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a copy happening here but I believe it is necessary since they pass in a List which doesn't have a way to obtain a pointer to it's internal array

Copy link
Member

@lokitoth lokitoth Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a way to avoid the copy on .NET 5 or above with List<> types, but not generally, no.

See https://learn.microsoft.com/en-us/dotnet/api/system.runtime.interopservices.collectionsmarshal.asspan?view=net-5.0

IntPtr ptr = handle.AddrOfPinnedObject();
return CreateFactoryContextWithStaticModel(ptr, vwModelArray.Length);
}
finally {
if (handle.IsAllocated)
handle.Free();
}
}),
new Delete<FactoryContext>(DeleteFactoryContext))
{
}

private GCHandleLifetime registeredSenderCreateHandle;

public void SetSenderFactory<TSender>(Func<IReadOnlyConfiguration, ErrorCallback, TSender> createSender) where TSender : ISender
Expand Down
1 change: 1 addition & 0 deletions bindings/cs/rl.net/LiveModel.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Runtime.InteropServices;
using System.Collections.Generic;

using Rl.Net.Native;

Expand Down
Binary file added examples/test_cpp/model.vw
Binary file not shown.
1 change: 1 addition & 0 deletions include/errors_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ ERROR_CODE_DEFINITION(48, extension_error, "Error from extension: ")
ERROR_CODE_DEFINITION(49, baseline_actions_not_defined, "Baseline Actions must be defined in apprentice mode")
ERROR_CODE_DEFINITION(50, http_api_key_not_provided, "Http api key must be provided")
ERROR_CODE_DEFINITION(51, http_model_uri_not_provided, "Model Blob URI parameter was not passed in via configuration")
ERROR_CODE_DEFINITION(52, static_model_load_error, "Static model passed in C# layer is not loading properly")
//! [Error Definitions]
5 changes: 5 additions & 0 deletions include/model_mgmt.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#pragma once
#include "err_constants.h"
#include "multistep.h"

#include <cstddef>
#include <cstdint>
#include <cstring>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -34,6 +36,9 @@ class model_data
char* alloc(size_t desired);
void free();

// Set data
int set_data(const char* vw_model, size_t len);

model_data() = default;
~model_data() = default;

Expand Down
13 changes: 13 additions & 0 deletions rlclientlib/model_mgmt/model_mgmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,18 @@ char* model_data::alloc(const size_t desired)

void model_data::free() { _data.clear(); }

int model_data::set_data(const char* vw_model, size_t len)
{
if (vw_model == nullptr || len == 0) { return reinforcement_learning::error_code::static_model_load_error; }

char* buffer = this->alloc(len);
if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; }

memcpy(buffer, vw_model, len);
this->data_sz(len);

return reinforcement_learning::error_code::success;
}

} // namespace model_management
} // namespace reinforcement_learning
Loading