Skip to content

Commit

Permalink
feat: Initial Support for FlatBuffer contexts in .NET bindings
Browse files Browse the repository at this point in the history
* Adds support for fb inputs for CB loops
* Adds helper for generating fb inputs
* Cleans up ActionFlags APIs a little
  • Loading branch information
lokitoth committed Mar 1, 2024
1 parent b829909 commit 1c1fab7
Show file tree
Hide file tree
Showing 13 changed files with 620 additions and 204 deletions.
1 change: 1 addition & 0 deletions bindings/cs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ include(FindDotnet)
add_subdirectory(rl.net.native)
add_subdirectory(rl.net)
add_subdirectory(rl.net.cli)
add_subdirectory(rl.net.flatbuffer)
add_subdirectory(rl.net.cli.test)
2 changes: 1 addition & 1 deletion bindings/cs/rl.net.cli.test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ else()
add_custom_target(rl.net.cli.test ALL
COMMAND ${DOTNET_COMMAND} build ${CMAKE_CURRENT_SOURCE_DIR} -o $<TARGET_FILE_DIR:rlnetnative> -v m --nologo --no-dependencies /clp:NoSummary
COMMENT Building rl.net.cli.test
DEPENDS rl.net.cli
DEPENDS rl.net.cli rl.net.flatbuffers
SOURCES ${RL_NET_CLI_TEST_SOURCES})
endif()

