From c9f692da624a0126591322fd0714ee69a4400678 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Mon, 9 Oct 2023 15:01:44 -0400 Subject: [PATCH] feat: cb loop csharp bindings (#587) * fix: make benchmarks compatable with windows * double include * fix valgrind ubuntu * feat: cb loop csharp bindings * clang * address comments: basic usage and nullptr throw --------- Co-authored-by: zwd-ms <71728747+zwd-ms@users.noreply.github.com> --- bindings/cs/rl.net.cli.test/CMakeLists.txt | 2 + bindings/cs/rl.net.cli.test/MockSender.cs | 70 +++ .../SenderExtensibilityCBLoopTest.cs | 500 ++++++++++++++++ .../SenderExtensibilityTest.cs | 57 +- bindings/cs/rl.net.cli/BasicUsageCommand.cs | 124 ++++ bindings/cs/rl.net.cli/CMakeLists.txt | 1 + bindings/cs/rl.net.cli/EntryPoints.cs | 64 +-- bindings/cs/rl.net.cli/Helpers.cs | 30 + bindings/cs/rl.net.native/CMakeLists.txt | 3 + bindings/cs/rl.net.native/binding_tracer.cc | 2 +- bindings/cs/rl.net.native/binding_tracer.h | 6 +- bindings/cs/rl.net.native/rl.net.base_loop.h | 23 + bindings/cs/rl.net.native/rl.net.cb_loop.cc | 133 +++++ bindings/cs/rl.net.native/rl.net.cb_loop.h | 47 ++ .../cs/rl.net.native/rl.net.live_model.cc | 23 +- bindings/cs/rl.net.native/rl.net.live_model.h | 19 +- bindings/cs/rl.net.native/rl.net.native.h | 1 + .../cs/rl.net.native/rl.net.native.vcxproj | 3 + bindings/cs/rl.net/BaseLoop.cs | 13 + bindings/cs/rl.net/CBLoop.cs | 541 ++++++++++++++++++ bindings/cs/rl.net/CMakeLists.txt | 2 + bindings/cs/rl.net/LiveModel.cs | 4 - 22 files changed, 1514 insertions(+), 154 deletions(-) create mode 100644 bindings/cs/rl.net.cli.test/MockSender.cs create mode 100644 bindings/cs/rl.net.cli.test/SenderExtensibilityCBLoopTest.cs create mode 100644 bindings/cs/rl.net.cli/BasicUsageCommand.cs create mode 100644 bindings/cs/rl.net.native/rl.net.base_loop.h create mode 100644 bindings/cs/rl.net.native/rl.net.cb_loop.cc create mode 100644 bindings/cs/rl.net.native/rl.net.cb_loop.h create mode 100644 bindings/cs/rl.net/BaseLoop.cs create mode 100644 bindings/cs/rl.net/CBLoop.cs diff --git a/bindings/cs/rl.net.cli.test/CMakeLists.txt b/bindings/cs/rl.net.cli.test/CMakeLists.txt index abeadaf1d..52b711b45 100644 --- a/bindings/cs/rl.net.cli.test/CMakeLists.txt +++ b/bindings/cs/rl.net.cli.test/CMakeLists.txt @@ -1,7 +1,9 @@ set (RL_NET_CLI_TEST_SOURCES CleanupContainer.cs + MockSender.cs ReplayStepProviderTest.cs SenderExtensibilityTest.cs + SenderExtensibilityCBLoopTest.cs TempFileDisposable.cs TestBase.cs UnicodeTest.cs diff --git a/bindings/cs/rl.net.cli.test/MockSender.cs b/bindings/cs/rl.net.cli.test/MockSender.cs new file mode 100644 index 000000000..b71a8cbda --- /dev/null +++ b/bindings/cs/rl.net.cli.test/MockSender.cs @@ -0,0 +1,70 @@ +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Rl.Net; +using Rl.Net.Native; + +namespace Rl.Net.Cli.Test +{ + using SenderFactory = Func; + using BackgroundErrorCallback = Action; + + internal class MockSender : ISender + { + public Action Init + { + get; + set; + } + + public Action Send + { + get; + set; + } + + void ISender.Init(ApiStatus status) + { + if (this.Init != null) + { + this.Init(status); + } + } + + void ISender.Send(SharedBuffer buffer, ApiStatus status) + { + if (this.Send != null) + { + this.Send(buffer, status); + } + } + } + + + internal class MockAsyncSender : AsyncSender + { + public MockAsyncSender(ErrorCallback callback) : base(callback) + { + } + + public new Func Send + { + get; + set; + } + + protected override Task SendAsync(SharedBuffer buffer) + { + if (this.Send != null) + { + return this.Send(buffer, this.RaiseBackgroundError); + } + + return Task.CompletedTask; + } + } +} \ No newline at end of file diff --git a/bindings/cs/rl.net.cli.test/SenderExtensibilityCBLoopTest.cs b/bindings/cs/rl.net.cli.test/SenderExtensibilityCBLoopTest.cs new file mode 100644 index 000000000..e50fea084 --- /dev/null +++ b/bindings/cs/rl.net.cli.test/SenderExtensibilityCBLoopTest.cs @@ -0,0 +1,500 @@ +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Rl.Net; +using Rl.Net.Native; +using Rl.Net.Cli.Test; + +namespace Rl.Net.Cli.Test +{ + using SenderFactory = Func; + using BackgroundErrorCallback = Action; + + [TestClass] + public class SenderExtensibilityCBLoopTest : TestBase + { + const string CustomSenderConfigJson = +@"{ + ""ApplicationID"": ""ßïϱTèƨƭÂƥƥℓïçáƭïôñNá₥è-ℓôř"", + ""IsExplorationEnabled"": true, + ""InitialExplorationEpsilon"": 1.0, + ""model.source"": ""NO_MODEL_DATA"", + ""model.implementation"": ""PASSTHROUGH_PDF"", + ""model.backgroundrefresh"": false, + ""observation.sender.implementation"": ""BINDING_SENDER"", + ""interaction.sender.implementation"": ""BINDING_SENDER"", + ""observation.send.batchintervalms"": 10, + ""interaction.send.batchintervalms"": 10 +} +"; + + const string FileSenderConfigJson = +@"{ + ""ApplicationID"": ""ßïϱTèƨƭÂƥƥℓïçáƭïôñNá₥è-ℓôř"", + ""IsExplorationEnabled"": true, + ""InitialExplorationEpsilon"": 1.0, + ""model.source"": ""NO_MODEL_DATA"", + ""model.implementation"": ""PASSTHROUGH_PDF"", + ""model.backgroundrefresh"": false, + ""observation.sender.implementation"": ""OBSERVATION_FILE_SENDER"", + ""interaction.sender.implementation"": ""INTERACTION_FILE_SENDER"" +} +"; + + private CBLoop CreateCBLoop(FactoryContext factoryContext = null) + { + Configuration config; + ApiStatus apiStatus = new ApiStatus(); + if (!Configuration.TryLoadConfigurationFromJson(CustomSenderConfigJson, out config, apiStatus)) + { + Assert.Fail("Failed to parse pseudolocalized configuration JSON: " + apiStatus.ErrorMessage); + } + + CBLoop cbLoop = factoryContext == null ? new CBLoop(config) : new CBLoop(config, factoryContext); + + return cbLoop; + } + + [TestMethod] + public void Test_CustomSender_FailsWhenNotRegistered() + { + const int TypeNotRegisteredError = 10; // see errors_data.h + + ApiStatus apiStatus = new ApiStatus(); + CBLoop cbLoop = CreateCBLoop(); + + Assert.IsFalse(cbLoop.TryInit(apiStatus), "Should not be able to configure a model with BINDING_SENDER if custom factory is not set."); + + Assert.AreEqual(TypeNotRegisteredError, apiStatus.ErrorCode); + } + + [TestMethod] + public void Test_CustomSender_FactoryNotCalled_WhenNotRequested() + { + ApiStatus apiStatus = new ApiStatus(); + + FactoryContext factoryContext = new FactoryContext(); + + bool factoryCalled = false; + Func customFactory = + (IReadOnlyConfiguration readOnlyConfig, ErrorCallback callback) => + { + factoryCalled = true; + + return new MockSender(); + }; + + Configuration config = new Configuration(); + if (!Configuration.TryLoadConfigurationFromJson(FileSenderConfigJson, out config, apiStatus)) + { + Assert.Fail("Failed to parse pseudolocalized configuration JSON: " + apiStatus.ErrorMessage); + } + + TempFileDisposable interactionDisposable = new TempFileDisposable(); + this.TestCleanup.Add(interactionDisposable); + + TempFileDisposable observationDisposable = new TempFileDisposable(); + this.TestCleanup.Add(observationDisposable); + + config["interaction.file.name"] = interactionDisposable.Path; + config["observation.file.name"] = observationDisposable.Path; + + factoryContext.SetSenderFactory(customFactory); + + CBLoop cbLoop = new CBLoop(config, factoryContext); + cbLoop.Init(); + + Assert.IsFalse(factoryCalled, "Custom factory should not be called unless BINDING_SENDER is selected in configuration."); + } + + private FactoryContext CreateFactoryContext(SenderFactory customFactory) + { + FactoryContext factoryContext = new FactoryContext(); + factoryContext.SetSenderFactory(customFactory); + + return factoryContext; + } + + private FactoryContext CreateFactoryContext(Action initAction = null, Action sendAction = null) + { + return CreateFactoryContext( + (config, callback) => + { + return new MockSender + { + Init = initAction, + Send = sendAction + }; + } + ); + } + + private FactoryContext CreateFactoryContext(Func asyncSendFunc = null) + { + return CreateFactoryContext( + (config, callback) => + { + return new MockAsyncSender(callback) + { + Send = asyncSendFunc + }; + } + ); + } + + [TestMethod] + public void Test_CustomSender_FactoryCalled_WhenRequested() + { + ApiStatus apiStatus = new ApiStatus(); + + bool factoryCalled = false; + FactoryContext factoryContext = CreateFactoryContext( + (IReadOnlyConfiguration config, ErrorCallback callback) => + { + factoryCalled = true; + + return new MockSender(); + }); + + + CBLoop cbLoop = CreateCBLoop(factoryContext); + cbLoop.Init(); + + Assert.IsTrue(factoryCalled, "Custom factory must be called when BINDING_SENDER is selected in configuration."); + } + + [TestMethod] + public void Test_CustomSender_InitSuccess() + { + bool initCalled = false; + void SenderInit(ApiStatus status) + { + initCalled = true; + } + + FactoryContext factoryContext = CreateFactoryContext(initAction: SenderInit); + + CBLoop cbLoop = CreateCBLoop(factoryContext); + cbLoop.Init(); + + Assert.IsTrue(initCalled, "MockSender.Init should be called and succeed, which means CBLoop.Init should succeed."); + } + + const string OpaqueErrorMessage = "Opaque error in external code. §ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"; + + private void Run_TestCustomSender_InitFailure(Action senderInit, string expectedString, bool expectPrefix = false) + { + FactoryContext factoryContext = CreateFactoryContext(initAction: senderInit); + + ApiStatus apiStatus = new ApiStatus(); + CBLoop cbLoop = CreateCBLoop(factoryContext); + Assert.IsFalse(cbLoop.TryInit(apiStatus), "MockSender.Init should be called and fail, which means CBLoop.Init should fail."); + + Assert.AreEqual(NativeMethods.OpaqueBindingError, apiStatus.ErrorCode); + + if (!expectPrefix) + { + Assert.AreEqual(expectedString, apiStatus.ErrorMessage); + } + else + { + Assert.IsTrue(apiStatus.ErrorMessage.StartsWith(expectedString)); + } + } + + [TestMethod] + public void Test_CustomSender_InitFailure_ViaApiStatusBuilder() + { + const string ExpectedString = OpaqueErrorMessage; + void SenderInit(ApiStatus status) + { + new ApiStatusBuilder(NativeMethods.OpaqueBindingError) + .Append("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥") + .UpdateApiStatus(status); + } + + this.Run_TestCustomSender_InitFailure(SenderInit, ExpectedString); + } + + [TestMethod] + public void Test_CustomSender_InitFailure_ViaRLException() + { + const string ExpectedString = OpaqueErrorMessage; + void SenderInit(ApiStatus status) + { + throw new RLException("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"); + } + + this.Run_TestCustomSender_InitFailure(SenderInit, ExpectedString); + } + + [TestMethod] + public void Test_CustomSender_InitFailure_ViaNonRLException() + { + string expectedStringPrefix = OpaqueErrorMessage; + void SenderInit(ApiStatus status) + { + try + { + throw new Exception("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"); + } + catch (Exception e) + { + expectedStringPrefix += "\n" + e.StackTrace; + throw; + } + } + + this.Run_TestCustomSender_InitFailure(SenderInit, expectedStringPrefix, expectPrefix: true); + } + + const string EventId = "É1ß9ÐÇ83-8ÐF5-45É5-8Ð59-29F05ßÉ89ß96"; + const string ContextJsonWithPdf = +@"{ + ""共有"": {""δèƒ"": ""ℓƙωèř áωèℓ ωèJƙř""}, + ""_multi"": [{ + ""ωôω"": 1 + },{ + ""ωôω"": 2 + }], + ""p"": [0.3, 0.7] +} +"; + + [TestMethod] + public void Test_CustomSender_SendSuccess() + { + ManualResetEventSlim senderCalledWaiter = new ManualResetEventSlim(initialState: false); + + bool sendCalled = false; + void SenderSend(SharedBuffer buffer, ApiStatus status) + { + sendCalled = true; + senderCalledWaiter.Set(); + } + + FactoryContext factoryContext = CreateFactoryContext(sendAction: SenderSend); + + CBLoop cbLoop = CreateCBLoop(factoryContext); + cbLoop.Init(); + RankingResponse response = cbLoop.ChooseRank(EventId, ContextJsonWithPdf); + + senderCalledWaiter.Wait(TimeSpan.FromSeconds(1)); + + Assert.IsTrue(sendCalled); + } + + private void Run_TestCustomSender_SendFailure(FactoryContext factoryContext, string expectedString, bool expectPrefix = false) + { + ManualResetEventSlim backgroundMessageWaiter = new ManualResetEventSlim(initialState: false); + + int backgroundErrorCount = 0; + int backgroundErrorCode = 0; + string backgroundErrorMessage = null; + void OnBackgroundError(object sender, ApiStatus args) + { + Assert.AreEqual(0, backgroundErrorCount++, "Do not duplicate background errors."); + backgroundErrorCode = args.ErrorCode; + backgroundErrorMessage = args.ErrorMessage; + + backgroundMessageWaiter.Set(); + } + + CBLoop cbLoop = CreateCBLoop(factoryContext); + cbLoop.BackgroundError += OnBackgroundError; + + cbLoop.Init(); + + ApiStatus apiStatus = new ApiStatus(); + RankingResponse response; + Assert.IsTrue(cbLoop.TryChooseRank(EventId, ContextJsonWithPdf, out response, apiStatus)); + Assert.AreEqual(NativeMethods.SuccessStatus, apiStatus.ErrorCode, "Errors from ISender.Send should be background errors."); + + backgroundMessageWaiter.Wait(TimeSpan.FromSeconds(1)); + + Assert.AreEqual(NativeMethods.OpaqueBindingError, backgroundErrorCode, "Error from ISender did not get raised."); + + if (!expectPrefix) + { + Assert.AreEqual(OpaqueErrorMessage, backgroundErrorMessage); + } + else + { + Assert.IsTrue(backgroundErrorMessage.StartsWith(OpaqueErrorMessage)); + } + } + + private void Run_TestCustomSender_SendFailure(Action senderSend, string expectedString, bool expectPrefix = false) + { + FactoryContext factoryContext = CreateFactoryContext(sendAction: senderSend); + + Run_TestCustomSender_SendFailure(factoryContext, expectedString, expectPrefix); + } + + [TestMethod] + public void Test_CustomSender_SendFailure_ViaApiStatusBuilder() + { + void SenderSend(SharedBuffer buffer, ApiStatus status) + { + new ApiStatusBuilder(NativeMethods.OpaqueBindingError) + .Append("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥") + .UpdateApiStatus(status); + } + + this.Run_TestCustomSender_SendFailure(SenderSend, OpaqueErrorMessage); + } + + [TestMethod] + public void Test_CustomSender_SendFailure_ViaRLException() + { + void SenderSend(SharedBuffer buffer, ApiStatus status) + { + throw new RLException("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"); + } + + this.Run_TestCustomSender_SendFailure(SenderSend, OpaqueErrorMessage); + } + + [TestMethod] + public void Test_CustomSender_SendFailure_ViaNonRLException() + { + string expectedStringPrefix = OpaqueErrorMessage; + void SenderSend(SharedBuffer buffer, ApiStatus status) + { + try + { + throw new Exception("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"); + } + catch (Exception e) + { + expectedStringPrefix += "\n" + e.StackTrace; + throw; + } + } + + this.Run_TestCustomSender_SendFailure(SenderSend, expectedStringPrefix, expectPrefix: true); + } + + [TestMethod] + public void Test_AsyncSender_SendSuccess() + { + ManualResetEventSlim senderCalledWaiter = new ManualResetEventSlim(initialState: false); + + bool sendCalled = false; + Task AsyncSenderSend(SharedBuffer buffer, BackgroundErrorCallback raiseBackgroundError) + { + sendCalled = true; + senderCalledWaiter.Set(); + + return Task.CompletedTask; + } + + FactoryContext factoryContext = CreateFactoryContext(asyncSendFunc: AsyncSenderSend); + + CBLoop cbLoop = CreateCBLoop(factoryContext); + cbLoop.Init(); + RankingResponse response = cbLoop.ChooseRank(EventId, ContextJsonWithPdf); + + senderCalledWaiter.Wait(TimeSpan.FromSeconds(1)); + + Assert.IsTrue(sendCalled); + } + + private void Run_TestAsyncSender_SendFailure(Func asyncSenderSend, string expectedString, bool expectPrefix = false) + { + FactoryContext factoryContext = CreateFactoryContext(asyncSendFunc: asyncSenderSend); + + Run_TestCustomSender_SendFailure(factoryContext, expectedString, expectPrefix); + } + + [TestMethod] + public void Test_AsyncSender_SendFailure_ViaExplicitRaise() + { + Task AsyncSenderSend(SharedBuffer buffer, BackgroundErrorCallback raiseBackgroundError) + { + ApiStatusBuilder statusBuilder = new ApiStatusBuilder(NativeMethods.OpaqueBindingError) + .Append("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"); + + raiseBackgroundError(statusBuilder.ToApiStatus()); + + return Task.CompletedTask; + } + + this.Run_TestAsyncSender_SendFailure(AsyncSenderSend, OpaqueErrorMessage); + } + + [TestMethod] + public void Test_AsyncSender_SendFailure_ViaRLException() + { + + #pragma warning disable 1998 + async Task AsyncSenderSend(SharedBuffer buffer, BackgroundErrorCallback raiseBackgroundError) + { + throw new RLException("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"); + } + #pragma warning restore + + this.Run_TestAsyncSender_SendFailure(AsyncSenderSend, OpaqueErrorMessage); + } + + [TestMethod] + public void Test_AsyncSender_SendFailure_ViaRLExceptionInTask() + { + Task AsyncSenderSend(SharedBuffer buffer, BackgroundErrorCallback raiseBackgroundError) + { + return Task.FromException(new RLException("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥")); + } + + this.Run_TestAsyncSender_SendFailure(AsyncSenderSend, OpaqueErrorMessage); + } + + [TestMethod] + public void Test_AsyncSender_SendFailure_ViaNonRLException() + { + string expectedStringPrefix = OpaqueErrorMessage; + + #pragma warning disable 1998 + async Task AsyncSenderSend(SharedBuffer buffer, BackgroundErrorCallback raiseBackgroundError) + { + try + { + throw new Exception("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"); + } + catch (Exception e) + { + expectedStringPrefix += "\n" + e.StackTrace; + throw e; + } + } + #pragma warning restore + + this.Run_TestAsyncSender_SendFailure(AsyncSenderSend, OpaqueErrorMessage, expectPrefix: true); + } + + [TestMethod] + public void Test_AsyncSender_SendFailure_ViaNonRLExceptionInTask() + { + string expectedStringPrefix = OpaqueErrorMessage; + Task AsyncSenderSend(SharedBuffer buffer, BackgroundErrorCallback raiseBackgroundError) + { + try + { + throw new Exception("§ô₥è Tèжƭ Fřô₥ ÐôƭNèƭ ℓôřè₥"); + } + catch (Exception e) + { + expectedStringPrefix += "\n" + e.StackTrace; + return Task.FromException(e); + } + } + + this.Run_TestAsyncSender_SendFailure(AsyncSenderSend, expectedStringPrefix, expectPrefix: true); + } + } + + +} diff --git a/bindings/cs/rl.net.cli.test/SenderExtensibilityTest.cs b/bindings/cs/rl.net.cli.test/SenderExtensibilityTest.cs index 3adb8e5f0..ad53b47fd 100644 --- a/bindings/cs/rl.net.cli.test/SenderExtensibilityTest.cs +++ b/bindings/cs/rl.net.cli.test/SenderExtensibilityTest.cs @@ -7,66 +7,13 @@ using Newtonsoft.Json.Linq; using Rl.Net; using Rl.Net.Native; +using Rl.Net.Cli.Test; namespace Rl.Net.Cli.Test { using SenderFactory = Func; using BackgroundErrorCallback = Action; - internal class MockSender : ISender - { - public Action Init - { - get; - set; - } - - public Action Send - { - get; - set; - } - - void ISender.Init(ApiStatus status) - { - if (this.Init != null) - { - this.Init(status); - } - } - - void ISender.Send(SharedBuffer buffer, ApiStatus status) - { - if (this.Send != null) - { - this.Send(buffer, status); - } - } - } - - - internal class MockAsyncSender : AsyncSender - { - public MockAsyncSender(ErrorCallback callback) : base(callback) - { - } - - public new Func Send - { - get; - set; - } - - protected override Task SendAsync(SharedBuffer buffer) - { - if (this.Send != null) - { - return this.Send(buffer, this.RaiseBackgroundError); - } - - return Task.CompletedTask; - } - } - + [TestClass] public class SenderExtensibilityTest : TestBase { diff --git a/bindings/cs/rl.net.cli/BasicUsageCommand.cs b/bindings/cs/rl.net.cli/BasicUsageCommand.cs new file mode 100644 index 000000000..7b95cdc63 --- /dev/null +++ b/bindings/cs/rl.net.cli/BasicUsageCommand.cs @@ -0,0 +1,124 @@ +using System; +using CommandLine; + +namespace Rl.Net.Cli +{ + [Verb("basicUsage", HelpText = "Basic usage of the API")] + class BasicUsageCommand : CommandBase + { + [Option(longName: "testType", HelpText = "select from (liveModel, CBLoop, PdfExample) basic usage examples", Required = false, Default = "liveModel")] + public string testType { get; set; } + + public override void Run() + { + switch (testType) + { + case "liveModel": + BasicUsage(this.ConfigPath); + break; + case "CBLoop": + BasicUsageCBLoop(this.ConfigPath); + break; + case "PdfExample": + PdfExample(this.ConfigPath); + break; + default: + Console.WriteLine("Invalid test type"); + break; + } + } + + public static void BasicUsage(string configPath) + { + const float outcome = 1.0f; + const string eventId = "event_id"; + const string contextJson = "{\"GUser\":{\"id\":\"a\",\"major\":\"eng\",\"hobby\":\"hiking\"},\"_multi\":[ { \"TAction\":{\"a1\":\"f1\"} },{\"TAction\":{\"a2\":\"f2\"}}]}"; + + LiveModel liveModel = Helpers.CreateLiveModelOrExit(configPath); + + ApiStatus apiStatus = new ApiStatus(); + + RankingResponse rankingResponse = new RankingResponse(); + if (!liveModel.TryChooseRank(eventId, contextJson, rankingResponse, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + + long actionId; + if (!rankingResponse.TryGetChosenAction(out actionId, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + + Console.WriteLine($"Chosen action id: {actionId}"); + + if (!liveModel.TryQueueOutcomeEvent(eventId, outcome, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + Console.WriteLine("Basice usage live model success"); + } + + public static void BasicUsageCBLoop(string configPath) + { + const float outcome = 1.0f; + const string eventId = "event_id"; + const string contextJson = "{\"GUser\":{\"id\":\"a\",\"major\":\"eng\",\"hobby\":\"hiking\"},\"_multi\":[ { \"TAction\":{\"a1\":\"f1\"} },{\"TAction\":{\"a2\":\"f2\"}}]}"; + + CBLoop cb_loop = Helpers.CreateCBLoopOrExit(configPath); + + ApiStatus apiStatus = new ApiStatus(); + + RankingResponse rankingResponse = new RankingResponse(); + if (!cb_loop.TryChooseRank(eventId, contextJson, rankingResponse, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + + long actionId; + if (!rankingResponse.TryGetChosenAction(out actionId, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + + Console.WriteLine($"Chosen action id: {actionId}"); + + if (!cb_loop.TryQueueOutcomeEvent(eventId, outcome, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + Console.WriteLine("Basice usage cb loop success"); + } + + public static void PdfExample(string configPath) + { + const float outcome = 1.0f; + const string eventId = "event_id"; + const string contextJson = "{\"GUser\":{\"id\":\"a\",\"major\":\"eng\",\"hobby\":\"hiking\"},\"_multi\":[ { \"TAction\":{\"a1\":\"f1\"} },{\"TAction\":{\"a2\":\"f2\"}}],\"p\":[0.2, 0.8]}"; + + LiveModel liveModel = Helpers.CreateLiveModelOrExit(configPath); + + ApiStatus apiStatus = new ApiStatus(); + + RankingResponse rankingResponse = new RankingResponse(); + if (!liveModel.TryChooseRank(eventId, contextJson, rankingResponse, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + + long actionId; + if (!rankingResponse.TryGetChosenAction(out actionId, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + + Console.WriteLine($"Chosen action id: {actionId}"); + + if (!liveModel.TryQueueOutcomeEvent(eventId, outcome, apiStatus)) + { + Helpers.WriteStatusAndExit(apiStatus); + } + Console.WriteLine("Basice usage pdf example success"); + } + } +} diff --git a/bindings/cs/rl.net.cli/CMakeLists.txt b/bindings/cs/rl.net.cli/CMakeLists.txt index 15bef1ce2..35439b1c4 100644 --- a/bindings/cs/rl.net.cli/CMakeLists.txt +++ b/bindings/cs/rl.net.cli/CMakeLists.txt @@ -1,4 +1,5 @@ set (RL_NET_CLI_SOURCES + BasicUsageCommand.cs CommandBase.cs EntryPoints.cs Helpers.cs diff --git a/bindings/cs/rl.net.cli/EntryPoints.cs b/bindings/cs/rl.net.cli/EntryPoints.cs index 0fb3cc75b..89794ce92 100644 --- a/bindings/cs/rl.net.cli/EntryPoints.cs +++ b/bindings/cs/rl.net.cli/EntryPoints.cs @@ -10,70 +10,8 @@ static class EntryPoints public static void Main(string[] args) { Parser.Default.ParseArguments - (args) + (args) .WithParsed(command => command.Run()); - //BasicUsage(args[0]); - //PdfExample(args[0]); - } - - public static void BasicUsage(string configPath) - { - const float outcome = 1.0f; - const string eventId = "event_id"; - const string contextJson = "{\"GUser\":{\"id\":\"a\",\"major\":\"eng\",\"hobby\":\"hiking\"},\"_multi\":[ { \"TAction\":{\"a1\":\"f1\"} },{\"TAction\":{\"a2\":\"f2\"}}]}"; - - LiveModel liveModel = Helpers.CreateLiveModelOrExit(configPath); - - ApiStatus apiStatus = new ApiStatus(); - - RankingResponse rankingResponse = new RankingResponse(); - if (!liveModel.TryChooseRank(eventId, contextJson, rankingResponse, apiStatus)) - { - Helpers.WriteStatusAndExit(apiStatus); - } - - long actionId; - if (!rankingResponse.TryGetChosenAction(out actionId, apiStatus)) - { - Helpers.WriteStatusAndExit(apiStatus); - } - - Console.WriteLine($"Chosen action id: {actionId}"); - - if (!liveModel.TryQueueOutcomeEvent(eventId, outcome, apiStatus)) - { - Helpers.WriteStatusAndExit(apiStatus); - } - } - - public static void PdfExample(string configPath) - { - const float outcome = 1.0f; - const string eventId = "event_id"; - const string contextJson = "{\"GUser\":{\"id\":\"a\",\"major\":\"eng\",\"hobby\":\"hiking\"},\"_multi\":[ { \"TAction\":{\"a1\":\"f1\"} },{\"TAction\":{\"a2\":\"f2\"}}],\"p\":[0.2, 0.8]}"; - - LiveModel liveModel = Helpers.CreateLiveModelOrExit(configPath); - - ApiStatus apiStatus = new ApiStatus(); - - RankingResponse rankingResponse = new RankingResponse(); - if (!liveModel.TryChooseRank(eventId, contextJson, rankingResponse, apiStatus)) - { - Helpers.WriteStatusAndExit(apiStatus); - } - - long actionId; - if (!rankingResponse.TryGetChosenAction(out actionId, apiStatus)) - { - Helpers.WriteStatusAndExit(apiStatus); - } - - Console.WriteLine($"Chosen action id: {actionId}"); - - if (!liveModel.TryQueueOutcomeEvent(eventId, outcome, apiStatus)) - { - Helpers.WriteStatusAndExit(apiStatus); - } } public static IEnumerable LazyReadLines(this TextReader textReader) diff --git a/bindings/cs/rl.net.cli/Helpers.cs b/bindings/cs/rl.net.cli/Helpers.cs index 028d2c204..5116290f0 100644 --- a/bindings/cs/rl.net.cli/Helpers.cs +++ b/bindings/cs/rl.net.cli/Helpers.cs @@ -50,6 +50,36 @@ public static LiveModel CreateLiveModelOrExit(string clientJsonPath) return liveModel; } + public static CBLoop CreateCBLoopOrExit(string clientJsonPath) + { + if (!File.Exists(clientJsonPath)) + { + WriteErrorAndExit($"Could not find file with path '{clientJsonPath}'."); + } + + string json = File.ReadAllText(clientJsonPath); + + ApiStatus apiStatus = new ApiStatus(); + + Configuration config; + if (!Configuration.TryLoadConfigurationFromJson(json, out config, apiStatus)) + { + WriteStatusAndExit(apiStatus); + } + + CBLoop cb_loop = new CBLoop(config); + + cb_loop.BackgroundError += LiveModel_BackgroundError; + cb_loop.TraceLoggerEvent += LiveModel_TraceLogEvent; + + if (!cb_loop.TryInit(apiStatus)) + { + WriteStatusAndExit(apiStatus); + } + + return cb_loop; + } + public static void LiveModel_BackgroundError(object sender, ApiStatus e) { Console.Error.WriteLine(e.ErrorMessage); diff --git a/bindings/cs/rl.net.native/CMakeLists.txt b/bindings/cs/rl.net.native/CMakeLists.txt index 243b82409..6c46ce17d 100644 --- a/bindings/cs/rl.net.native/CMakeLists.txt +++ b/bindings/cs/rl.net.native/CMakeLists.txt @@ -3,6 +3,7 @@ set(rl_net_native_SOURCES binding_tracer.cc rl.net.api_status.cc rl.net.buffer.cc + rl.net.cb_loop.cc rl.net.config.cc rl.net.continuous_action_response.cc rl.net.decision_response.cc @@ -20,7 +21,9 @@ set(rl_net_native_HEADERS binding_sender.h binding_tracer.h rl.net.api_status.h + rl.net.base_loop.h rl.net.buffer.h + rl.net.cb_loop.h rl.net.config.h rl.net.continuous_action_response.h rl.net.decision_response.h diff --git a/bindings/cs/rl.net.native/binding_tracer.cc b/bindings/cs/rl.net.native/binding_tracer.cc index 80d36be90..8507c0491 100644 --- a/bindings/cs/rl.net.native/binding_tracer.cc +++ b/bindings/cs/rl.net.native/binding_tracer.cc @@ -2,7 +2,7 @@ namespace rl_net_native { -binding_tracer::binding_tracer(livemodel_context& _context) : context(_context) {} +binding_tracer::binding_tracer(base_loop_context& _context) : context(_context) {} void binding_tracer::log(int log_level, const std::string& msg) { diff --git a/bindings/cs/rl.net.native/binding_tracer.h b/bindings/cs/rl.net.native/binding_tracer.h index e1394a1d2..75aa2f5fe 100644 --- a/bindings/cs/rl.net.native/binding_tracer.h +++ b/bindings/cs/rl.net.native/binding_tracer.h @@ -1,5 +1,5 @@ #pragma once -#include "rl.net.live_model.h" +#include "rl.net.base_loop.h" #include "trace_logger.h" namespace rl_net_native @@ -8,10 +8,10 @@ class binding_tracer : public reinforcement_learning::i_trace { public: // Inherited via i_trace - binding_tracer(livemodel_context& _context); + binding_tracer(base_loop_context& _context); void log(int log_level, const std::string& msg) override; private: - livemodel_context& context; + base_loop_context& context; }; } // namespace rl_net_native diff --git a/bindings/cs/rl.net.native/rl.net.base_loop.h b/bindings/cs/rl.net.native/rl.net.base_loop.h new file mode 100644 index 000000000..db15f2112 --- /dev/null +++ b/bindings/cs/rl.net.native/rl.net.base_loop.h @@ -0,0 +1,23 @@ +#pragma once + +#include "rl.net.factory_context.h" + +namespace rl_net_native +{ +namespace constants +{ +const char* const BINDING_TRACE_LOGGER = "BINDING_TRACE_LOGGER"; +} + +typedef void (*trace_logger_callback_t)(int log_level, const char* msg); +} // namespace rl_net_native + +typedef struct base_loop_context +{ + // callback funtion to user when there is background error. + rl_net_native::background_error_callback_t background_error_callback; + // callback funtion to user for trace log. + rl_net_native::trace_logger_callback_t trace_logger_callback; + // A trace log factory instance holder of one loop instance for binding calls. + reinforcement_learning::trace_logger_factory_t* trace_logger_factory; +} base_loop_context_t; \ No newline at end of file diff --git a/bindings/cs/rl.net.native/rl.net.cb_loop.cc b/bindings/cs/rl.net.native/rl.net.cb_loop.cc new file mode 100644 index 000000000..29475fbca --- /dev/null +++ b/bindings/cs/rl.net.native/rl.net.cb_loop.cc @@ -0,0 +1,133 @@ +#include "rl.net.cb_loop.h" + +#include "binding_tracer.h" +#include "constants.h" +#include "err_constants.h" +#include "trace_logger.h" + +#include + +static void pipe_background_error_callback(const reinforcement_learning::api_status& status, cb_loop_context_t* context) +{ + auto managed_backgroud_error_callback_local = context->base_loop_context.background_error_callback; + if (managed_backgroud_error_callback_local) { managed_backgroud_error_callback_local(status); } +} + +API cb_loop_context_t* CreateCBLoop( + reinforcement_learning::utility::configuration* config, factory_context_t* factory_context) +{ + cb_loop_context_t* context = new cb_loop_context_t; + context->base_loop_context.background_error_callback = nullptr; + context->base_loop_context.trace_logger_callback = nullptr; + context->base_loop_context.trace_logger_factory = nullptr; + + // Create a trace log factory by passing in below creator. It allows CBLoop to use trace_logger provided by user. + const auto binding_tracer_create = [context](std::unique_ptr& retval, + const reinforcement_learning::utility::configuration& cfg, + reinforcement_learning::i_trace* trace_logger, + reinforcement_learning::api_status* status) + { + retval.reset(new rl_net_native::binding_tracer(context->base_loop_context)); + return reinforcement_learning::error_code::success; + }; + + // TODO: Unify this factory projection and the sender_factory projection in FactoryContext. + reinforcement_learning::trace_logger_factory_t* trace_logger_factory = + new reinforcement_learning::trace_logger_factory_t(*factory_context->trace_logger_factory); + + // Register the type in factor to use trace logger creatation function. + trace_logger_factory->register_type(rl_net_native::constants::BINDING_TRACE_LOGGER, binding_tracer_create); + + // This is a clone of cleanup_trace_logger_factory + std::swap(trace_logger_factory, factory_context->trace_logger_factory); + if (trace_logger_factory != nullptr && trace_logger_factory != &reinforcement_learning::trace_logger_factory) + { + delete trace_logger_factory; + } + + // Set TRACE_LOG_IMPLEMENTATION configuration to use trace logger. + config->set(reinforcement_learning::name::TRACE_LOG_IMPLEMENTATION, rl_net_native::constants::BINDING_TRACE_LOGGER); + + context->cb_loop = new reinforcement_learning::cb_loop(*config, pipe_background_error_callback, context, + factory_context->trace_logger_factory, factory_context->data_transport_factory, factory_context->model_factory, + factory_context->sender_factory, factory_context->time_provider_factory); + + return context; +} + +API void DeleteCBLoop(cb_loop_context_t* context) +{ + // Since the cb_loop destructor waits for queues to drain, this can have unhappy consequences, + // so detach the callback pipe first. This will cause all background callbacks to no-op in the + // unmanaged side, which maintains expected thread semantics (the user of the bindings) + context->base_loop_context.background_error_callback = nullptr; + context->base_loop_context.trace_logger_callback = nullptr; + + delete context->base_loop_context.trace_logger_factory; + delete context->cb_loop; + delete context; +} + +API int CBLoopInit(cb_loop_context_t* context, reinforcement_learning::api_status* status) +{ + if (context == nullptr) { return reinforcement_learning::error_code::not_initialized; } + return context->cb_loop->init(status); +} + +API int CBLoopChooseRank(cb_loop_context_t* context, const char* event_id, const char* context_json, + int context_json_size, reinforcement_learning::ranking_response* resp, reinforcement_learning::api_status* status) +{ + if (event_id == nullptr) + { + return context->cb_loop->choose_rank({context_json, static_cast(context_json_size)}, *resp, status); + } + + return context->cb_loop->choose_rank(event_id, {context_json, static_cast(context_json_size)}, *resp, status); +} + +API int CBLoopChooseRankWithFlags(cb_loop_context_t* context, const char* event_id, const char* context_json, + int context_json_size, unsigned int flags, reinforcement_learning::ranking_response* resp, + reinforcement_learning::api_status* status) +{ + return context->cb_loop->choose_rank( + event_id, {context_json, static_cast(context_json_size)}, flags, *resp, status); +} + +API int CBLoopReportActionTaken( + cb_loop_context_t* context, const char* event_id, reinforcement_learning::api_status* status) +{ + return context->cb_loop->report_action_taken(event_id, status); +} + +API int CBLoopReportActionMultiIdTaken(cb_loop_context_t* context, const char* primary_id, const char* secondary_id, + reinforcement_learning::api_status* status) +{ + return context->cb_loop->report_action_taken(primary_id, secondary_id, status); +} + +API int CBLoopReportOutcomeF( + cb_loop_context_t* context, const char* event_id, float outcome, reinforcement_learning::api_status* status) +{ + return context->cb_loop->report_outcome(event_id, outcome, status); +} + +API int CBLoopReportOutcomeJson(cb_loop_context_t* context, const char* event_id, const char* outcome_json, + reinforcement_learning::api_status* status) +{ + return context->cb_loop->report_outcome(event_id, outcome_json, status); +} + +API int CBLoopRefreshModel(cb_loop_context_t* context, reinforcement_learning::api_status* status) +{ + return context->cb_loop->refresh_model(status); +} + +API void CBLoopSetCallback(cb_loop_context_t* cb_loop, rl_net_native::background_error_callback_t callback) +{ + cb_loop->base_loop_context.background_error_callback = callback; +} + +API void CBLoopSetTrace(cb_loop_context_t* cb_loop, rl_net_native::trace_logger_callback_t callback) +{ + cb_loop->base_loop_context.trace_logger_callback = callback; +} diff --git a/bindings/cs/rl.net.native/rl.net.cb_loop.h b/bindings/cs/rl.net.native/rl.net.cb_loop.h new file mode 100644 index 000000000..5a06ced27 --- /dev/null +++ b/bindings/cs/rl.net.native/rl.net.cb_loop.h @@ -0,0 +1,47 @@ +#pragma once + +#include "constants.h" +#include "rl.net.base_loop.h" +#include "rl.net.factory_context.h" +#include "rl.net.native.h" + +typedef struct cb_loop_context +{ + // reinforcement learning cb_loop instance. + reinforcement_learning::cb_loop* cb_loop; + // contains base fields for all loops + base_loop_context_t base_loop_context; +} cb_loop_context_t; + +// Global exports +extern "C" +{ + // NOTE: THIS IS NOT POLYMORPHISM SAFE! + API cb_loop_context_t* CreateCBLoop( + reinforcement_learning::utility::configuration* config, factory_context_t* factory_context); + API void DeleteCBLoop(cb_loop_context_t* context); + + API int CBLoopInit(cb_loop_context_t* cb_loop, reinforcement_learning::api_status* status = nullptr); + + API int CBLoopChooseRank(cb_loop_context_t* cb_loop, const char* event_id, const char* context_json, + int context_json_size, reinforcement_learning::ranking_response* resp, + reinforcement_learning::api_status* status = nullptr); + API int CBLoopChooseRankWithFlags(cb_loop_context_t* cb_loop, const char* event_id, const char* context_json, + int context_json_size, unsigned int flags, reinforcement_learning::ranking_response* resp, + reinforcement_learning::api_status* status = nullptr); + API int CBLoopReportActionTaken( + cb_loop_context_t* cb_loop, const char* event_id, reinforcement_learning::api_status* status = nullptr); + API int CBLoopReportActionMultiIdTaken(cb_loop_context_t* cb_loop, const char* primary_id, const char* secondary_id, + reinforcement_learning::api_status* status = nullptr); + + API int CBLoopReportOutcomeF(cb_loop_context_t* cb_loop, const char* event_id, float outcome, + reinforcement_learning::api_status* status = nullptr); + API int CBLoopReportOutcomeJson(cb_loop_context_t* cb_loop, const char* event_id, const char* outcomeJson, + reinforcement_learning::api_status* status = nullptr); + + API int CBLoopRefreshModel(cb_loop_context_t* context, reinforcement_learning::api_status* status = nullptr); + + API void CBLoopSetCallback(cb_loop_context_t* cb_loop, rl_net_native::background_error_callback_t callback = nullptr); + API void CBLoopSetTrace( + cb_loop_context_t* cb_loop, rl_net_native::trace_logger_callback_t trace_logger_callback = nullptr); +} diff --git a/bindings/cs/rl.net.native/rl.net.live_model.cc b/bindings/cs/rl.net.native/rl.net.live_model.cc index 7525d6ff8..c18fef97d 100644 --- a/bindings/cs/rl.net.native/rl.net.live_model.cc +++ b/bindings/cs/rl.net.native/rl.net.live_model.cc @@ -3,7 +3,6 @@ #include "binding_tracer.h" #include "constants.h" #include "err_constants.h" -#include "rl.net.live_model.h" #include "trace_logger.h" #include @@ -11,7 +10,7 @@ static void pipe_background_error_callback( const reinforcement_learning::api_status& status, livemodel_context_t* context) { - auto managed_backgroud_error_callback_local = context->background_error_callback; + auto managed_backgroud_error_callback_local = context->base_loop_context.background_error_callback; if (managed_backgroud_error_callback_local) { managed_backgroud_error_callback_local(status); } } @@ -19,9 +18,9 @@ API livemodel_context_t* CreateLiveModel( reinforcement_learning::utility::configuration* config, factory_context_t* factory_context) { livemodel_context_t* context = new livemodel_context_t; - context->background_error_callback = nullptr; - context->trace_logger_callback = nullptr; - context->trace_logger_factory = nullptr; + context->base_loop_context.background_error_callback = nullptr; + context->base_loop_context.trace_logger_callback = nullptr; + context->base_loop_context.trace_logger_factory = nullptr; // Create a trace log factory by passing in below creator. It allows LiveModel to use trace_logger provided by user. const auto binding_tracer_create = [context](std::unique_ptr& retval, @@ -29,7 +28,7 @@ API livemodel_context_t* CreateLiveModel( reinforcement_learning::i_trace* trace_logger, reinforcement_learning::api_status* status) { - retval.reset(new rl_net_native::binding_tracer(*context)); + retval.reset(new rl_net_native::binding_tracer(context->base_loop_context)); return reinforcement_learning::error_code::success; }; @@ -62,10 +61,10 @@ API void DeleteLiveModel(livemodel_context_t* context) // Since the livemodel destructor waits for queues to drain, this can have unhappy consequences, // so detach the callback pipe first. This will cause all background callbacks to no-op in the // unmanaged side, which maintains expected thread semantics (the user of the bindings) - context->background_error_callback = nullptr; - context->trace_logger_callback = nullptr; + context->base_loop_context.background_error_callback = nullptr; + context->base_loop_context.trace_logger_callback = nullptr; - delete context->trace_logger_factory; + delete context->base_loop_context.trace_logger_factory; delete context->livemodel; delete context; } @@ -287,10 +286,10 @@ API int LiveModelRefreshModel(livemodel_context_t* context, reinforcement_learni API void LiveModelSetCallback(livemodel_context_t* livemodel, rl_net_native::background_error_callback_t callback) { - livemodel->background_error_callback = callback; + livemodel->base_loop_context.background_error_callback = callback; } API void LiveModelSetTrace(livemodel_context_t* livemodel, rl_net_native::trace_logger_callback_t callback) { - livemodel->trace_logger_callback = callback; -} + livemodel->base_loop_context.trace_logger_callback = callback; +} \ No newline at end of file diff --git a/bindings/cs/rl.net.native/rl.net.live_model.h b/bindings/cs/rl.net.native/rl.net.live_model.h index ea4118a86..09d3bdd43 100644 --- a/bindings/cs/rl.net.native/rl.net.live_model.h +++ b/bindings/cs/rl.net.native/rl.net.live_model.h @@ -1,29 +1,16 @@ #pragma once #include "constants.h" +#include "rl.net.base_loop.h" #include "rl.net.factory_context.h" #include "rl.net.native.h" -namespace rl_net_native -{ -namespace constants -{ -const char* const BINDING_TRACE_LOGGER = "BINDING_TRACE_LOGGER"; -} - -typedef void (*trace_logger_callback_t)(int log_level, const char* msg); -} // namespace rl_net_native - typedef struct livemodel_context { // reinforcement learning live_model instance. reinforcement_learning::live_model* livemodel; - // callback funtion to user when there is background error. - rl_net_native::background_error_callback_t background_error_callback; - // callback funtion to user for trace log. - rl_net_native::trace_logger_callback_t trace_logger_callback; - // A trace log factory instance holder of one live_model instance for binding calls. - reinforcement_learning::trace_logger_factory_t* trace_logger_factory; + // contains base fields for all loops + base_loop_context_t base_loop_context; } livemodel_context_t; // Global exports diff --git a/bindings/cs/rl.net.native/rl.net.native.h b/bindings/cs/rl.net.native/rl.net.native.h index 7244f7e2f..696b22368 100644 --- a/bindings/cs/rl.net.native/rl.net.native.h +++ b/bindings/cs/rl.net.native/rl.net.native.h @@ -1,5 +1,6 @@ #pragma once +#include "cb_loop.h" #include "config_utility.h" #include "live_model.h" diff --git a/bindings/cs/rl.net.native/rl.net.native.vcxproj b/bindings/cs/rl.net.native/rl.net.native.vcxproj index 1a11c3f39..d8940c4f6 100644 --- a/bindings/cs/rl.net.native/rl.net.native.vcxproj +++ b/bindings/cs/rl.net.native/rl.net.native.vcxproj @@ -95,6 +95,7 @@ + @@ -111,7 +112,9 @@ + + diff --git a/bindings/cs/rl.net/BaseLoop.cs b/bindings/cs/rl.net/BaseLoop.cs new file mode 100644 index 000000000..e40106be6 --- /dev/null +++ b/bindings/cs/rl.net/BaseLoop.cs @@ -0,0 +1,13 @@ +using System; + +namespace Rl.Net +{ + namespace Native + { + internal static partial class NativeMethods + { + public delegate void managed_background_error_callback_t(IntPtr apiStatus); + public delegate void managed_trace_callback_t(int logLevel, IntPtr msgUtf8Ptr); + } + } +} \ No newline at end of file diff --git a/bindings/cs/rl.net/CBLoop.cs b/bindings/cs/rl.net/CBLoop.cs new file mode 100644 index 000000000..82253411a --- /dev/null +++ b/bindings/cs/rl.net/CBLoop.cs @@ -0,0 +1,541 @@ +using System; +using System.Runtime.InteropServices; + +using Rl.Net.Native; + +namespace Rl.Net +{ + namespace Native + { + // The publics in this class are just a verbose, but jittably-efficient way of enabling overriding a native invocation + internal static partial class NativeMethods + { + [DllImport("rlnetnative")] + public static extern IntPtr CreateCBLoop(IntPtr config, IntPtr factoryContext); + + [DllImport("rlnetnative")] + public static extern void DeleteCBLoop(IntPtr cbLoop); + + [DllImport("rlnetnative")] + public static extern int CBLoopInit(IntPtr cbLoop, IntPtr apiStatus); + + [DllImport("rlnetnative", EntryPoint = "CBLoopChooseRank")] + private static extern int CBLoopChooseRankNative(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus); + + internal static Func CBLoopChooseRankOverride { get; set; } + + public static int CBLoopChooseRank(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus) + { + if (CBLoopChooseRankOverride != null) + { + return CBLoopChooseRankOverride(cbLoop, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); + } + + return CBLoopChooseRankNative(cbLoop, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); + } + + [DllImport("rlnetnative", EntryPoint = "CBLoopChooseRankWithFlags")] + private static extern int CBLoopChooseRankWithFlagsNative(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus); + + internal static Func CBLoopChooseRankWithFlagsOverride { get; set; } + + public static int CBLoopChooseRankWithFlags(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus) + { + if (CBLoopChooseRankWithFlagsOverride != null) + { + return CBLoopChooseRankWithFlagsOverride(cbLoop, eventId, contextJson, contextJsonSize, flags, rankingResponse, apiStatus); + } + + return CBLoopChooseRankWithFlagsNative(cbLoop, eventId, contextJson, contextJsonSize, flags, rankingResponse, apiStatus); + } + + [DllImport("rlnetnative", EntryPoint = "CBLoopReportActionTaken")] + private static extern int CBLoopReportActionTakenNative(IntPtr cbLoop, IntPtr eventId, IntPtr apiStatus); + + internal static Func CBLoopReportActionTakenOverride { get; set; } + + public static int CBLoopReportActionTaken(IntPtr cbLoop, IntPtr eventId, IntPtr apiStatus) + { + if (CBLoopReportActionTakenOverride != null) + { + return CBLoopReportActionTakenOverride(cbLoop, eventId, apiStatus); + } + + return CBLoopReportActionTakenNative(cbLoop, eventId, apiStatus); + } + + [DllImport("rlnetnative", EntryPoint = "CBLoopReportActionMultiIdTaken")] + private static extern int CBLoopReportActionTakenMultiIdNative(IntPtr cbLoop, IntPtr primaryId, IntPtr secondaryId, IntPtr apiStatus); + + internal static Func CBLoopReportActionTakenMultiIdOverride { get; set; } + + public static int CBLoopReportActionMultiIdTaken(IntPtr cbLoop, IntPtr primaryId, IntPtr secondaryId, IntPtr apiStatus) + { + if (CBLoopReportActionTakenMultiIdOverride != null) + { + return CBLoopReportActionTakenMultiIdOverride(cbLoop, primaryId, secondaryId, apiStatus); + } + + return CBLoopReportActionTakenMultiIdNative(cbLoop, primaryId, secondaryId, apiStatus); + } + + [DllImport("rlnetnative", EntryPoint = "CBLoopReportOutcomeF")] + private static extern int CBLoopReportOutcomeFNative(IntPtr cbLoop, IntPtr eventId, float outcome, IntPtr apiStatus); + + internal static Func CBLoopReportOutcomeFOverride { get; set; } + + public static int CBLoopReportOutcomeF(IntPtr cbLoop, IntPtr eventId, float outcome, IntPtr apiStatus) + { + if (CBLoopReportOutcomeFOverride != null) + { + return CBLoopReportOutcomeFOverride(cbLoop, eventId, outcome, apiStatus); + } + + return CBLoopReportOutcomeFNative(cbLoop, eventId, outcome, apiStatus); + } + + [DllImport("rlnetnative", EntryPoint = "CBLoopReportOutcomeJson")] + private static extern int CBLoopReportOutcomeJsonNative(IntPtr cbLoop, IntPtr eventId, IntPtr outcomeJson, IntPtr apiStatus); + + internal static Func CBLoopReportOutcomeJsonOverride { get; set; } + + public static int CBLoopReportOutcomeJson(IntPtr cbLoop, IntPtr eventId, IntPtr outcomeJson, IntPtr apiStatus) + { + if (CBLoopReportOutcomeJsonOverride != null) + { + return CBLoopReportOutcomeJsonOverride(cbLoop, eventId, outcomeJson, apiStatus); + } + + return CBLoopReportOutcomeJsonNative(cbLoop, eventId, outcomeJson, apiStatus); + } + + [DllImport("rlnetnative")] + public static extern int CBLoopRefreshModel(IntPtr cbLoop, IntPtr apiStatus); + + [DllImport("rlnetnative")] + public static extern void CBLoopSetCallback(IntPtr cbLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_background_error_callback_t callback = null); + + [DllImport("rlnetnative")] + public static extern void CBLoopSetTrace(IntPtr cbLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_trace_callback_t callback = null); + } + } + + public sealed class CBLoop : NativeObject + { + private readonly NativeMethods.managed_background_error_callback_t managedErrorCallback; + private readonly NativeMethods.managed_trace_callback_t managedTraceCallback; + + private static New BindConstructorArguments(Configuration config, FactoryContext factoryContext) + { + return new New(() => + { + factoryContext = factoryContext ?? new FactoryContext(); + IntPtr result = NativeMethods.CreateCBLoop(config.DangerousGetHandle(), factoryContext.DangerousGetHandle()); + + // These references do not live on the heap in this delegate, and could disappear during the invocation + // of CreateCBLoop. Thus, we need to ensure GC knows not to release them until after that call + // returns. + GC.KeepAlive(config); + GC.KeepAlive(factoryContext); + + return result; + }); + + } + + public CBLoop(Configuration config) : this(config, null) + { } + + public CBLoop(Configuration config, FactoryContext factoryContext) : base(BindConstructorArguments(config, factoryContext), new Delete(NativeMethods.DeleteCBLoop)) + { + this.managedErrorCallback = new NativeMethods.managed_background_error_callback_t(this.WrapStatusAndRaiseBackgroundError); + + // DangerousGetHandle here is trivially safe, because .Dispose() cannot be called before the object is + // constructed. + NativeMethods.CBLoopSetCallback(this.DangerousGetHandle(), this.managedErrorCallback); + + this.managedTraceCallback = new NativeMethods.managed_trace_callback_t(this.SendTrace); + } + + private static void CheckJsonString(string json) + { + if (String.IsNullOrWhiteSpace(json)) + { + throw new ArgumentException("Input json is empty", "json"); + } + } + + unsafe private static int CBLoopChooseRank(IntPtr cbLoop, string eventId, string contextJson, IntPtr rankingResponse, IntPtr apiStatus) + { + CheckJsonString(contextJson); + + fixed (byte* contextJsonUtf8Bytes = NativeMethods.StringEncoding.GetBytes(contextJson)) + { + int contextJsonSize = NativeMethods.StringEncoding.GetByteCount(contextJson); + IntPtr contextJsonUtf8Ptr = new IntPtr(contextJsonUtf8Bytes); + + // It is important to pass null on faithfully here, because we rely on this to switch between auto-generate + // eventId and use supplied eventId at the rl.net.native layer. + if (eventId == null) + { + return NativeMethods.CBLoopChooseRank(cbLoop, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, rankingResponse, apiStatus); + } + + fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) + { + return NativeMethods.CBLoopChooseRank(cbLoop, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, rankingResponse, apiStatus); + } + } + } + + // TODO: Should we reduce the rl.net.native interface to only have one of these? + unsafe private static int CBLoopChooseRankWithFlags(IntPtr cbLoop, string eventId, string contextJson, uint flags, IntPtr rankingResponse, IntPtr apiStatus) + { + CheckJsonString(contextJson); + + fixed (byte* contextJsonUtf8Bytes = NativeMethods.StringEncoding.GetBytes(contextJson)) + { + int contextJsonSize = NativeMethods.StringEncoding.GetByteCount(contextJson); + IntPtr contextJsonUtf8Ptr = new IntPtr(contextJsonUtf8Bytes); + + // It is important to pass null on faithfully here, because we rely on this to switch between auto-generate + // eventId and use supplied eventId at the rl.net.native layer. + if (eventId == null) + { + return NativeMethods.CBLoopChooseRankWithFlags(cbLoop, IntPtr.Zero, contextJsonUtf8Ptr, contextJsonSize, flags, rankingResponse, apiStatus); + } + + fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) + { + return NativeMethods.CBLoopChooseRankWithFlags(cbLoop, new IntPtr(eventIdUtf8Bytes), contextJsonUtf8Ptr, contextJsonSize, flags, rankingResponse, apiStatus); + } + } + } + + unsafe private static int CBLoopReportActionTaken(IntPtr cbLoop, string eventId, IntPtr apiStatus) + { + if (eventId == null) + { + throw new ArgumentNullException("eventId"); + } + + fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) + { + return NativeMethods.CBLoopReportActionTaken(cbLoop, new IntPtr(eventIdUtf8Bytes), apiStatus); + } + } + + unsafe private static int CBLoopReportActionMultiIdTaken(IntPtr cbLoop, string primaryId, string secondaryId, IntPtr apiStatus) + { + if (primaryId == null) + { + throw new ArgumentNullException("primaryId"); + } + + if (secondaryId == null) + { + throw new ArgumentNullException("secondaryId"); + } + + fixed (byte* episodeIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(primaryId)) + fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(secondaryId)) + { + return NativeMethods.CBLoopReportActionMultiIdTaken(cbLoop, new IntPtr(episodeIdUtf8Bytes), new IntPtr(eventIdUtf8Bytes), apiStatus); + } + } + + unsafe private static int CBLoopReportOutcomeF(IntPtr cbLoop, string eventId, float outcome, IntPtr apiStatus) + { + if (eventId == null) + { + throw new ArgumentNullException("eventId"); + } + + fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) + { + return NativeMethods.CBLoopReportOutcomeF(cbLoop, new IntPtr(eventIdUtf8Bytes), outcome, apiStatus); + } + } + + unsafe private static int CBLoopReportOutcomeJson(IntPtr cbLoop, string eventId, string outcomeJson, IntPtr apiStatus) + { + if (eventId == null) + { + throw new ArgumentNullException("eventId"); + } + + CheckJsonString(outcomeJson); + + fixed (byte* eventIdUtf8Bytes = NativeMethods.StringEncoding.GetBytes(eventId)) + fixed (byte* outcomeJsonUtf8Bytes = NativeMethods.StringEncoding.GetBytes(outcomeJson)) + { + return NativeMethods.CBLoopReportOutcomeJson(cbLoop, new IntPtr(eventIdUtf8Bytes), new IntPtr(outcomeJsonUtf8Bytes), apiStatus); + } + } + + private void WrapStatusAndRaiseBackgroundError(IntPtr apiStatusHandle) + { + using (ApiStatus status = new ApiStatus(apiStatusHandle)) + { + EventHandler targetEventLocal = this.BackgroundErrorInternal; + if (targetEventLocal != null) + { + targetEventLocal.Invoke(this, status); + } + else + { + // This comes strictly from the background thread - so simply throwing here has + // the right semantics with respect to AppDomain.UnhandledException. Unfortunately, + // that seems to bring down the process, if there is nothing Managed under the native + // stack this will cause an application-level unhandled native exception, and will + // likely terminate the application. So new up a thread, and throw from it. + // See https://stackoverflow.com/questions/42298126/raising-exception-on-managed-and-unmanaged-callback-chain-with-p-invoke + + // IMPORTANT: This is safe solely because the status string is marshaled into the + // exception message on construction (in other words, before control returns to the + // unmanaged call-stack - the Dispose() is a no-op because in this case NativeObject does + // not own the unmanaged pointer, but we use it to remove itself from the finalizer queue) + RLException e = new RLException(status); + new System.Threading.Thread(() => throw e).Start(); + } + } + } + + private void SendTrace(int logLevel, IntPtr msgUtf8Ptr) + { + string msg = NativeMethods.StringMarshallingFunc(msgUtf8Ptr); + + this.OnTraceLoggerEventInternal?.Invoke(this, new TraceLogEventArgs((RLLogLevel)logLevel, msg)); + } + + public bool TryInit(ApiStatus apiStatus = null) + { + int result = NativeMethods.CBLoopInit(this.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + public void Init() + { + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryInit(apiStatus)) + { + throw new RLException(apiStatus); + } + } + + public bool TryChooseRank(string eventId, string contextJson, out RankingResponse response, ApiStatus apiStatus = null) + { + response = new RankingResponse(); + return this.TryChooseRank(eventId, contextJson, response, apiStatus); + } + + public bool TryChooseRank(string eventId, string contextJson, RankingResponse response, ApiStatus apiStatus = null) + { + int result = CBLoopChooseRank(this.DangerousGetHandle(), eventId, contextJson, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + public RankingResponse ChooseRank(string eventId, string contextJson) + { + RankingResponse result = new RankingResponse(); + + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryChooseRank(eventId, contextJson, result, apiStatus)) + { + throw new RLException(apiStatus); + } + + return result; + } + + public bool TryChooseRank(string eventId, string contextJson, ActionFlags flags, out RankingResponse response, ApiStatus apiStatus = null) + { + response = new RankingResponse(); + return this.TryChooseRank(eventId, contextJson, flags, response, apiStatus); + } + + public bool TryChooseRank(string eventId, string contextJson, ActionFlags flags, RankingResponse response, ApiStatus apiStatus = null) + { + int result = CBLoopChooseRankWithFlags(this.DangerousGetHandle(), eventId, contextJson, (uint)flags, response.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + public RankingResponse ChooseRank(string eventId, string contextJson, ActionFlags flags) + { + RankingResponse result = new RankingResponse(); + + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryChooseRank(eventId, contextJson, flags, result, apiStatus)) + { + throw new RLException(apiStatus); + } + + return result; + } + + [Obsolete("Use TryQueueActionTakenEvent instead.")] + public bool TryReportActionTaken(string eventId, ApiStatus apiStatus = null) + => this.TryQueueActionTakenEvent(eventId, apiStatus); + + public bool TryQueueActionTakenEvent(string eventId, ApiStatus apiStatus = null) + { + int result = CBLoopReportActionTaken(this.DangerousGetHandle(), eventId, apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(apiStatus); + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + public bool TryQueueActionTakenEvent(string primaryId, string secondaryId, ApiStatus apiStatus = null) + { + int result = CBLoopReportActionMultiIdTaken(this.DangerousGetHandle(), primaryId, secondaryId, apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(apiStatus); + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + [Obsolete("Use QueueActionTakenEvent instead.")] + public void ReportActionTaken(string eventId) + => this.QueueActionTakenEvent(eventId); + + public void QueueActionTakenEvent(string eventId) + { + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryQueueActionTakenEvent(eventId, apiStatus)) + { + throw new RLException(apiStatus); + } + } + + public void QueueActionTakenEvent(string primaryId, string secondaryId) + { + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryQueueActionTakenEvent(primaryId, secondaryId, apiStatus)) + { + throw new RLException(apiStatus); + } + } + + [Obsolete("Use TryQueueOutcomeEvent instead.")] + public bool TryReportOutcome(string eventId, float outcome, ApiStatus apiStatus = null) + => this.TryQueueOutcomeEvent(eventId, outcome, apiStatus); + + public bool TryQueueOutcomeEvent(string eventId, float outcome, ApiStatus apiStatus = null) + { + int result = CBLoopReportOutcomeF(this.DangerousGetHandle(), eventId, outcome, apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(apiStatus); + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + [Obsolete("Use QueueOutcomeReport instead.")] + public void ReportOutcome(string eventId, float outcome) + => this.QueueOutcomeEvent(eventId, outcome); + + public void QueueOutcomeEvent(string eventId, float outcome) + { + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryQueueOutcomeEvent(eventId, outcome, apiStatus)) + { + throw new RLException(apiStatus); + } + } + + [Obsolete("Use TryQueueOutcomeEvent instead.")] + public bool TryReportOutcome(string eventId, string outcomeJson, ApiStatus apiStatus = null) + => this.TryQueueOutcomeEvent(eventId, outcomeJson, apiStatus); + + public bool TryQueueOutcomeEvent(string eventId, string outcomeJson, ApiStatus apiStatus = null) + { + int result = CBLoopReportOutcomeJson(this.DangerousGetHandle(), eventId, outcomeJson, apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(apiStatus); + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + [Obsolete("Use QueueOutcomeEvent instead.")] + public void ReportOutcome(string eventId, string outcomeJson) + => this.QueueOutcomeEvent(eventId, outcomeJson); + + public void QueueOutcomeEvent(string eventId, string outcomeJson) + { + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryQueueOutcomeEvent(eventId, outcomeJson, apiStatus)) + { + throw new RLException(apiStatus); + } + } + + public void RefreshModel() + { + using (ApiStatus apiStatus = new ApiStatus()) + if (!this.TryRefreshModel(apiStatus)) + { + throw new RLException(apiStatus); + } + } + + public bool TryRefreshModel(ApiStatus apiStatus = null) + { + int result = NativeMethods.CBLoopRefreshModel(this.DangerousGetHandle(), apiStatus.ToNativeHandleOrNullptrDangerous()); + + GC.KeepAlive(apiStatus); + GC.KeepAlive(this); + return result == NativeMethods.SuccessStatus; + } + + private event EventHandler BackgroundErrorInternal; + + // This event is thread-safe, because we do not hook/unhook the event in user-scheduleable code anymore. + public event EventHandler BackgroundError + { + add + { + this.BackgroundErrorInternal += value; + } + remove + { + this.BackgroundErrorInternal -= value; + } + } + + private event EventHandler OnTraceLoggerEventInternal; + + // TODO: + /// + /// Add/remove here is not thread safe. + /// + public event EventHandler TraceLoggerEvent + { + add + { + if (this.OnTraceLoggerEventInternal == null) + { + NativeMethods.CBLoopSetTrace(this.DangerousGetHandle(), this.managedTraceCallback); + GC.KeepAlive(this); + } + + this.OnTraceLoggerEventInternal += value; + } + remove + { + this.OnTraceLoggerEventInternal -= value; + + if (this.OnTraceLoggerEventInternal == null) + { + NativeMethods.CBLoopSetTrace(this.DangerousGetHandle(), null); + GC.KeepAlive(this); + } + } + } + } +} \ No newline at end of file diff --git a/bindings/cs/rl.net/CMakeLists.txt b/bindings/cs/rl.net/CMakeLists.txt index d69c5d265..df357eb71 100644 --- a/bindings/cs/rl.net/CMakeLists.txt +++ b/bindings/cs/rl.net/CMakeLists.txt @@ -9,6 +9,8 @@ set(RL_NET_SOURCES ActionFlags.cs ApiStatus.cs AsyncSender.cs + BaseLoop.cs + CBLoop.cs Configuration.cs ContinuousActionResponse.cs DecisionResponse.cs diff --git a/bindings/cs/rl.net/LiveModel.cs b/bindings/cs/rl.net/LiveModel.cs index af931b10d..59f5b3b41 100644 --- a/bindings/cs/rl.net/LiveModel.cs +++ b/bindings/cs/rl.net/LiveModel.cs @@ -337,13 +337,9 @@ public static int LiveModelReportOutcomeSlotStringIdJson(IntPtr liveModel, IntPt [DllImport("rlnetnative")] public static extern int LiveModelRefreshModel(IntPtr liveModel, IntPtr apiStatus); - public delegate void managed_background_error_callback_t(IntPtr apiStatus); - [DllImport("rlnetnative")] public static extern void LiveModelSetCallback(IntPtr liveModel, [MarshalAs(UnmanagedType.FunctionPtr)] managed_background_error_callback_t callback = null); - public delegate void managed_trace_callback_t(int logLevel, IntPtr msgUtf8Ptr); - [DllImport("rlnetnative")] public static extern void LiveModelSetTrace(IntPtr liveModel, [MarshalAs(UnmanagedType.FunctionPtr)] managed_trace_callback_t callback = null); }