From f053fbba4148ceb043503b5069f1addeb4695022 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 1 Feb 2024 13:50:00 -0500 Subject: [PATCH] static model loader --- bindings/cs/rl.net.native/CMakeLists.txt | 2 ++ bindings/cs/rl.net.native/binding_static_model.cc | 12 ++++++++++++ bindings/cs/rl.net.native/binding_static_model.h | 12 ++++++++++++ bindings/cs/rl.net.native/rl.net.factory_context.h | 2 ++ bindings/cs/rl.net/FactoryContext.cs | 7 ++++++- bindings/cs/rl.net/LiveModel.cs | 1 + 6 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 bindings/cs/rl.net.native/binding_static_model.cc create mode 100644 bindings/cs/rl.net.native/binding_static_model.h diff --git a/bindings/cs/rl.net.native/CMakeLists.txt b/bindings/cs/rl.net.native/CMakeLists.txt index 6f0d3327f..c7dba0dd5 100644 --- a/bindings/cs/rl.net.native/CMakeLists.txt +++ b/bindings/cs/rl.net.native/CMakeLists.txt @@ -1,5 +1,6 @@ set(rl_net_native_SOURCES binding_sender.cc + binding_static_model.cc binding_tracer.cc rl.net.api_status.cc rl.net.buffer.cc @@ -22,6 +23,7 @@ set(rl_net_native_SOURCES set(rl_net_native_HEADERS binding_sender.h + binding_static_model.h binding_tracer.h rl.net.api_status.h rl.net.loop_context.h diff --git a/bindings/cs/rl.net.native/binding_static_model.cc b/bindings/cs/rl.net.native/binding_static_model.cc new file mode 100644 index 000000000..bb89374be --- /dev/null +++ b/bindings/cs/rl.net.native/binding_static_model.cc @@ -0,0 +1,12 @@ +#include "binding_static_model.h" + +using namespace rl_net_native; + +binding_static_model::binding_static_model(const std::vector& model_weights) : weights(model_weights) { +} + +int binding_static_model::get_data(reinforcement_learning::model_management::model_data& data, reinforcement_learning::api_status* status) { + data.set_data(weights.data(), weights.size()); + + return error_code::success; +} diff --git a/bindings/cs/rl.net.native/binding_static_model.h b/bindings/cs/rl.net.native/binding_static_model.h new file mode 100644 index 000000000..1c6df23ea --- /dev/null +++ b/bindings/cs/rl.net.native/binding_static_model.h @@ -0,0 +1,12 @@ +#include + +namespace rl_net_native { +class binding_static_model : public reinforcement_learning::model_management::i_data_transport { +public: + binding_static_model(const std::vector& model_weights); + int get_data(reinforcement_learning::model_management::model_data& data, reinforcement_learning::api_status* status = nullptr) override; + +private: + std::vector weights; +}; +} \ No newline at end of file diff --git a/bindings/cs/rl.net.native/rl.net.factory_context.h b/bindings/cs/rl.net.native/rl.net.factory_context.h index e53564d47..8fc0b3ec1 100644 --- a/bindings/cs/rl.net.native/rl.net.factory_context.h +++ b/bindings/cs/rl.net.native/rl.net.factory_context.h @@ -1,5 +1,6 @@ #pragma once +#include #include "binding_sender.h" #include "factory_resolver.h" #include "rl.net.native.h" @@ -17,6 +18,7 @@ extern "C" { API factory_context_t* CreateFactoryContext(); API void DeleteFactoryContext(factory_context_t* context); + //API void SetFactoryContextModelWeights(factory_context_t* context, const std::byte* weights, size_t size); API void SetFactoryContextBindingSenderFactory( factory_context_t* context, rl_net_native::sender_create_fn create_fn, rl_net_native::sender_vtable_t vtable); diff --git a/bindings/cs/rl.net/FactoryContext.cs b/bindings/cs/rl.net/FactoryContext.cs index 59fd6239d..dc547c057 100644 --- a/bindings/cs/rl.net/FactoryContext.cs +++ b/bindings/cs/rl.net/FactoryContext.cs @@ -1,6 +1,7 @@ using System; using System.Threading; using System.Runtime.InteropServices; +using System.Collections.Generic; using Rl.Net.Native; @@ -16,9 +17,13 @@ public sealed class FactoryContext : NativeObject [DllImport("rlnetnative")] private static extern IntPtr SetFactoryContextBindingSenderFactory(IntPtr context, sender_create_fn create_Fn, sender_vtable vtable); - public FactoryContext() : base(new New(CreateFactoryContext), new Delete(DeleteFactoryContext)) + public List Weights { get; private set; } + + public FactoryContext(List weights = null) : base(new New(CreateFactoryContext), new Delete(DeleteFactoryContext)) { + Weights = weights ?? new List(); } + private GCHandleLifetime registeredSenderCreateHandle; diff --git a/bindings/cs/rl.net/LiveModel.cs b/bindings/cs/rl.net/LiveModel.cs index 59f5b3b41..2e829d011 100644 --- a/bindings/cs/rl.net/LiveModel.cs +++ b/bindings/cs/rl.net/LiveModel.cs @@ -1,5 +1,6 @@ using System; using System.Runtime.InteropServices; +using System.Collections.Generic; using Rl.Net.Native;