Expand Down
28 changes: 14 additions & 14 deletions bindings/cs/rl.net.cli.test/UnicodeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void CleanupPInvokeOverrides()
{
NativeMethods.ConfigurationGetOverride = null;
NativeMethods.ConfigurationSetOverride = null;
NativeMethods.LiveModelChooseRankOverride = null;
//NativeMethods.LiveModelChooseRankOverride = null;
NativeMethods.LiveModelChooseRankWithFlagsOverride = null;
NativeMethods.LiveModelReportActionTakenOverride = null;
NativeMethods.LiveModelReportOutcomeFOverride = null;
Expand Down Expand Up @@ -374,8 +374,8 @@ public void Test_LiveModel_RequestEpisodicE2E()

private void Run_LiveModelChooseRank_Test(LiveModel liveModel, string eventId, string contextJson)
{
NativeMethods.LiveModelChooseRankOverride =
(IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus) =>
NativeMethods.LiveModelChooseRankWithFlagsOverride =
(IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus) =>
{
string eventIdMarshalledBack = NativeMethods.StringMarshallingFunc(eventIdPtr);
Assert.AreEqual(eventId, eventIdMarshalledBack, "Marshalling eventId does not work properly in LiveModelChooseRank");
Expand Down Expand Up @@ -417,11 +417,11 @@ public void Test_LiveModel_ChooseRank()

private void Run_LiveModelRequestDecision_Test(LiveModel liveModel, string contextJson)
{
NativeMethods.LiveModelRequestDecisionOverride =
(IntPtr liveModelPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr rankingResponse, IntPtr ApiStatus) =>
NativeMethods.LiveModelRequestDecisionWithFlagsOverride =
(IntPtr liveModelPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr ApiStatus) =>
{
string contextJsonMarshalledBack = NativeMethods.StringMarshallingFunc(contextJsonPtr);
Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestDecision");
Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestDecisionWithFlags");
return NativeMethods.SuccessStatus;
};
Expand Down Expand Up @@ -454,11 +454,11 @@ public void Test_LiveModel_RequestDecision()

private void Run_LiveModelRequestMultiSlotDetailed_Test(LiveModel liveModel, string contextJson, string eventId)
{
NativeMethods.LiveModelRequestMultiSlotDecisionDetailedOverride =
(IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr rankingResponse, IntPtr ApiStatus) =>
NativeMethods.LiveModelRequestMultiSlotDecisionDetailedWithFlagsOverride =
(IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr ApiStatus) =>
{
string contextJsonMarshalledBack = NativeMethods.StringMarshallingFunc(contextJsonPtr);
Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestMultiSlotDecisionDetailed");
Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestDecisionDetailedWithFlags");
return NativeMethods.SuccessStatus;
};
Expand Down Expand Up @@ -491,11 +491,11 @@ public void Test_LiveModel_RequestMultiSlotDecisionDetailed()

private void Run_LiveModelRequestMultiSlot_Test(LiveModel liveModel, string contextJson, string eventId)
{
NativeMethods.LiveModelRequestMultiSlotDecisionOverride =
(IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr rankingResponse, IntPtr ApiStatus) =>
NativeMethods.LiveModelRequestMultiSlotDecisionWithFlagsOverride =
(IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr ApiStatus) =>
{
string contextJsonMarshalledBack = NativeMethods.StringMarshallingFunc(contextJsonPtr);
Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestMultiSlotDecision");
Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestDecisionWithFlags");
return NativeMethods.SuccessStatus;
};
Expand Down Expand Up @@ -528,8 +528,8 @@ public void Test_LiveModel_RequestMultiSlotDecision()

private void Run_LiveModelRequestContinuousAction_Test(LiveModel liveModel, string contextJson)
{
NativeMethods.LiveModelRequestContinuousActionOverride =
(IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, IntPtr continuousActionResponse, IntPtr ApiStatus) =>
NativeMethods.LiveModelRequestContinuousActionWithFlagsOverride =
(IntPtr liveModelPtr, IntPtr eventIdPtr, IntPtr contextJsonPtr, int contextJsonSize, uint flags, IntPtr continuousActionResponse, IntPtr ApiStatus) =>
{
string contextJsonMarshalledBack = NativeMethods.StringMarshallingFunc(contextJsonPtr);
Assert.AreEqual(contextJson, contextJsonMarshalledBack, "Marshalling contextJson does not work properly in LiveModelRequestContinuousAction");
Expand Down
1 change: 1 addition & 0 deletions bindings/cs/rl.net.cli.test/rl.net.cli.test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<ItemGroup>
<ProjectReference Include="..\rl.net.cli\rl.net.cli.csproj" />
<ProjectReference Include="..\rl.net\rl.net.csproj" />
<ProjectReference Include="..\rl.net.flatbuffer\rl.net.flatbuffer.csproj" />
</ItemGroup>

</Project>
2 changes: 2 additions & 0 deletions bindings/cs/rl.net.flatbuffer/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
generated/
AnyCPU/
44 changes: 44 additions & 0 deletions bindings/cs/rl.net.flatbuffer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/../cmake/Modules/")

# First try to find the config version. Newer, used by vcpkg etc
find_package(Flatbuffers CONFIG)
if(TARGET flatbuffers::flatbuffers AND TARGET flatbuffers::flatc)
get_property(flatc_location TARGET flatbuffers::flatc PROPERTY LOCATION)
message(STATUS "Found Flatbuffers with CONFIG, flatc located at: ${flatc_location}")
else()
# Fallback to the old version
find_package(Flatbuffers MODULE REQUIRED)
set(flatc_location ${FLATBUFFERS_FLATC_EXECUTABLE})
message(STATUS "Found Flatbuffers with MODULE, flatc located at: ${flatc_location}")
endif()

include(FlatbufferUtils)

set(RL_NET_VWFB_SCHEMA_FILES
"${CMAKE_SOURCE_DIR}/ext_libs/vowpal_wabbit/vowpalwabbit/fb_parser/schema/example.fbs" )

set(RL_NET_FLATBUFFERS_SOURCES
VWExampleBuilder.cs
)

add_flatbuffer_schema(
TARGET rl.net.flatbuffers_generated
SCHEMAS ${RL_NET_VWFB_SCHEMA_FILES}
OUTPUT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/generated/vwfb
GENERATED_FILES_VAR RL_NET_FLATBUFFERS_SOURCES
FLATC_LANGUAGE CSharp
FLATC_EXE ${flatc_location}
)

# find all the files

if (rlclientlib_DOTNET_USE_MSPROJECT)
include_external_msproject(rl.net.flatbuffers ${CMAKE_CURRENT_SOURCE_DIR}/rl.net.flatbuffers.csproj)
else()
add_custom_target(rl.net.flatbuffers
COMMAND ${DOTNET_COMMAND} build ${CMAKE_CURRENT_SOURCE_DIR} -o $<TARGET_FILE_DIR:rlnetnative> -v n --nologo --no-dependencies /clp:NoSummary --configuration "$<$<CONFIG:Debug>:Debug>$<$<CONFIG:Release>:Release>$<$<CONFIG:RelWithDebInfo>:Release>"
COMMENT Building rl.net.flatbuffers
SOURCES ${RL_NET_FLATBUFFERS_SOURCES})
endif()

add_dependencies(rl.net.flatbuffers rl.net.flatbuffers_generated)
117 changes: 117 additions & 0 deletions bindings/cs/rl.net.flatbuffer/VWExampleBuilder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

using Google.FlatBuffers;
using System.Collections.Generic;
using vwfb = VW.parsers.flatbuffer;

namespace Rl.Net.Flatbuffers
{
public interface ICBExampleBuilder
{
bool BuildingShared { get; }
int BuildingAction { get; }

void PushNamespace(string name = "");

void AddFeature(string name, float value = 1.0f);
void AddFeature(string name, string value);

int PushAction();

byte[] FinishExample();
}

public class VW_CBExampleBuilder : ICBExampleBuilder
{
struct namespace_prototype
{
public string name;

public List<StringOffset> feature_names;
public List<float> feature_values;

public Offset<vwfb.Namespace> Build(FlatBufferBuilder builder)
{
var nameOffset = builder.CreateString(this.name);

var feature_names_array = vwfb.Namespace.CreateFeatureNamesVector(builder, this.feature_names.ToArray());
var feature_indicies_array = vwfb.Namespace.CreateFeatureHashesVector(builder, new ulong[0]);
var feature_values_array = vwfb.Namespace.CreateFeatureValuesVector(builder, this.feature_values.ToArray());

// The presence of the name will get VW to compute a hash for us
return vwfb.Namespace.CreateNamespace(builder, nameOffset, (byte)this.name[0], 0, feature_names_array, feature_values_array, feature_indicies_array);
}
}

private List<Offset<vwfb.Example>> built_example_offsets;
private List<namespace_prototype> namespaces = new List<namespace_prototype>();

private FlatBufferBuilder builder;

public VW_CBExampleBuilder()
{
this.builder = new FlatBufferBuilder(64); // todo - figure out right min size?
}

public void PushNamespace(string name = "")
{
this.namespaces.Add(new namespace_prototype()
{
name = name,
feature_names = new List<StringOffset>(),
feature_values = new List<float>()
});
}

public void AddFeature(string name, string value)
{
this.AddFeature(name + "_" + value);
}

public void AddFeature(string name, float value = 1.0f)
{
if (this.namespaces.Count == 0)
{
// TODO: is this the right hash?
this.PushNamespace("");
}

StringOffset nameOffset = this.builder.CreateString(name);
this.namespaces[this.namespaces.Count - 1].feature_names.Add(nameOffset);
this.namespaces[this.namespaces.Count - 1].feature_values.Add(value);
}

public int PushAction()
{
this.CollectExample();

return this.built_example_offsets.Count - 1;
}

private void CollectExample()
{
Offset<vwfb.Namespace>[] namespaceOffsets = new Offset<vwfb.Namespace>[this.namespaces.Count];
for (int i = 0; i < this.namespaces.Count; i++)
{
namespaceOffsets[i] = this.namespaces[i].Build(this.builder);
}

var namespacesVector = vwfb.Example.CreateNamespacesVector(this.builder, namespaceOffsets);
var result = vwfb.Example.CreateExample(this.builder, namespacesVector);
}

public byte[] FinishExample()
{
this.CollectExample();

var multi_ex = vwfb.MultiExample.CreateMultiExample(this.builder, vwfb.MultiExample.CreateExamplesVector(this.builder, this.built_example_offsets.ToArray()));
var root = vwfb.ExampleRoot.CreateExampleRoot(this.builder, vwfb.ExampleType.MultiExample, multi_ex.Value);

this.builder.FinishSizePrefixed(root.Value);

return this.builder.SizedByteArray();
}

public bool BuildingShared { get => this.built_example_offsets.Count == 0; }
public int BuildingAction { get => this.built_example_offsets.Count - 1; }
}
}
58 changes: 58 additions & 0 deletions bindings/cs/rl.net.flatbuffer/rl.net.flatbuffer.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<BinaryOutputBase Condition="'$(BinaryOutputBase.Trim())'==''">$(SolutionDir).</BinaryOutputBase>
<OutputPath>$(BinaryOutputBase.Trim())\$(Platform)\$(Configuration)</OutputPath>
<Platforms>AnyCPU;x64</Platforms>
<PlatformTarget>x64</PlatformTarget>
<AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath>
<AssemblyName>rl.net.flatbuffer</AssemblyName>
<RootNamespace>Rl.Net.Flatbuffers</RootNamespace>
<HighEntropyVA>true</HighEntropyVA>
<SignAssembly>True</SignAssembly>
<AssemblyOriginatorKeyFile>$(MSBuildProjectDir)..\..\..\ext_libs\vowpal_wabbit\cs\vw_key.snk</AssemblyOriginatorKeyFile>
</PropertyGroup>

<Import Project="$(PackagingIntegration)" Condition="Exists('$(PackagingIntegration)')" />

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DebugSymbols>true</DebugSymbols>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="System.Memory" Version="4.5.4" />
<PackageReference Include="Google.FlatBuffers" Version="23.1.21" />
</ItemGroup>

<Import Project="..\common\codegen\TextTemplate.targets" />

<ItemGroup>
<TextTransformParameter Include="SNRequired" Value="true" />
<TextTransformParameter Include="PublicKey" Value="0024000004800000940000000602000000240000525341310004000001000100515aa9bda65291811af92b381378bd271aff3a9e177bac69ff0e85874952fd82c0fbcb53f4e968181d07418481ee2be97522d44c324aa5c683dafaa449fe66ddc65e1d9b3c0600c8820bd2be6401c6888ea88864ef0b6ae5bfbf450aa1f548568d638913d82954195947e394c225cca2cd2f8132d525c2fdc0c57835b87200aa" />
<TextTemplate Include="..\common\codegen\InternalsVisibleToTest.tt"></TextTemplate>
</ItemGroup>

<Target Name="OutputVars" BeforeTargets="Build">
<Message Importance="high" Text="INFO: PackagingIntegration = $(PackagingIntegration)" />
<Message Importance="high" Text="INFO: SignAssembly = $(SignAssembly)" />
<Message Importance="high" Text="INFO: DelaySign = $(DelaySign)" />
<Message Importance="high" Text="INFO: KeyFile = $(KeyFile)" />
<Message Importance="high" Text="INFO: AssemblyOriginatorKeyFile = $(AssemblyOriginatorKeyFile)" />
<Message Importance="high" Text="INFO: TextTransformerParams = @(TextTransformParameter)" />
</Target>

</Project>
Loading

0 comments on commit 1c1fab7

Please sign in to comment.