diff --git a/bindings/cs/rl.net.native/binding_static_model.cc b/bindings/cs/rl.net.native/binding_static_model.cc index 87013953d..e8ed08bd2 100644 --- a/bindings/cs/rl.net.native/binding_static_model.cc +++ b/bindings/cs/rl.net.native/binding_static_model.cc @@ -1,8 +1,5 @@ #include "binding_static_model.h" -#include "err_constants.h" -#include "model_mgmt.h" - using namespace rl_net_native; using namespace reinforcement_learning; @@ -11,16 +8,5 @@ binding_static_model::binding_static_model(const char* vw_model, const size_t le int binding_static_model::get_data( model_transport::model_data& data, reinforcement_learning::api_status* status) { - if (this->vw_model == nullptr || this->len == 0) - { - return reinforcement_learning::error_code::static_model_load_error; - } - - char* buffer = data.alloc(this->len); - if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; } - - memcpy(buffer, this->vw_model, this->len); - data.data_sz(this->len); - - return reinforcement_learning::error_code::success; + return data.set_data(vw_model, len); } diff --git a/bindings/cs/rl.net.native/binding_static_model.h b/bindings/cs/rl.net.native/binding_static_model.h index 6aab860aa..c24a3a4dd 100644 --- a/bindings/cs/rl.net.native/binding_static_model.h +++ b/bindings/cs/rl.net.native/binding_static_model.h @@ -1,5 +1,6 @@ #pragma once +#include "err_constants.h" #include "model_mgmt.h" #include diff --git a/bindings/cs/rl.net.native/rl.net.factory_context.cc b/bindings/cs/rl.net.native/rl.net.factory_context.cc index 0373e2426..021b87f47 100644 --- a/bindings/cs/rl.net.native/rl.net.factory_context.cc +++ b/bindings/cs/rl.net.native/rl.net.factory_context.cc @@ -26,19 +26,16 @@ void cleanup_data_transport_factory(data_transport_factory_t* data_transport_fac } } -API factory_context_t* CreateFactoryContextWithStaticModel(const char* weights, const size_t len) +API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len) { using namespace reinforcement_learning::model_management; auto context = CreateFactoryContext(); - auto data_transport_factory_fn = [weights, len](std::unique_ptr& retval, - const utility::configuration& configuration, i_trace* trace_logger, - api_status* status) -> int - { - char* weightsCopy = new char[len]; - std::memcpy(weightsCopy, weights, len); - retval.reset(new rl_net_native::binding_static_model(weightsCopy, len)); - return error_code::success; + auto data_transport_factory_fn = [vw_model, len](std::unique_ptr& retval, const utility::configuration& configuration, i_trace* trace_logger, api_status* status) -> int { + char* vw_model_copy = new char[len]; + std::memcpy(vw_model_copy, vw_model, len); + retval.reset(new rl_net_native::binding_static_model(vw_model_copy, len)); + return error_code::success; }; data_transport_factory_t* data_transport_factory = diff --git a/bindings/cs/rl.net.native/rl.net.factory_context.h b/bindings/cs/rl.net.native/rl.net.factory_context.h index 7ef4a9a31..15c97dd4a 100644 --- a/bindings/cs/rl.net.native/rl.net.factory_context.h +++ b/bindings/cs/rl.net.native/rl.net.factory_context.h @@ -20,7 +20,7 @@ typedef struct factory_context extern "C" { API factory_context_t* CreateFactoryContext(); - API factory_context_t* CreateFactoryContextWithStaticModel(const char* weights, const size_t len); + API factory_context_t* CreateFactoryContextWithStaticModel(const char* vw_model, const size_t len); API void DeleteFactoryContext(factory_context_t* context); API void SetFactoryContextBindingSenderFactory( factory_context_t* context, rl_net_native::sender_create_fn create_fn, rl_net_native::sender_vtable_t vtable); diff --git a/bindings/cs/rl.net/FactoryContext.cs b/bindings/cs/rl.net/FactoryContext.cs index 977ac9d05..16533f873 100644 --- a/bindings/cs/rl.net/FactoryContext.cs +++ b/bindings/cs/rl.net/FactoryContext.cs @@ -13,7 +13,7 @@ public sealed class FactoryContext : NativeObject private static extern IntPtr CreateFactoryContext(); [DllImport("rlnetnative")] - private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr weights, int length); + private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr vw_model, int len); [DllImport("rlnetnative")] private static extern void DeleteFactoryContext(IntPtr context); @@ -21,19 +21,17 @@ public sealed class FactoryContext : NativeObject [DllImport("rlnetnative")] private static extern IntPtr SetFactoryContextBindingSenderFactory(IntPtr context, sender_create_fn create_Fn, sender_vtable vtable); - public List Weights { get; private set; } - public FactoryContext() : base(new New(CreateFactoryContext), new Delete(DeleteFactoryContext)) { } - public FactoryContext(IEnumerable weightsEnumerable) : base( + public FactoryContext(IEnumerable vwModelEnumerable) : base( new New(() => { - var weightsArray = weightsEnumerable.ToArray(); - GCHandle handle = GCHandle.Alloc(weightsArray, GCHandleType.Pinned); + var vwModelArray = vwModelEnumerable.ToArray(); + GCHandle handle = GCHandle.Alloc(vwModelArray, GCHandleType.Pinned); try { IntPtr ptr = handle.AddrOfPinnedObject(); - return CreateFactoryContextWithStaticModel(ptr, weightsArray.Length); + return CreateFactoryContextWithStaticModel(ptr, vwModelArray.Length); } finally { if (handle.IsAllocated) diff --git a/include/model_mgmt.h b/include/model_mgmt.h index 4b381dba7..4f0cf10d4 100644 --- a/include/model_mgmt.h +++ b/include/model_mgmt.h @@ -1,4 +1,5 @@ #pragma once +#include "err_constants.h" #include "multistep.h" #include @@ -34,6 +35,9 @@ class model_data char* alloc(size_t desired); void free(); + // Set data + int set_data(const char* vw_model, size_t len); + model_data() = default; ~model_data() = default; diff --git a/rlclientlib/model_mgmt/model_mgmt.cc b/rlclientlib/model_mgmt/model_mgmt.cc index a4d416694..a02128d89 100644 --- a/rlclientlib/model_mgmt/model_mgmt.cc +++ b/rlclientlib/model_mgmt/model_mgmt.cc @@ -25,5 +25,19 @@ char* model_data::alloc(const size_t desired) void model_data::free() { _data.clear(); } +int model_data::set_data(const char* vw_model, size_t len) +{ + if (vw_model == nullptr || len == 0) + { + return reinforcement_learning::error_code::static_model_load_error; + } + + char* buffer = this->alloc(len); + if (buffer == nullptr) { return reinforcement_learning::error_code::static_model_load_error; } + + memcpy(buffer, vw_model, len); + this->data_sz(len); +} + } // namespace model_management } // namespace reinforcement_learning