diff --git a/bindings/cs/CMakeLists.txt b/bindings/cs/CMakeLists.txt index 70bf843ca..7c3d0bb1a 100644 --- a/bindings/cs/CMakeLists.txt +++ b/bindings/cs/CMakeLists.txt @@ -4,4 +4,5 @@ include(FindDotnet) add_subdirectory(rl.net.native) add_subdirectory(rl.net) add_subdirectory(rl.net.cli) +add_subdirectory(rl.net.flatbuffer) add_subdirectory(rl.net.cli.test) diff --git a/bindings/cs/rl.net.cli.test/CMakeLists.txt b/bindings/cs/rl.net.cli.test/CMakeLists.txt index 0b986480d..e5f21148c 100644 --- a/bindings/cs/rl.net.cli.test/CMakeLists.txt +++ b/bindings/cs/rl.net.cli.test/CMakeLists.txt @@ -17,7 +17,7 @@ else() add_custom_target(rl.net.cli.test ALL COMMAND ${DOTNET_COMMAND} build ${CMAKE_CURRENT_SOURCE_DIR} -o $ -v m --nologo --no-dependencies /clp:NoSummary COMMENT Building rl.net.cli.test - DEPENDS rl.net.cli + DEPENDS rl.net.cli rl.net.flatbuffers SOURCES ${RL_NET_CLI_TEST_SOURCES}) endif() diff --git a/bindings/cs/rl.net.cli.test/UnicodeTest.cs b/bindings/cs/rl.net.cli.test/UnicodeTest.cs index 8527c856c..1055b58bd 100644 --- a/bindings/cs/rl.net.cli.test/UnicodeTest.cs +++ b/bindings/cs/rl.net.cli.test/UnicodeTest.cs @@ -146,7 +146,7 @@ void CleanupPInvokeOverrides() { NativeMethods.ConfigurationGetOverride = null; NativeMethods.ConfigurationSetOverride = null; - NativeMethods.LiveModelChooseRankOverride = null; + //NativeMethods.LiveModelChooseRankOverride = null; NativeMethods.LiveModelChooseRankWithFlagsOverride = null; NativeMethods.LiveModelReportActionTakenOverride = null; NativeMethods.LiveModelReportOutcomeFOverride = null; @@ -374,8 +374,8 @@ public void Test_LiveModel_RequestEpisodicE2E() private void Run_LiveModelChooseRank_Test(LiveModel liveModel, string eventId, string contextJson) { - NativeMethods.LiveModelChooseRankOverride = - (IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus) => + NativeMethods.LiveModelChooseRankWithFlagsOverride = + (IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus) => { string eventIdMarshalledBack = NativeMethods.StringMarshallingFunc(eventIdPtr); Assert.AreEqual(eventId, eventIdMarshalledBack, "Marshalling eventId does not work properly in LiveModelChooseRank"); @@ -417,11 +417,11 @@ public void Test_LiveModel_ChooseRank() private void Run_LiveModelRequestDecision_Test(LiveModel liveModel, string contextJson) { - NativeMethods.LiveModelRequestDecisionOverride = - (IntPtr liveModelPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr rankingResponse, IntPtr ApiStatus) => + NativeMethods.LiveModelRequestDecisionWithFlagsOverride = + (IntPtr liveModelPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr ApiStatus) => { string contextJsonMarshalledBack = NativeMethods.StringMarshallingFunc(contextJsonPtr); - Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestDecision"); + Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestDecisionWithFlags"); return NativeMethods.SuccessStatus; }; @@ -454,11 +454,11 @@ public void Test_LiveModel_RequestDecision() private void Run_LiveModelRequestMultiSlotDetailed_Test(LiveModel liveModel, string contextJson, string eventId) { - NativeMethods.LiveModelRequestMultiSlotDecisionDetailedOverride = - (IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr rankingResponse, IntPtr ApiStatus) => + NativeMethods.LiveModelRequestMultiSlotDecisionDetailedWithFlagsOverride = + (IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr ApiStatus) => { string contextJsonMarshalledBack = NativeMethods.StringMarshallingFunc(contextJsonPtr); - Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestMultiSlotDecisionDetailed"); + Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestDecisionDetailedWithFlags"); return NativeMethods.SuccessStatus; }; @@ -491,11 +491,11 @@ public void Test_LiveModel_RequestMultiSlotDecisionDetailed() private void Run_LiveModelRequestMultiSlot_Test(LiveModel liveModel, string contextJson, string eventId) { - NativeMethods.LiveModelRequestMultiSlotDecisionOverride = - (IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr rankingResponse, IntPtr ApiStatus) => + NativeMethods.LiveModelRequestMultiSlotDecisionWithFlagsOverride = + (IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr ApiStatus) => { string contextJsonMarshalledBack = NativeMethods.StringMarshallingFunc(contextJsonPtr); - Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestMultiSlotDecision"); + Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestDecisionWithFlags"); return NativeMethods.SuccessStatus; }; @@ -528,8 +528,8 @@ public void Test_LiveModel_RequestMultiSlotDecision() private void Run_LiveModelRequestContinuousAction_Test(LiveModel liveModel, string contextJson) { - NativeMethods.LiveModelRequestContinuousActionOverride = - (IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr continuousActionResponse, IntPtr ApiStatus) => + NativeMethods.LiveModelRequestContinuousActionWithFlagsOverride = + (IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr continuousActionResponse, IntPtr ApiStatus) => { string contextJsonMarshalledBack = NativeMethods.StringMarshallingFunc(contextJsonPtr); Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestContinuousAction"); diff --git a/bindings/cs/rl.net.cli.test/rl.net.cli.test.csproj b/bindings/cs/rl.net.cli.test/rl.net.cli.test.csproj index c07727cd9..3fcb7efbc 100644 --- a/bindings/cs/rl.net.cli.test/rl.net.cli.test.csproj +++ b/bindings/cs/rl.net.cli.test/rl.net.cli.test.csproj @@ -27,6 +27,7 @@ + diff --git a/bindings/cs/rl.net.flatbuffer/.gitignore b/bindings/cs/rl.net.flatbuffer/.gitignore new file mode 100644 index 000000000..e4484093b --- /dev/null +++ b/bindings/cs/rl.net.flatbuffer/.gitignore @@ -0,0 +1,2 @@ +generated/ +AnyCPU/ \ No newline at end of file diff --git a/bindings/cs/rl.net.flatbuffer/CMakeLists.txt b/bindings/cs/rl.net.flatbuffer/CMakeLists.txt new file mode 100644 index 000000000..fa7b5becd --- /dev/null +++ b/bindings/cs/rl.net.flatbuffer/CMakeLists.txt @@ -0,0 +1,44 @@ +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/../cmake/Modules/") + +# First try to find the config version. Newer, used by vcpkg etc +find_package(Flatbuffers CONFIG) +if(TARGET flatbuffers::flatbuffers AND TARGET flatbuffers::flatc) + get_property(flatc_location TARGET flatbuffers::flatc PROPERTY LOCATION) + message(STATUS "Found Flatbuffers with CONFIG, flatc located at: ${flatc_location}") +else() + # Fallback to the old version + find_package(Flatbuffers MODULE REQUIRED) + set(flatc_location ${FLATBUFFERS_FLATC_EXECUTABLE}) + message(STATUS "Found Flatbuffers with MODULE, flatc located at: ${flatc_location}") +endif() + +include(FlatbufferUtils) + +set(RL_NET_VWFB_SCHEMA_FILES + "${CMAKE_SOURCE_DIR}/ext_libs/vowpal_wabbit/vowpalwabbit/fb_parser/schema/example.fbs" ) + +set(RL_NET_FLATBUFFERS_SOURCES + VWExampleBuilder.cs +) + +add_flatbuffer_schema( + TARGET rl.net.flatbuffers_generated + SCHEMAS ${RL_NET_VWFB_SCHEMA_FILES} + OUTPUT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/generated/vwfb + GENERATED_FILES_VAR RL_NET_FLATBUFFERS_SOURCES + FLATC_LANGUAGE CSharp + FLATC_EXE ${flatc_location} +) + +# find all the files + +if (rlclientlib_DOTNET_USE_MSPROJECT) + include_external_msproject(rl.net.flatbuffers ${CMAKE_CURRENT_SOURCE_DIR}/rl.net.flatbuffers.csproj) +else() + add_custom_target(rl.net.flatbuffers + COMMAND ${DOTNET_COMMAND} build ${CMAKE_CURRENT_SOURCE_DIR} -o $ -v n --nologo --no-dependencies /clp:NoSummary --configuration "$<$:Debug>$<$:Release>$<$:Release>" + COMMENT Building rl.net.flatbuffers + SOURCES ${RL_NET_FLATBUFFERS_SOURCES}) +endif() + +add_dependencies(rl.net.flatbuffers rl.net.flatbuffers_generated) \ No newline at end of file diff --git a/bindings/cs/rl.net.flatbuffer/VWExampleBuilder.cs b/bindings/cs/rl.net.flatbuffer/VWExampleBuilder.cs new file mode 100644 index 000000000..92cc0976c --- /dev/null +++ b/bindings/cs/rl.net.flatbuffer/VWExampleBuilder.cs @@ -0,0 +1,117 @@ + +using Google.FlatBuffers; +using System.Collections.Generic; +using vwfb = VW.parsers.flatbuffer; + +namespace Rl.Net.Flatbuffers +{ + public interface ICBExampleBuilder + { + bool BuildingShared { get; } + int BuildingAction { get; } + + void PushNamespace(string name = ""); + + void AddFeature(string name, float value = 1.0f); + void AddFeature(string name, string value); + + int PushAction(); + + byte[] FinishExample(); + } + + public class VW_CBExampleBuilder : ICBExampleBuilder + { + struct namespace_prototype + { + public string name; + + public List feature_names; + public List feature_values; + + public Offset Build(FlatBufferBuilder builder) + { + var nameOffset = builder.CreateString(this.name); + + var feature_names_array = vwfb.Namespace.CreateFeatureNamesVector(builder, this.feature_names.ToArray()); + var feature_indicies_array = vwfb.Namespace.CreateFeatureHashesVector(builder, new ulong[0]); + var feature_values_array = vwfb.Namespace.CreateFeatureValuesVector(builder, this.feature_values.ToArray()); + + // The presence of the name will get VW to compute a hash for us + return vwfb.Namespace.CreateNamespace(builder, nameOffset, (byte)this.name[0], 0, feature_names_array, feature_values_array, feature_indicies_array); + } + } + + private List> built_example_offsets; + private List namespaces = new List(); + + private FlatBufferBuilder builder; + + public VW_CBExampleBuilder() + { + this.builder = new FlatBufferBuilder(64); // todo - figure out right min size? + } + + public void PushNamespace(string name = "") + { + this.namespaces.Add(new namespace_prototype() + { + name = name, + feature_names = new List(), + feature_values = new List() + }); + } + + public void AddFeature(string name, string value) + { + this.AddFeature(name + "_" + value); + } + + public void AddFeature(string name, float value = 1.0f) + { + if (this.namespaces.Count == 0) + { + // TODO: is this the right hash? + this.PushNamespace(""); + } + + StringOffset nameOffset = this.builder.CreateString(name); + this.namespaces[this.namespaces.Count - 1].feature_names.Add(nameOffset); + this.namespaces[this.namespaces.Count - 1].feature_values.Add(value); + } + + public int PushAction() + { + this.CollectExample(); + + return this.built_example_offsets.Count - 1; + } + + private void CollectExample() + { + Offset[] namespaceOffsets = new Offset[this.namespaces.Count]; + for (int i = 0; i < this.namespaces.Count; i++) + { + namespaceOffsets[i] = this.namespaces[i].Build(this.builder); + } + + var namespacesVector = vwfb.Example.CreateNamespacesVector(this.builder, namespaceOffsets); + var result = vwfb.Example.CreateExample(this.builder, namespacesVector); + } + + public byte[] FinishExample() + { + this.CollectExample(); + + var multi_ex = vwfb.MultiExample.CreateMultiExample(this.builder, vwfb.MultiExample.CreateExamplesVector(this.builder, this.built_example_offsets.ToArray())); + var root = vwfb.ExampleRoot.CreateExampleRoot(this.builder, vwfb.ExampleType.MultiExample, multi_ex.Value); + + this.builder.FinishSizePrefixed(root.Value); + + return this.builder.SizedByteArray(); + } + + public bool BuildingShared { get => this.built_example_offsets.Count == 0; } + public int BuildingAction { get => this.built_example_offsets.Count - 1; } + } +} \ No newline at end of file diff --git a/bindings/cs/rl.net.flatbuffer/rl.net.flatbuffer.csproj b/bindings/cs/rl.net.flatbuffer/rl.net.flatbuffer.csproj new file mode 100644 index 000000000..fc6fb6025 --- /dev/null +++ b/bindings/cs/rl.net.flatbuffer/rl.net.flatbuffer.csproj @@ -0,0 +1,58 @@ + + + + netstandard2.0 + $(SolutionDir). + $(BinaryOutputBase.Trim())\$(Platform)\$(Configuration) + AnyCPU;x64 + x64 + false + rl.net.flatbuffer + Rl.Net.Flatbuffers + true + True + $(MSBuildProjectDir)..\..\..\ext_libs\vowpal_wabbit\cs\vw_key.snk + + + + + + true + + + + true + + + + true + + + + true + true + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/bindings/cs/rl.net.native/rl.net.live_model.cc b/bindings/cs/rl.net.native/rl.net.live_model.cc index b91432670..9683e8b01 100644 --- a/bindings/cs/rl.net.native/rl.net.live_model.cc +++ b/bindings/cs/rl.net.native/rl.net.live_model.cc @@ -74,35 +74,40 @@ API int LiveModelInit(livemodel_context_t* context, reinforcement_learning::api_ return context->livemodel->init(status); } -API int LiveModelChooseRank(livemodel_context_t* context, const char* event_id, const char* context_json, - int context_json_size, reinforcement_learning::ranking_response* resp, reinforcement_learning::api_status* status) -{ - if (event_id == nullptr) - { - return context->livemodel->choose_rank({context_json, static_cast(context_json_size)}, *resp, status); - } - - return context->livemodel->choose_rank( - event_id, {context_json, static_cast(context_json_size)}, *resp, status); -} +//API int LiveModelChooseRank(livemodel_context_t* context, const char* event_id, const char* context_json, +// int context_json_size, reinforcement_learning::ranking_response* resp, reinforcement_learning::api_status* status) +//{ +// if (event_id == nullptr) +// { +// return context->livemodel->choose_rank({context_json, static_cast(context_json_size)}, *resp, status); +// } +// +// return context->livemodel->choose_rank( +// event_id, {context_json, static_cast(context_json_size)}, *resp, status); +//} API int LiveModelChooseRankWithFlags(livemodel_context_t* context, const char* event_id, const char* context_json, int context_json_size, unsigned int flags, reinforcement_learning::ranking_response* resp, reinforcement_learning::api_status* status) { + if (event_id == nullptr) + { + return context->livemodel->choose_rank({context_json, static_cast(context_json_size)}, flags, *resp, status); + } + return context->livemodel->choose_rank( event_id, {context_json, static_cast(context_json_size)}, flags, *resp, status); } -API int LiveModelRequestContinuousAction(livemodel_context_t* context, const char* event_id, const char* context_json, - int context_json_size, reinforcement_learning::continuous_action_response* resp, - reinforcement_learning::api_status* status) -{ - RL_IGNORE_DEPRECATED_USAGE_START - return context->livemodel->request_continuous_action( - event_id, {context_json, static_cast(context_json_size)}, *resp, status); - RL_IGNORE_DEPRECATED_USAGE_END -} +//API int LiveModelRequestContinuousAction(livemodel_context_t* context, const char* event_id, const char* context_json, +// int context_json_size, reinforcement_learning::continuous_action_response* resp, +// reinforcement_learning::api_status* status) +//{ +// RL_IGNORE_DEPRECATED_USAGE_START +// return context->livemodel->request_continuous_action( +// event_id, {context_json, static_cast(context_json_size)}, *resp, status); +// RL_IGNORE_DEPRECATED_USAGE_END +//} API int LiveModelRequestContinuousActionWithFlags(livemodel_context_t* context, const char* event_id, const char* context_json, int context_json_size, unsigned int flags, @@ -114,13 +119,13 @@ API int LiveModelRequestContinuousActionWithFlags(livemodel_context_t* context, RL_IGNORE_DEPRECATED_USAGE_END } -API int LiveModelRequestDecision(livemodel_context_t* context, const char* context_json, int context_json_size, - reinforcement_learning::decision_response* resp, reinforcement_learning::api_status* status) -{ - RL_IGNORE_DEPRECATED_USAGE_START - return context->livemodel->request_decision({context_json, static_cast(context_json_size)}, *resp, status); - RL_IGNORE_DEPRECATED_USAGE_END -} +//API int LiveModelRequestDecision(livemodel_context_t* context, const char* context_json, int context_json_size, +// reinforcement_learning::decision_response* resp, reinforcement_learning::api_status* status) +//{ +// RL_IGNORE_DEPRECATED_USAGE_START +// return context->livemodel->request_decision({context_json, static_cast(context_json_size)}, *resp, status); +// RL_IGNORE_DEPRECATED_USAGE_END +//} API int LiveModelRequestDecisionWithFlags(livemodel_context_t* context, const char* context_json, int context_json_size, unsigned int flags, reinforcement_learning::decision_response* resp, reinforcement_learning::api_status* status) @@ -131,19 +136,19 @@ API int LiveModelRequestDecisionWithFlags(livemodel_context_t* context, const ch RL_IGNORE_DEPRECATED_USAGE_END } -API int LiveModelRequestMultiSlotDecision(livemodel_context_t* context, const char* event_id, const char* context_json, - int context_json_size, reinforcement_learning::multi_slot_response* resp, - reinforcement_learning::api_status* status) -{ - RL_IGNORE_DEPRECATED_USAGE_START - if (event_id == nullptr) - return context->livemodel->request_multi_slot_decision( - {context_json, static_cast(context_json_size)}, *resp, status); - else - return context->livemodel->request_multi_slot_decision( - event_id, {context_json, static_cast(context_json_size)}, *resp, status); - RL_IGNORE_DEPRECATED_USAGE_END -} +//API int LiveModelRequestMultiSlotDecision(livemodel_context_t* context, const char* event_id, const char* context_json, +// int context_json_size, reinforcement_learning::multi_slot_response* resp, +// reinforcement_learning::api_status* status) +//{ +// RL_IGNORE_DEPRECATED_USAGE_START +// if (event_id == nullptr) +// return context->livemodel->request_multi_slot_decision( +// {context_json, static_cast(context_json_size)}, *resp, status); +// else +// return context->livemodel->request_multi_slot_decision( +// event_id, {context_json, static_cast(context_json_size)}, *resp, status); +// RL_IGNORE_DEPRECATED_USAGE_END +//} API int LiveModelRequestMultiSlotDecisionWithFlags(livemodel_context_t* context, const char* event_id, const char* context_json, int context_json_size, unsigned int flags, @@ -171,19 +176,19 @@ API int LiveModelRequestMultiSlotDecisionWithBaselineAndFlags(livemodel_context_ RL_IGNORE_DEPRECATED_USAGE_END } -API int LiveModelRequestMultiSlotDecisionDetailed(livemodel_context_t* context, const char* event_id, - const char* context_json, int context_json_size, reinforcement_learning::multi_slot_response_detailed* resp, - reinforcement_learning::api_status* status) -{ - RL_IGNORE_DEPRECATED_USAGE_START - if (event_id == nullptr) - return context->livemodel->request_multi_slot_decision( - {context_json, static_cast(context_json_size)}, *resp, status); - else - return context->livemodel->request_multi_slot_decision( - event_id, {context_json, static_cast(context_json_size)}, *resp, status); - RL_IGNORE_DEPRECATED_USAGE_END -} +//API int LiveModelRequestMultiSlotDecisionDetailed(livemodel_context_t* context, const char* event_id, +// const char* context_json, int context_json_size, reinforcement_learning::multi_slot_response_detailed* resp, +// reinforcement_learning::api_status* status) +//{ +// RL_IGNORE_DEPRECATED_USAGE_START +// if (event_id == nullptr) +// return context->livemodel->request_multi_slot_decision( +// {context_json, static_cast(context_json_size)}, *resp, status); +// else +// return context->livemodel->request_multi_slot_decision( +// event_id, {context_json, static_cast(context_json_size)}, *resp, status); +// RL_IGNORE_DEPRECATED_USAGE_END +//} API int LiveModelRequestMultiSlotDecisionDetailedWithFlags(livemodel_context_t* context, const char* event_id, const char* context_json, int context_json_size, unsigned int flags, @@ -222,14 +227,14 @@ API int LiveModelRequestEpisodicDecisionWithFlags(livemodel_context_t* context, RL_IGNORE_DEPRECATED_USAGE_END } -API int LiveModelRequestEpisodicDecision(livemodel_context_t* context, const char* event_id, const char* previous_id, - const char* context_json, reinforcement_learning::ranking_response& resp, - reinforcement_learning::episode_state& episode, reinforcement_learning::api_status* status) -{ - RL_IGNORE_DEPRECATED_USAGE_START - return context->livemodel->request_episodic_decision(event_id, previous_id, context_json, resp, episode, status); - RL_IGNORE_DEPRECATED_USAGE_END -} +//API int LiveModelRequestEpisodicDecision(livemodel_context_t* context, const char* event_id, const char* previous_id, +// const char* context_json, reinforcement_learning::ranking_response& resp, +// reinforcement_learning::episode_state& episode, reinforcement_learning::api_status* status) +//{ +// RL_IGNORE_DEPRECATED_USAGE_START +// return context->livemodel->request_episodic_decision(event_id, previous_id, context_json, resp, episode, status); +// RL_IGNORE_DEPRECATED_USAGE_END +//} API int LiveModelReportActionTaken( livemodel_context_t* context, const char* event_id, reinforcement_learning::api_status* status) diff --git a/bindings/cs/rl.net.native/rl.net.live_model.h b/bindings/cs/rl.net.native/rl.net.live_model.h index 0fbadff89..0fdecb3ad 100644 --- a/bindings/cs/rl.net.native/rl.net.live_model.h +++ b/bindings/cs/rl.net.native/rl.net.live_model.h @@ -23,29 +23,29 @@ extern "C" API int LiveModelInit(livemodel_context_t* livemodel, reinforcement_learning::api_status* status = nullptr); - API int LiveModelChooseRank(livemodel_context_t* livemodel, const char* event_id, const char* context_json, + /*API int LiveModelChooseRank(livemodel_context_t* livemodel, const char* event_id, const char* context_json, int context_json_size, reinforcement_learning::ranking_response* resp, - reinforcement_learning::api_status* status = nullptr); + reinforcement_learning::api_status* status = nullptr);*/ API int LiveModelChooseRankWithFlags(livemodel_context_t* livemodel, const char* event_id, const char* context_json, int context_json_size, unsigned int flags, reinforcement_learning::ranking_response* resp, reinforcement_learning::api_status* status = nullptr); - API int LiveModelRequestContinuousAction(livemodel_context_t* livemodel, const char* event_id, + /*API int LiveModelRequestContinuousAction(livemodel_context_t* livemodel, const char* event_id, const char* context_json, int context_json_size, reinforcement_learning::continuous_action_response* resp, - reinforcement_learning::api_status* status = nullptr); + reinforcement_learning::api_status* status = nullptr);*/ API int LiveModelRequestContinuousActionWithFlags(livemodel_context_t* livemodel, const char* event_id, const char* context_json, int context_json_size, unsigned int flags, reinforcement_learning::continuous_action_response* resp, reinforcement_learning::api_status* status = nullptr); - API int LiveModelRequestDecision(livemodel_context_t* livemodel, const char* context_json, int context_json_size, - reinforcement_learning::decision_response* resp, reinforcement_learning::api_status* status = nullptr); + /*API int LiveModelRequestDecision(livemodel_context_t* livemodel, const char* context_json, int context_json_size, + reinforcement_learning::decision_response* resp, reinforcement_learning::api_status* status = nullptr);*/ API int LiveModelRequestDecisionWithFlags(livemodel_context_t* livemodel, const char* context_json, int context_json_size, unsigned int flags, reinforcement_learning::decision_response* resp, reinforcement_learning::api_status* status = nullptr); - API int LiveModelRequestMultiSlotDecision(livemodel_context_t* context, const char* event_id, + /*API int LiveModelRequestMultiSlotDecision(livemodel_context_t* context, const char* event_id, const char* context_json, int context_json_size, reinforcement_learning::multi_slot_response* resp, - reinforcement_learning::api_status* status = nullptr); + reinforcement_learning::api_status* status = nullptr);*/ API int LiveModelRequestMultiSlotDecisionWithFlags(livemodel_context_t* context, const char* event_id, const char* context_json, int context_json_size, unsigned int flags, reinforcement_learning::multi_slot_response* resp, reinforcement_learning::api_status* status = nullptr); @@ -54,9 +54,9 @@ extern "C" reinforcement_learning::multi_slot_response* resp, const int* baseline_actions, const size_t baseline_actions_size, reinforcement_learning::api_status* status = nullptr); - API int LiveModelRequestMultiSlotDecisionDetailed(livemodel_context_t* context, const char* event_id, + /*API int LiveModelRequestMultiSlotDecisionDetailed(livemodel_context_t* context, const char* event_id, const char* context_json, int context_json_size, reinforcement_learning::multi_slot_response_detailed* resp, - reinforcement_learning::api_status* status = nullptr); + reinforcement_learning::api_status* status = nullptr);*/ API int LiveModelRequestMultiSlotDecisionDetailedWithFlags(livemodel_context_t* context, const char* event_id, const char* context_json, int context_json_size, unsigned int flags, reinforcement_learning::multi_slot_response_detailed* resp, reinforcement_learning::api_status* status = nullptr); @@ -68,9 +68,9 @@ extern "C" const char* previous_id, const char* context_json, unsigned int flags, reinforcement_learning::ranking_response& resp, reinforcement_learning::episode_state& episode, reinforcement_learning::api_status* status); - API int LiveModelRequestEpisodicDecision(livemodel_context_t* context, const char* event_id, const char* previous_id, + /*API int LiveModelRequestEpisodicDecision(livemodel_context_t* context, const char* event_id, const char* previous_id, const char* context_json, reinforcement_learning::ranking_response& resp, - reinforcement_learning::episode_state& episode, reinforcement_learning::api_status* status); + reinforcement_learning::episode_state& episode, reinforcement_learning::api_status* status);*/ API int LiveModelReportActionTaken( livemodel_context_t* livemodel, const char* event_id, reinforcement_learning::api_status* status = nullptr); diff --git a/bindings/cs/rl.net/CBLoop.cs b/bindings/cs/rl.net/CBLoop.cs index 5b8846697..e43f6d6de 100644 --- a/bindings/cs/rl.net/CBLoop.cs +++ b/bindings/cs/rl.net/CBLoop.cs @@ -19,34 +19,34 @@ internal static partial class NativeMethods [DllImport("rlnetnative")] public static extern int CBLoopInit(IntPtr cbLoop, IntPtr apiStatus); - [DllImport("rlnetnative", EntryPoint = "CBLoopChooseRank")] - private static extern int CBLoopChooseRankNative(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus); + // [DllImport("rlnetnative", EntryPoint = "CBLoopChooseRank")] + // private static extern int CBLoopChooseRankNative(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus); - internal static Func CBLoopChooseRankOverride { get; set; } + // internal static Func CBLoopChooseRankOverride { get; set; } - public static int CBLoopChooseRank(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus) - { - if (CBLoopChooseRankOverride != null) - { - return CBLoopChooseRankOverride(cbLoop, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); - } + // public static int CBLoopChooseRank(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus) + // { + // if (CBLoopChooseRankOverride != null) + // { + // return CBLoopChooseRankOverride(cbLoop, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); + // } - return CBLoopChooseRankNative(cbLoop, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); - } + // return CBLoopChooseRankNative(cbLoop, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); + // } [DllImport("rlnetnative", EntryPoint = "CBLoopChooseRankWithFlags")] - private static extern int CBLoopChooseRankWithFlagsNative(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus); + private static extern int CBLoopChooseRankWithFlagsNative(IntPtr cbLoop, IntPtr eventId, IntPtr contextBytes, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus); internal static Func CBLoopChooseRankWithFlagsOverride { get; set; } - public static int CBLoopChooseRankWithFlags(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus) + public static int CBLoopChooseRankWithFlags(IntPtr cbLoop, IntPtr eventId, IntPtr contextBytes, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus) { if (CBLoopChooseRankWithFlagsOverride != null) { - return CBLoopChooseRankWithFlagsOverride(cbLoop, eventId, contextJson, contextJsonSize, flags, rankingResponse, apiStatus); + return CBLoopChooseRankWithFlagsOverride(cbLoop, eventId, contextBytes, contextJsonSize, flags, rankingResponse, apiStatus); } - return CBLoopChooseRankWithFlagsNative(cbLoop, eventId, contextJson, contextJsonSize, flags, rankingResponse, apiStatus); + return CBLoopChooseRankWithFlagsNative(cbLoop, eventId, contextBytes, contextJsonSize, flags, rankingResponse, apiStatus); } [DllImport("rlnetnative", EntryPoint = "CBLoopReportActionTaken")] @@ -165,51 +165,62 @@ private static void CheckJsonString(string json) } } - unsafe private static int CBLoopChooseRank(IntPtr cbLoop, string eventId, string contextJson, IntPtr rankingResponse, IntPtr apiStatus) + private const uint DEFAULT_FLAGS = (uint)ActionFlags.Default; + + // unsafe private static int CBLoopChooseRank(IntPtr cbLoop, string eventId, string contextJson, IntPtr rankingResponse, IntPtr apiStatus) + // { + // CheckJsonString(contextJson); + + // fixed (byte* contextJsonUtf8Bytes = NativeMethods.StringEncoding.GetBytes(contextJson)) + // { + // int contextJsonSize = NativeMethods.StringEncoding.GetByteCount(contextJson); + // IntPtr contextJsonUtf8Ptr = new IntPtr(contextJsonUtf8Bytes); + + // // It is important to pass null on faithfully here, because we rely on this to switch between auto-generate + // // eventId and use supplied eventId at the rl.net.native layer. + // if (eventId == null) + // { + // return NativeMethods.CBLoopChooseRank(cbLoop, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, rankingResponse, apiStatus); + // } + + // fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) + // { + // return NativeMethods.CBLoopChooseRank(cbLoop, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, rankingResponse, apiStatus); + // } + // } + // } + + // TODO: Should we reduce the rl.net.native interface to only have one of these? + unsafe private static int CBLoopChooseRankWithFlags(IntPtr cbLoop, string eventId, byte[] contextBytes, uint flags, IntPtr rankingResponse, IntPtr apiStatus) { - CheckJsonString(contextJson); + //CheckJsonString(contextJson); - fixed (byte* contextJsonUtf8Bytes = NativeMethods.StringEncoding.GetBytes(contextJson)) + fixed (byte* contextBytesFixed = contextBytes) { - int contextJsonSize = NativeMethods.StringEncoding.GetByteCount(contextJson); - IntPtr contextJsonUtf8Ptr = new IntPtr(contextJsonUtf8Bytes); + int contextBytesSize = contextBytes.Length; + IntPtr contextBytesPtr = new IntPtr(contextBytesFixed); // It is important to pass null on faithfully here, because we rely on this to switch between auto-generate // eventId and use supplied eventId at the rl.net.native layer. if (eventId == null) { - return NativeMethods.CBLoopChooseRank(cbLoop, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, rankingResponse, apiStatus); + return NativeMethods.CBLoopChooseRankWithFlags(cbLoop, IntPtr.Zero, contextBytesPtr, contextBytesSize, flags, rankingResponse, apiStatus); } fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) { - return NativeMethods.CBLoopChooseRank(cbLoop, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, rankingResponse, apiStatus); + return NativeMethods.CBLoopChooseRankWithFlags(cbLoop, new IntPtr(eventIdUtf8Bytes), contextBytesPtr, contextBytesSize, flags, rankingResponse, apiStatus); } } } - // TODO: Should we reduce the rl.net.native interface to only have one of these? unsafe private static int CBLoopChooseRankWithFlags(IntPtr cbLoop, string eventId, string contextJson, uint flags, IntPtr rankingResponse, IntPtr apiStatus) { CheckJsonString(contextJson); - fixed (byte* contextJsonUtf8Bytes = NativeMethods.StringEncoding.GetBytes(contextJson)) - { - int contextJsonSize = NativeMethods.StringEncoding.GetByteCount(contextJson); - IntPtr contextJsonUtf8Ptr = new IntPtr(contextJsonUtf8Bytes); + byte[] contextJsonEncodedBytes = NativeMethods.StringEncoding.GetBytes(contextJson); - // It is important to pass null on faithfully here, because we rely on this to switch between auto-generate - // eventId and use supplied eventId at the rl.net.native layer. - if (eventId == null) - { - return NativeMethods.CBLoopChooseRankWithFlags(cbLoop, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, flags, rankingResponse, apiStatus); - } - - fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) - { - return NativeMethods.CBLoopChooseRankWithFlags(cbLoop, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, flags, rankingResponse, apiStatus); - } - } + return CBLoopChooseRankWithFlags(cbLoop, eventId, contextJsonEncodedBytes, flags, rankingResponse, apiStatus); } unsafe private static int CBLoopReportActionTaken(IntPtr cbLoop, string eventId, IntPtr apiStatus) @@ -331,9 +342,23 @@ public bool TryChooseRank(string eventId, string contextJson, out RankingRespons return this.TryChooseRank(eventId, contextJson, response, apiStatus); } + public bool TryChooseRank(string eventId, byte[] contextBytes, out RankingResponse response, ApiStatus apiStatus = null) + { + response = new RankingResponse(); + return this.TryChooseRank(eventId, contextBytes, response, apiStatus); + } + public bool TryChooseRank(string eventId, string contextJson, RankingResponse response, ApiStatus apiStatus = null) { - int result = CBLoopChooseRank(this.DangerousGetHandle(), eventId, contextJson, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + int result = CBLoopChooseRankWithFlags(this.DangerousGetHandle(), eventId, contextJson, CBLoop.DEFAULT_FLAGS, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + public bool TryChooseRank(string eventId, byte[] contextBytes, RankingResponse response, ApiStatus apiStatus = null) + { + int result = CBLoopChooseRankWithFlags(this.DangerousGetHandle(), eventId, contextBytes, CBLoop.DEFAULT_FLAGS, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); GC.KeepAlive(this); return result == NativeMethods.SuccessStatus; @@ -352,12 +377,31 @@ public RankingResponse ChooseRank(string eventId, string contextJson) return result; } + public RankingResponse ChooseRank(string eventId, byte[] contextBytes) + { + RankingResponse result = new RankingResponse(); + + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryChooseRank(eventId, contextBytes, result, apiStatus)) + { + throw new RLException(apiStatus); + } + + return result; + } + public bool TryChooseRank(string eventId, string contextJson, ActionFlags flags, out RankingResponse response, ApiStatus apiStatus = null) { response = new RankingResponse(); return this.TryChooseRank(eventId, contextJson, flags, response, apiStatus); } + public bool TryChooseRank(string eventId, byte[] contextBytes, ActionFlags flags, out RankingResponse response, ApiStatus apiStatus = null) + { + response = new RankingResponse(); + return this.TryChooseRank(eventId, contextBytes, flags, response, apiStatus); + } + public bool TryChooseRank(string eventId, string contextJson, ActionFlags flags, RankingResponse response, ApiStatus apiStatus = null) { int result = CBLoopChooseRankWithFlags(this.DangerousGetHandle(), eventId, contextJson, (uint)flags, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); @@ -366,6 +410,14 @@ public bool TryChooseRank(string eventId, string contextJson, ActionFlags flags, return result == NativeMethods.SuccessStatus; } + public bool TryChooseRank(string eventId, byte[] contextBytes, ActionFlags flags, RankingResponse response, ApiStatus apiStatus = null) + { + int result = CBLoopChooseRankWithFlags(this.DangerousGetHandle(), eventId, contextBytes, (uint)flags, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + public RankingResponse ChooseRank(string eventId, string contextJson, ActionFlags flags) { RankingResponse result = new RankingResponse(); @@ -379,6 +431,19 @@ public RankingResponse ChooseRank(string eventId, string contextJson, ActionFlag return result; } + public RankingResponse ChooseRank(string eventId, byte[] contextBytes, ActionFlags flags) + { + RankingResponse result = new RankingResponse(); + + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryChooseRank(eventId, contextBytes, flags, result, apiStatus)) + { + throw new RLException(apiStatus); + } + + return result; + } + [Obsolete("Use TryQueueActionTakenEvent instead.")] public bool TryReportActionTaken(string eventId, ApiStatus apiStatus = null) => this.TryQueueActionTakenEvent(eventId, apiStatus); diff --git a/bindings/cs/rl.net/LiveModel.cs b/bindings/cs/rl.net/LiveModel.cs index 2e829d011..5685a98a0 100644 --- a/bindings/cs/rl.net/LiveModel.cs +++ b/bindings/cs/rl.net/LiveModel.cs @@ -20,50 +20,50 @@ internal static partial class NativeMethods [DllImport("rlnetnative")] public static extern int LiveModelInit(IntPtr liveModel, IntPtr apiStatus); - [DllImport("rlnetnative", EntryPoint = "LiveModelChooseRank")] - private static extern int LiveModelChooseRankNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus); + //[DllImport("rlnetnative", EntryPoint = "LiveModelChooseRank")] + //private static extern int LiveModelChooseRankNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus); - internal static Func LiveModelChooseRankOverride { get; set; } + // internal static Func LiveModelChooseRankOverride { get; set; } - public static int LiveModelChooseRank(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus) - { - if (LiveModelChooseRankOverride != null) - { - return LiveModelChooseRankOverride(liveModel, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); - } + // public static int LiveModelChooseRank(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus) + // { + // if (LiveModelChooseRankOverride != null) + // { + // return LiveModelChooseRankOverride(liveModel, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); + // } - return LiveModelChooseRankNative(liveModel, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); - } + // return LiveModelChooseRankNative(liveModel, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); + // } [DllImport("rlnetnative", EntryPoint = "LiveModelChooseRankWithFlags")] - private static extern int LiveModelChooseRankWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus); + private static extern int LiveModelChooseRankWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextBytes, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus); internal static Func LiveModelChooseRankWithFlagsOverride { get; set; } - public static int LiveModelChooseRankWithFlags(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus) + public static int LiveModelChooseRankWithFlags(IntPtr liveModel, IntPtr eventId, IntPtr contextBytes, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus) { if (LiveModelChooseRankWithFlagsOverride != null) { - return LiveModelChooseRankWithFlagsOverride(liveModel, eventId, contextJson, contextJsonSize, flags, rankingResponse, apiStatus); + return LiveModelChooseRankWithFlagsOverride(liveModel, eventId, contextBytes, contextJsonSize, flags, rankingResponse, apiStatus); } - return LiveModelChooseRankWithFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, rankingResponse, apiStatus); + return LiveModelChooseRankWithFlagsNative(liveModel, eventId, contextBytes, contextJsonSize, flags, rankingResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestContinuousAction")] - private static extern int LiveModelRequestContinuousActionNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr continuousActionResponse, IntPtr apiStatus); + // [DllImport("rlnetnative", EntryPoint = "LiveModelRequestContinuousAction")] + // private static extern int LiveModelRequestContinuousActionNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr continuousActionResponse, IntPtr apiStatus); - internal static Func LiveModelRequestContinuousActionOverride { get; set; } + // internal static Func LiveModelRequestContinuousActionOverride { get; set; } - public static int LiveModelRequestContinuousAction(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr continuousActionResponse, IntPtr apiStatus) - { - if (LiveModelRequestContinuousActionOverride != null) - { - return LiveModelRequestContinuousActionOverride(liveModel, eventId, contextJson, contextJsonSize, continuousActionResponse, apiStatus); - } + // public static int LiveModelRequestContinuousAction(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr continuousActionResponse, IntPtr apiStatus) + // { + // if (LiveModelRequestContinuousActionOverride != null) + // { + // return LiveModelRequestContinuousActionOverride(liveModel, eventId, contextJson, contextJsonSize, continuousActionResponse, apiStatus); + // } - return LiveModelRequestContinuousActionNative(liveModel, eventId, contextJson, contextJsonSize, continuousActionResponse, apiStatus); - } + // return LiveModelRequestContinuousActionNative(liveModel, eventId, contextJson, contextJsonSize, continuousActionResponse, apiStatus); + // } [DllImport("rlnetnative", EntryPoint = "LiveModelRequestContinuousActionWithFlags")] private static extern int LiveModelRequestContinuousActionWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr continuousActionResponse, IntPtr apiStatus); @@ -80,20 +80,20 @@ public static int LiveModelRequestContinuousActionWithFlags(IntPtr liveModel, In return LiveModelRequestContinuousActionWithFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, continuousActionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestDecision")] - private static extern int LiveModelRequestDecisionNative(IntPtr liveModel, IntPtr contextJson, int contextJsonSize, IntPtr decisionResponse, IntPtr apiStatus); + // [DllImport("rlnetnative", EntryPoint = "LiveModelRequestDecision")] + // private static extern int LiveModelRequestDecisionNative(IntPtr liveModel, IntPtr contextJson, int contextJsonSize, IntPtr decisionResponse, IntPtr apiStatus); - internal static Func LiveModelRequestDecisionOverride { get; set; } + // internal static Func LiveModelRequestDecisionOverride { get; set; } - public static int LiveModelRequestDecision(IntPtr liveModel, IntPtr contextJson, int contextJsonSize, IntPtr decisionResponse, IntPtr apiStatus) - { - if (LiveModelRequestDecisionOverride != null) - { - return LiveModelRequestDecisionOverride(liveModel, contextJson, contextJsonSize, decisionResponse, apiStatus); - } + // public static int LiveModelRequestDecision(IntPtr liveModel, IntPtr contextJson, int contextJsonSize, IntPtr decisionResponse, IntPtr apiStatus) + // { + // if (LiveModelRequestDecisionOverride != null) + // { + // return LiveModelRequestDecisionOverride(liveModel, contextJson, contextJsonSize, decisionResponse, apiStatus); + // } - return LiveModelRequestDecisionNative(liveModel, contextJson, contextJsonSize, decisionResponse, apiStatus); - } + // return LiveModelRequestDecisionNative(liveModel, contextJson, contextJsonSize, decisionResponse, apiStatus); + // } [DllImport("rlnetnative", EntryPoint = "LiveModelRequestDecisionWithFlags")] private static extern int LiveModelRequestDecisionWithFlagsNative(IntPtr liveModel, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr decisionResponse, IntPtr apiStatus); @@ -110,20 +110,20 @@ public static int LiveModelRequestDecisionWithFlags(IntPtr liveModel, IntPtr con return LiveModelRequestDecisionWithFlagsNative(liveModel, contextJson, contextJsonSize, flags, decisionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecision")] - private static extern int LiveModelRequestMultiSlotDecisionNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponse, IntPtr apiStatus); + // [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecision")] + // private static extern int LiveModelRequestMultiSlotDecisionNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponse, IntPtr apiStatus); - internal static Func LiveModelRequestMultiSlotDecisionOverride { get; set; } + // internal static Func LiveModelRequestMultiSlotDecisionOverride { get; set; } - public static int LiveModelRequestMultiSlotDecision(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponse, IntPtr apiStatus) - { - if (LiveModelRequestMultiSlotDecisionOverride != null) - { - return LiveModelRequestMultiSlotDecisionOverride(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponse, apiStatus); - } + // public static int LiveModelRequestMultiSlotDecision(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponse, IntPtr apiStatus) + // { + // if (LiveModelRequestMultiSlotDecisionOverride != null) + // { + // return LiveModelRequestMultiSlotDecisionOverride(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponse, apiStatus); + // } - return LiveModelRequestMultiSlotDecisionNative(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponse, apiStatus); - } + // return LiveModelRequestMultiSlotDecisionNative(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponse, apiStatus); + // } [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionWithFlags")] private static extern int LiveModelRequestMultiSlotDecisionWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponse, IntPtr apiStatus); @@ -155,20 +155,20 @@ public static int LiveModelRequestMultiSlotDecisionWithBaselineAndFlags(IntPtr l return LiveModelRequestMultiSlotDecisionWithBaselineAndFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, multiSlotResponse, baselineActions, baselineActionsSize, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionDetailed")] - private static extern int LiveModelRequestMultiSlotDecisionDetailedNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); + // [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionDetailed")] + // private static extern int LiveModelRequestMultiSlotDecisionDetailedNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); - internal static Func LiveModelRequestMultiSlotDecisionDetailedOverride { get; set; } + // internal static Func LiveModelRequestMultiSlotDecisionDetailedOverride { get; set; } - public static int LiveModelRequestMultiSlotDecisionDetailed(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponseDetailed, IntPtr apiStatus) - { - if (LiveModelRequestMultiSlotDecisionDetailedOverride != null) - { - return LiveModelRequestMultiSlotDecisionDetailedOverride(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponseDetailed, apiStatus); - } + // public static int LiveModelRequestMultiSlotDecisionDetailed(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponseDetailed, IntPtr apiStatus) + // { + // if (LiveModelRequestMultiSlotDecisionDetailedOverride != null) + // { + // return LiveModelRequestMultiSlotDecisionDetailedOverride(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponseDetailed, apiStatus); + // } - return LiveModelRequestMultiSlotDecisionDetailedNative(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponseDetailed, apiStatus); - } + // return LiveModelRequestMultiSlotDecisionDetailedNative(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponseDetailed, apiStatus); + // } [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionDetailedWithFlags")] private static extern int LiveModelRequestMultiSlotDecisionDetailedWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); @@ -391,6 +391,8 @@ private static void CheckJsonString(string json) } } + private const uint DEFAULT_FLAGS = (uint)ActionFlags.Default; + unsafe private static int LiveModelChooseRank(IntPtr liveModel, string eventId, string contextJson, IntPtr rankingResponse, IntPtr apiStatus) { CheckJsonString(contextJson); @@ -404,12 +406,38 @@ unsafe private static int LiveModelChooseRank(IntPtr liveModel, string eventId, // eventId and use supplied eventId at the rl.net.native layer. if (eventId == null) { - return NativeMethods.LiveModelChooseRank(liveModel, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, rankingResponse, apiStatus); + return NativeMethods.LiveModelChooseRankWithFlags(liveModel, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, LiveModel.DEFAULT_FLAGS, rankingResponse, apiStatus); } fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) { - return NativeMethods.LiveModelChooseRank(liveModel, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, rankingResponse, apiStatus); + return NativeMethods.LiveModelChooseRankWithFlags(liveModel, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, LiveModel.DEFAULT_FLAGS, rankingResponse, apiStatus); + } + } + } + + unsafe private static int LiveModelChooseRankWithFlags(IntPtr liveModel, string eventId, byte[] contextBinary, uint flags, IntPtr rankingReponse, IntPtr apiStatus) + { + if (contextBinary == null) + { + throw new ArgumentNullException("contextBinary"); + } + + fixed (byte* contextBinaryBytes = contextBinary) + { + IntPtr contextBinaryPtr = new IntPtr(contextBinaryBytes); + int contextBinarySize = contextBinary.Length; + + // It is important to pass null on faithfully here, because we rely on this to switch between auto-generate + // eventId and use supplied eventId at the rl.net.native layer. + if (eventId == null) + { + return NativeMethods.LiveModelChooseRankWithFlags(liveModel, IntPtr.Zero, contextBinaryPtr, contextBinarySize, LiveModel.DEFAULT_FLAGS, rankingReponse, apiStatus); + } + + fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) + { + return NativeMethods.LiveModelChooseRankWithFlags(liveModel, new IntPtr(eventIdUtf8Bytes), contextBinaryPtr, contextBinarySize, LiveModel.DEFAULT_FLAGS, rankingReponse, apiStatus); } } } @@ -451,12 +479,12 @@ unsafe private static int LiveModelRequestContinuousAction(IntPtr liveModel, str // eventId and use supplied eventId at the rl.net.native layer. if (eventId == null) { - return NativeMethods.LiveModelRequestContinuousAction(liveModel, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, continuousActionResponse, apiStatus); + return NativeMethods.LiveModelRequestContinuousActionWithFlags(liveModel, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, LiveModel.DEFAULT_FLAGS, continuousActionResponse, apiStatus); } fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) { - return NativeMethods.LiveModelRequestContinuousAction(liveModel, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, continuousActionResponse, apiStatus); + return NativeMethods.LiveModelRequestContinuousActionWithFlags(liveModel, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, LiveModel.DEFAULT_FLAGS, continuousActionResponse, apiStatus); } } } @@ -491,7 +519,7 @@ unsafe private static int LiveModelRequestDecision(IntPtr liveModel, string cont fixed (byte* contextJsonUtf8Bytes = NativeMethods.StringEncoding.GetBytes(contextJson)) { int contextJsonSize = NativeMethods.StringEncoding.GetByteCount(contextJson); - return NativeMethods.LiveModelRequestDecision(liveModel, new IntPtr(contextJsonUtf8Bytes), contextJsonSize, decisionResponse, apiStatus); + return NativeMethods.LiveModelRequestDecisionWithFlags(liveModel, new IntPtr(contextJsonUtf8Bytes), contextJsonSize, LiveModel.DEFAULT_FLAGS, decisionResponse, apiStatus); } } @@ -515,12 +543,12 @@ unsafe private static int LiveModelRequestMultiSlotDecision(IntPtr liveModel, st int contextJsonSize = NativeMethods.StringEncoding.GetByteCount(contextJson); if (eventId == null) { - return NativeMethods.LiveModelRequestMultiSlotDecision(liveModel, IntPtr.Zero, (IntPtr)contextJsonUtf8Bytes, contextJsonSize, multiSlotResponse, apiStatus); + return NativeMethods.LiveModelRequestMultiSlotDecisionWithFlags(liveModel, IntPtr.Zero, (IntPtr)contextJsonUtf8Bytes, contextJsonSize, LiveModel.DEFAULT_FLAGS, multiSlotResponse, apiStatus); } fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) { - return NativeMethods.LiveModelRequestMultiSlotDecision(liveModel, (IntPtr)eventIdUtf8Bytes, (IntPtr)contextJsonUtf8Bytes, contextJsonSize, multiSlotResponse, apiStatus); + return NativeMethods.LiveModelRequestMultiSlotDecisionWithFlags(liveModel, (IntPtr)eventIdUtf8Bytes, (IntPtr)contextJsonUtf8Bytes, contextJsonSize, LiveModel.DEFAULT_FLAGS, multiSlotResponse, apiStatus); } } } @@ -573,12 +601,12 @@ unsafe private static int LiveModelRequestMultiSlotDecisionDetailed(IntPtr liveM int contextJsonSize = NativeMethods.StringEncoding.GetByteCount(contextJson); if (eventId == null) { - return NativeMethods.LiveModelRequestMultiSlotDecisionDetailed(liveModel, IntPtr.Zero, (IntPtr)contextJsonUtf8Bytes, contextJsonSize, multiSlotResponseDetailed, apiStatus); + return NativeMethods.LiveModelRequestMultiSlotDecisionDetailedWithFlags(liveModel, IntPtr.Zero, (IntPtr)contextJsonUtf8Bytes, contextJsonSize, LiveModel.DEFAULT_FLAGS, multiSlotResponseDetailed, apiStatus); } fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) { - return NativeMethods.LiveModelRequestMultiSlotDecisionDetailed(liveModel, (IntPtr)eventIdUtf8Bytes, (IntPtr)contextJsonUtf8Bytes, contextJsonSize, multiSlotResponseDetailed, apiStatus); + return NativeMethods.LiveModelRequestMultiSlotDecisionDetailedWithFlags(liveModel, (IntPtr)eventIdUtf8Bytes, (IntPtr)contextJsonUtf8Bytes, contextJsonSize, LiveModel.DEFAULT_FLAGS, multiSlotResponseDetailed, apiStatus); } } } @@ -837,6 +865,12 @@ public bool TryChooseRank(string eventId, string contextJson, out RankingRespons return this.TryChooseRank(eventId, contextJson, response, apiStatus); } + public bool TryChooseRank(string eventId, byte[] contextBytes, out RankingResponse response, ApiStatus apiStatus = null) + { + response = new RankingResponse(); + return this.TryChooseRank(eventId, contextBytes, response, apiStatus); + } + public bool TryChooseRank(string eventId, string contextJson, RankingResponse response, ApiStatus apiStatus = null) { int result = LiveModelChooseRank(this.DangerousGetHandle(), eventId, contextJson, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); @@ -845,6 +879,14 @@ public bool TryChooseRank(string eventId, string contextJson, RankingResponse re return result == NativeMethods.SuccessStatus; } + public bool TryChooseRank(string eventId, byte[] contextBinary, RankingResponse response, ApiStatus apiStatus = null) + { + int result = LiveModelChooseRankWithFlags(this.DangerousGetHandle(), eventId, contextBinary, LiveModel.DEFAULT_FLAGS, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + public RankingResponse ChooseRank(string eventId, string contextJson) { RankingResponse result = new RankingResponse(); @@ -858,12 +900,31 @@ public RankingResponse ChooseRank(string eventId, string contextJson) return result; } + public RankingResponse ChooseRank(string eventId, byte[] contextBinary) + { + RankingResponse result = new RankingResponse(); + + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryChooseRank(eventId, contextBinary, result, apiStatus)) + { + throw new RLException(apiStatus); + } + + return result; + } + public bool TryChooseRank(string eventId, string contextJson, ActionFlags flags, out RankingResponse response, ApiStatus apiStatus = null) { response = new RankingResponse(); return this.TryChooseRank(eventId, contextJson, flags, response, apiStatus); } + public bool TryChooseRank(string eventId, byte[] contextBinary, ActionFlags flags, out RankingResponse response, ApiStatus apiStatus = null) + { + response = new RankingResponse(); + return this.TryChooseRank(eventId, contextBinary, flags, response, apiStatus); + } + public bool TryChooseRank(string eventId, string contextJson, ActionFlags flags, RankingResponse response, ApiStatus apiStatus = null) { int result = LiveModelChooseRankWithFlags(this.DangerousGetHandle(), eventId, contextJson, (uint)flags, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); @@ -872,6 +933,14 @@ public bool TryChooseRank(string eventId, string contextJson, ActionFlags flags, return result == NativeMethods.SuccessStatus; } + public bool TryChooseRank(string eventId, byte[] contextBinary, ActionFlags flags, RankingResponse response, ApiStatus apiStatus = null) + { + int result = LiveModelChooseRankWithFlags(this.DangerousGetHandle(), eventId, contextBinary, (uint)flags, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + public RankingResponse ChooseRank(string eventId, string contextJson, ActionFlags flags) { RankingResponse result = new RankingResponse(); @@ -885,6 +954,19 @@ public RankingResponse ChooseRank(string eventId, string contextJson, ActionFlag return result; } + public RankingResponse ChooseRank(string eventId, byte[] contextBinary, ActionFlags flags) + { + RankingResponse result = new RankingResponse(); + + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryChooseRank(eventId, contextBinary, flags, result, apiStatus)) + { + throw new RLException(apiStatus); + } + + return result; + } + public bool TryRequestContinuousAction(string eventId, string contextJson, out ContinuousActionResponse response, ApiStatus apiStatus = null) { response = new ContinuousActionResponse(); diff --git a/cmake/Modules/FlatbufferUtils.cmake b/cmake/Modules/FlatbufferUtils.cmake index 88abf0ed1..1e1d78796 100644 --- a/cmake/Modules/FlatbufferUtils.cmake +++ b/cmake/Modules/FlatbufferUtils.cmake @@ -3,11 +3,25 @@ include(CMakeParseArguments) function(add_flatbuffer_schema) cmake_parse_arguments(ADD_FB_SCHEMA_ARGS "" - "TARGET;FLATC_EXE;OUTPUT_DIR;FLATC_EXTRA_SCHEMA_ARGS" + "TARGET;FLATC_EXE;OUTPUT_DIR;FLATC_LANGUAGE;GENERATED_FILES;FLATC_EXTRA_SCHEMA_ARGS" "SCHEMAS" ${ARGN} ) + # check that FLATC_LANGAUGE is either undefined (in which case set it to CXX), or is in any of + # { CXX, CSharp } + if(NOT DEFINED ADD_FB_SCHEMA_ARGS_FLATC_LANGUAGE) + set(ADD_FB_SCHEMA_ARGS_FLATC_LANGUAGE "CXX") + endif() + + message(STATUS "ADD_FB_SCHEMA_ARGS_FLATC_LANGUAGE: ${ADD_FB_SCHEMA_ARGS_FLATC_LANGUAGE}") + + set (RL_FLATC_VALID_LANGUAGES "CXX" "CSharp") + + if(NOT ADD_FB_SCHEMA_ARGS_FLATC_LANGUAGE IN_LIST RL_FLATC_VALID_LANGUAGES) + message(FATAL_ERROR "FLATC_LANGUAGE must be either CXX or CSharp") + endif() + if(NOT DEFINED ADD_FB_SCHEMA_ARGS_TARGET) message(FATAL_ERROR "Missing TARGET argument to build_flatbuffers") endif() @@ -35,13 +49,40 @@ function(add_flatbuffer_schema) foreach(schema IN ITEMS ${ADD_FB_SCHEMA_ARGS_SCHEMAS}) get_filename_component(filename ${schema} NAME_WE) - set(generated_file_name ${ADD_FB_SCHEMA_ARGS_OUTPUT_DIR}/${filename}_generated.h) - add_custom_command( - OUTPUT ${generated_file_name} - COMMAND ${ADD_FB_SCHEMA_ARGS_FLATC_EXE} ${FLATC_SCHEMA_ARGS} -o ${ADD_FB_SCHEMA_ARGS_OUTPUT_DIR} -c ${schema} - DEPENDS ${schema} - ) - list(APPEND ALL_SCHEMAS ${generated_file_name}) + + # if the target language is CXX, we generate a header file + if(${ADD_FB_SCHEMA_ARGS_FLATC_LANGUAGE} STREQUAL "CXX") + set(generated_file_name ${ADD_FB_SCHEMA_ARGS_OUTPUT_DIR}/${filename}_generated.h) + add_custom_command( + OUTPUT ${generated_file_name} + COMMAND ${ADD_FB_SCHEMA_ARGS_FLATC_EXE} ${FLATC_SCHEMA_ARGS} -o ${ADD_FB_SCHEMA_ARGS_OUTPUT_DIR} -c ${schema} + DEPENDS ${schema} + ) + list(APPEND ALL_SCHEMAS ${generated_file_name}) + endif() + + # if the target language is CSharp, we generate a C# file + if(${ADD_FB_SCHEMA_ARGS_FLATC_LANGUAGE} STREQUAL "CSharp") + set(generated_file_name ${ADD_FB_SCHEMA_ARGS_OUTPUT_DIR}/${filename}_generated.cs) + add_custom_command( + OUTPUT ${generated_file_name} + COMMAND ${ADD_FB_SCHEMA_ARGS_FLATC_EXE} ${FLATC_SCHEMA_ARGS} -o ${ADD_FB_SCHEMA_ARGS_OUTPUT_DIR} --csharp ${schema} + DEPENDS ${schema} + ) + list(APPEND ALL_SCHEMAS ${generated_file_name}) + endif() + + if (DEFINED ADD_FB_SCHEMA_ARGS_GENERATED_FILES) + list(APPEND ${ADD_FB_SCHEMA_ARGS_GENERATED_FILES} ${generated_file_name}) + endif () + + # set(generated_file_name ${ADD_FB_SCHEMA_ARGS_OUTPUT_DIR}/${filename}_generated.h) + # add_custom_command( + # OUTPUT ${generated_file_name} + # COMMAND ${ADD_FB_SCHEMA_ARGS_FLATC_EXE} ${FLATC_SCHEMA_ARGS} -o ${ADD_FB_SCHEMA_ARGS_OUTPUT_DIR} -c ${schema} + # DEPENDS ${schema} + # ) + # list(APPEND ALL_SCHEMAS ${generated_file_name}) endforeach() add_custom_target(${ADD_FB_SCHEMA_ARGS_TARGET} DEPENDS ${ALL_SCHEMAS})